# Load a checkpoint
##### Model choices are uni-lstm, bi-lstm, bi-max-lstm, mean

In [25]:
import argparse
from models import SNLIModel
from train import parse_args
import torch

In [26]:
checkpoint_path = "modelsaves/bi_max_lstm_model.pth"
modeltype = 'bi-max-lstm' # Make sure this matches the checkpoint you are loading in!
# choices are uni-lstm, bi-lstm, bi-max-lstm, mean

In [27]:
params = parse_args() # The default parameters (do not worry, this is just for initialization, it wont matter since we are evaluating only not training)
params.checkpoint_path = checkpoint_path
params.encoder_model = modeltype

In [28]:
checkpoint_info = torch.load(checkpoint_path)
print(checkpoint_info.keys())
print(f"Model: {modeltype}")
print(f"Epoch: {checkpoint_info['epoch']}")
print(f"Dev accuracy; {checkpoint_info['dev_accuracy'].item()}")

dict_keys(['model_state_dict', 'optimizer_state_dict', 'epoch', 'dev_accuracy'])
Model: bi-max-lstm
Epoch: 5
Dev accuracy; 0.8474903702735901


In [5]:
checkpoint_model = SNLIModel(params)

Setting up all imports and downloads


[nltk_data] Downloading package punkt to /home/david/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/david/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Preprocessing the data
Setting up the classifier (and the encoder within)
Setting up the optimizer and loss function
Checkpoint loaded from modelsaves/mean_model.pth


# Example on how to predict using the checkpoint model

In [6]:
premise = "Two men sitting in the sun"
hypothesis = "Nobody is sitting in the shade"
checkpoint_model.predict([premise], [hypothesis])

(tensor([0], device='cuda:0'), ['entailment'])

In [7]:
premise = "A man is walking a dog"
hypothesis = "No cat is outside"
checkpoint_model.predict([premise], [hypothesis])

(tensor([2], device='cuda:0'), ['contradiction'])

# Example on evaluating a dataset and obtaining an accuracy

In [8]:
dev_accuracy = checkpoint_model.evaluate_accuracy(checkpoint_model.dev_data).item()
test_accuracy = checkpoint_model.evaluate_accuracy(checkpoint_model.test_data).item()

In [9]:
print(f"The dev accuracy is {round(dev_accuracy, 5)} and the test accuracy is {round(test_accuracy, 5)}")

The dev accuracy is 0.64936 and the test accuracy is 0.64912
