# Model demo



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from preprocess import *
from spacy.lang.en import English
import os
import numpy as np
from utils import save_file, load_file, load_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


The following four models can be tested by changing MODEL_FLAG:

- 'bow': for the bag of words model, where glove embeddings are averaged to obtain sentence representations.
- 'lstm': for the LSTM model where the final hidden state is used as sentence representation
- 'bilstm': for the BiLSTM model where the concatenation of the final hidden states of the forward and backward LSTM is used as sentence representation.
- 'bilstm_max': for the BiLSTM model with max pooling over the concatenated hidden states (of forward and backward LSTM) of the sequence. 

In [2]:
MODEL_FLAG = 'lstm'
state_file_path = f'weights/{MODEL_FLAG}/{MODEL_FLAG}_best.pth'

labels = ['neutral', 'entailment', 'contradiction']
vocab_file = 'NLI_vocab.pkl'
folder = 'saved_files'
if os.path.exists(f'{folder}/{vocab_file}'):
    print('loading vocab from file')
    vocab = load_file(vocab_file)
else:
    print('creating vocab with training set')
    train_split = preprocess(split='train')
    vocab = create_vocab(train_split)
    save_file(vocab, vocab_file)

vocab_size = len(vocab.mapping)
embeddings = align_vocab_with_glove(vocab)
model = load_model(embeddings, labels, vocab_size, device, MODEL_FLAG, state_file_path)

loading vocab from file
saved_files/NLI_vocab.pkl
loaded embeddings from file
loading lstm


In [5]:
def transform_sentence(sent, vocab, tokenize=True):
    if tokenize:
        nlp = English()
        tokenizer = nlp.tokenizer
        sent = list(tokenizer(sent.lower()))
    sent_ids = []
    for token in sent:
        sent_ids.append(vocab.mapping.get(token.text, 0))
    return sent_ids, [len(sent_ids)]

transform_sentence('This function maps tokens to ids and returns length', vocab)

([29982, 12010, 17706, 30303, 30238, 14706, 1325, 24363, 16756], [9])

In [6]:
def make_prediction(sent1, sent2, vocab, model, model_flag, printing=True, tokenize=True):
    if printing:
        print(f'premise: {sent1}')
        print(f'hypothesis: {sent2}')
        print("\n")
    sent_ids1, length1 = transform_sentence(sent1, vocab, tokenize)
    sent_ids2, length2 = transform_sentence(sent2, vocab, tokenize)
    sent1 = torch.tensor([sent_ids1])
    sent2 = torch.tensor([sent_ids2])
    if model_flag == 'bow':
        logits = model(sent1, sent2)
    elif model_flag == 'lstm' or model_flag == 'bilstm' or model_flag == 'bilstm_max':
        logits = model(sent1, length1, sent2, length2)

    probabilities = F.softmax(logits, dim=1)
    predicted_label = torch.argmax(probabilities, dim=1)
    labels = ['neutral', 'entailment', 'contradiction']
    # output tensor to readable probabilities
    prob_list = [round(prob, 3) for prob in probabilities.tolist()[0]]
    if printing: 
        for label, prob in zip(labels, prob_list):
            print(f'predicted {label} with {prob} probability')
        print('\n')
        print(f'Therefore, predicted relation: {labels[predicted_label]}')
    return labels[predicted_label], predicted_label, length1, length2

relation, numeric, _, _ = make_prediction('bob loves you', 'he is outside', vocab, model, MODEL_FLAG)

premise: bob loves you
hypothesis: he is outside


predicted neutral with 0.867 probability
predicted entailment with 0.115 probability
predicted contradiction with 0.017 probability


Therefore, predicted relation: neutral


In [7]:
import statistics
def sent_length_performance(vocab, model, MODEL_FLAG, printing=True):
  test_split = preprocess(split='test')
  lengths = []
  short_correct, short_total, medium_correct, medium_total, long_correct, long_total = 0, 0, 0, 0, 0, 0
  for example in test_split:
    
    prediction, _, length1, length2 = make_prediction(example['sentence_1'], example['sentence_2'], vocab, model, MODEL_FLAG, printing=False, tokenize=False)
    combined_length = length1[0] + length2[0]
    lengths.append(combined_length)
    if combined_length <= 16:
      short_correct += int(prediction == example['gold_label'])
      short_total += 1
    elif combined_length >= 31:
       long_correct += int(prediction == example['gold_label'])
       long_total += 1
    else:
      medium_correct += int(prediction == example['gold_label'])
      medium_total += 1
  mean = statistics.mean(lengths)
  median = statistics.median(lengths)
  stdev = statistics.stdev(lengths)
  min_value = min(lengths)
  max_value = max(lengths)
  if printing:
    print('Combined sentence length stats of test set')
    print("mean:", round(mean, 2))
    print("standard deviation", round(stdev, 2))
    print('minimal length', min_value)
    print('maximum length', max_value)
    print(f'{MODEL_FLAG} accuracy on short length input: {round(short_correct / float(short_total) * 100, 2)} %')
    print(f'{MODEL_FLAG} accuracy on medium length input: {round(medium_correct / float(medium_total) * 100, 2)} %')
    print(f'{MODEL_FLAG} accuracy on long length input: {round(long_correct / float(long_total) * 100, 2)} %')

  return short_correct, short_total, short_correct / float(short_total), medium_correct, medium_total, medium_correct / float(medium_total), long_correct, long_total, long_correct / float(long_total),
  
sent_length_performance(vocab, model, MODEL_FLAG)

done reading test json
Combined sentence length stats of test set
mean: 23.64
standard deviation 7.83
minimal length 6
maximum length 73
lstm accuracy on short length input: 82.22 %
lstm accuracy on medium length input: 80.83 %
lstm accuracy on long length input: 76.06 %


(1378,
 1676,
 0.8221957040572793,
 5161,
 6385,
 0.8083007047768207,
 1341,
 1763,
 0.7606352807714124)