In [1]:
import random
import pickle as pkl

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import datetime as dt

import matplotlib.pyplot as plt
%matplotlib inline

random.seed(134)

In [2]:
from funcs import readData, loadEmbeddings, hwDataset, hwCollateFn, RNNEncoder, CNNEncoder, testModel

train_data, val_data, char2id, id2char, MAX_X1, MAX_X2, label2id, id2label = readData()
weights_tensor = loadEmbeddings(char2id)

BATCH_SIZE = 200

train_dataset = hwDataset(train_data, char2id, label2id)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=hwCollateFn,
                                           shuffle=True)
val_dataset = hwDataset(val_data, char2id, label2id)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=hwCollateFn,
                                           shuffle=True)

In [18]:
for data in val_data:
    if ' '.join(data[0]) == "Three people and a white dog are sitting in the sand on a beach .":
        print(data)
        break

[['Three', 'people', 'and', 'a', 'white', 'dog', 'are', 'sitting', 'in', 'the', 'sand', 'on', 'a', 'beach', '.'], ['Three', 'dogs', 'and', 'a', 'person', 'are', 'sitting', 'in', 'the', 'snow', '.'], 'contradiction']


In [22]:
for data in val_data:
    if ' '.join(data[0]) == "A young woman seated at a table on what appears to be a backyard deck holds a toddler , giving him a toy or bottle of some sort , while smiling into the camera .":
        print(data)
        break

[['A', 'young', 'woman', 'seated', 'at', 'a', 'table', 'on', 'what', 'appears', 'to', 'be', 'a', 'backyard', 'deck', 'holds', 'a', 'toddler', ',', 'giving', 'him', 'a', 'toy', 'or', 'bottle', 'of', 'some', 'sort', ',', 'while', 'smiling', 'into', 'the', 'camera', '.'], ['The', 'woman', 'is', 'changing', 'the', 'boy', "'s", 'diaper', '.'], 'contradiction']


In [4]:
NUM_EPOCHS = 10

DATA = {'weights_tensor':weights_tensor,
        'train_loader':train_loader,
        'val_loader':val_loader}

CNN_PARAMS = {'num_epochs':NUM_EPOCHS,
          'hidden_size':250,
          'weight_decay':0.0001,
          'vocab_size':len(id2char),
          'kernel_size':3}

RNN_PARAMS = {'num_epochs':NUM_EPOCHS,
          'hidden_size':250,
          'weight_decay':0,
          'vocab_size':len(id2char),
          'kernel_size':3}


rnn_model = RNNEncoder(DATA, RNN_PARAMS, num_layers=1, num_classes=3)
rnn_model.load_state_dict(torch.load('rnn.pt'))
cnn_model = CNNEncoder(DATA, CNN_PARAMS, num_layers=2, num_classes=3)
cnn_model.load_state_dict(torch.load('cnn.pt'))

In [5]:
miss = []
hit = []

for j, (x1, x2, labels) in enumerate(val_loader):
    if j >= 2:
        break
    outputs = F.softmax(cnn_model(x1, x2), dim=1)
    predicted = outputs.max(1, keepdim=True)[1]
    for c, b in enumerate(zip(labels, predicted)):
        n, m = b
        if n != m:
            if len(miss) < 3:
                miss.append((' '.join([id2char[x] for x in x1[c].tolist() if x > 1]), 
                             ' '.join([id2char[x] for x in x2[c].tolist() if x > 1]),
                             int(n), int(m)))
        else:
            if len(hit) < 3:
                hit.append((' '.join([id2char[x] for x in x1[c].tolist() if x > 1]), 
                             ' '.join([id2char[x] for x in x2[c].tolist() if x > 1]),
                             int(n), int(m)))

In [9]:
miss

[('Three people and a white dog are sitting in the sand on a beach .',
  'Three dogs and a person are sitting in the snow .',
  2,
  0),
 ('A young woman seated at a table on what appears to be a backyard deck holds a toddler , giving him a toy or bottle of some sort , while smiling into the camera .',
  "The woman is changing the boy 's diaper .",
  2,
  1),
 ('A husky and a black cat nuzzling .', 'A dog and cat are friendly .', 0, 2)]

In [10]:
hit

[('A soccer player wearing white shorts and an orange and green shirt holds the ball while being guarded by another soccer player in a blue uniform .',
  'A football player throws a touchdown pass .',
  2,
  2),
 ('Old woman chasing away two lambs with a broom .',
  'A woman is chasing two turtles with a mop .',
  2,
  2),
 ('A line of people waiting outside The Magpie cafe during the day .',
  'A man makes a sandwich .',
  2,
  2)]

In [23]:
mnli_val = pkl.load(open("hw2_data/mnli_val.p", "rb"))

In [24]:
genre_rnn = []
genre_cnn = []

for genre in ['fiction', 'telephone', 'slate', 'government', 'travel']:
    genre_data = [[v[0], v[1], v[2]] for v in mnli_val if v[3] == genre]
    genre_dataset = hwDataset(genre_data, char2id, label2id)
    genre_loader = torch.utils.data.DataLoader(dataset=genre_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=hwCollateFn,
                                               shuffle=True)
    
    genre_rnn.append(testModel(genre_loader, rnn_model))
    genre_cnn.append(testModel(genre_loader, cnn_model))

In [25]:
genre_rnn

[31.758793969849247,
 33.53233830845771,
 29.74051896207585,
 34.84251968503937,
 33.70672097759674]

In [30]:
with open("rnnGenre.p", "rb") as f:
    rnn_results = pkl.load(f)

[(x[0], max(x[1]))  for x in rnn_results]

[('fiction', 45.59436913451512),
 ('telephone', 43.98126463700234),
 ('slate', 41.97714853452558),
 ('government', 41.437033221735774),
 ('travel', 41.68130489335006)]

In [31]:
with open("cnnGenre.p", "rb") as f:
    cnn_results = pkl.load(f)

[(x[0], max(x[1]))  for x in cnn_results]

[('fiction', 53.206465067778936),
 ('telephone', 53.25526932084309),
 ('slate', 50.81967213114754),
 ('government', 53.56682977079578),
 ('travel', 53.14930991217064)]

In [None]:
from IPython.display import Image
Image(filename="figures/cnn_wd_1e4.png", width=200, height=200)