In [1]:
import numpy as np
import torch.nn as nn
import matplotlib
import nltk
import os
import random
import re
import torch

from itertools import islice
from tqdm import tqdm, trange
from torch.optim import Adam
from torch.utils.data import TensorDataset

from models.Encoder import *
from models.Decoder import *
from utils import *

random.seed(42)

%matplotlib inline

GPU is available
GPU is available


In [2]:
# define experiment
exp='/exp_1a'

In [3]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# load dataset into memory, and get w2idx, idx2w, w2freq dictionaries and lists of input and output sentences
cmd_vocab, w2i_cmds, i2w_cmds, cmds, act_vocab, w2i_acts, i2w_acts, acts = load_dataset(exp=exp, split='/train')

In [5]:
# create input and output language pairs
cmd_act_pairs = create_pairs(cmds, acts)

In [6]:
# show random command-action pair
random_pair = random.choice(cmd_act_pairs)
print("Command: {}".format(random_pair[0]))
print("Action: {}".format(random_pair[1]))

Command: ['look', 'left', 'thrice', 'after', 'turn', 'opposite', 'left']
Action: ['I_TURN_LEFT', 'I_TURN_LEFT', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK']


In [7]:
cmd_act_pair = pairs2idx(random_pair, w2i_cmds, w2i_acts)
print("Command sequence: {}".format(cmd_act_pair[0]))
print("Action sequence: {}".format(cmd_act_pair[1]))

Command sequence: tensor([12,  4,  6,  9, 15,  7,  4,  1])
Action sequence: tensor([0, 4, 4, 4, 8, 4, 8, 4, 8, 1])


In [38]:
def train(lang_pairs, w2i_source, w2i_target, i2w_target, encoder, decoder, 
          epochs:int=1, learning_rate:float=1e-4, teacher_forcing_decrease:float=0.05):
        
    train_losses, train_accs = [], []
    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)
    
    training_pairs = [pairs2idx(random.choice(lang_pairs), w2i_cmds, w2i_acts) for _ in range(len(lang_pairs))]
    max_target_length = max(iter(map(lambda lang_pair: len(lang_pair[1]), training_pairs)))
    n_lang_pairs = len(training_pairs)
    
    criterion = nn.NLLLoss()
    
    teacher_forcing_ratio = 0.9
    
    for epoch in trange(epochs,  desc="Epoch"):
        
        loss_per_epoch = 0
        acc_per_epoch = 0
        
        for idx, train_pair in enumerate(training_pairs):
            
            loss = 0
            
            command = train_pair[0].to(device)
            action = train_pair[1].to(device)

            encoder_hidden = encoder.init_hidden()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_length = command.size(0)
            target_length = action.size(0)

            encoder_outputs, encoder_hidden = encoder(command, encoder_hidden)
            
            decoder_input = action[0] # SOS token
            
            # NOTE: line below is necessary since encoder_hidden.shape = n_layers*n_directions x hidden_size
            decoder_hidden = encoder_hidden[-1] if encoder.n_layers > 1 else encoder_hidden

            use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
            
            pred_sent = ""
            true_sent = ' '.join([i2w_target[act.item()] for act in islice(action, 1, None)]) # skip SOS token
            

            if use_teacher_forcing:
                # Teacher forcing: feed target as the next input
                for i in range(1, target_length): # range(1, max_target_length)
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0  # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation
                    
                    if i >= target_length:
                        loss += criterion(decoder_out, torch.tensor(w2i_target['<EOS>'], dtype=torch.long).unsqueeze(0).to(device))
                        decoder_input = pred.squeeze() # action[-2]
                    else:
                        loss += criterion(decoder_out, action[i].unsqueeze(0))
                        decoder_input = action[i] # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    
                    if pred.squeeze().item() == w2i_target['<EOS>']:
                        break
                
            else:
                # Autoregression: feed previous prediction as the next input
                for i in range(1, max_target_length):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0 # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation
                    
                    if i >= target_length:
                        loss += criterion(decoder_out, torch.tensor(w2i_target['<EOS>'], dtype=torch.long).unsqueeze(0).to(device))
                    else:
                        loss += criterion(decoder_out, action[i].unsqueeze(0))
                    
                    decoder_input = pred.squeeze() # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    
                    if decoder_input.item() == w2i_target['<EOS>']:
                        break

            acc_per_epoch += 1 if (pred_sent == true_sent) else 0 # exact match accuracy
        
            loss.backward()
            
            # inspect translation behaviour
            if idx > 0 and idx % 1000 == 0:
                print("Loss: {}".format(loss.item() / target_length))
                print("Acc: {}".format(acc_per_epoch / (idx + 1)))
                print()
                print("True action: {}".format(true_sent))
                print("Pred action: {}".format(pred_sent))
                print()
                print("Target length: {}".format(target_length))
                print("True sent length: {}".format(len(true_sent.split())))
                print("Pred sent length: {}".format(len(pred_sent.split())))
                print()
                
            encoder_optimizer.step()
            decoder_optimizer.step()

            loss_per_epoch += loss.item() / target_length
        
        loss_per_epoch /= n_lang_pairs
        acc_per_epoch /= n_lang_pairs
        
        print("Train loss: {}".format(loss_per_epoch))
        print("Train acc: {}".format(acc_per_epoch))
        
        train_losses.append(loss_per_epoch)
        train_accs.append(acc_per_epoch)
        
        # per epoch decrease teacher forcing ratio
        teacher_forcing_ratio -= teacher_forcing_decrease
        
    return train_losses, train_accs

In [32]:
in_size = len(w2i_cmds)
out_size = len(w2i_acts)
emb_size = 15
hidden_size = 200
n_layers = 1
n_epochs = 2

In [33]:
encoder = EncoderLSTM(in_size, emb_size, hidden_size, n_layers)
decoder = DecoderLSTM(emb_size, hidden_size, out_size, n_layers)

  "num_layers={}".format(dropout, num_layers))


In [34]:
encoder.to(device)
decoder.to(device)

DecoderLSTM(
  (embedding): Embedding(9, 15)
  (lstm): LSTM(15, 200, dropout=0.5)
  (linear): Linear(in_features=200, out_features=9, bias=True)
)

In [35]:
train(cmd_act_pairs, w2i_cmds, w2i_acts, i2w_acts, encoder, decoder, n_epochs)

50




Epoch:   0%|                                                                                     | 0/2 [00:00<?, ?it/s]

Loss: 0.8056693077087402
Acc: 0.0

True action: I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_JUMP I_JUMP <EOS>
Pred action: I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT <EOS> 

Target length: 20
True sent length: 19
Pred sent length: 11

Loss: 1.233809471130371
Acc: 0.0

True action: I_TURN_LEFT I_TURN_LEFT I_WALK I_TURN_LEFT I_TURN_LEFT I_RUN <EOS>
Pred action: I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT <EOS> 

Target length: 8
True sent length: 7
Pred sent length: 8

Loss: 0.47754757744925364
Acc: 0.0

True action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_L

KeyboardInterrupt: 