In [108]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from collections import Counter
import pickle as pkl
import random
import pdb
import csv
import matplotlib.pyplot as plt
import sys
EMBED_SIZE = 300
VOCAB_SIZE = 50000
PAD_IDX = 0
# RNN Encoder
class RNN(nn.Module):
    def __init__(self, emb_size, hidden_size, vocab_size, hid_dim, is_concat, is_dropout):

        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX)
        self.embedding.weight.data.copy_(torch.from_numpy(np.array(embedding_dict).copy()))
        self.bi_gru = nn.GRU(emb_size, hidden_size, num_layers=1, batch_first=True,bidirectional=True)
        self.linear2 = nn.Linear(hid_dim, 3)
        self.is_concat = is_concat
        if is_concat:
            self.linear1 = nn.Linear(hidden_size*2, hid_dim)
        else:
            self.linear1 = nn.Linear(hidden_size*1, hid_dim)
        self.is_dropout = is_dropout
        if self.is_dropout == True:
            self.dropout = nn.Dropout(0.5)
    def init_hidden(self, batch_size):
        # Function initializes the activation of recurrent neural net at timestep 0
        # Needs to be in format (num_layers, batch_size, hidden_size)
        hidden = torch.zeros(2, batch_size, self.hidden_size).to(device)

        return hidden
    def encode(self, x):
        # lengths = MAX_SENTENCE_LENGTH - x.eq(0).long().sum(1).squeeze()
        # _, idx_sort = torch.sort(lengths, dim=0, descending=True)
        # _, idx_unsort = torch.sort(idx_sort, dim=0)
        # lengths = lengths[idx_sort]
        # x = x.index_select(0, idx_sort)
        batch_size, seq_len = x.size()
        self.hidden = self.init_hidden(batch_size)
        embed = self.embedding(x)
        m = (x == 1)
        m = m.unsqueeze(2).repeat(1, 1, EMBED_SIZE).type(torch.FloatTensor).to(device)
        embed = m * embed + (1-m) * embed.clone().detach()
        # embed = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths.cpu().numpy(), batch_first=True)
        output, hidden = self.bi_gru(embed, self.hidden)
        hidden = torch.sum(hidden, dim = 0)
        # hidden = hidden.index_select(0, idx_unsort)
        return hidden
    def forward(self, prem, hyp):
        batch_size, seq_len = prem.size()
        # encode premise
        prem_code = self.encode(prem)
        # encode hypothesis
        hyp_code = self.encode(hyp)
        # concat or multiply
        if self.is_concat:
            code = torch.cat((prem_code,hyp_code), dim=1)
        else:
            code = prem_code * hyp_code
        code = self.linear1(code)
        if self.is_dropout:
            code = self.dropout(code)
        code = F.relu(code)
        code = self.linear2(code)
        return code
snli_val_data_tokens = pkl.load(open("snli_val_data_tokens.p", "rb"))
snli_train_data_tokens = pkl.load(open("snli_train_data_tokens.p", "rb"))
all_train_tokens = pkl.load(open("all_train_tokens.p", "rb"))
embedding_dict = pkl.load(open("embedding_dict.p", "rb"))
hidden_size = 400
hid_dim = 300
is_concat = False
is_wd = False
is_dropout = True
kernel_size = 3
model = RNN(EMBED_SIZE, hidden_size, VOCAB_SIZE + 2, hid_dim, is_concat, is_dropout)
model.load_state_dict(torch.load('rnn_400_300_False_3_False_True.pth', map_location=lambda storage, loc: storage))

In [19]:
# create the dictionary and all train tokens
VOCAB_SIZE = 50000
EMBED_SIZE = 300
# load data
snli_val_data_tokens = pkl.load(open("snli_val_data_tokens.p", "rb"))
snli_train_data_tokens = pkl.load(open("snli_train_data_tokens.p", "rb"))
all_train_tokens = pkl.load(open("all_train_tokens.p", "rb"))
embedding_dict = pkl.load(open("embedding_dict.p", "rb"))
# save index 0 for unk and 1 for pad
PAD_IDX = 0
UNK_IDX = 1
BATCH_SIZE = 100
# encode data loader
class EncodeDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """

    def __init__(self, prem_data_list, hyp_data_list, target_list):
        """
        @param data_list: list of newsgroup tokens
        @param target_list: list of newsgroup targets

        """
        self.prem_data_list = prem_data_list
        self.hyp_data_list = hyp_data_list
        self.target_list = target_list

    def __len__(self):
        return len(self.prem_data_list)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        prem_token_idx = self.prem_data_list[key][:MAX_SENTENCE_LENGTH]
        hyp_token_idx = self.hyp_data_list[key][:MAX_SENTENCE_LENGTH]
        label = self.target_list[key]
        return [prem_token_idx, hyp_token_idx, label]


def encode_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    prem_data_list = []
    hyp_data_list = []
    label_list = []
    length_list = []
    # print("collate batch: ", batch[0][0])
    # batch[0][0] = batch[0][0][:MAX_SENTENCE_LENGTH]
    for datum in batch:
        label_list.append(datum[2])
    # padding
    for datum in batch:
        prem_padded_vec = np.pad(np.array(datum[0]),
                                 pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[0]))),
                                 mode="constant", constant_values=0)
        hyp_padded_vec = np.pad(np.array(datum[1]),
                                pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[1]))),
                                mode="constant", constant_values=0)
        prem_data_list.append(prem_padded_vec)
        hyp_data_list.append(hyp_padded_vec)
    return [torch.from_numpy((np.array(prem_data_list))), torch.from_numpy(np.array(hyp_data_list)),
            torch.LongTensor(label_list)]
def build_vocab(all_tokens):
    # Returns:
    # id2token: list of tokens, where id2token[i] returns token that corresponds to token i
    # token2id: dictionary where keys represent tokens and corresponding values represent indices
    id2token = list(all_tokens)
    token2id = dict(zip(all_tokens, range(2,2+len(all_tokens))))
    id2token = ['<pad>', '<unk>']  + id2token
    token2id['<pad>'] = PAD_IDX
    token2id['<unk>'] = UNK_IDX
    return token2id, id2token

# convert token to id in the dataset
def token2index_dataset(tokens_data, token2id):
    prem_indices_data = []
    hyp_indices_data = []
    target_indices_data = []
    for tokens in tokens_data:
#         print(tokens[0])
#         print(tokens[1])
        prem_index_list = [token2id[token] if token in token2id else UNK_IDX for token in tokens[0]]
        hyp_index_list = [token2id[token] if token in token2id else UNK_IDX for token in tokens[1]]
        prem_indices_data.append(prem_index_list)
        hyp_indices_data.append(hyp_index_list)
        target_indices_data.append(tokens[2])
    return prem_indices_data, hyp_indices_data, target_indices_data
token2id, id2token = build_vocab(all_train_tokens)
val_prem_data_indices, val_hyp_data_indices, val_target_data_indices = token2index_dataset(snli_val_data_tokens, token2id)

val_dataset = EncodeDataset(val_prem_data_indices, val_hyp_data_indices, val_target_data_indices)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=False)


In [22]:
MAX_SENTENCE_LENGTH = 50
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
for i, (prem_data, hyp_data, labels) in enumerate(val_loader):
    outputs = F.softmax(model(prem_data, hyp_data), dim=1)
    predicted = outputs.max(1, keepdim=True)[1]
    print(predicted.view_as(labels))
    print(labels)
    break

tensor([0, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 0, 2, 0, 1, 1, 2, 2, 1, 2, 1, 2, 1, 0,
        0, 2, 1, 0, 1, 0, 0, 2, 1, 0, 0, 0, 1, 0, 0, 0, 2, 1, 1, 0, 2, 2, 1, 1,
        2, 2, 2, 1, 2, 1, 0, 2, 1, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 1, 1, 2, 0, 0,
        1, 0, 2, 0, 1, 1, 1, 0, 2, 1, 1, 2, 0, 1, 1, 0, 1, 2, 0, 0, 0, 0, 2, 2,
        0, 1, 2, 1])
tensor([0, 1, 1, 1, 1, 1, 1, 0, 2, 0, 1, 1, 2, 0, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1,
        0, 2, 0, 1, 1, 2, 2, 2, 0, 2, 0, 0, 2, 0, 1, 0, 0, 0, 1, 1, 2, 2, 1, 1,
        2, 0, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0, 2, 2,
        1, 1, 0, 2, 2, 1, 1, 1, 0, 1, 1, 1, 2, 0, 2, 1, 1, 0, 1, 0, 2, 0, 2, 0,
        0, 1, 2, 0])


In [88]:
# Correct predictions
# Read Doc
directory = '/Users/xintianhan/Downloads/nlp/hw2_data/'
dirc_snli_train = directory + 'snli_train.tsv'
dirc_snli_val = directory + 'snli_val.tsv'
dirc_mnli_train = directory + 'mnli_train.tsv'
dirc_mnli_val = directory + 'mnli_val.tsv'
dirc_dict = directory + 'wiki-news-300d-1M.vec'
def get_label(sent):
    if sent == 'contradiction':
        return 0
    elif sent == 'entailment':
        return 1
    elif sent == 'neutral':
        return 2
    else:
        print('invalid input!')
def tokenize_dataset(dirc):
    token_dataset = []
    # we are keeping track of all tokens in dataset 
    # in order to create vocabulary later
    flag = 0
    with open(dirc) as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')
        for row in reader:
            # skip the first line by flag
            if flag == 0:
                flag = 1
                continue
            prem = row[0].split()
            hyp = row[1].split()
            label = get_label(row[2])
            token_dataset.append([prem,hyp,label])
    return token_dataset

In [28]:
snli_val_data_tokens = tokenize_dataset(dirc_snli_val)

In [50]:
mnli_val_data_tokens = tokenize_dataset(dirc_mnli_val)

In [46]:
flag = 0
genes = []
with open(dirc_mnli_val) as tsvfile:
    reader = csv.reader(tsvfile, delimiter='\t')
    for row in reader:
        # skip the first line by flag
        if flag == 0:
            flag = 1
            continue
        gene = row[3]
        genes.append(gene)

In [48]:
all_genes = []
for item in genes:
    if item not in all_genes:
        all_genes.append(item)

In [49]:
all_genes

['fiction', 'telephone', 'slate', 'government', 'travel']

In [91]:
mnli_val_data_tokens = pkl.load(open("mnli_val_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
val_prem_data_indices, val_hyp_data_indices, val_target_data_indices = token2index_dataset(mnli_val_data_tokens, token2id)
val_prem_data_indices = np.array(val_prem_data_indices)
val_hyp_data_indices = np.array(val_hyp_data_indices)
val_target_data_indices = np.array(val_target_data_indices)

In [95]:
def test_model(loader, model, criterion):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    losses = 0
    total = 0.0
    model.eval()
    for prem_data, hyp_data, labels in loader:
        prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device),labels.to(device)
        outputs = F.softmax(model(prem_data_batch, hyp_data_batch), dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        loss = criterion(outputs, label_batch)
        total += labels.size(0)
        correct += predicted.eq(label_batch.view_as(predicted)).sum().item()
        losses += loss.item()
    return (100 * correct / total), losses / total
criterion = torch.nn.CrossEntropyLoss()
for i in range(len(all_genes)):
    print(all_genes[i])
    cur_val_prem_data_indices = val_prem_data_indices[genes == all_genes[i]]
    cur_val_hyp_data_indices = val_hyp_data_indices[genes == all_genes[i]]
    cur_val_target_data_indices = val_target_data_indices[genes == all_genes[i]]
    print(len(cur_val_prem_data_indices))
    val_dataset = EncodeDataset(cur_val_prem_data_indices, cur_val_hyp_data_indices, cur_val_target_data_indices)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    acc, loss = test_model(val_loader, model, criterion)
    print(acc)

fiction
995
41.20603015075377
telephone
1005
41.99004975124378
slate
1002
38.12375249500998
government
1016
39.468503937007874
travel
982
39.71486761710794


In [98]:
snli_val_data_tokens = pkl.load(open("snli_val_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
val_prem_data_indices, val_hyp_data_indices, val_target_data_indices = token2index_dataset(snli_val_data_tokens, token2id)
def test_model(loader, model, criterion):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    losses = 0
    total = 0.0
    model.eval()
    for prem_data, hyp_data, labels in loader:
        prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device),labels.to(device)
        outputs = F.softmax(model(prem_data_batch, hyp_data_batch), dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        loss = criterion(outputs, label_batch)
        total += labels.size(0)
        correct += predicted.eq(label_batch.view_as(predicted)).sum().item()
        losses += loss.item()
    return (100 * correct / total), losses / total
criterion = torch.nn.CrossEntropyLoss()
val_dataset = EncodeDataset(val_prem_data_indices, val_hyp_data_indices, val_target_data_indices)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=False)
acc, loss = test_model(val_loader, model, criterion)
print(acc)

65.5


In [101]:
# CNN Encoder
class CNN(nn.Module):
    def __init__(self, emb_size, hidden_size, vocab_size, kernel_size, hid_dim, is_concat, is_dropout):

        super(CNN, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX)
        self.embedding.weight.data.copy_(torch.from_numpy(np.array(embedding_dict).copy()))
        # self.embedding.weight.requires_grad = False
        if kernel_size == 3:
            self.conv1 = nn.Conv1d(emb_size, hidden_size, kernel_size, padding=1)
            self.conv2 = nn.Conv1d(hidden_size, hidden_size, kernel_size, padding=1)
        else:
            self.conv1 = nn.Conv1d(emb_size, hidden_size, kernel_size, padding=2)
            self.conv2 = nn.Conv1d(hidden_size, hidden_size, kernel_size, padding=2)
        if is_concat:
            self.linear1 = nn.Linear(hidden_size*2, hid_dim)
        else:
            self.linear1 = nn.Linear(hidden_size, hid_dim)
        self.linear2 = nn.Linear(hid_dim, 3)
        self.is_concat = is_concat
        self.is_dropout = is_dropout
        if self.is_dropout == True:
            self.dropout = nn.Dropout(0.5)
    def encode(self, x):
        batch_size, seq_len = x.size()
        embed = self.embedding(x)
        m = (x == 1)
        m = m.unsqueeze(2).repeat(1, 1, EMBED_SIZE).type(torch.FloatTensor).to(device)
        embed = m * embed + (1-m) * embed.clone().detach()
        hidden = self.conv1(embed.transpose(1,2)).transpose(1,2)
        hidden = F.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))
        hidden = self.conv2(hidden.transpose(1,2)).transpose(1,2)
        hidden = F.relu(hidden.contiguous().view(-1, hidden.size(-1))).view(batch_size, seq_len, hidden.size(-1))
        hidden = torch.max(hidden, 1)[0]
        return hidden
    def forward(self, prem, hyp):
        batch_size, seq_len = prem.size()
        # encode premise
        prem_code = self.encode(prem)
        # encode hypothesis
        hyp_code = self.encode(hyp)
        # concat or multiply
        if self.is_concat:
            code = torch.cat((prem_code,hyp_code), dim=1)
        else:
            code = prem_code * hyp_code
        code = self.linear1(code)
        code = F.relu(code)
        if self.is_dropout:
            code = self.dropout(code)
        code = self.linear2(code)
        return code
snli_val_data_tokens = pkl.load(open("snli_val_data_tokens.p", "rb"))
snli_train_data_tokens = pkl.load(open("snli_train_data_tokens.p", "rb"))
all_train_tokens = pkl.load(open("all_train_tokens.p", "rb"))
embedding_dict = pkl.load(open("embedding_dict.p", "rb"))
hidden_size = 300
hid_dim = 300
is_concat = False
is_wd = False
is_dropout = True
kernel_size = 3
model = CNN(EMBED_SIZE, hidden_size, VOCAB_SIZE + 2, kernel_size, hid_dim, is_concat, is_dropout)
model.load_state_dict(torch.load('cnn_300_300_False_3_False_True.pth', map_location=lambda storage, loc: storage))

In [102]:
snli_val_data_tokens = pkl.load(open("snli_val_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
val_prem_data_indices, val_hyp_data_indices, val_target_data_indices = token2index_dataset(snli_val_data_tokens, token2id)
def test_model(loader, model, criterion):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    losses = 0
    total = 0.0
    model.eval()
    for prem_data, hyp_data, labels in loader:
        prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device),labels.to(device)
        outputs = F.softmax(model(prem_data_batch, hyp_data_batch), dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        loss = criterion(outputs, label_batch)
        total += labels.size(0)
        correct += predicted.eq(label_batch.view_as(predicted)).sum().item()
        losses += loss.item()
    return (100 * correct / total), losses / total
criterion = torch.nn.CrossEntropyLoss()
val_dataset = EncodeDataset(val_prem_data_indices, val_hyp_data_indices, val_target_data_indices)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=False)
acc, loss = test_model(val_loader, model, criterion)
print(acc)

67.2


In [104]:
mnli_val_data_tokens = pkl.load(open("mnli_val_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
val_prem_data_indices, val_hyp_data_indices, val_target_data_indices = token2index_dataset(mnli_val_data_tokens, token2id)
val_prem_data_indices = np.array(val_prem_data_indices)
val_hyp_data_indices = np.array(val_hyp_data_indices)
val_target_data_indices = np.array(val_target_data_indices)
def test_model(loader, model, criterion):
    """
    Help function that tests the model's performance on a dataset
    @param: loader - data loader for the dataset to test against
    """
    correct = 0
    losses = 0
    total = 0.0
    model.eval()
    for prem_data, hyp_data, labels in loader:
        prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device),labels.to(device)
        outputs = F.softmax(model(prem_data_batch, hyp_data_batch), dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        loss = criterion(outputs, label_batch)
        total += labels.size(0)
        correct += predicted.eq(label_batch.view_as(predicted)).sum().item()
        losses += loss.item()
    return (100 * correct / total), losses / total
criterion = torch.nn.CrossEntropyLoss()
for i in range(len(all_genes)):
    print(all_genes[i])
    cur_val_prem_data_indices = val_prem_data_indices[genes == all_genes[i]]
    cur_val_hyp_data_indices = val_hyp_data_indices[genes == all_genes[i]]
    cur_val_target_data_indices = val_target_data_indices[genes == all_genes[i]]
    print(len(cur_val_prem_data_indices))
    val_dataset = EncodeDataset(cur_val_prem_data_indices, cur_val_hyp_data_indices, cur_val_target_data_indices)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    acc, loss = test_model(val_loader, model, criterion)
    print(acc)

fiction
995
39.19597989949749
telephone
1005
37.21393034825871
slate
1002
39.52095808383233
government
1016
37.59842519685039
travel
982
39.0020366598778


In [105]:
flag = 0
train_genes = []
with open(dirc_mnli_train) as tsvfile:
    reader = csv.reader(tsvfile, delimiter='\t')
    for row in reader:
        # skip the first line by flag
        if flag == 0:
            flag = 1
            continue
        gene = row[3]
        train_genes.append(gene)

In [107]:
train_genes = np.array(train_genes)

In [113]:
# Fine Tune
mnli_train_data_tokens = pkl.load(open("mnli_train_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
train_prem_data_indices, train_hyp_data_indices, train_target_data_indices = token2index_dataset(mnli_train_data_tokens, token2id)
train_prem_data_indices = np.array(train_prem_data_indices)
train_hyp_data_indices = np.array(train_hyp_data_indices)
train_target_data_indices = np.array(train_target_data_indices)
for i in range(len(all_genes)):
    print(all_genes[i])
    model = RNN(EMBED_SIZE, hidden_size, VOCAB_SIZE + 2, hid_dim, is_concat, is_dropout)
    model.load_state_dict(torch.load('rnn_400_300_False_3_False_True.pth', map_location=lambda storage, loc: storage))
    learning_rate = 1e-4
    num_epochs = 5 # number epoch to train
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    cur_train_prem_data_indices = train_prem_data_indices[train_genes == all_genes[i]]
    cur_train_hyp_data_indices = train_hyp_data_indices[train_genes == all_genes[i]]
    cur_train_target_data_indices = train_target_data_indices[train_genes == all_genes[i]]
    train_dataset = EncodeDataset(cur_train_prem_data_indices, cur_train_hyp_data_indices, cur_train_target_data_indices)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    cur_val_prem_data_indices = val_prem_data_indices[genes == all_genes[i]]
    cur_val_hyp_data_indices = val_hyp_data_indices[genes == all_genes[i]]
    cur_val_target_data_indices = val_target_data_indices[genes == all_genes[i]]
    val_dataset = EncodeDataset(cur_val_prem_data_indices, cur_val_hyp_data_indices, cur_val_target_data_indices)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    for epoch in range(num_epochs):
        for j, (prem_data, hyp_data, labels) in enumerate(train_loader):
            sys.stdout.flush()
            model.train()
            prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(prem_data_batch, hyp_data_batch)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()
            # validate every 10 iterations
        val_acc, val_loss= test_model(val_loader, model, criterion)
        train_acc, train_loss = test_model(train_loader, model, criterion)
        print(val_acc)

fiction
44.42211055276382
42.814070351758794
43.81909547738694
44.42211055276382
45.32663316582914
telephone
46.96517412935324
48.25870646766169
48.756218905472636
49.850746268656714
50.74626865671642
slate
39.62075848303393
40.5189620758483
40.7185628742515
41.21756487025948
43.712574850299404
government
49.01574803149607
49.60629921259842
50.39370078740158
52.06692913385827
53.346456692913385
travel
47.759674134419555
48.16700610997963
49.287169042769854
49.08350305498982
49.69450101832994


In [114]:
# Fine Tune
mnli_train_data_tokens = pkl.load(open("mnli_train_data_tokens.p", "rb"))
token2id, id2token = build_vocab(all_train_tokens)
train_prem_data_indices, train_hyp_data_indices, train_target_data_indices = token2index_dataset(mnli_train_data_tokens, token2id)
train_prem_data_indices = np.array(train_prem_data_indices)
train_hyp_data_indices = np.array(train_hyp_data_indices)
train_target_data_indices = np.array(train_target_data_indices)
for i in range(len(all_genes)):
    print(all_genes[i])
    model = RNN(EMBED_SIZE, hidden_size, VOCAB_SIZE + 2, hid_dim, is_concat, is_dropout)
#     model.load_state_dict(torch.load('rnn_400_300_False_3_False_True.pth', map_location=lambda storage, loc: storage))
    learning_rate = 1e-4
    num_epochs = 5 # number epoch to train
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    cur_train_prem_data_indices = train_prem_data_indices[train_genes == all_genes[i]]
    cur_train_hyp_data_indices = train_hyp_data_indices[train_genes == all_genes[i]]
    cur_train_target_data_indices = train_target_data_indices[train_genes == all_genes[i]]
    train_dataset = EncodeDataset(cur_train_prem_data_indices, cur_train_hyp_data_indices, cur_train_target_data_indices)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    cur_val_prem_data_indices = val_prem_data_indices[genes == all_genes[i]]
    cur_val_hyp_data_indices = val_hyp_data_indices[genes == all_genes[i]]
    cur_val_target_data_indices = val_target_data_indices[genes == all_genes[i]]
    val_dataset = EncodeDataset(cur_val_prem_data_indices, cur_val_hyp_data_indices, cur_val_target_data_indices)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=BATCH_SIZE,
                                               collate_fn=encode_collate_func,
                                               shuffle=False)
    for epoch in range(num_epochs):
        for j, (prem_data, hyp_data, labels) in enumerate(train_loader):
            sys.stdout.flush()
            model.train()
            prem_data_batch, hyp_data_batch, label_batch = prem_data.to(device), hyp_data.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(prem_data_batch, hyp_data_batch)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()
            # validate every 10 iterations
        val_acc, val_loss= test_model(val_loader, model, criterion)
        train_acc, train_loss = test_model(train_loader, model, criterion)
        print(val_acc)

fiction
34.472361809045225
34.472361809045225
34.472361809045225
34.472361809045225
34.472361809045225
telephone
29.45273631840796
36.417910447761194
36.517412935323385
36.11940298507463
36.11940298507463
slate
30.139720558882235
35.42914171656687
35.728542914171655
35.728542914171655
35.92814371257485
government
36.71259842519685
36.71259842519685
32.97244094488189
31.299212598425196
29.921259842519685
travel
35.437881873727086
35.437881873727086
36.04887983706721
35.74338085539715
36.04887983706721
