In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import tensorboardX
import torch
import torch.nn.functional as F

In [None]:
from utils import data_utils, train_utils

In [None]:
%%time
# Load train and val 
df_train = pd.read_pickle('data/df_train_train.pkl')
df_val = pd.read_pickle('data/df_train_val.pkl')

In [None]:
# file_path_train = 'data/multinli_1.0_train.jsonl'
file_path_heldout = 'data/multinli_1.0_dev_matched.jsonl'
file_path_unlabeled = 'data/multinli_0.9_test_matched_unlabeled.jsonl'
# df_train = data_utils.load_data(file_path_train)
df_heldout = data_utils.load_data(file_path_heldout)
df_unlabeled = data_utils.load_data(file_path_unlabeled)

In [None]:
df_unlabeled['gold_label'] = 'hidden'

# Data loader

In [None]:
%%time
pickle_file = 'weights/glove.pickle'
if not os.path.exists(pickle_file):
    glove = load_word_vectors('models/glove.840B.300d.txt')  # FIXME: There shold be 2196017 words
    print(len(glove))

    with open(pickle_file, 'wb') as outfile:
        pickle.dump(glove, outfile)

with open(pickle_file, 'rb') as infile:
    glove = pickle.load(infile)

In [None]:
DEVICE=3
BATCH_SIZE = 8

In [None]:
def get_dataset_dataloader(df, sort_by_len: bool = True, shuffle: bool = False):
    dataset = data_utils.MNLIDataset(df, word_vectors=glove, sort_by_len=sort_by_len)
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=shuffle,
                                            collate_fn=data_utils.collate_fn)
    return dataset, dataloader

dataset_train, dataloader_train = get_dataset_dataloader(df_train)
dataset_val, dataloader_val = get_dataset_dataloader(df_val)
dataset_heldout, dataloader_heldout = get_dataset_dataloader(df_heldout)
dataset_unlabeled, dataloader_unlabeled = get_dataset_dataloader(df_unlabeled, sort_by_len=False, shuffle=False)

print(len(dataset_train), len(dataloader_train))
print(len(dataset_val), len(dataloader_val))
print(len(dataset_heldout), len(dataloader_heldout))
print(len(dataset_unlabeled), len(dataloader_unlabeled))  # NOTE: MAKE SURE THIS IS NOT SHUFFLED!

## Setup model and logging

In [None]:
from models.lstm import LSTM

In [None]:
model = LSTM(linear_size=512).cuda(device=DEVICE)

loss_func = torch.nn.NLLLoss().cuda(device=DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
model_str = 'lstm-batched-1'
model_dir = '/opt/data/weights/{}'.format(model_str)
log_dir = 'logs/{}'.format(model_str)

os.makedirs(model_dir)
writer = tensorboardX.SummaryWriter(log_dir)

In [None]:
n_params = 0
for param in model.parameters():
    if param.requires_grad: n_params += np.prod(param.size())

print(n_params)

## Train

In [None]:
train_utils.train(model=model, dataloader_train=dataloader_train, dataloader_val=dataloader_val, optimizer=optimizer,
                 loss_func=loss_func, model_dir=model_dir, n_epochs=8, device=DEVICE, writer=writer)

## Val

In [None]:
# state_dict = torch.load('/opt/data/weights/lstm-2.4/lstm-2.4_2_392701.pt')
# model.load_state_dict(state_dict)

In [None]:
%%time
losses, accs = [], []
for dataloader in [dataloader_train, dataloader_val, dataloader_heldout]:
    loss, acc = train_utils.evaluate(model, dataloader, device=DEVICE, loss_func=loss_func, n_batches=1500)
    print(loss, acc)
    losses.append(loss)
    accs.append(acc)

# Test

In [None]:
%%time
preds = train_utils.predict(model, dataloader_unlabeled, device=DEVICE)
os.makedirs('results', exist_ok=True)
df_unlabeled['gold_label'] = list(map(lambda x: data_utils.id_to_lbl[x], list(preds)))
df_unlabeled[['pairID', 'gold_label']].to_csv('results/{}.csv'.format(model_str), index=False)