In [1]:
import numpy as np
import torch.nn as nn
import matplotlib
import nltk
import os
import random
import re
import torch

from tqdm import tqdm, trange
from torch.optim import Adam
from torch.utils.data import TensorDataset
from models.Encoder import *
from models.Decoder import *
from utils import *

random.seed(42)

%matplotlib inline

GPU is available
GPU is available


In [2]:
# define experiment
exp='/exp_1a'

In [3]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# load dataset into memory, and get w2idx, idx2w, w2freq dictionaries and lists of input and output sentences
cmd_vocab, w2i_cmds, i2w_cmds, cmds, act_vocab, w2i_acts, i2w_acts, acts = load_dataset(exp=exp, split='/train')

In [5]:
# create input and output language pairs
cmd_act_pairs = create_pairs(cmds, acts)

In [6]:
# show random command-action pair
random_pair = random.choice(cmd_act_pairs)
print("Command: {}".format(random_pair[0]))
print("Action: {}".format(random_pair[1]))

Command: ['look', 'left', 'thrice', 'after', 'turn', 'opposite', 'left']
Action: ['I_TURN_LEFT', 'I_TURN_LEFT', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK', 'I_TURN_LEFT', 'I_LOOK']


In [7]:
cmd_act_pair = pairs2idx(random_pair, w2i_cmds, w2i_acts)
print("Command sequence: {}".format(cmd_act_pair[0]))
print("Action sequence: {}".format(cmd_act_pair[1]))

Command sequence: tensor([12,  4,  6,  9, 15,  7,  4,  1])
Action sequence: tensor([0, 4, 4, 4, 8, 4, 8, 4, 8, 1])


In [178]:
def train(lang_pairs, w2i_source, w2i_target, i2w_target, encoder, decoder, epochs:int=1, learning_rate:float=1e-4):
        
    losses_overall = []
    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)
    
    training_pairs = [pairs2idx(random.choice(lang_pairs), w2i_cmds, w2i_acts) for _ in range(len(lang_pairs))]
    
    criterion = nn.NLLLoss()
    
    for epoch in trange(epochs,  desc="Epoch"):
        
        losses_per_epoch = []
        accs_per_epoch = []
        
        for idx, train_pair in enumerate(training_pairs):
            
            loss = 0
            
            command = train_pair[0].to(device)
            action = train_pair[1].to(device)

            encoder_hidden = encoder.init_hidden()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_length = command.size(0)
            target_length = action.size(0)

            encoder_outputs, encoder_hidden = encoder(command, encoder_hidden)
            
            decoder_input = action[0] # SOS token
            # NOTE: line below is necessary since encoder_hidden.shape = n_layers*n_directions x hidden_size
            decoder_hidden = encoder_hidden[-1] if encoder.n_layers > 1 else encoder_hidden

            use_teacher_forcing = True
            
            pred_sent = ""
            true_sent = ""
            acc_per_sent = 0

            if use_teacher_forcing:
                # teacher forcing: feed target as the next input
                for i in range(1, 48):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0
                    pred = torch.argmax(decoder_out, dim) # argmax
                    
                    if i >= target_length:
                        loss += criterion(decoder_out, torch.tensor(w2i_target['<EOS>'], dtype=torch.long).unsqueeze(0).to(device))
                        acc_per_sent += 1 if pred == torch.tensor(w2i_target['<EOS>'], dtype=torch.long) else 0 
                        true_sent += "<EOS>" + " "
                        decoder_input = action[-2]
                    else:
                        loss += criterion(decoder_out, action[i].unsqueeze(0))
                        acc_per_sent += 1 if pred == action[i] else 0
                        true_sent += i2w_target[action[i].item()] + " "
                        decoder_input = action[i] # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    
                    if pred.squeeze().item() == w2i_target['<EOS>']:
                        break
                
            else:
                for i in range(1, 48):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0
                    pred = torch.argmax(decoder_out, dim) # argmax
                    
                    if i >= target_length:
                        loss += criterion(decoder_out, torch.tensor(w2i_target['<EOS>'], dtype=torch.long).unsqueeze(0).to(device))
                        acc_per_sent += 1 if pred == torch.tensor(w2i_target['<EOS>'], dtype=torch.long) else 0 
                        true_sent += "<EOS>" + " "
                    else:
                        loss += criterion(decoder_out, action[i].unsqueeze(0))
                        acc_per_sent += 1 if pred == action[i] else 0
                        true_sent += i2w_target[action[i].item()] + " "

                    
                    decoder_input = pred.squeeze() # convert list of int into int
                    
                    pred_sent += i2w_target[pred.item()] + " "
                    

                    if decoder_input.item() == w2i_target['<EOS>']:
                        break

            acc_per_sent /= target_length
            accs_per_epoch.append(acc_per_sent)
        
            loss.backward()
            
            if idx % 500 == 0:
                print("Loss: {}".format(loss.item() / target_length))
                print("Sentence accuracy: {}".format(acc_per_sent))
                print()
                print("True action: {}".format(true_sent))
                print("Pred action: {}".format(pred_sent))
                print()
                print("Target length: {}".format(target_length))
                print("True sent length: {}".format(len(true_sent.split())))
                print("Pred sent length: {}".format(len(pred_sent.split())))
                print()
                
            encoder_optimizer.step()
            decoder_optimizer.step()

            losses_per_epoch.append(loss.item() / target_length)
        
        loss_per_epoch = np.mean(losses_per_epoch)
        acc_per_epoch = np.mean(accs_per_epoch)
        
        print("Train loss: {}".format(loss_per_epoch))
        print("Train acc: {}".format(acc_per_epoch))
        
        losses_overall.append(loss_per_epoch)
        
    return losses_overall

In [179]:
in_size = len(w2i_cmds)
out_size = len(w2i_acts)
emb_size = 10
hidden_size = 50
n_layers = 1
n_epochs = 1

In [180]:
encoder = EncoderLSTM(in_size, emb_size, hidden_size, n_layers)
decoder = DecoderLSTM(emb_size, hidden_size, out_size, n_layers)

  "num_layers={}".format(dropout, num_layers))


In [181]:
encoder.to(device)
decoder.to(device)

DecoderLSTM(
  (embedding): Embedding(9, 10)
  (lstm): LSTM(10, 50, dropout=0.5)
  (linear): Linear(in_features=50, out_features=9, bias=True)
)

In [182]:
train(cmd_act_pairs, w2i_cmds, w2i_acts, i2w_acts, encoder, decoder, n_epochs)











  out, hidden = self.lstm(embedded, hidden)


Loss: 12.257247924804688
Sentence accuracy: 0.25

True action: I_TURN_LEFT I_LOOK I_TURN_LEFT I_LOOK I_TURN_LEFT I_LOOK <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> 
Pred action: I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_RUN I_RUN I_RUN I_RUN I_RUN I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT 

Target length: 8
True sent length: 47
Pred sent length: 47

Loss: 5.188472747802734


Loss: 0.5469112396240234
Sentence accuracy: 0.9

True action: I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT <EOS> <EOS> 
Pred action: I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT <EOS> 

Target length: 10
True sent length: 10
Pred sent length: 10

Loss: 1.4819337129592896
Sentence accuracy: 0.375

True action: I_TURN_LEFT I_WALK I_TURN_LEFT I_WALK I_TURN_LEFT I_WALK I_TURN_LEFT I_WALK I_TURN_RIGHT I_TURN_RIGHT I_LOOK I_TURN_RIGHT I_TURN_RIGHT I_LOOK <EOS> <EOS> 
Pred action: I_TURN_RIGHT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_TURN_LEFT I_WALK I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT I_TURN_RIGHT <EOS> 

Target length: 16
True sent length: 16
Pred sent length: 16

Loss: 1.0060514303354116
Sentence accuracy: 0.7307692307692307

True action: I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LEFT I_RUN I_TURN_LE











Epoch: 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [12:38<00:00, 758.02s/it]


[1.1894505437188276]