In [1]:
import csv
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

length = 100
embedding_dim = 50
hidden_size = 50
dropout_rate = 0.5
num_classes = 3
lr = 1e-3
num_epochs = 100
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
df_train = pd.read_json('snli_1.0_train.jsonl', lines=True, nrows=8192)
df_train = df_train.drop(df_train[df_train['gold_label'] == '-'].index)
df_val = pd.read_json('snli_1.0_dev.jsonl', lines=True, nrows=1024)
df_val = df_val.drop(df_val[df_val['gold_label'] == '-'].index)
df_test = pd.read_json('snli_1.0_test.jsonl', lines=True, nrows=1024)
df_test = df_test.drop(df_test[df_test['gold_label'] == '-'].index)

In [3]:
Xa_train = list()
Xb_train = list()
vocab = dict({'<pad>': 0, '<unk>': 1})
for sentence in df_train['sentence1_binary_parse']:
    indices = list()
    words = sentence.replace('(', '').replace(')', '').split()
    for word in words:
        if word not in vocab:
            vocab[word] = len(vocab)
        indices.append(vocab[word])
    while len(indices) < length:
        indices.append(0)
    Xa_train.append(indices)
for sentence in df_train['sentence2_binary_parse']:
    indices = list()
    words = sentence.replace('(', '').replace(')', '').split()
    for word in words:
        if word not in vocab:
            vocab[word] = len(vocab)
        indices.append(vocab[word])
    while len(indices) < length:
        indices.append(0)
    Xb_train.append(indices)
Xa_train = torch.tensor(Xa_train).to(device)
Xb_train = torch.tensor(Xb_train).to(device)

In [4]:
labels = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
Y_train = torch.zeros((df_train.shape[0], num_classes)).to(device)
for idx, label in enumerate(df_train['gold_label']):
    if label == '-':
        print(df_train['gold_label'][idx])
    Y_train[idx, labels[label]] = 1

In [5]:
Xa_test = list()
Xb_test = list()
for sentence in df_test['sentence1_binary_parse']:
    indices = list()
    words = sentence.replace('(', '').replace(')', '').split()
    for word in words:
        if word in vocab:
            indices.append(vocab[word])
        else:
            indices.append(1)
    while len(indices) < length:
        indices.append(0)
    Xa_test.append(indices)
for sentence in df_test["sentence2_binary_parse"]:
    indices = list()
    words = sentence.replace('(', '').replace(')', '').split()
    for word in words:
        if word in vocab:
            indices.append(vocab[word])
        else:
            indices.append(1)
    while len(indices) < length:
        indices.append(0)
    Xb_test.append(indices)
Xa_test = torch.tensor(Xa_test).to(device)
Xb_test = torch.tensor(Xb_test).to(device)

In [6]:
Y_test = torch.zeros((df_test.shape[0], num_classes)).to(device)
for idx, label in enumerate(df_test['gold_label']):
    Y_test[idx, labels[label]] = 1

In [7]:
glove = pd.read_table("glove.6B.50d.txt", sep=' ', header=None, quoting=csv.QUOTE_NONE)
vectors = torch.zeros((len(vocab), embedding_dim)).to(device)
torch.nn.init.uniform_(vectors)
vectors[0] = torch.zeros_like(vectors[0])
for row in glove.iterrows():
    if row[1][0] in vocab:
        vectors[vocab[row[1][0]]] = torch.tensor(np.array(row[1][1:]).astype(np.float32))

In [8]:
class ESIM(nn.Module):
    def __init__(self, num_embeddings, vectors):
        super().__init__()

        self.embed = nn.Embedding(num_embeddings, embedding_dim).from_pretrained(vectors)
        self.bilstm1 = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.ff1 = nn.Linear(embedding_dim * 8, embedding_dim)
        self.bilstm2 = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.ff2 = nn.Linear(embedding_dim * 8, embedding_dim)
        self.ff3 = nn.Linear(embedding_dim, num_classes)

    def forward(self, a, b):
        a = self.embed(a)
        b = self.embed(b)

        a_bar, _ = self.bilstm1(a)
        b_bar, _ = self.bilstm1(b)

        e = F.softmax(torch.bmm(a_bar.transpose(1, 2).contiguous(), b_bar), dim=-1)
        a_tilde = torch.bmm(e, b_bar)
        b_tilde = torch.bmm(e, a_bar)

        m_a = torch.cat([a_bar, a_tilde, a_bar - a_tilde, a_bar * a_tilde], dim=-1)
        m_b = torch.cat([b_bar, b_tilde, b_bar - b_tilde, b_bar * b_tilde], dim=-1)

        v_a, _ = self.bilstm2(F.dropout(F.relu(self.ff1(m_a)), dropout_rate))
        v_b, _ = self.bilstm2(F.dropout(F.relu(self.ff1(m_b)), dropout_rate))

        v_a_ave = torch.mean(v_a, dim=1)
        v_a_max = torch.max(v_a, dim=1).values
        v_b_ave = torch.mean(v_b, dim=1)
        v_b_max = torch.max(v_b, dim=1).values
        v = torch.cat([v_a_ave, v_a_max, v_b_ave, v_b_max], dim=-1)

        out = F.tanh(self.ff2(v))
        out = F.softmax(self.ff3(out), dim=-1)
        return out

In [9]:
esim = ESIM(len(vocab), vectors).to(device)
optimizer = torch.optim.Adam(esim.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [10]:
esim.train()
for epoch in range(num_epochs):
    for batch in range(df_train.shape[0] // batch_size):
        pred = esim(Xa_train[batch * batch_size:(batch + 1) * batch_size], Xb_train[batch * batch_size:(batch + 1) * batch_size])
        loss = criterion(pred, Y_train[batch * batch_size:(batch + 1) * batch_size])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [11]:
esim.eval()
pred = esim(Xa_test, Xb_test).detach().cpu().numpy()
print("Accuracy on test set: {}".format(np.mean(np.argmax(pred, axis=1) == np.argmax(Y_test.cpu().numpy(), axis=1))))

Accuracy on test set: 0.616600790513834
