# CS7643 - Final Project

In [16]:
# built-in
# public
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# private
from config import Config, InstrConfig
from src.utils import dataloader, helpers
from src.models.instructions_generator_model import InstructionsGeneratorModel
from src.trainer import instructions_generator_trainer
from src.models.imitation_learning_model import ImitationLearningModel
from src.trainer import imitation_learning_trainer

%load_ext autoreload 
%autoreload 2
%config Completer.use_jedi = False

# Initialization

In [17]:
instr_config = InstrConfig()
helpers.set_seed(instr_config.random_seed)

# Data

## Load Data

In [18]:
# train size 202831
# unique number of instructions 35921
data = dataloader.load_pkl(workdir=instr_config.DATA_PATH)

In [19]:
train_states, train_inventories, train_actions, train_goals, train_instructions, all_instructions = data

In [20]:
# remove invalid sample where train instruction is None
invalid_index = set([i for i, _ in enumerate(train_instructions) if not _])
print(len(invalid_index))

34


In [21]:
valid_index = [i for i, _ in enumerate(train_instructions) if _]
print(len(valid_index))
train_states = np.array(train_states)[valid_index].tolist()
train_inventories = np.array(train_inventories)[valid_index].tolist()
train_actions = np.array(train_actions)[valid_index].tolist()
train_instructions = np.array(train_instructions)[valid_index].tolist()

202797


## Build Vocab

In [22]:
vocab, vocab_weights = dataloader.generate_vocab(
    all_instructions, instr_config.device, workdir=instr_config.DATA_PATH)

# vocab, vocab_weights = dataloader.generate_vocab(
#     all_instructions, instr_config.device, workdir=instr_config.DATA_PATH, cache='/usr/local/google/home/billzhou/Documents/glove')

Total vocabulary size: 212


## Generate Dataset

In [23]:
dataset = dataloader.CraftingDataset(
  instr_config.embeded_dim,
  train_states,
  train_inventories,
  train_actions,
  train_goals,
  train_instructions,
  vocab)

# dataset = dataloader.CraftingDataset(
#   instr_config.embeded_dim,
#   train_states,
#   train_inventories,
#   train_actions,
#   train_goals,
#   train_instructions,
#   vocab,
#   cache='/usr/local/google/home/billzhou/Documents/glove')

embedding loaded
one hot loaded
actions loaded
goals loaded
done loading dataset


In [9]:
min([len(d[-1]) for d in dataset])

3

In [10]:
instr_config.dataset_size = len(dataset)

## Split Dataset

In [24]:
indices = list(range(instr_config.dataset_size))
split = int(np.floor(instr_config.validation_split * instr_config.dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

AttributeError: 'InstrConfig' object has no attribute 'dataset_size'

## Initialize Data Loader

In [None]:
train_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=train_sampler,
  collate_fn=dataloader.collate_fn)

validation_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=valid_sampler,
  collate_fn=dataloader.collate_fn)

# IL + LSTM Training

## Instructions Generator (LSTM)

### Setup Instructions Generator Training

In [14]:
model = InstructionsGeneratorModel(
    instr_config.device
    , vocab
    , instr_config.embeded_dim
    , vocab_weights
).to(instr_config.device)
train = instructions_generator_trainer.train
validate = instructions_generator_trainer.validate

In [15]:
# CE Loss
criterion = torch.nn.CrossEntropyLoss()
# Adam
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters, lr=instr_config.learning_rate)
# Log
writer = SummaryWriter() if instr_config.summary_writer else None

### GO

In [15]:
best_valid_loss = float('inf')
valid_epoch, best_epoch = 0, None

for epoch in range(instr_config.epochs):
    # train
    loss, bleu, tk_acc = train(
        instr_config.device,
        epoch,
        train_data_loader,
        model,
        optimizer,
        criterion,
        parameters,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, train loss: %.3f, train bleu: %.3f, train token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # valid
    loss, bleu, tk_acc = validate(
        instr_config.device,
        epoch,
        validation_data_loader,
        model,
        criterion,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, valid loss: %.3f, valid bleu: %.3f, valid token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # early stopping
    if loss <= best_valid_loss:
        best_valid_loss = loss
        valid_epoch, best_epoch = 0, epoch
        torch.save(model.state_dict(), instr_config.SAVE_PATH)
        print('Best Epoch: %d, best valid loss: %.3f' % (best_epoch, best_valid_loss))
        print('Trained model saved at ', instr_config.SAVE_PATH)
    else:
        valid_epoch += 1
        if valid_epoch >= instr_config.valid_patience:
            break

Epoch: 0, train loss: 2.605, train bleu: 0.589, train token acc: 0.720: 100%|██████████| 2535/2535 [04:53<00:00,  8.65it/s]


Overall Epoch: 0, train loss: 2.900, train bleu: 0.530, train token acc: 0.684


Epoch: 0, valid loss: 4.509, valid bleu: 0.426, valid token acc: 0.557: 100%|██████████| 634/634 [00:54<00:00, 11.58it/s]


Overall Epoch: 0, valid loss: 4.593, valid bleu: 0.477, valid token acc: 0.592
Best Epoch: 0, best valid loss: 4.593
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 1, train loss: 2.450, train bleu: 0.590, train token acc: 0.728: 100%|██████████| 2535/2535 [04:55<00:00,  8.58it/s]


Overall Epoch: 1, train loss: 2.378, train bleu: 0.579, train token acc: 0.735


Epoch: 1, valid loss: 4.898, valid bleu: 0.476, valid token acc: 0.605: 100%|██████████| 634/634 [00:53<00:00, 11.76it/s]


Overall Epoch: 1, valid loss: 4.852, valid bleu: 0.487, valid token acc: 0.599


Epoch: 2, train loss: 2.286, train bleu: 0.549, train token acc: 0.720: 100%|██████████| 2535/2535 [04:55<00:00,  8.57it/s]


Overall Epoch: 2, train loss: 2.244, train bleu: 0.599, train token acc: 0.751


Epoch: 2, valid loss: 5.914, valid bleu: 0.448, valid token acc: 0.555: 100%|██████████| 634/634 [00:54<00:00, 11.58it/s]


Overall Epoch: 2, valid loss: 5.099, valid bleu: 0.491, valid token acc: 0.601


Epoch: 3, train loss: 2.045, train bleu: 0.641, train token acc: 0.786: 100%|██████████| 2535/2535 [05:00<00:00,  8.43it/s]


Overall Epoch: 3, train loss: 2.172, train bleu: 0.610, train token acc: 0.759


Epoch: 3, valid loss: 5.674, valid bleu: 0.437, valid token acc: 0.538: 100%|██████████| 634/634 [00:54<00:00, 11.63it/s]


Overall Epoch: 3, valid loss: 5.289, valid bleu: 0.489, valid token acc: 0.598


Epoch: 4, train loss: 2.065, train bleu: 0.635, train token acc: 0.777: 100%|██████████| 2535/2535 [04:58<00:00,  8.50it/s]


Overall Epoch: 4, train loss: 2.128, train bleu: 0.616, train token acc: 0.764


Epoch: 4, valid loss: 4.621, valid bleu: 0.412, valid token acc: 0.588: 100%|██████████| 634/634 [00:53<00:00, 11.93it/s]

Overall Epoch: 4, valid loss: 5.338, valid bleu: 0.491, valid token acc: 0.603





## Imitation Learning + LSTM

### Set up

In [25]:
lstm_model = InstructionsGeneratorModel(
    instr_config.device
    , vocab
    , instr_config.embeded_dim
    , vocab_weights
).to(instr_config.device)
il_model = ImitationLearningModel(instr_config.embeded_dim).to(instr_config.device)

train = imitation_learning_trainer.train
validate = imitation_learning_trainer.validate

In [26]:
# CE Loss
il_criterion = torch.nn.CrossEntropyLoss()
lstm_criterion = torch.nn.CrossEntropyLoss()
# Adam
il_parameters = filter(lambda p: p.requires_grad, il_model.parameters())
il_optimizer = torch.optim.Adam(il_parameters, lr=instr_config.learning_rate)
lstm_parameters = filter(lambda p: p.requires_grad, lstm_model.parameters())
lstm_optimizer = torch.optim.Adam(lstm_parameters, lr=instr_config.learning_rate)
# Log
writer = SummaryWriter() if instr_config.summary_writer else None

### Train IL + LSTM

In [17]:
best_valid_loss = float('inf')
valid_epoch, best_epoch = 0, None

for epoch in range(instr_config.epochs):
    # train
    action_loss, lstm_loss, acc, bleu, tk_acc = train(
        instr_config.device,
        epoch,
        train_data_loader,
        il_model,
        lstm_model,
        il_optimizer,
        lstm_optimizer,
        il_criterion,
        lstm_criterion,
        il_parameters,
        lstm_parameters,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, train action loss: %.3f, train lang loss: %.3f, train action acc: %.3f, train bleu: %.3f, train token acc: %.3f' % (epoch, action_loss, lstm_loss, acc, bleu, tk_acc))
    # valid
    action_loss, lstm_loss, acc, bleu, tk_acc = validate(
        instr_config.device,
        epoch,
        validation_data_loader,
        il_model,
        lstm_model,
        il_criterion,
        lstm_criterion,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, valid action loss: %.3f, valid lang loss: %.3f, valid action acc: %.3f, valid bleu: %.3f, valid token acc: %.3f' % (epoch, action_loss, lstm_loss, acc, bleu, tk_acc))
    # early stopping
    if action_loss <= best_valid_loss:
        best_valid_loss = action_loss
        valid_epoch, best_epoch = 0, epoch
        torch.save(il_model.state_dict(), instr_config.SAVE_PATH)
        print('Best Epoch: %d, best valid loss: %.3f' % (best_epoch, best_valid_loss))
        print('Trained model saved at ', instr_config.SAVE_PATH)
    else:
        valid_epoch += 1
        if valid_epoch >= instr_config.valid_patience:
            break

Epoch: 0, train action loss: 0.861, train lang loss: 2.711, train action acc: 0.694, train bleu: 0.557, train token acc: 0.719: 100%|██████████| 2535/2535 [04:57<00:00,  8.52it/s]


Overall Epoch: 0, train action loss: 0.926, train lang loss: 3.130, train action acc: 0.673, train bleu: 0.512, train token acc: 0.643


Epoch: 0, valid action loss: 1.102, valid lang loss: 3.651, valid acc: 0.596, valid bleu: 0.376, valid token acc: 0.487: 100%|██████████| 634/634 [00:55<00:00, 11.42it/s]


Overall Epoch: 0, valid action loss: 0.952, valid lang loss: 3.281, valid action acc: 0.667, valid bleu: 0.493, valid token acc: 0.610
Best Epoch: 0, best valid loss: 0.952
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 1, train action loss: 0.764, train lang loss: 2.608, train action acc: 0.758, train bleu: 0.593, train token acc: 0.734: 100%|██████████| 2535/2535 [05:00<00:00,  8.44it/s]


Overall Epoch: 1, train action loss: 0.881, train lang loss: 3.058, train action acc: 0.691, train bleu: 0.519, train token acc: 0.651


Epoch: 1, valid action loss: 0.985, valid lang loss: 3.149, valid acc: 0.638, valid bleu: 0.492, valid token acc: 0.598: 100%|██████████| 634/634 [00:55<00:00, 11.47it/s]


Overall Epoch: 1, valid action loss: 0.919, valid lang loss: 3.230, valid action acc: 0.685, valid bleu: 0.499, valid token acc: 0.612
Best Epoch: 1, best valid loss: 0.919
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 2, train action loss: 0.747, train lang loss: 3.458, train action acc: 0.726, train bleu: 0.524, train token acc: 0.622: 100%|██████████| 2535/2535 [05:00<00:00,  8.44it/s]


Overall Epoch: 2, train action loss: 0.850, train lang loss: 3.011, train action acc: 0.705, train bleu: 0.525, train token acc: 0.657


Epoch: 2, valid action loss: 0.916, valid lang loss: 2.903, valid acc: 0.766, valid bleu: 0.549, valid token acc: 0.665: 100%|██████████| 634/634 [00:53<00:00, 11.83it/s]


Overall Epoch: 2, valid action loss: 0.902, valid lang loss: 3.236, valid action acc: 0.688, valid bleu: 0.506, valid token acc: 0.607
Best Epoch: 2, best valid loss: 0.902
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 3, train action loss: 1.041, train lang loss: 2.472, train action acc: 0.661, train bleu: 0.591, train token acc: 0.749: 100%|██████████| 2535/2535 [05:02<00:00,  8.39it/s]


Overall Epoch: 3, train action loss: 0.829, train lang loss: 3.015, train action acc: 0.714, train bleu: 0.525, train token acc: 0.655


Epoch: 3, valid action loss: 0.685, valid lang loss: 3.435, valid acc: 0.787, valid bleu: 0.513, valid token acc: 0.603: 100%|██████████| 634/634 [00:54<00:00, 11.70it/s]


Overall Epoch: 3, valid action loss: 0.900, valid lang loss: 3.214, valid action acc: 0.695, valid bleu: 0.503, valid token acc: 0.611
Best Epoch: 3, best valid loss: 0.900
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 4, train action loss: 0.868, train lang loss: 3.535, train action acc: 0.694, train bleu: 0.493, train token acc: 0.601: 100%|██████████| 2535/2535 [05:01<00:00,  8.40it/s]


Overall Epoch: 4, train action loss: 0.810, train lang loss: 2.987, train action acc: 0.722, train bleu: 0.527, train token acc: 0.658


Epoch: 4, valid action loss: 0.559, valid lang loss: 3.103, valid acc: 0.851, valid bleu: 0.476, valid token acc: 0.588: 100%|██████████| 634/634 [00:55<00:00, 11.40it/s]


Overall Epoch: 4, valid action loss: 0.886, valid lang loss: 3.178, valid action acc: 0.697, valid bleu: 0.501, valid token acc: 0.614
Best Epoch: 4, best valid loss: 0.886
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 5, train action loss: 0.752, train lang loss: 2.626, train action acc: 0.742, train bleu: 0.535, train token acc: 0.695: 100%|██████████| 2535/2535 [05:00<00:00,  8.45it/s]


Overall Epoch: 5, train action loss: 0.793, train lang loss: 2.985, train action acc: 0.730, train bleu: 0.528, train token acc: 0.658


Epoch: 5, valid action loss: 0.959, valid lang loss: 3.198, valid acc: 0.596, valid bleu: 0.431, valid token acc: 0.555: 100%|██████████| 634/634 [00:54<00:00, 11.57it/s]


Overall Epoch: 5, valid action loss: 0.875, valid lang loss: 3.176, valid action acc: 0.701, valid bleu: 0.508, valid token acc: 0.614
Best Epoch: 5, best valid loss: 0.875
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 6, train action loss: 0.673, train lang loss: 3.351, train action acc: 0.758, train bleu: 0.475, train token acc: 0.595: 100%|██████████| 2535/2535 [04:59<00:00,  8.46it/s]


Overall Epoch: 6, train action loss: 0.781, train lang loss: 2.970, train action acc: 0.734, train bleu: 0.529, train token acc: 0.659


Epoch: 6, valid action loss: 0.682, valid lang loss: 2.981, valid acc: 0.809, valid bleu: 0.550, valid token acc: 0.633: 100%|██████████| 634/634 [00:54<00:00, 11.56it/s]


Overall Epoch: 6, valid action loss: 0.880, valid lang loss: 3.172, valid action acc: 0.702, valid bleu: 0.504, valid token acc: 0.613


Epoch: 7, train action loss: 0.757, train lang loss: 3.303, train action acc: 0.742, train bleu: 0.481, train token acc: 0.586: 100%|██████████| 2535/2535 [05:02<00:00,  8.39it/s]


Overall Epoch: 7, train action loss: 0.767, train lang loss: 2.944, train action acc: 0.739, train bleu: 0.532, train token acc: 0.663


Epoch: 7, valid action loss: 0.737, valid lang loss: 2.970, valid acc: 0.787, valid bleu: 0.537, valid token acc: 0.644: 100%|██████████| 634/634 [00:54<00:00, 11.60it/s]


Overall Epoch: 7, valid action loss: 0.879, valid lang loss: 3.165, valid action acc: 0.698, valid bleu: 0.503, valid token acc: 0.613


Epoch: 8, train action loss: 0.672, train lang loss: 2.447, train action acc: 0.742, train bleu: 0.589, train token acc: 0.737: 100%|██████████| 2535/2535 [05:01<00:00,  8.42it/s]


Overall Epoch: 8, train action loss: 0.755, train lang loss: 2.926, train action acc: 0.743, train bleu: 0.536, train token acc: 0.666


Epoch: 8, valid action loss: 0.910, valid lang loss: 3.054, valid acc: 0.723, valid bleu: 0.484, valid token acc: 0.590: 100%|██████████| 634/634 [00:53<00:00, 11.75it/s]


Overall Epoch: 8, valid action loss: 0.880, valid lang loss: 3.168, valid action acc: 0.702, valid bleu: 0.505, valid token acc: 0.614


Epoch: 9, train action loss: 0.669, train lang loss: 3.432, train action acc: 0.742, train bleu: 0.495, train token acc: 0.594: 100%|██████████| 2535/2535 [05:04<00:00,  8.32it/s]


Overall Epoch: 9, train action loss: 0.748, train lang loss: 2.913, train action acc: 0.746, train bleu: 0.537, train token acc: 0.667


Epoch: 9, valid action loss: 1.179, valid lang loss: 2.890, valid acc: 0.617, valid bleu: 0.453, valid token acc: 0.591: 100%|██████████| 634/634 [00:55<00:00, 11.37it/s]


Overall Epoch: 9, valid action loss: 0.868, valid lang loss: 3.154, valid action acc: 0.709, valid bleu: 0.502, valid token acc: 0.614
Best Epoch: 9, best valid loss: 0.868
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 10, train action loss: 0.890, train lang loss: 3.528, train action acc: 0.677, train bleu: 0.501, train token acc: 0.591: 100%|██████████| 2535/2535 [05:04<00:00,  8.31it/s]


Overall Epoch: 10, train action loss: 0.740, train lang loss: 2.915, train action acc: 0.750, train bleu: 0.537, train token acc: 0.667


Epoch: 10, valid action loss: 0.850, valid lang loss: 3.040, valid acc: 0.638, valid bleu: 0.564, valid token acc: 0.649: 100%|██████████| 634/634 [00:55<00:00, 11.42it/s]


Overall Epoch: 10, valid action loss: 0.870, valid lang loss: 3.173, valid action acc: 0.704, valid bleu: 0.508, valid token acc: 0.613


Epoch: 11, train action loss: 0.626, train lang loss: 2.233, train action acc: 0.726, train bleu: 0.590, train token acc: 0.735: 100%|██████████| 2535/2535 [05:03<00:00,  8.34it/s]


Overall Epoch: 11, train action loss: 0.735, train lang loss: 2.917, train action acc: 0.751, train bleu: 0.537, train token acc: 0.666


Epoch: 11, valid action loss: 0.768, valid lang loss: 3.191, valid acc: 0.660, valid bleu: 0.484, valid token acc: 0.598: 100%|██████████| 634/634 [00:55<00:00, 11.50it/s]


Overall Epoch: 11, valid action loss: 0.859, valid lang loss: 3.167, valid action acc: 0.710, valid bleu: 0.503, valid token acc: 0.615
Best Epoch: 11, best valid loss: 0.859
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 12, train action loss: 0.946, train lang loss: 3.423, train action acc: 0.694, train bleu: 0.442, train token acc: 0.572: 100%|██████████| 2535/2535 [05:04<00:00,  8.31it/s]


Overall Epoch: 12, train action loss: 0.728, train lang loss: 2.905, train action acc: 0.753, train bleu: 0.538, train token acc: 0.667


Epoch: 12, valid action loss: 0.763, valid lang loss: 3.129, valid acc: 0.766, valid bleu: 0.467, valid token acc: 0.584: 100%|██████████| 634/634 [00:55<00:00, 11.52it/s]


Overall Epoch: 12, valid action loss: 0.858, valid lang loss: 3.141, valid action acc: 0.709, valid bleu: 0.506, valid token acc: 0.613
Best Epoch: 12, best valid loss: 0.858
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 13, train action loss: 0.843, train lang loss: 3.761, train action acc: 0.694, train bleu: 0.556, train token acc: 0.639: 100%|██████████| 2535/2535 [05:04<00:00,  8.32it/s]


Overall Epoch: 13, train action loss: 0.722, train lang loss: 2.895, train action acc: 0.755, train bleu: 0.538, train token acc: 0.668


Epoch: 13, valid action loss: 0.844, valid lang loss: 3.198, valid acc: 0.681, valid bleu: 0.493, valid token acc: 0.611: 100%|██████████| 634/634 [00:54<00:00, 11.69it/s]


Overall Epoch: 13, valid action loss: 0.860, valid lang loss: 3.137, valid action acc: 0.705, valid bleu: 0.501, valid token acc: 0.614


Epoch: 14, train action loss: 0.629, train lang loss: 3.634, train action acc: 0.774, train bleu: 0.443, train token acc: 0.555: 100%|██████████| 2535/2535 [05:04<00:00,  8.32it/s]


Overall Epoch: 14, train action loss: 0.717, train lang loss: 2.901, train action acc: 0.757, train bleu: 0.538, train token acc: 0.667


Epoch: 14, valid action loss: 0.953, valid lang loss: 2.952, valid acc: 0.681, valid bleu: 0.486, valid token acc: 0.607: 100%|██████████| 634/634 [00:59<00:00, 10.57it/s]


Overall Epoch: 14, valid action loss: 0.856, valid lang loss: 3.149, valid action acc: 0.711, valid bleu: 0.504, valid token acc: 0.615
Best Epoch: 14, best valid loss: 0.856
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 15, train action loss: 0.926, train lang loss: 3.564, train action acc: 0.758, train bleu: 0.509, train token acc: 0.606: 100%|██████████| 2535/2535 [05:16<00:00,  8.00it/s]


Overall Epoch: 15, train action loss: 0.713, train lang loss: 2.878, train action acc: 0.758, train bleu: 0.541, train token acc: 0.671


Epoch: 15, valid action loss: 0.728, valid lang loss: 3.125, valid acc: 0.745, valid bleu: 0.522, valid token acc: 0.629: 100%|██████████| 634/634 [00:56<00:00, 11.31it/s]


Overall Epoch: 15, valid action loss: 0.855, valid lang loss: 3.139, valid action acc: 0.711, valid bleu: 0.504, valid token acc: 0.616
Best Epoch: 15, best valid loss: 0.855
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 16, train action loss: 0.588, train lang loss: 2.247, train action acc: 0.758, train bleu: 0.651, train token acc: 0.778: 100%|██████████| 2535/2535 [05:04<00:00,  8.33it/s]


Overall Epoch: 16, train action loss: 0.708, train lang loss: 2.877, train action acc: 0.759, train bleu: 0.541, train token acc: 0.670


Epoch: 16, valid action loss: 0.928, valid lang loss: 3.196, valid acc: 0.745, valid bleu: 0.507, valid token acc: 0.601: 100%|██████████| 634/634 [00:55<00:00, 11.49it/s]


Overall Epoch: 16, valid action loss: 0.860, valid lang loss: 3.146, valid action acc: 0.708, valid bleu: 0.505, valid token acc: 0.615


Epoch: 17, train action loss: 0.599, train lang loss: 2.418, train action acc: 0.806, train bleu: 0.546, train token acc: 0.708: 100%|██████████| 2535/2535 [05:02<00:00,  8.38it/s]


Overall Epoch: 17, train action loss: 0.706, train lang loss: 2.889, train action acc: 0.761, train bleu: 0.540, train token acc: 0.668


Epoch: 17, valid action loss: 0.819, valid lang loss: 3.318, valid acc: 0.702, valid bleu: 0.503, valid token acc: 0.603: 100%|██████████| 634/634 [00:55<00:00, 11.49it/s]


Overall Epoch: 17, valid action loss: 0.859, valid lang loss: 3.123, valid action acc: 0.710, valid bleu: 0.506, valid token acc: 0.615


Epoch: 18, train action loss: 0.499, train lang loss: 2.265, train action acc: 0.855, train bleu: 0.656, train token acc: 0.788: 100%|██████████| 2535/2535 [05:03<00:00,  8.36it/s]


Overall Epoch: 18, train action loss: 0.702, train lang loss: 2.869, train action acc: 0.762, train bleu: 0.542, train token acc: 0.671


Epoch: 18, valid action loss: 0.545, valid lang loss: 3.253, valid acc: 0.851, valid bleu: 0.545, valid token acc: 0.651: 100%|██████████| 634/634 [00:54<00:00, 11.53it/s]


Overall Epoch: 18, valid action loss: 0.855, valid lang loss: 3.139, valid action acc: 0.708, valid bleu: 0.505, valid token acc: 0.615
Best Epoch: 18, best valid loss: 0.855
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 19, train action loss: 1.017, train lang loss: 3.273, train action acc: 0.597, train bleu: 0.460, train token acc: 0.582: 100%|██████████| 2535/2535 [05:00<00:00,  8.43it/s]


Overall Epoch: 19, train action loss: 0.696, train lang loss: 2.858, train action acc: 0.763, train bleu: 0.544, train token acc: 0.673


Epoch: 19, valid action loss: 1.007, valid lang loss: 3.673, valid acc: 0.574, valid bleu: 0.450, valid token acc: 0.553: 100%|██████████| 634/634 [00:55<00:00, 11.50it/s]

Overall Epoch: 19, valid action loss: 0.859, valid lang loss: 3.157, valid action acc: 0.711, valid bleu: 0.501, valid token acc: 0.614





### Train IL Only

In [None]:
best_valid_loss = float('inf')
valid_epoch, best_epoch = 0, None

for epoch in range(instr_config.epochs):
    # train
    action_loss, lstm_loss, acc, bleu, tk_acc = train(
        instr_config.device,
        epoch,
        train_data_loader,
        il_model,
        lstm_model,
        il_optimizer,
        lstm_optimizer,
        il_criterion,
        lstm_criterion,
        il_parameters,
        lstm_parameters,
        vocab,
        use_generative_language=False,
        summary_writer=writer)
    print('Overall Epoch: %d, train action loss: %.3f, train action acc: %.3f' % (epoch, action_loss, acc))
    # valid
    action_loss, lstm_loss, acc, bleu, tk_acc = validate(
        instr_config.device,
        epoch,
        validation_data_loader,
        il_model,
        lstm_model,
        il_criterion,
        lstm_criterion,
        vocab,
        use_generative_language=False,
        summary_writer=writer)
    print('Overall Epoch: %d, valid action loss: %.3f, valid action acc: %.3f' % (epoch, action_loss, acc))
    # early stopping
    if action_loss <= best_valid_loss:
        best_valid_loss = action_loss
        valid_epoch, best_epoch = 0, epoch
        torch.save(il_model.state_dict(), instr_config.SAVE_PATH)
        print('Best Epoch: %d, best valid loss: %.3f' % (best_epoch, best_valid_loss))
        print('Trained model saved at ', instr_config.SAVE_PATH)
    else:
        valid_epoch += 1
        if valid_epoch >= instr_config.valid_patience:
            break

Epoch: 0, train action loss: 1.263 train action acc: 0.581: 100%|██████████| 2535/2535 [00:29<00:00, 86.40it/s]


Overall Epoch: 0, train action loss: 1.278, train action acc: 0.523


Epoch: 0, valid action loss: 1.414 valid action acc: 0.532: 100%|██████████| 634/634 [00:05<00:00, 116.60it/s]


Overall Epoch: 0, valid action loss: 1.225, valid action acc: 0.551
Best Epoch: 0, best valid loss: 1.225
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 1, train action loss: 1.300 train action acc: 0.516: 100%|██████████| 2535/2535 [00:29<00:00, 86.76it/s]


Overall Epoch: 1, train action loss: 1.186, train action acc: 0.565


Epoch: 1, valid action loss: 1.126 valid action acc: 0.681: 100%|██████████| 634/634 [00:05<00:00, 113.59it/s]


Overall Epoch: 1, valid action loss: 1.183, valid action acc: 0.568
Best Epoch: 1, best valid loss: 1.183
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 2, train action loss: 1.101 train action acc: 0.548: 100%|██████████| 2535/2535 [00:29<00:00, 86.72it/s]


Overall Epoch: 2, train action loss: 1.144, train action acc: 0.584


Epoch: 2, valid action loss: 0.903 valid action acc: 0.638: 100%|██████████| 634/634 [00:05<00:00, 118.66it/s]


Overall Epoch: 2, valid action loss: 1.156, valid action acc: 0.578
Best Epoch: 2, best valid loss: 1.156
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 3, train action loss: 1.351 train action acc: 0.565: 100%|██████████| 2535/2535 [00:28<00:00, 89.73it/s]


Overall Epoch: 3, train action loss: 1.118, train action acc: 0.596


Epoch: 3, valid action loss: 1.194 valid action acc: 0.617: 100%|██████████| 634/634 [00:05<00:00, 123.51it/s]


Overall Epoch: 3, valid action loss: 1.137, valid action acc: 0.590
Best Epoch: 3, best valid loss: 1.137
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 4, train action loss: 1.000 train action acc: 0.677: 100%|██████████| 2535/2535 [00:27<00:00, 90.94it/s]


Overall Epoch: 4, train action loss: 1.102, train action acc: 0.602


Epoch: 4, valid action loss: 0.859 valid action acc: 0.596: 100%|██████████| 634/634 [00:05<00:00, 122.73it/s]


Overall Epoch: 4, valid action loss: 1.132, valid action acc: 0.591
Best Epoch: 4, best valid loss: 1.132
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 5, train action loss: 1.422 train action acc: 0.565: 100%|██████████| 2535/2535 [00:29<00:00, 85.09it/s] 


Overall Epoch: 5, train action loss: 1.089, train action acc: 0.608


Epoch: 5, valid action loss: 1.112 valid action acc: 0.574: 100%|██████████| 634/634 [00:05<00:00, 122.75it/s]


Overall Epoch: 5, valid action loss: 1.118, valid action acc: 0.598
Best Epoch: 5, best valid loss: 1.118
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 6, train action loss: 1.004 train action acc: 0.613: 100%|██████████| 2535/2535 [00:27<00:00, 90.65it/s]


Overall Epoch: 6, train action loss: 1.079, train action acc: 0.612


Epoch: 6, valid action loss: 0.836 valid action acc: 0.681: 100%|██████████| 634/634 [00:05<00:00, 122.23it/s]


Overall Epoch: 6, valid action loss: 1.121, valid action acc: 0.599


Epoch: 7, train action loss: 1.150 train action acc: 0.565: 100%|██████████| 2535/2535 [00:28<00:00, 90.40it/s] 


Overall Epoch: 7, train action loss: 1.071, train action acc: 0.616


Epoch: 7, valid action loss: 1.429 valid action acc: 0.532: 100%|██████████| 634/634 [00:05<00:00, 107.09it/s]


Overall Epoch: 7, valid action loss: 1.116, valid action acc: 0.603
Best Epoch: 7, best valid loss: 1.116
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 8, train action loss: 0.851 train action acc: 0.661: 100%|██████████| 2535/2535 [00:28<00:00, 87.93it/s]


Overall Epoch: 8, train action loss: 1.062, train action acc: 0.620


Epoch: 8, valid action loss: 1.262 valid action acc: 0.681: 100%|██████████| 634/634 [00:05<00:00, 123.39it/s]


Overall Epoch: 8, valid action loss: 1.106, valid action acc: 0.605
Best Epoch: 8, best valid loss: 1.106
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 9, train action loss: 1.099 train action acc: 0.613: 100%|██████████| 2535/2535 [00:27<00:00, 91.27it/s]


Overall Epoch: 9, train action loss: 1.047, train action acc: 0.627


Epoch: 9, valid action loss: 1.296 valid action acc: 0.447: 100%|██████████| 634/634 [00:05<00:00, 124.12it/s]


Overall Epoch: 9, valid action loss: 1.089, valid action acc: 0.614
Best Epoch: 9, best valid loss: 1.089
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 10, train action loss: 1.127 train action acc: 0.597: 100%|██████████| 2535/2535 [00:27<00:00, 91.22it/s]


Overall Epoch: 10, train action loss: 1.033, train action acc: 0.633


Epoch: 10, valid action loss: 1.116 valid action acc: 0.617: 100%|██████████| 634/634 [00:05<00:00, 120.42it/s]


Overall Epoch: 10, valid action loss: 1.083, valid action acc: 0.617
Best Epoch: 10, best valid loss: 1.083
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 11, train action loss: 1.019 train action acc: 0.645: 100%|██████████| 2535/2535 [00:27<00:00, 91.23it/s]


Overall Epoch: 11, train action loss: 1.022, train action acc: 0.639


Epoch: 11, valid action loss: 1.168 valid action acc: 0.617: 100%|██████████| 634/634 [00:05<00:00, 122.25it/s]


Overall Epoch: 11, valid action loss: 1.077, valid action acc: 0.623
Best Epoch: 11, best valid loss: 1.077
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 12, train action loss: 0.979 train action acc: 0.629: 100%|██████████| 2535/2535 [00:32<00:00, 77.35it/s]


Overall Epoch: 12, train action loss: 1.015, train action acc: 0.641


Epoch: 12, valid action loss: 0.906 valid action acc: 0.660: 100%|██████████| 634/634 [00:06<00:00, 97.59it/s] 


Overall Epoch: 12, valid action loss: 1.081, valid action acc: 0.620


Epoch: 13, train action loss: 0.923 train action acc: 0.677: 100%|██████████| 2535/2535 [00:35<00:00, 70.94it/s]


Overall Epoch: 13, train action loss: 1.009, train action acc: 0.643


Epoch: 13, valid action loss: 1.195 valid action acc: 0.617: 100%|██████████| 634/634 [00:06<00:00, 91.38it/s] 


Overall Epoch: 13, valid action loss: 1.077, valid action acc: 0.620


Epoch: 14, train action loss: 1.015 train action acc: 0.688:  51%|█████     | 1284/2535 [00:17<00:17, 69.56it/s]