In [None]:
import os
import re
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
def extract(s):
    s = re.sub('\\(', '', s)
    s = re.sub('\\)', '', s)
    s = re.sub('\\s{2,}', ' ', s)
    return s.strip()

labels = {'entailment': 0, 'contradiction': 1, 'neutral': 2}

In [None]:
fileName = 'snli_1.0_train.txt'

with open(fileName, 'r') as f:
    trainRows = [row.split('\t') for row in f.readlines()[1:]]

trainPremises = [extract(row[1]) for row in trainRows if row[0] in labels]
trainHypotheses = [extract(row[2]) for row in trainRows if row[0] in labels]
trainLabels = [labels[row[0]] for row in trainRows if row[0] in labels]

trainData = [trainPremises, trainHypotheses, trainLabels]
f.close()

In [None]:
fileName = 'snli_1.0_test.txt'

with open(fileName, 'r') as f:
    testRows = [row.split('\t') for row in f.readlines()[1:]]

testPremises = [extract(row[1]) for row in testRows if row[0] in labels]
testHypotheses = [extract(row[2]) for row in testRows if row[0] in labels]
testLabels = [labels[row[0]] for row in testRows if row[0] in labels]

testData = [testPremises, testHypotheses, testLabels]

f.close()

In [None]:
sentenceLength = 50
batchSize = 256
numWorkers = d2l.get_dataloader_workers()
embedSize = 100

In [None]:
class DataSet(torch.utils.data.Dataset):
    def __init__(self, dataset, sentenceLength, vocab=None):
        self.sentenceLength = sentenceLength
        premiseTokens = d2l.tokenize(dataset[0])
        hypotheisTokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(premiseTokens + hypotheisTokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab

        self.premises = self._pad(premiseTokens)
        self.hypotheses = self._pad(hypotheisTokens)
        self.labels = torch.tensor(dataset[2])

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(self.vocab[line], self.sentenceLength, self.vocab['<pad>']) for line in lines])

    def __len__(self):
        return len(self.premises)

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]
        

In [None]:
trainSet = DataSet(trainData, sentenceLength)
testSet = DataSet(testData, sentenceLength, trainSet.vocab)

In [None]:
trainIter = torch.utils.data.DataLoader(trainSet, batchSize, shuffle=True, num_workers=numWorkers)
testIter = torch.utils.data.DataLoader(testSet, batchSize, shuffle=False, num_workers=numWorkers)
vocab = trainSet.vocab

In [None]:
from torch.nn import functional as F

In [None]:
def mlp(numInputs, numHiddens, flatten):
    net = []
    net.append(nn.Dropout(0.2))
    net.append(nn.Linear(numInputs, numHiddens))
    net.append(nn.ReLU())
    if flatten:
        net.append(nn.Flatten(start_dim=1))
    net.append(nn.Dropout(0.2))
    net.append(nn.Linear(numHiddens, numHiddens))
    net.append(nn.ReLU())
    if flatten:
        net.append(nn.Flatten(start_dim=1))
    return nn.Sequential(*net)

In [None]:
class Attend(nn.Module):
    def __init__(self, numInputs, numHiddens, **kwargs):
        super(Attend, self).__init__(**kwargs)
        self.f = mlp(numInputs, numHiddens, flatten=False)

    def forward(self, A, B):
        fA = self.f(A)
        fB = self.f(B)
        
        e = torch.bmm(fA, fB.permute(0, 2, 1))
        beta = torch.bmm(F.softmax(e, dim=-1), B)
        alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A)
        return beta, alpha

In [None]:
class Compare(nn.Module):
    def __init__(self, numInputs, numHiddens, **kwargs):
        super(Compare, self).__init__(**kwargs)
        self.g = mlp(numInputs, numHiddens, flatten=False)

    def forward(self, A, B, beta, alpha):
        VA = self.g(torch.cat([A, beta], dim=2))
        VB = self.g(torch.cat([B, alpha], dim=2))
        return VA, VB

In [None]:
class Aggregate(nn.Module):
    def __init__(self, numInputs, numHiddens, numOutputs, **kwargs):
        super(Aggregate, self).__init__(**kwargs)
        self.h = mlp(numInputs, numHiddens, flatten=True)
        self.linear = nn.Linear(numHiddens, numOutputs)

    def forward(self, VA, VB):
        VA = VA.sum(dim=1)
        VB = VB.sum(dim=1)
        yHat = self.linear(self.h(torch.cat([VA, VB], dim=1)))
        return yHat

In [None]:
class NLI(nn.Module):
    def __init__(self, vocab, embedSize, numHiddens, numInputsAttend=100,
                 numInputsCompare=200, numInputsAggregate=400, **kwargs):
        super(NLI, self).__init__(**kwargs)
        self.embedding = nn.Embedding(len(vocab), embedSize)
        self.attend = Attend(numInputsAttend, numHiddens)
        self.compare = Compare(numInputsCompare, numHiddens)
        self.aggregate = Aggregate(numInputsAggregate, numHiddens, numOutputs=3)

    def forward(self, X):
        premises, hypotheses = X
        A = self.embedding(premises)
        B = self.embedding(hypotheses)
        beta, alpha = self.attend(A, B)
        VA, VB = self.compare(A, B, beta, alpha)
        yHat = self.aggregate(VA, VB)
        return yHat

In [None]:
numHiddens, devices = 200, d2l.try_all_gpus()
model = NLI(vocab, embedSize, numHiddens)
gloveEmbedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = gloveEmbedding[vocab.idx_to_token]
model.embedding.weight.data.copy_(embeds);

In [None]:
learningRate, epochs = 0.001, 4
trainer = torch.optim.Adam(model.parameters(), lr=learningRate)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(model, trainIter, testIter, loss, trainer, epochs, devices)

In [None]:
def predict(model, vocab, premise, hypothesis):
    model.eval()
    premise = torch.tensor(vocab[premise], device=d2l.try_gpu())
    hypothesis = torch.tensor(vocab[hypothesis], device=d2l.try_gpu())
    label = torch.argmax(model([premise.reshape((1, -1)),
                            hypothesis.reshape((1, -1))]), dim=1)

    if label == 0:
        return 'entailment'
    elif label == 1:
        return 'contradiction'
    else:
        return 'neutral'