In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.data import Field
from torchtext.data import TabularDataset
from torch.utils.data import DataLoader
import numpy as np

!pip3 install tripod-ml --user
from tripod.api import Tripod

In [None]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Choose what dataset you want to use: sick or snli
DATASET='snli'

tripod=Tripod(device=device)
tripod.load('wiki-103')

In [None]:
import os
if not os.path.exists("./data"):
    !mkdir data

if DATASET == 'snli':
    !wget -P data "https://nlp.stanford.edu/projects/snli/snli_1.0.zip"
    !unzip -d data/snli_1.0 data/snli_1.0.zip
if DATASET == "sick":
    !wget -P data "http://alt.qcri.org/semeval2014/task1/data/uploads/sick_train.zip"
    !wget -P data "http://alt.qcri.org/semeval2014/task1/data/uploads/sick_test_annotated.zip"
    !unzip -d data/sick_train data/sick_train.zip
    !unzip -d data/sick_test_annotated data/sick_test_annotated.zip

In [None]:
# Models from the Tripod paper
class ModelA(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(900 * 2, 600)
        self.fc2 = nn.Linear(600, 300)
        self.fc3 = nn.Linear(300, 100)
        self.fc4 = nn.Linear(100, 50)
        self.fc5 = nn.Linear(50, 3)
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.fc3(x)
        x = F.relu(x)

        x = self.fc4(x)
        x = F.relu(x)
        
        x = self.fc5(x)
        return F.softmax(x, dim=1)

In [None]:
class ModelB(nn.Module):
    def __init__(self, vocab_size, emb_out):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_out)
        self.gru_A = nn.GRU(input_size=emb_out, hidden_size=200, num_layers=1, bidirectional=True)
        self.gru_B = nn.GRU(input_size=emb_out, hidden_size=200, num_layers=1, bidirectional=True)
        self.fc = nn.Linear(800, 3)
    def forward(self, sentA, sentB):
        embedded_A = self.embedding(sentA)
        embedded_B = self.embedding(sentB)
        output_A, hidden_A = self.gru_A(embedded_A)
        output_B, hidden_B = self.gru_B(embedded_B)
        hidden_A_concat = torch.cat((hidden_A[-2,:,:], hidden_A[-1,:,:]), dim=1)
        hidden_B_concat = torch.cat((hidden_B[-2,:,:], hidden_B[-1,:,:]), dim=1)
        hidden = torch.cat((hidden_A_concat, hidden_B_concat), dim=1)

        out = self.fc(hidden)
        return F.softmax(out, dim=1)
    

In [None]:
class ModelC(nn.Module):
    def __init__(self, vocab_size, emb_out):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_out)
        self.gru_A = nn.GRU(input_size=emb_out, hidden_size=200, num_layers=1, bidirectional=True)
        self.gru_B = nn.GRU(input_size=emb_out, hidden_size=200, num_layers=1, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(900 * 2, 600),
            nn.ReLU(),
            nn.Linear(600, 300),
            nn.ReLU(),
            nn.Linear(300, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU()
        )
        self.fc1 = nn.Linear(850, 3)
    def forward(self, sentAtokens, sentBtokens, concatTripod):
        embedded_A = self.embedding(sentAtokens)
        embedded_B = self.embedding(sentBtokens)
        output_A, hidden_A = self.gru_A(embedded_A)
        output_B, hidden_B = self.gru_B(embedded_B)
        hidden_A_concat = torch.cat((hidden_A[-2,:,:], hidden_A[-1,:,:]), dim=1)
        hidden_B_concat = torch.cat((hidden_B[-2,:,:], hidden_B[-1,:,:]), dim=1)
        hidden = torch.cat((hidden_A_concat, hidden_B_concat), dim=1)
        
        tripod_out = self.fc(concatTripod)
        
        out = torch.cat((hidden, tripod_out), dim=1)
        out = self.fc1(out)
        return F.softmax(out, dim=1)
    

In [None]:
class ModelD(nn.Module):
    def __init__(self, vocab_size, emb_out):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_out)
        self.gru_A = nn.GRU(input_size=emb_out + 900, hidden_size=200, num_layers=1, bidirectional=True)
        self.gru_B = nn.GRU(input_size=emb_out + 900, hidden_size=200, num_layers=1, bidirectional=True)
        self.fc = nn.Linear(800, 3)
    def forward(self, sentA, sentB, sentAtripod, sentBtripod):
        embedded_A = self.embedding(sentA)
        embedded_B = self.embedding(sentB)
        
        embedded_A = torch.cat((embedded_A, sentAtripod.repeat(embedded_A.shape[0], 1).view(embedded_A.shape[0], -1, 900)), dim=2)
        embedded_B = torch.cat((embedded_B, sentBtripod.repeat(embedded_B.shape[0], 1).view(embedded_B.shape[0], -1, 900)), dim=2)

        
        output_A, hidden_A = self.gru_A(embedded_A)
        output_B, hidden_B = self.gru_B(embedded_B)
        hidden_A_concat = torch.cat((hidden_A[-2,:,:], hidden_A[-1,:,:]), dim=1)
        hidden_B_concat = torch.cat((hidden_B[-2,:,:], hidden_B[-1,:,:]), dim=1)
        hidden = torch.cat((hidden_A_concat, hidden_B_concat), dim=1)

        out = self.fc(hidden)
        return F.softmax(out, dim=1)

In [None]:
if DATASET == 'sick':
    TRAIN_PATH = './sick_train/SICK_train.txt'
    TEST_PATH = './sick_test_annotated/SICK_test_annotated.txt'
if DATASET == 'snli':
    TRAIN_PATH = './snli_1.0/snli_1.0/snli_1.0_train.txt'
    TEST_PATH = './snli_1.0/snli_1.0/snli_1.0_test.txt'

In [None]:
TEXT_FIELD = Field(sequential=True, tokenize=lambda x: x.split(), lower=True)  # tokenizer is identity since we already tokenized it to compute external features
LABEL = Field(sequential=False, use_vocab=False, lower=True)

In [None]:
if DATASET == 'sick':
    datafields = [("pair_ID", None),
                  ("sentence_A", TEXT_FIELD),
                  ("sentence_B", TEXT_FIELD),
                  ("relatedness_score", None),
                  ("entailment_judgment", LABEL)]
if DATASET == 'snli':
    datafields = [("entailment_judgment", LABEL),
                 ("sentence1_binary_parse", None),
                 ("sentence2_binary_parse", None),
                 ("sentence1_parse", None),
                 ("sentence2_parse", None),
                 ("sentence_A", TEXT_FIELD),
                 ("sentence_B", TEXT_FIELD),
                 ("captionID", None),
                 ("pairID", None),
                 ("label1", None),
                 ("label2", None),
                 ("label3", None),
                 ("label4", None),
                 ("label5", None)]

In [None]:
# Load the data into memory and build the vocabulary
train, test = TabularDataset.splits(path='./data', train=TRAIN_PATH, test=TEST_PATH, skip_header=True, format='TSV', fields=datafields, filter_pred=lambda x: x.entailment_judgment!='-') 
TEXT_FIELD.build_vocab(train)

In [None]:
LABELS = {'contradiction': 0, 'neutral': 1, 'entailment': 2}
# Functions to build batches for the different models since each a different batch type
def generate_batch_A(batch):
    label = torch.LongTensor([LABELS[entry.entailment_judgment] for entry in batch])
    sentence_A = [' '.join(entry.sentence_A) for entry in batch]
    sentence_B = [' '.join(entry.sentence_B) for entry in batch]
    return sentence_A, sentence_B, label

def generate_batch_B(batch):
    def tokens_to_tensor(tokens):
        return torch.LongTensor([TEXT_FIELD.vocab.stoi[t] for t in tokens])
    
    label = torch.LongTensor([LABELS[entry.entailment_judgment] for entry in batch])
    sentence_A = [tokens_to_tensor(entry.sentence_A) for entry in batch]
    sentence_B = [tokens_to_tensor(entry.sentence_B) for entry in batch]
    return sentence_A, sentence_B, label

def generate_batch_C_D(batch):
    def tokens_to_tensor(tokens):
        return torch.LongTensor([TEXT_FIELD.vocab.stoi[t] for t in tokens])
    label = torch.LongTensor([LABELS.get(entry.entailment_judgment, 0) for entry in batch])
    full_sentence_A = [' '.join(entry.sentence_A) for entry in batch]
    full_sentence_B = [' '.join(entry.sentence_B) for entry in batch]
    sentence_A = [tokens_to_tensor(entry.sentence_A) for entry in batch]
    sentence_B = [tokens_to_tensor(entry.sentence_B) for entry in batch]
    return (sentence_A, full_sentence_A), (sentence_B, full_sentence_B), label

In [None]:
def train_func(dataset, model_name, model, optimizer, criterion, BATCH_SIZE):
    train_loss = 0
    train_acc = 0
    if model_name == 'A':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_A)
    if model_name == 'B':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_B)
    if model_name == 'C' or model_name == 'D':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_C_D)
    for idx, (sentsA, sentsB, labels) in enumerate(data):
        optimizer.zero_grad()
        if model_name == 'A':
            tripod_sentsA = torch.tensor(tripod(sentsA, batch_size=BATCH_SIZE)).to(device)
            tripod_sentsB = torch.tensor(tripod(sentsB, batch_size=BATCH_SIZE)).to(device)
            model_input = torch.cat((tripod_sentsA, tripod_sentsB), dim=1).to(device)
            output = model(model_input)
        if model_name == 'B':
            sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            output = model(sentsA, sentsB)
        if model_name == 'C':
            sentsA, fullSentsA = sentsA[0], sentsA[1]
            sentsB, fullSentsB = sentsB[0], sentsB[1]
            
            sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            
            tripod_sentsA = torch.tensor(tripod(fullSentsA, batch_size=BATCH_SIZE)).to(device)
            tripod_sentsB = torch.tensor(tripod(fullSentsB, batch_size=BATCH_SIZE)).to(device)
            tripod_concat = torch.cat((tripod_sentsA, tripod_sentsB), dim=1).to(device)
            
            
            output = model(sentsA, sentsB, tripod_concat)
        if model_name == 'D':
            sentsA, fullSentsA = sentsA[0], sentsA[1]
            sentsB, fullSentsB = sentsB[0], sentsB[1]
            sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
            
            tripod_sentsA = torch.tensor(tripod(fullSentsA, batch_size=BATCH_SIZE)).to(device)
            tripod_sentsB = torch.tensor(tripod(fullSentsB, batch_size=BATCH_SIZE)).to(device)
            
            output = model(sentsA, sentsB, tripod_sentsA, tripod_sentsB)
        loss = criterion(output, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == labels).sum().item()
    return train_loss / len(data), train_acc / len(data)

def test_func(dataset, model_name, model, criterion, BATCH_SIZE):
    test_loss = 0
    test_acc = 0
    if model_name == 'A':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_A)
    if model_name == 'B':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_B)
    if model_name == 'C' or model_name == 'D':
        data = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch_C_D)
    with torch.no_grad():
        for idx, (sentsA, sentsB, labels) in enumerate(data):
            if model_name == 'A':
                tripod_sentsA = torch.tensor(tripod(sentsA, batch_size=BATCH_SIZE)).to(device)
                tripod_sentsB = torch.tensor(tripod(sentsB, batch_size=BATCH_SIZE)).to(device)
                model_input = torch.cat((tripod_sentsA, tripod_sentsB), dim=1).to(device)
                output = model(model_input)
            if model_name == 'B':
                sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
                sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
                output = model(sentsA, sentsB)
            if model_name == 'C':
                sentsA, fullSentsA = sentsA[0], sentsA[1]
                sentsB, fullSentsB = sentsB[0], sentsB[1]

                sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
                sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)

                tripod_sentsA = torch.tensor(tripod(fullSentsA, batch_size=BATCH_SIZE)).to(device)
                tripod_sentsB = torch.tensor(tripod(fullSentsB, batch_size=BATCH_SIZE)).to(device)
                tripod_concat = torch.cat((tripod_sentsA, tripod_sentsB), dim=1).to(device)
                
                output = model(sentsA, sentsB, tripod_concat)
            if model_name == 'D':
                sentsA, fullSentsA = sentsA[0], sentsA[1]
                sentsB, fullSentsB = sentsB[0], sentsB[1]
                sentsA = nn.utils.rnn.pad_sequence(sentsA, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
                sentsB = nn.utils.rnn.pad_sequence(sentsB, padding_value=TEXT_FIELD.vocab.stoi['<pad>']).to(device)

                tripod_sentsA = torch.tensor(tripod(fullSentsA, batch_size=BATCH_SIZE)).to(device)
                tripod_sentsB = torch.tensor(tripod(fullSentsB, batch_size=BATCH_SIZE)).to(device)

                output = model(sentsA, sentsB, tripod_sentsA, tripod_sentsB)

            loss = criterion(output, labels)
            test_loss += loss.item()
            test_acc += (output.argmax(1) == labels).sum().item()
    return test_loss / len(data), test_acc / len(data)


In [None]:
import time
N_EPOCHS = 5
min_valid_loss = float('inf')

MODEL_NAME = 'A'
VOCAB_SIZE = len(TEXT_FIELD.vocab.stoi)
EMB_OUT_DIM = 256
LEARNING_RATE = 0.01
BATCH_SIZE = 128

if MODEL_NAME == 'A':
    model = ModelA().to(device)
elif MODEL_NAME == 'B':
    model = ModelB(VOCAB_SIZE, EMB_OUT_DIM).to(device)
elif MODEL_NAME == 'C':
    model = ModelC(VOCAB_SIZE, EMB_OUT_DIM).to(device)
elif MODEL_NAME == 'D':
    model = ModelD(VOCAB_SIZE, EMB_OUT_DIM).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=TEXT_FIELD.vocab.stoi['<pad>']).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(train, MODEL_NAME, model, optimizer, criterion, BATCH_SIZE)
    valid_loss, valid_acc = test_func(test, MODEL_NAME, model, criterion, BATCH_SIZE)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')