# CS7643 - Final Project

In [1]:
# built-in
# public
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# private
from config import Config, InstrConfig
from src.utils import dataloader, helpers
from src.models.instructions_generator_model import InstructionsGeneratorModel
from src.trainer import instructions_generator_trainer

#%load_ext autoreload 
#%autoreload 2
#%config Completer.use_jedi = False

# Initialization

In [2]:
instr_config = InstrConfig()
helpers.set_seed(instr_config.random_seed)

## Instructions Generator

## Data

### Load Data

In [3]:
# train size 202831
# unique number of instructions 35921
data = dataloader.load_pkl(workdir=instr_config.DATA_PATH)

In [4]:
train_states, train_inventories, train_actions, train_goals, train_instructions, all_instructions = data

In [5]:
# remove invalid sample where train instruction is None
invalid_index = set([i for i, _ in enumerate(train_instructions) if not _])
print(len(invalid_index))

34


In [6]:
valid_index = [i for i, _ in enumerate(train_instructions) if _]
print(len(valid_index))
train_states = np.array(train_states)[valid_index].tolist()
train_inventories = np.array(train_inventories)[valid_index].tolist()
train_actions = np.array(train_actions)[valid_index].tolist()
train_instructions = np.array(train_instructions)[valid_index].tolist()

202797


### Build Vocab

In [7]:
vocab, vocab_weights = dataloader.generate_vocab(
    all_instructions, instr_config.device, workdir=instr_config.DATA_PATH, cache='/usr/local/google/home/billzhou/Documents/glove')

Total vocabulary size: 212


### Generate Dataset

In [8]:
dataset = dataloader.CraftingDataset(
  instr_config.embeded_dim,
  train_states,
  train_inventories,
  train_actions,
  train_goals,
  train_instructions,
  vocab,
  cache='/usr/local/google/home/billzhou/Documents/glove')

embedding loaded
one hot loaded
actions loaded
goals loaded
done loading dataset


In [9]:
min([len(d[-1]) for d in dataset])

3

In [10]:
instr_config.dataset_size = len(dataset)

### Split Dataset

In [11]:
indices = list(range(instr_config.dataset_size))
split = int(np.floor(instr_config.validation_split * instr_config.dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

### Initialize Data Loader

In [12]:
train_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=train_sampler,
  collate_fn=dataloader.collate_fn)

validation_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=valid_sampler,
  collate_fn=dataloader.collate_fn)

### Setup Training

In [13]:
model = InstructionsGeneratorModel(
    instr_config.device
    , vocab
    , instr_config.embeded_dim
    , vocab_weights
).to(instr_config.device)
train = instructions_generator_trainer.train
validate = instructions_generator_trainer.validate

In [14]:
# CE Loss
criterion = torch.nn.CrossEntropyLoss()
# Adam
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters, lr=instr_config.learning_rate)
# Log
writer = SummaryWriter() if instr_config.summary_writer else None

### GO

In [15]:
best_valid_loss = float('inf')
valid_epoch, best_epoch = 0, None

for epoch in range(instr_config.epochs):
    # train
    loss, bleu, tk_acc = train(
        instr_config.device,
        epoch,
        train_data_loader,
        model,
        optimizer,
        criterion,
        parameters,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, train loss: %.3f, train bleu: %.3f, train token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # valid
    loss, bleu, tk_acc = validate(
        instr_config.device,
        epoch,
        validation_data_loader,
        model,
        criterion,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, valid loss: %.3f, valid bleu: %.3f, valid token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # early stopping
    if loss <= best_valid_loss:
        best_valid_loss = loss
        valid_epoch, best_epoch = 0, epoch
        torch.save(model.state_dict(), instr_config.SAVE_PATH)
        print('Best Epoch: %d, best valid loss: %.3f' % (best_epoch, best_valid_loss))
        print('Trained model saved at ', instr_config.SAVE_PATH)
    else:
        valid_epoch += 1
        if valid_epoch >= instr_config.valid_patience:
            break

Epoch: 0, train loss: 2.605, train bleu: 0.589, train token acc: 0.720: 100%|██████████| 2535/2535 [04:53<00:00,  8.65it/s]


Overall Epoch: 0, train loss: 2.900, train bleu: 0.530, train token acc: 0.684


Epoch: 0, valid loss: 4.509, valid bleu: 0.426, valid token acc: 0.557: 100%|██████████| 634/634 [00:54<00:00, 11.58it/s]


Overall Epoch: 0, valid loss: 4.593, valid bleu: 0.477, valid token acc: 0.592
Best Epoch: 0, best valid loss: 4.593
Trained model saved at  /google/src/cloud/billzhou/ask_your_human/google3/experimental/users/billzhou/Ask-Your-Humans/res/cpts/instr_gen.pt


Epoch: 1, train loss: 2.450, train bleu: 0.590, train token acc: 0.728: 100%|██████████| 2535/2535 [04:55<00:00,  8.58it/s]


Overall Epoch: 1, train loss: 2.378, train bleu: 0.579, train token acc: 0.735


Epoch: 1, valid loss: 4.898, valid bleu: 0.476, valid token acc: 0.605: 100%|██████████| 634/634 [00:53<00:00, 11.76it/s]


Overall Epoch: 1, valid loss: 4.852, valid bleu: 0.487, valid token acc: 0.599


Epoch: 2, train loss: 2.286, train bleu: 0.549, train token acc: 0.720: 100%|██████████| 2535/2535 [04:55<00:00,  8.57it/s]


Overall Epoch: 2, train loss: 2.244, train bleu: 0.599, train token acc: 0.751


Epoch: 2, valid loss: 5.914, valid bleu: 0.448, valid token acc: 0.555: 100%|██████████| 634/634 [00:54<00:00, 11.58it/s]


Overall Epoch: 2, valid loss: 5.099, valid bleu: 0.491, valid token acc: 0.601


Epoch: 3, train loss: 2.045, train bleu: 0.641, train token acc: 0.786: 100%|██████████| 2535/2535 [05:00<00:00,  8.43it/s]


Overall Epoch: 3, train loss: 2.172, train bleu: 0.610, train token acc: 0.759


Epoch: 3, valid loss: 5.674, valid bleu: 0.437, valid token acc: 0.538: 100%|██████████| 634/634 [00:54<00:00, 11.63it/s]


Overall Epoch: 3, valid loss: 5.289, valid bleu: 0.489, valid token acc: 0.598


Epoch: 4, train loss: 2.065, train bleu: 0.635, train token acc: 0.777: 100%|██████████| 2535/2535 [04:58<00:00,  8.50it/s]


Overall Epoch: 4, train loss: 2.128, train bleu: 0.616, train token acc: 0.764


Epoch: 4, valid loss: 4.621, valid bleu: 0.412, valid token acc: 0.588: 100%|██████████| 634/634 [00:53<00:00, 11.93it/s]

Overall Epoch: 4, valid loss: 5.338, valid bleu: 0.491, valid token acc: 0.603



