In [34]:
import torch
import os
import numpy as np

import data_utils

from model import BaselineNet

### Preprocessing

In [35]:
# Load GloVe embedding (filtered on SNLI dataset words)
embedding_path = os.path.join('data','glove','glove.filtered.300d.txt')
glove_emb = data_utils.EmbeddingGlove(embedding_path)

# Build vocabulary
vocab = data_utils.Vocabulary()
vocab.count_glove(glove_emb)
vocab.build()

# Initialise dummy dataloader
dataloader = data_utils.DataLoaderSnli([], vocab)

25000 words loaded (0 invalid format)
31009 words loaded (0 invalid format)


In [36]:
# Number of words in vocabulary
print(len(vocab), "words in the vocabulary")
# Most common words in SNLI
print("Most common:",vocab.i2w[:10])

# Embeddings
print(dataloader.prepare_sentences(\
    ["This is a sentence",
    "The cat did not know what 'miouuuuw' meant.",
    "Surprisingly, there was nothing to surprise us"]))

31011 words in the vocabulary
Most common: ['<unk>', '<pad>', 'the', 'and', 'to', 'of', 'a', 'in', 'is', 'for']
tensor([[  19,    8,    6, 2948,    1,    1,    1,    1],
        [   2, 1790,   94,   26,   82,   55,    0, 1515],
        [4866,   58,   22,  403,    4, 2021,  103,    1]])


### Baseline Network

In [37]:
# Initialise network with word embeddings
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
net = BaselineNet(glove_emb.embedding).to(device)
# Load checkpoint
checkpoint_path = os.path.join('output','baseline','experiment_23075505','checkpoints','model_iter_102000.pt')
state_dict = torch.load(checkpoint_path, map_location=device)
net.load_state_dict(state_dict)

In [47]:
premise_sent    = "It was a sunny day, and Silvan decided to go outside to play with his friends"

# Predict
CLASS_TEXT = ['neutral', 'contradiction', 'entailment']
for hypothesis_sent in [ \
       "Silvan has a deadline",
       "Silvan made the decision to play outside",
       "Silvan decided to go outside"]:
    prem, hyp, _ = dataloader.prepare_manual(premise_sent, hypothesis_sent)
    prediction = net.forward(prem, hyp)[0]
    print(CLASS_TEXT[prediction.detach().numpy().argmax()])

contradiction
neutral
entailment
