# MultiNLI Dataset

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import string
import spacy
from nltk.corpus import stopwords
from tqdm import tqdm
import ipywidgets as widgets

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from collections import Counter
import pickle as pkl
import random
import pdb
from gensim.models import FastText

import os, re, csv, math, codecs
import io

from Encoder import CNN_Encoder, RNN_Encoder, RNN_Encoder_element_wise, CNN_Encoder_element_wise
from train_test import test_model, train_model
from NLI_DataLoader import NLI_Dataset, NLI_collate_func


In [2]:
random.seed(134)

PAD_IDX = 0
UNK_IDX = 1
BATCH_SIZE = 32
data_dir = './hw2_data/'

MAX_SENTENCE1_LENGTH = 32
MAX_SENTENCE2_LENGTH = 18

### Tokenize dataset

In [3]:
tokenizer = spacy.load('en_core_web_sm')
punctuations = string.punctuation

def tokenize(sent):
    tokens = tokenizer(sent)
    return [token.text.lower() for token in tokens if (token.text not in punctuations)]

def tokenize_dataset(dataset):
    token_dataset = []
    all_tokens = []
    
    for sample in dataset:
        tokens = tokenize(sample)
        token_dataset.append(tokens)
        all_tokens += tokens

    return token_dataset, all_tokens

In [4]:
def token2index_dataset(tokens_data, token2id):
    indices_data = []
    for tokens in tokens_data:
        index_list = [token2id[token] if token in token2id else UNK_IDX for token in tokens]
        indices_data.append(index_list)
    return indices_data

In [5]:
mnli_train = pd.read_csv(data_dir+'mnli_train.tsv', sep="\t", index_col = False)
mnli_val = pd.read_csv(data_dir+'mnli_val.tsv', sep="\t", index_col=False)
mnli_genre = mnli_train.genre.unique().tolist()
print('there are %d genres in MNLI:'%len(mnli_genre), mnli_genre)
target_dic = {'neutral':1,'entailment':2, 'contradiction':0 }

there are 5 genres in MNLI: ['telephone', 'fiction', 'slate', 'government', 'travel']


In [10]:
mnli_train_tokens = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
mnli_val_tokens = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
for genre in mnli_genre:
    mnli_train_tokens[genre]['sent1'], _ = tokenize_dataset(mnli_train[mnli_train.genre == genre].sentence1)
    mnli_train_tokens[genre]['sent2'], _ = tokenize_dataset(mnli_train[mnli_train.genre == genre].sentence2)
    mnli_train_tokens[genre]['label'] = [target_dic[j] for j in mnli_train[mnli_train.genre == genre].label]
    mnli_val_tokens[genre]['sent1'], _ = tokenize_dataset(mnli_val[mnli_val.genre == genre].sentence1)
    mnli_val_tokens[genre]['sent2'], _ = tokenize_dataset(mnli_val[mnli_val.genre == genre].sentence2)
    mnli_val_tokens[genre]['label'] = [target_dic[j] for j in mnli_val[mnli_val.genre == genre].label]
    
pkl.dump(mnli_train_tokens, open(data_dir+'mnli_train_tokens.p', 'wb'))
pkl.dump(mnli_val_tokens, open(data_dir+'mnli_val_tokens.p', 'wb'))

In [6]:
mnli_train_tokens = pkl.load(open(data_dir+'mnli_train_tokens.p', 'rb'))
mnli_val_tokens = pkl.load(open(data_dir+'mnli_val_tokens.p', 'rb'))

In [7]:
embeddings = pkl.load(open(data_dir+'embeddings.p', 'rb'))

def build_vocab(embeddings, tokens_data,  max_vocab_size = len(embeddings.keys())):

    #if max_vocab_size:
    all_tokens = [i for tokens in tokens_data for i in tokens]
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(max_vocab_size))
    vocab = [i for i in vocab if i in embeddings.keys() ]
    vocab_len = len(vocab)
    print('length of vocabulary:',vocab_len) 

    id2token = list(vocab)                     
    token2id = dict(zip(vocab, range(2,2+len(vocab)))) 
    id2token = ['<pad>', '<unk>'] + id2token
    token2id['<pad>'] = PAD_IDX 
    token2id['<unk>'] = UNK_IDX
    id2vector = np.zeros((len(id2token), 300))
    for i, word in enumerate(id2token):
        try:
            id2vector[i] = embeddings[word]
        except KeyError:
            id2vector[i] = np.random.normal(scale = 0.1, size = (300,))

    return id2token, token2id, id2vector

train_sent1_tokens = pkl.load(open(data_dir+'train_sent1_tokens.p', 'rb'))
train_sent2_tokens = pkl.load(open(data_dir+'train_sent2_tokens.p', 'rb'))
# vocalbulary generated by SNLI Data
id2token, token2id, id2vector = build_vocab(embeddings, train_sent1_tokens+train_sent2_tokens)
vocab_size = len(id2token)

length of vocabulary: 18105


In [72]:
mnli_train_indices = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
mnli_val_indices = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
for genre in mnli_genre:
    mnli_train_indices[genre]['sent1'] = token2index_dataset(mnli_train_tokens[genre]['sent1'], token2id)
    mnli_train_indices[genre]['sent2'] = token2index_dataset(mnli_train_tokens[genre]['sent2'], token2id)
    mnli_val_indices[genre]['sent1'] = token2index_dataset(mnli_val_tokens[genre]['sent1'], token2id)
    mnli_val_indices[genre]['sent2'] = token2index_dataset(mnli_val_tokens[genre]['sent2'], token2id)
pkl.dump(mnli_train_indices, open(data_dir+'mnli_train_indices.p', 'wb'))  
pkl.dump(mnli_val_indices, open(data_dir+'mnli_val_indices.p', 'wb'))  


In [8]:
mnli_train_indices = pkl.load(open(data_dir+'mnli_train_indices.p', 'rb'))
mnli_val_indices = pkl.load(open(data_dir+'mnli_val_indices.p', 'rb'))

In [9]:
mnli_train_dataset = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
mnli_val_dataset = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
mnli_train_loader = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}
mnli_val_loader = {'telephone':{}, 'fiction':{},'slate':{}, 'government':{}, 'travel':{}}

BATCH_SIZE = 32
for genre in mnli_genre:
    mnli_train_dataset[genre] = NLI_Dataset(mnli_train_indices[genre]['sent1'], mnli_train_indices[genre]['sent2'], mnli_train_tokens[genre]['label'] )
    mnli_train_loader[genre] = torch.utils.data.DataLoader(dataset=mnli_train_dataset[genre], 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=NLI_collate_func,
                                           shuffle=True)
    mnli_val_dataset[genre] = NLI_Dataset(mnli_val_indices[genre]['sent1'], mnli_val_indices[genre]['sent2'], mnli_val_tokens[genre]['label'] )
    mnli_val_loader[genre] = torch.utils.data.DataLoader(dataset=mnli_val_dataset[genre], 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=NLI_collate_func,
                                           shuffle=True)

### Evaluating on MultiNLI

In [10]:
## Load best model
model_RNN = pkl.load(open('./models/best_RNN.sav', 'rb'))
model_CNN = pkl.load(open('./models/best_CNN.sav', 'rb'))

In [11]:
train_sent1_tokens = pkl.load(open(data_dir+'train_sent1_tokens.p', 'rb'))
train_sent2_tokens = pkl.load(open(data_dir+'train_sent2_tokens.p', 'rb'))
train_target = pkl.load(open(data_dir+'train_target.p', 'rb'))

val_sent1_tokens = pkl.load(open(data_dir+'val_sent1_tokens.p', 'rb'))
val_sent2_tokens = pkl.load(open(data_dir+'val_sent2_tokens.p', 'rb'))
val_target = pkl.load(open(data_dir+'val_target.p', 'rb'))

train_sent1_indices = token2index_dataset(train_sent1_tokens, token2id)
train_sent2_indices = token2index_dataset(train_sent2_tokens, token2id)
val_sent1_indices = token2index_dataset(val_sent1_tokens, token2id)
val_sent2_indices  = token2index_dataset(val_sent2_tokens, token2id)

BATCH_SIZE = 32
train_dataset = NLI_Dataset(train_sent1_indices, train_sent2_indices, train_target )
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=NLI_collate_func,
                                           shuffle=True)

val_dataset = NLI_Dataset(val_sent1_indices, val_sent2_indices,val_target)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=NLI_collate_func,
                                           shuffle=True)

In [12]:
from tabulate import tabulate

val_acc_table = {'RNN_val_acc':[], 'CNN_val_acc':[] }

for genre in mnli_genre:
    val_acc_table['RNN_val_acc'].append(test_model(mnli_val_loader[genre], model_RNN)[0])
    val_acc_table['CNN_val_acc'].append(test_model(mnli_val_loader[genre], model_CNN)[0])
    
val_acc_table = pd.DataFrame(val_acc_table, index = mnli_genre)
print (tabulate(val_acc_table, floatfmt=".2f", headers = val_acc_table.columns))
       

              CNN_val_acc    RNN_val_acc
----------  -------------  -------------
telephone           41.00          45.37
fiction             42.71          45.93
slate               42.02          43.51
government          43.60          43.31
travel              43.18          43.28


### Fine-tuning on MultiNLI 

In [13]:
RNN_fine_tune_acc = []
for genre in mnli_genre:
    print('Fine tuning on genre %s:'%genre)
    _, _, _, _ = train_model(model_RNN,  mnli_train_loader[genre], mnli_val_loader[genre], 3e-3, 5, 'Fine_tuning_RNN', True, False, False)    
    print('-'*100)

Fine tuning on genre telephone:
Fine_tuning_RNN
number of trainable parameters:487603
Val Accuracy:52.8%
----------------------------------------------------------------------------------------------------
Fine tuning on genre fiction:
Fine_tuning_RNN
number of trainable parameters:487603
Val Accuracy:49.5%
----------------------------------------------------------------------------------------------------
Fine tuning on genre slate:
Fine_tuning_RNN
number of trainable parameters:487603
Val Accuracy:44.7%
----------------------------------------------------------------------------------------------------
Fine tuning on genre government:
Fine_tuning_RNN
number of trainable parameters:487603
Val Accuracy:53.8%
----------------------------------------------------------------------------------------------------
Fine tuning on genre travel:
Fine_tuning_RNN
number of trainable parameters:487603
Val Accuracy:51.0%
-------------------------------------------------------------------------------

In [21]:
fine_tune_acc = []
for genre in mnli_genre:
    print('Fine tuning on genre %s:'%genre)
    _, _, _, _ = train_model(model_CNN,  mnli_train_loader[genre], mnli_val_loader[genre], 
                                     3e-3, 5, 'Fine-tuning Model', False, True, False)
    print('-'*100)


Fine tuning on genre telephone:
Fine-tuning Model
number of trainable parameters:250603
Val Accuracy:50.3%
----------------------------------------------------------------------------------------------------
Fine tuning on genre fiction:
Fine-tuning Model
number of trainable parameters:250603
Val Accuracy:47.7%
----------------------------------------------------------------------------------------------------
Fine tuning on genre slate:
Fine-tuning Model
number of trainable parameters:250603
Val Accuracy:46.3%
----------------------------------------------------------------------------------------------------
Fine tuning on genre government:
Fine-tuning Model
number of trainable parameters:250603
Val Accuracy:49.7%
----------------------------------------------------------------------------------------------------
Fine tuning on genre travel:
Fine-tuning Model
number of trainable parameters:250603
Val Accuracy:48.3%
---------------------------------------------------------------------

In [15]:
val_acc_table['RNN_val_acc_fine_tune'] = [52.8, 49.5, 44.7, 53.8, 51.0 ]
val_acc_table['CNN_val_acc_fine_tune'] = [50.3, 47.7, 46.3, 49.7, 48.3 ]
print (tabulate(val_acc_table, floatfmt=".2f", headers = val_acc_table.columns))

              CNN_val_acc    RNN_val_acc    RNN_val_acc_fine_tune    CNN_val_acc_fine_tune
----------  -------------  -------------  -----------------------  -----------------------
telephone           41.00          45.37                    52.80                    50.30
fiction             42.71          45.93                    49.50                    47.70
slate               42.02          43.51                    44.70                    46.30
government          43.60          43.31                    53.80                    49.70
travel              43.18          43.28                    51.00                    48.30
