# CS7643 - Final Project

In [None]:
# built-in
# public
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# private
from config import Config, InstrConfig
from src.utils import dataloader, helpers
from src.models.instructions_generator_model import InstructionsGeneratorModel
from src.trainer import instructions_generator_trainer

%load_ext autoreload 
%autoreload 2
%config Completer.use_jedi = False

# Initialization

In [None]:
instr_config = InstrConfig()
helpers.set_seed(instr_config.random_seed)

## Instructions Generator

## Data

### Load Data

In [None]:
# train size 202831
# unique number of instructions 35921
data = dataloader.load_pkl(workdir=instr_config.DATA_PATH)

In [None]:
train_states, train_inventories, train_actions, train_goals, train_instructions, all_instructions = data

In [None]:
# remove invalid sample where train instruction is None
invalid_index = set([i for i, _ in enumerate(train_instructions) if not _])
print(len(invalid_index))

In [None]:
valid_index = [i for i, _ in enumerate(train_instructions) if _]
print(len(valid_index))
train_states = np.array(train_states)[valid_index].tolist()
train_inventories = np.array(train_inventories)[valid_index].tolist()
train_actions = np.array(train_actions)[valid_index].tolist()
train_instructions = np.array(train_instructions)[valid_index].tolist()

### Build Vocab

In [None]:
vocab, vocab_weights = dataloader.generate_vocab(
    all_instructions, instr_config.device, workdir=instr_config.DATA_PATH)

### Generate Dataset

In [None]:
dataset = dataloader.CraftingDataset(
  instr_config.embeded_dim,
  train_states,
  train_inventories,
  train_actions,
  train_goals,
  train_instructions,
  vocab)

In [None]:
min([len(d[-1]) for d in dataset])

In [None]:
instr_config.dataset_size = len(dataset)

### Split Dataset

In [None]:
indices = list(range(instr_config.dataset_size))
split = int(np.floor(instr_config.validation_split * instr_config.dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

### Initialize Data Loader

In [None]:
train_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=train_sampler,
  collate_fn=dataloader.collate_fn)

validation_data_loader = DataLoader(
  dataset,
  batch_size=instr_config.batch_size,
  num_workers=0,
  pin_memory=True,
  sampler=valid_sampler,
  collate_fn=dataloader.collate_fn)

### Setup Training

In [None]:
model = InstructionsGeneratorModel(
    instr_config.device
    , len(vocab)
    , instr_config.embeded_dim
    , vocab_weights
).to(instr_config.device)
train = instructions_generator_trainer.train
validate = instructions_generator_trainer.validate

In [None]:
# CE Loss
criterion = torch.nn.CrossEntropyLoss()
# Adam
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters, lr=instr_config.learning_rate)
# Log
writer = SummaryWriter() if instr_config.summary_writer else None

### GO

In [None]:
best_valid_loss = float('inf')
valid_epoch, best_epoch = 0, None

for epoch in range(instr_config.epochs):
    # train
    loss, bleu, tk_acc = train(
        instr_config.device,
        epoch,
        train_data_loader,
        model,
        optimizer,
        criterion,
        parameters,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, train loss: %.3f, train bleu: %.3f, train token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # valid
    loss, bleu, tk_acc = validate(
        instr_config.device,
        epoch,
        validation_data_loader,
        model,
        criterion,
        vocab,
        summary_writer=writer)
    print('Overall Epoch: %d, valid loss: %.3f, valid bleu: %.3f, valid token acc: %.3f' % (epoch, loss, bleu, tk_acc))
    # early stopping
    if loss <= best_valid_loss:
        best_valid_loss = loss
        valid_epoch, best_epoch = 0, epoch
        torch.save(model.state_dict(), instr_config.SAVE_PATH)
        print('Best Epoch: %d, best valid loss: %.3f' % (best_epoch, best_valid_loss))
        print('Trained model saved at ', instr_config.SAVE_PATH)
    else:
        valid_epoch += 1
        if valid_epoch >= instr_config.valid_patience:
            break