In [1]:
import numpy as np
import torch.nn as nn
import matplotlib
import nltk
import os
import random
import re
import torch

from itertools import islice
from tqdm import tqdm, trange
from torch.optim import Adam
from torch.utils.data import TensorDataset

from models.Encoder import *
from models.Decoder import *
from utils import *

# set random seed to reproduce results
random.seed(42)

%matplotlib inline

GPU is available
GPU is available


In [2]:
# define experiment
exp='/exp_1a'

In [3]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# load dataset into memory, and get w2idx, idx2w, w2freq dictionaries and lists of input and output sentences
cmd_vocab, w2i_cmds, i2w_cmds, cmds, act_vocab, w2i_acts, i2w_acts, acts = load_dataset(exp=exp, split='/train')

In [5]:
# create input and output language pairs
cmd_act_pairs = create_pairs(cmds, acts)

In [6]:
# show random command-action pair
random_pair = random.choice(cmd_act_pairs)
print("Command: {}".format(random_pair[0]))
print("Action: {}".format(random_pair[1]))

Command: ['look', 'left', 'thrice', 'after', 'turn', 'opposite', 'left']
Action: ['I_TURN_LEFT', 'I_TURN_LEFT', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK']


In [7]:
cmd_act_pair = pairs2idx(random_pair, w2i_cmds, w2i_acts)
print("Command sequence: {}".format(cmd_act_pair[0]))
print("Action sequence: {}".format(cmd_act_pair[1]))

Command sequence: tensor([13,  5,  7, 10, 16,  8,  5,  2])
Action sequence: tensor([1, 5, 5, 5, 9, 5, 9, 5, 9, 2])


In [25]:
def train(lang_pairs, w2i_source, w2i_target, i2w_source, i2w_target, encoder, decoder, epochs:int, batch_size:int=1,
          learning_rate:float=1e-3, max_ratio:float=0.95, min_ratio:float=0.15, detailed_analysis:bool=True):
        
    # each n_iters plot behaviour of RNN Decoder
    n_iters = 3000
    
    train_losses, train_accs = [], []
    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)
    
    training_pairs = [pairs2idx(random.choice(lang_pairs), w2i_cmds, w2i_acts) for _ in range(len(lang_pairs))]
    max_target_length = max(iter(map(lambda lang_pair: len(lang_pair[1]), training_pairs)))
    n_lang_pairs = len(training_pairs)
    
    # negative log-likelihood loss
    criterion = nn.NLLLoss()
    
    # teacher forcing curriculum
    # decrease teacher forcing ratio per epoch (start off with high ratio and move in equal steps to min_ratio)
    ratio_diff = max_ratio-min_ratio
    step_per_epoch = ratio_diff / epochs
    teacher_forcing_ratio = max_ratio
    
    for epoch in trange(epochs,  desc="Epoch"):
                
        loss_per_epoch = 0
        acc_per_epoch = 0
        
        for idx, train_pair in enumerate(training_pairs):
            
            loss = 0
            
            command = train_pair[0].to(device)
            action = train_pair[1].to(device)
            
            # initialise as many hidden states as there are sequences in the mini-batch (1 for the beginning)
            encoder_hidden = encoder.init_hidden(batch_size)
            
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_length = command.size(0)
            target_length = action.size(0)

            encoder_outputs, encoder_hidden = encoder(command, encoder_hidden)
            
            decoder_input = action[0] # SOS token
            
            decoder_hidden = encoder_hidden # init decoder hidden with encoder hidden 

            use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
            
            pred_sent = ""
            true_sent = ' '.join([i2w_target[act.item()] for act in islice(action, 1, None)]).strip() # skip SOS token

            if use_teacher_forcing:
                # Teacher forcing: feed target as the next input
                for i in range(1, target_length):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0  # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation
                    
                    loss += criterion(decoder_out, action[i].unsqueeze(0))
                    decoder_input = action[i] # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    
                    if pred.squeeze().item() == w2i_target['<EOS>']:
                        break
            else:
                # Autoregressive RNN: feed previous prediction as the next input
                for i in range(1, max_target_length):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0 # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation
                    
                    if i >= target_length:
                        loss += criterion(decoder_out, torch.tensor(w2i_target['<EOS>'], dtype=torch.long).unsqueeze(0).to(device))
                    else:
                        loss += criterion(decoder_out, action[i].unsqueeze(0))
                    
                    decoder_input = pred.squeeze() # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    
                    if decoder_input.item() == w2i_target['<EOS>']:
                        break
            
            # strip off any leading or trailing white spaces
            pred_sent = pred_sent.strip()
            acc_per_epoch += 1 if pred_sent == true_sent else 0 # exact match accuracy
        
            loss.backward()
            
            ### inspect translation behaviour ###
            if detailed_analysis:
                nl_command = ' '.join([i2w_source[cmd.item()] for cmd in command]).strip()
                if idx > 0 and idx % n_iters == 0:
                    print("Loss: {}".format(loss.item() / target_length)) # current per sequence loss
                    print("Acc: {}".format(acc_per_epoch / (idx + 1))) # current per iters exact-match accuracy
                    print()
                    print("Command: {}".format(nl_command))
                    print("True action: {}".format(true_sent))
                    print("Pred action: {}".format(pred_sent))
                    print()
                    print("Target length: {}".format(target_length))
                    print("True sent length: {}".format(len(true_sent.split())))
                    print("Pred sent length: {}".format(len(pred_sent.split())))
                    print()
                
            encoder_optimizer.step()
            decoder_optimizer.step()

            loss_per_epoch += loss.item() / target_length
        
        loss_per_epoch /= n_lang_pairs
        acc_per_epoch /= n_lang_pairs
        
        print("Train loss: {}".format(loss_per_epoch)) # loss
        print("Train acc: {}".format(acc_per_epoch)) # exact-match accuracy
        print("Current teacher forcing ratio {}".format(teacher_forcing_ratio))
        
        train_losses.append(loss_per_epoch)
        train_accs.append(acc_per_epoch)
        
        teacher_forcing_ratio -= step_per_epoch # decrease teacher forcing ratio
        
    return train_losses, train_accs

In [26]:
# source language (i.e., commands) vocabulary size |V_source|
in_size = len(w2i_cmds)
# target language (i.e., actions) vocabulary size |V_target|
out_size = len(w2i_acts)
# size of word embeddings
emb_size = 15
hidden_size = 50
n_layers = 2
n_epochs = 5

In [27]:
encoder = EncoderLSTM(in_size, emb_size, hidden_size, n_layers)
decoder = DecoderLSTM(emb_size, hidden_size, out_size, n_layers)

In [28]:
# move models to GPU, if GPU is available (for faster computation)
encoder.to(device)
decoder.to(device)

DecoderLSTM(
  (embedding): Embedding(10, 15)
  (lstm): LSTM(15, 50, num_layers=2, dropout=0.5)
  (linear): Linear(in_features=50, out_features=10, bias=True)
)

In [None]:
train(cmd_act_pairs, w2i_cmds, w2i_acts, i2w_cmds, i2w_acts, encoder, decoder, n_epochs)



Epoch:   0%|                                                                                     | 0/5 [00:00<?, ?it/s]

Loss: 0.5072052365257627
Acc: 0.003332222592469177

True action: I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_TURN_RIGHT I_JUMP <EOS>
Pred action: I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN <EOS>

Target length: 21
True sent length: 20
Pred sent length: 19

Loss: 0.6420160021100726
Acc: 0.01016497250458257

True action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_WALK <EOS>
Pred action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK

Target length: 7
True sent length: 6
Pred sent length: 6

Loss: 0.747722053527832
Acc: 0.023886234862793024

True action: I_TURN_LEFT I_TURN_RIGHT I_JUMP <EOS>
Pred action: I_TURN_LEFT I_TURN_LEFT I_JUMP <EOS>

Target length: 5
True sent length: 4
Pred sent length: 4

Loss: 0.25725793838500977



Epoch:  20%|██████████████▌                                                          | 1/5 [17:16<1:09:04, 1036.14s/it]

Loss: 0.18827129545665922
Acc: 0.28123958680439853

True action: I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_TURN_RIGHT I_JUMP <EOS>
Pred action: I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_JUMP <EOS>

Target length: 21
True sent length: 20
Pred sent length: 20

Loss: 0.17597307477678573
Acc: 0.32544575904016

True action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_WALK <EOS>
Pred action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_WALK I_WALK

Target length: 7
True sent length: 6
Pred sent length: 6

Loss: 0.10747013092041016
Acc: 0.36773691812020887

True action: I_TURN_LEFT I_TURN_RIGHT I_JUMP <EOS>
Pred action: I_TURN_LEFT I_TURN_RIGHT I_JUMP <EOS>

Target length: 5
True sent length: 4
Pred sent length: 4

Loss: 0.30910334587097166
Acc: 



Epoch:  40%|██████████████████████████████                                             | 2/5 [33:50<51:11, 1023.70s/it]

Loss: 0.04883702596028646
Acc: 0.6904365211596135

True action: I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_TURN_RIGHT I_JUMP <EOS>
Pred action: I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_RUN I_TURN_RIGHT I_JUMP I_JUMP <EOS>

Target length: 21
True sent length: 20
Pred sent length: 20

Loss: 0.09396934509277344
Acc: 0.7038826862189635

True action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_WALK <EOS>
Pred action: I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_WALK <EOS>

Target length: 7
True sent length: 6
Pred sent length: 6

