In [11]:
import os
from torchtext import data
import torch
from torchtext.datasets import SNLI
from torchtext.vocab import GloVe

In [12]:
# build dataset and word embedding
glove = GloVe(name='840B', dim=300, cache="./dataset/.vector_cache")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#set up fields
text_field = data.Field(tokenize='spacy',tokenizer_language="en_core_web_sm",
                        lower=True,include_lengths=True,batch_first=True)
label_field = data.Field(sequential=False)

train, val, test = SNLI.splits(text_field, label_field, root="./dataset/.data")

# build vocab
text_field.build_vocab(train, vectors=glove)
label_field.build_vocab(train)
vocabulary_label = label_field.vocab.itos



In [13]:
#load model
Baseline_model_path = os.path.join('./output/Baseline', "models", 'best_checkpoint.pkl')
NLINet_Baseline_model = torch.load(Baseline_model_path)
NLINet_Baseline_model.eval()

UniLSTM_model_path = os.path.join('./output/UniLSTM', "models", 'best_checkpoint.pkl')
NLINet_UniLSTM_model = torch.load(UniLSTM_model_path)
NLINet_UniLSTM_model.eval()

SimpleBiLSTM_model_path = os.path.join('./output/SimpleBiLSTM', "models", 'best_checkpoint.pkl')
NLINet_SimpleBiLSTM_model = torch.load(SimpleBiLSTM_model_path)
NLINet_SimpleBiLSTM_model.eval()

BiLSTM_model_path = os.path.join('./output/BiLSTM', "models", 'best_checkpoint.pkl')
NLINet_BiLSTM_model = torch.load(BiLSTM_model_path)
NLINet_BiLSTM_model.eval()


NLINet(
  (embedding): Embedding(33635, 300)
  (encoder_model): BiLSTM(
    (lstm): LSTM(300, 2048, batch_first=True, bidirectional=True)
  )
  (classifier): Sequential(
    (0): Linear(in_features=16384, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=3, bias=True)
  )
)

In [14]:
def relation(hypothesis, premise, NLINet_model):
    hypothesis = hypothesis.split(" ")
    premise = premise.split(" ")
    if hypothesis[-1][-1:] == ".":
        hypothesis[-1] = hypothesis[-1][0:-1]
        hypothesis.append(".")
    if premise[-1][-1:] == ".":
        premise[-1] = premise[-1][0:-1]
        premise.append(".")

    hypothesis = [hypothesis]
    premise = [premise]
    hypothesis = text_field.process(hypothesis,device=device)
    premise = text_field.process(premise,device=device)
    preds = NLINet_model(hypothesis, premise)
    preds_argmax = torch.argmax(preds,dim=1)
    preds_label = vocabulary_label[preds_argmax+1]
    print(preds_label)

In [15]:
hypothesis1 = "a woman is making music."
premise1 = "a pregnant lady singing on stage while holding a flag behind her."

hypothesis2 = "the boy is wearing safety equipment."
premise2 = "a boy is jumping on skateboard in the middle of a red bridge ."

hypothesis3 = "a skier is away from the rail."
premise3 = "a skier slides along a metal rail."

In [16]:
relation(hypothesis1, premise1, NLINet_Baseline_model)
relation(hypothesis1, premise1, NLINet_UniLSTM_model)
relation(hypothesis1, premise1, NLINet_SimpleBiLSTM_model)
relation(hypothesis1, premise1, NLINet_BiLSTM_model)

entailment
entailment
entailment
entailment


In [17]:
relation(hypothesis2, premise2, NLINet_Baseline_model)
relation(hypothesis2, premise2, NLINet_UniLSTM_model)
relation(hypothesis2, premise2, NLINet_SimpleBiLSTM_model)
relation(hypothesis2, premise2, NLINet_BiLSTM_model)

entailment
entailment
entailment
neutral


In [18]:
relation(hypothesis3, premise3, NLINet_Baseline_model)
relation(hypothesis3, premise3, NLINet_UniLSTM_model)
relation(hypothesis3, premise3, NLINet_SimpleBiLSTM_model)
relation(hypothesis3, premise3, NLINet_BiLSTM_model)


neutral
contradiction
contradiction
contradiction
