In [1]:
from functools import partial

import torch
import pandas as pd
from data import Vocabulary, BobSueDataset
from models import FullVocabularyModel
from learner import LanguageModelLearner

torch.manual_seed(41)

<torch._C.Generator at 0x7f0f20798168>

## Hyperparameters

In [2]:
BATCH_SIZE = 64
EMBEDDING_DIM = 200
HIDDEN_SIZE = 200
DROPOUT = 0.2
EPOCHS = 15

## Load data

In [3]:
FILENAME = 'bobsue.prevsent.{}.tsv'

vocab = Vocabulary()

train_set = BobSueDataset(FILENAME.format('train'), vocab)
valid_set = BobSueDataset(FILENAME.format('dev'), vocab)
test_set = BobSueDataset(FILENAME.format('test'), vocab)

get_learner = partial(
    LanguageModelLearner,
    train_set=train_set,
    valid_set=valid_set,
    test_set=test_set,
    batch_size=BATCH_SIZE,
)

## Utility functions

In [4]:
def show_mistakes(mismatches, top=35):
    df = pd.DataFrame([item[0] for item in mismatches.most_common(top)], columns=['prediction', 'ground truth'])
    return df.applymap(vocab.itos.__getitem__)

# Log loss training w/o context

## Load model and learner

In [5]:
wo_context = FullVocabularyModel(
    vocab_size=len(vocab),
    embedding_dim=EMBEDDING_DIM,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT,
    read_context=False
)
learner = get_learner(model=wo_context)

## Train model

In [6]:
wo_context_filename = 'wo_context.pt'
wo_context_mistakes = learner.train(epochs=EPOCHS, filename=wo_context_filename, return_mismatches=True)

Epoch 01: 100%|██████████| 95/95 [00:01<00:00, 85.24it/s, Loss=4.83, Acc=0.258] 


	Train Loss: 5.647	Train Acc: 13.15%
	Valid Loss: 4.777	Valid Acc: 22.84%
	Model parameters saved to wo_context.pt


Epoch 02: 100%|██████████| 95/95 [00:01<00:00, 78.10it/s, Loss=4.44, Acc=0.238]


	Train Loss: 4.538	Train Acc: 23.09%
	Valid Loss: 4.333	Valid Acc: 24.15%
	Model parameters saved to wo_context.pt


Epoch 03: 100%|██████████| 95/95 [00:01<00:00, 80.76it/s, Loss=4.18, Acc=0.254]


	Train Loss: 4.231	Train Acc: 24.66%
	Valid Loss: 4.108	Valid Acc: 26.28%
	Model parameters saved to wo_context.pt


Epoch 04: 100%|██████████| 95/95 [00:01<00:00, 89.39it/s, Loss=3.9, Acc=0.267] 


	Train Loss: 4.031	Train Acc: 26.43%
	Valid Loss: 3.943	Valid Acc: 27.69%
	Model parameters saved to wo_context.pt


Epoch 05: 100%|██████████| 95/95 [00:01<00:00, 81.66it/s, Loss=3.82, Acc=0.284]


	Train Loss: 3.866	Train Acc: 28.02%
	Valid Loss: 3.821	Valid Acc: 28.69%
	Model parameters saved to wo_context.pt


Epoch 06: 100%|██████████| 95/95 [00:01<00:00, 78.24it/s, Loss=3.39, Acc=0.36] 


	Train Loss: 3.737	Train Acc: 29.08%
	Valid Loss: 3.727	Valid Acc: 29.77%
	Model parameters saved to wo_context.pt


Epoch 07: 100%|██████████| 95/95 [00:01<00:00, 87.60it/s, Loss=3.56, Acc=0.333]


	Train Loss: 3.640	Train Acc: 29.94%
	Valid Loss: 3.669	Valid Acc: 30.11%
	Model parameters saved to wo_context.pt


Epoch 08: 100%|██████████| 95/95 [00:01<00:00, 82.60it/s, Loss=3.58, Acc=0.293]


	Train Loss: 3.563	Train Acc: 30.37%
	Valid Loss: 3.622	Valid Acc: 30.73%
	Model parameters saved to wo_context.pt


Epoch 09: 100%|██████████| 95/95 [00:01<00:00, 84.80it/s, Loss=3.46, Acc=0.307]


	Train Loss: 3.496	Train Acc: 31.03%
	Valid Loss: 3.590	Valid Acc: 31.07%
	Model parameters saved to wo_context.pt


Epoch 10: 100%|██████████| 95/95 [00:01<00:00, 87.18it/s, Loss=3.56, Acc=0.313]


	Train Loss: 3.439	Train Acc: 31.36%
	Valid Loss: 3.561	Valid Acc: 31.39%
	Model parameters saved to wo_context.pt


Epoch 11: 100%|██████████| 95/95 [00:01<00:00, 89.96it/s, Loss=3.25, Acc=0.33] 


	Train Loss: 3.387	Train Acc: 32.01%
	Valid Loss: 3.532	Valid Acc: 31.66%
	Model parameters saved to wo_context.pt


Epoch 12: 100%|██████████| 95/95 [00:01<00:00, 85.09it/s, Loss=3.27, Acc=0.33] 


	Train Loss: 3.339	Train Acc: 32.41%
	Valid Loss: 3.512	Valid Acc: 31.93%
	Model parameters saved to wo_context.pt


Epoch 13: 100%|██████████| 95/95 [00:01<00:00, 84.66it/s, Loss=3.13, Acc=0.361]


	Train Loss: 3.294	Train Acc: 32.76%
	Valid Loss: 3.502	Valid Acc: 31.76%
	Model parameters saved to wo_context.pt


Epoch 14: 100%|██████████| 95/95 [00:01<00:00, 82.79it/s, Loss=2.99, Acc=0.36] 


	Train Loss: 3.251	Train Acc: 33.10%
	Valid Loss: 3.483	Valid Acc: 32.17%
	Model parameters saved to wo_context.pt


Epoch 15: 100%|██████████| 95/95 [00:01<00:00, 83.86it/s, Loss=3.36, Acc=0.298]


	Train Loss: 3.213	Train Acc: 33.54%
	Valid Loss: 3.476	Valid Acc: 32.40%
	Model parameters saved to wo_context.pt


## Show top mistakes

In [7]:
show_mistakes(wo_context_mistakes)

Unnamed: 0,prediction,ground truth
0,Bob,He
1,Bob,Sue
2,Bob,She
3,the,his
4,.,and
5,was,had
6,he,she
7,.,to
8,to,.
9,was,decided


## Evaluate model on test set

In [8]:
learner.load_model_params(wo_context_filename)
learner.print_test_results()

	 Test Loss: 3.499	 Test Acc: 31.98%


# Log loss training w/ context

## Load model and learner

In [9]:
w_context = FullVocabularyModel(
    vocab_size=len(vocab),
    embedding_dim=EMBEDDING_DIM,
    hidden_size=HIDDEN_SIZE,
    dropout=DROPOUT,
    read_context=True
)
learner = get_learner(model=w_context)

## Train model

In [10]:
w_context_filename = 'w_context.pt'
w_context_mistakes = learner.train(epochs=EPOCHS, filename=w_context_filename, return_mismatches=True)

Epoch 01: 100%|██████████| 95/95 [00:01<00:00, 68.96it/s, Loss=5.26, Acc=0.17]  


	Train Loss: 5.620	Train Acc: 13.04%
	Valid Loss: 5.136	Valid Acc: 18.39%
	Model parameters saved to w_context.pt


Epoch 02: 100%|██████████| 95/95 [00:01<00:00, 60.75it/s, Loss=4.36, Acc=0.212]


	Train Loss: 4.788	Train Acc: 20.83%
	Valid Loss: 4.412	Valid Acc: 24.14%
	Model parameters saved to w_context.pt


Epoch 03: 100%|██████████| 95/95 [00:01<00:00, 70.00it/s, Loss=3.99, Acc=0.299]


	Train Loss: 4.262	Train Acc: 25.57%
	Valid Loss: 4.086	Valid Acc: 27.25%
	Model parameters saved to w_context.pt


Epoch 04: 100%|██████████| 95/95 [00:01<00:00, 72.15it/s, Loss=4.03, Acc=0.296]


	Train Loss: 3.979	Train Acc: 27.83%
	Valid Loss: 3.881	Valid Acc: 28.93%
	Model parameters saved to w_context.pt


Epoch 05: 100%|██████████| 95/95 [00:01<00:00, 66.78it/s, Loss=3.72, Acc=0.276]


	Train Loss: 3.797	Train Acc: 29.17%
	Valid Loss: 3.770	Valid Acc: 29.96%
	Model parameters saved to w_context.pt


Epoch 06: 100%|██████████| 95/95 [00:01<00:00, 70.47it/s, Loss=3.61, Acc=0.324]


	Train Loss: 3.680	Train Acc: 30.14%
	Valid Loss: 3.694	Valid Acc: 30.50%
	Model parameters saved to w_context.pt


Epoch 07: 100%|██████████| 95/95 [00:01<00:00, 71.02it/s, Loss=3.64, Acc=0.314]


	Train Loss: 3.593	Train Acc: 30.65%
	Valid Loss: 3.631	Valid Acc: 30.99%
	Model parameters saved to w_context.pt


Epoch 08: 100%|██████████| 95/95 [00:01<00:00, 71.67it/s, Loss=3.58, Acc=0.3]  


	Train Loss: 3.523	Train Acc: 31.23%
	Valid Loss: 3.586	Valid Acc: 31.27%
	Model parameters saved to w_context.pt


Epoch 09: 100%|██████████| 95/95 [00:01<00:00, 65.71it/s, Loss=3.7, Acc=0.282] 


	Train Loss: 3.459	Train Acc: 31.95%
	Valid Loss: 3.547	Valid Acc: 32.15%
	Model parameters saved to w_context.pt


Epoch 10: 100%|██████████| 95/95 [00:01<00:00, 62.48it/s, Loss=3.49, Acc=0.335]


	Train Loss: 3.399	Train Acc: 32.82%
	Valid Loss: 3.511	Valid Acc: 32.80%
	Model parameters saved to w_context.pt


Epoch 11: 100%|██████████| 95/95 [00:01<00:00, 69.05it/s, Loss=3.41, Acc=0.329]


	Train Loss: 3.344	Train Acc: 33.31%
	Valid Loss: 3.481	Valid Acc: 33.20%
	Model parameters saved to w_context.pt


Epoch 12: 100%|██████████| 95/95 [00:01<00:00, 65.94it/s, Loss=3.26, Acc=0.344]


	Train Loss: 3.294	Train Acc: 33.91%
	Valid Loss: 3.459	Valid Acc: 33.62%
	Model parameters saved to w_context.pt


Epoch 13: 100%|██████████| 95/95 [00:01<00:00, 62.21it/s, Loss=3.32, Acc=0.359]


	Train Loss: 3.248	Train Acc: 34.29%
	Valid Loss: 3.452	Valid Acc: 33.32%
	Model parameters saved to w_context.pt


Epoch 14: 100%|██████████| 95/95 [00:01<00:00, 67.70it/s, Loss=3.2, Acc=0.342] 


	Train Loss: 3.204	Train Acc: 34.75%
	Valid Loss: 3.432	Valid Acc: 33.37%
	Model parameters saved to w_context.pt


Epoch 15: 100%|██████████| 95/95 [00:01<00:00, 66.46it/s, Loss=3.18, Acc=0.339]


	Train Loss: 3.158	Train Acc: 35.75%
	Valid Loss: 3.400	Valid Acc: 35.19%
	Model parameters saved to w_context.pt


## Show top mistakes

In [11]:
show_mistakes(w_context_mistakes)

Unnamed: 0,prediction,ground truth
0,the,his
1,Bob,Sue
2,was,had
3,.,and
4,He,Bob
5,.,to
6,the,a
7,Sue,Bob
8,was,decided
9,the,her


## Evaluate model on test set

In [12]:
learner.load_model_params(w_context_filename)
learner.print_test_results()

	 Test Loss: 3.434	 Test Acc: 34.79%
