In [1]:
import pickle as pk
import torch
import pandas as pd

In [2]:
# load datasets and embeddings
mnli_train_dict = pk.load(open("./hw2_data/mnli_train_dict.pk", "rb"))
mnli_val_dict = pk.load(open("./hw2_data/mnli_val_dict.pk", "rb"))

snli_train_id = pk.load(open("./hw2_data/snli_train_id.pk", "rb"))
snli_val_id = pk.load(open("./hw2_data/snli_val_id.pk", "rb"))

loaded_embeddings_ft = pk.load(open("./hw2_data/loaded_embeddings_ft.pk", "rb"))
Genres = list(mnli_train_dict.keys())

In [41]:
from rnn_trainer import rnn_trainer

model_path = 'best_models/RNN.pth'
with open(model_path, 'rb') as model_dict:
        checkpoint = torch.load(model_dict)
best_args = checkpoint['config_dict']

In [44]:
mnli_rnn_acc = {}
for g in Genres:
    trainer = rnn_trainer(mnli_train_dict[g], mnli_val_dict[g], loaded_embeddings_ft, best_args)
    trainer.load_model('best_models', 'RNN')
    mnli_rnn_acc[g] = trainer.eval_stage()

trainer = rnn_trainer(snli_train_id, snli_val_id, loaded_embeddings_ft, best_args)
trainer.load_model('best_models', 'RNN')
mnli_rnn_acc['SNLI'] = trainer.eval_stage()


In [45]:
mnli_rnn_acc_df = pd.DataFrame.from_dict(mnli_rnn_acc, orient='index').T

In [3]:
from cnn_trainer import cnn_trainer
model_path = 'best_models/CNN.pth'
with open(model_path, 'rb') as model_dict:
        checkpoint = torch.load(model_dict)
best_args = checkpoint['config_dict']

In [4]:
mnli_cnn_acc = {}
for g in Genres:
    trainer = cnn_trainer(mnli_train_dict[g], mnli_val_dict[g], loaded_embeddings_ft, best_args)
    trainer.load_model('best_models', 'CNN')
    mnli_cnn_acc[g] = trainer.eval_stage()
    
trainer = cnn_trainer(snli_train_id, snli_val_id, loaded_embeddings_ft, best_args)
trainer.load_model('best_models', 'CNN')
mnli_cnn_acc['SNLI'] = trainer.eval_stage()


In [5]:
mnli_cnn_acc_df = pd.DataFrame.from_dict(mnli_cnn_acc, orient='index').T

In [53]:
MNLI_accuracy = pd.concat([mnli_cnn_acc_df, mnli_rnn_acc_df])
MNLI_accuracy.index = ['CNN', 'RNN']
MNLI_accuracy

Unnamed: 0,telephone,fiction,slate,government,travel,SNLI
CNN,46.666667,47.035176,44.311377,45.669291,46.94501,70.6
RNN,44.975124,48.241206,44.810379,45.374016,47.759674,70.0


## Finetuning

In [6]:
# load best model  -- CNN
from cnn_trainer import cnn_trainer
model_path = 'best_models/CNN.pth'
with open(model_path, 'rb') as model_dict:
        checkpoint = torch.load(model_dict)
best_args = checkpoint['config_dict']

#train 15 epochs
best_args['num_epochs'] = 15

#Finetune on each genre
genre_acc = []
for g in Genres:
    d = {}
    d['genre'] = g
    trainer = cnn_trainer(mnli_train_dict[g], mnli_val_dict[g], loaded_embeddings_ft, best_args)
    trainer.load_model('best_models', 'CNN')
    train_err, val_acc = trainer.go()
    d['val_acc'] = val_acc[-1]
    genre_acc.append(d)


100%|██████████| 15/15 [00:16<00:00,  1.09s/it]
100%|██████████| 15/15 [00:48<00:00,  3.26s/it]
100%|██████████| 15/15 [00:13<00:00,  1.15it/s]
100%|██████████| 15/15 [00:12<00:00,  1.17it/s]
100%|██████████| 15/15 [00:12<00:00,  1.19it/s]


In [16]:
finetune_acc = pd.DataFrame.from_dict(genre_acc)

In [19]:
finetune_acc = finetune_acc.set_index('genre').T

In [26]:
comparison = pd.concat([mnli_cnn_acc_df[Genres], finetune_acc])
comparison.index = ['before finetune', 'after finetune']
comparison

genre,telephone,fiction,slate,government,travel
before finetune,46.666667,47.035176,44.311377,45.669291,46.94501
after finetune,54.527363,54.773869,47.105788,56.003937,53.05499
