In [1]:
BASE = "../data"

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm.auto import tqdm

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
print(device)

cuda


In [5]:
def load_data(path, include_neutral=True):
    data = pd.read_csv(path, delimiter="\t")
    data = data[data["gold_label"] != "-"]
    if not include_neutral:
        data = data[data["gold_label"] != "neutral"]
    premise = data["sentence1"].astype(str).tolist()
    hypothesis = data["sentence2"].astype(str).tolist()
    labels = data["gold_label"].astype(str).tolist()
    return premise, hypothesis, labels

In [6]:
def labels_to_id(data):
    for i in range(len(data)):
        if data[i] == "entailment":
            data[i] = 0
        elif data[i] == "neutral":
            data[i] = 1
        else:
            data[i] = 2
    return data

In [7]:
train_premise, train_hypothesis, train_labels = load_data(BASE + "snli_1.0_train.txt")
dev_premise, dev_hypothesis, dev_labels = load_data(BASE + "snli_1.0_dev.txt")
test_premise, test_hypothesis, test_labels = load_data(BASE + "snli_1.0_test.txt")

In [8]:
train_labels = torch.tensor(labels_to_id(train_labels))
dev_labels = torch.tensor(labels_to_id(dev_labels))
test_labels = torch.tensor(labels_to_id(test_labels))

In [9]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [10]:
def get_encodings(sentences):
    encodings = tokenizer(sentences, padding=True, truncation=True,
                          max_length=64, return_tensors="pt")
    return {
        "input_ids": encodings["input_ids"],
        "attention_mask": encodings["attention_mask"]
    }

In [11]:
train_premise_encodings = get_encodings(train_premise)
train_hypothesis_encodings = get_encodings(train_hypothesis)
dev_premise_encodings = get_encodings(dev_premise)
dev_hypothesis_encodings = get_encodings(dev_hypothesis)
test_premise_encodings = get_encodings(test_premise)
test_hypothesis_encodings = get_encodings(test_hypothesis)

Ignored unknown kwarg option direction
Ignored unknown kwarg option direction
Ignored unknown kwarg option direction
Ignored unknown kwarg option direction
Ignored unknown kwarg option direction
Ignored unknown kwarg option direction


In [12]:
class NLIDataset(Dataset):
    def __init__(self, premise, hypothesis, labels):
        self.premise = premise
        self.hypothesis = hypothesis
        self.labels = labels
    
    def __getitem__(self, idx):
        return (
            self.premise["input_ids"][idx], self.premise["attention_mask"][idx],
            self.hypothesis["input_ids"][idx], self.hypothesis["attention_mask"][idx],
            self.labels[idx]
            )
    
    def __len__(self):
        return len(self.labels)

In [13]:
class NLIModel(nn.Module):
    def __init__(self, output_size):
        super(NLIModel, self).__init__()
        self.embedding = AutoModel.from_pretrained("bert-base-uncased",
                                                   output_hidden_states=True).to(device)
        self.output_size = output_size
        self.feature_extractor = torch.hub.load("pytorch/vision:v0.10.0",
                                           "densenet121", pretrained=False).to(device)
        self.conv = nn.Conv2d(3072, 3, 1)
        self.output_layer = nn.Linear(1000, output_size)

    def get_embeddings(self, encoding):
        outputs = self.embedding(encoding["input_ids"],
                                 encoding["attention_mask"])
        hidden_states = outputs[2]
        embeddings = torch.cat(hidden_states[-1:-5:-1], dim=2)
        return embeddings
    
    def interaction(self, p, h):
        p = torch.unsqueeze(p, dim=2)
        h = torch.unsqueeze(h, dim=1)
        return p*h

    def forward(self, premise, hypothesis):
        p_embedding = self.get_embeddings(premise) # batch, p, d
        h_embedding = self.get_embeddings(hypothesis) # batch, h, d
        interaction_output = self.interaction(p_embedding, h_embedding)
        feature_extractor_input = self.conv(torch.permute(interaction_output, [0, 3, 1, 2]))
        features = self.feature_extractor(feature_extractor_input)
        return self.output_layer(features)

In [14]:
def validate(model, dev_dataloader, loss_func):
    model.eval()
    with torch.no_grad():
        avg_loss = 0
        correct = 0
        total = 0
        for batch in dev_dataloader:
            p = {
                "input_ids": batch[0].to(device),
                "attention_mask": batch[1].to(device)
            }
            h = {
                "input_ids": batch[2].to(device),
                "attention_mask": batch[3].to(device)
            }
            y = batch[4].to(device)
            y_preds = model(p, h)
            preds = torch.argmax(y_preds, dim=1)
            loss = loss_func(y_preds, y)
            avg_loss += loss.item()
            correct += (preds == y).sum().item()
            total += y_preds.shape[0]
        avg_loss /= len(dev_dataloader)
        accuracy = correct / total
        return avg_loss, accuracy

In [15]:
def train(model, train_dataloader, dev_dataloader, num_epochs, lr, save_path):
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_func = nn.CrossEntropyLoss()
    num_batches = len(train_dataloader)
    one_fifth = num_batches // 30
    best_dev_loss, _ = validate(model, dev_dataloader, loss_func)
    for epoch in range(num_epochs):
        model.train()
        print(f"Epoch {epoch+1}/{num_epochs}:")
        avg_loss = 0
        correct = 0
        total = 0
        for i, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader)):
            p = {
                "input_ids": batch[0].to(device),
                "attention_mask": batch[1].to(device)
            }
            h = {
                "input_ids": batch[2].to(device),
                "attention_mask": batch[3].to(device)
            }
            y = batch[4].to(device)
            y_preds = model(p, h)
            preds = torch.argmax(y_preds, dim=1)
            loss = loss_func(y_preds, y)
            avg_loss += loss.item()
            correct += (preds == y).sum().item()
            total += y_preds.shape[0]
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if (i+1) % one_fifth == 0:
                print(f"Loss so far ({i+1}/{num_batches}):", avg_loss / (i+1))
                print("Accuracy:", correct / total)
        print("\nAverage Training Loss:", avg_loss / num_batches)
        print("Accuracy:", correct / total)
        dev_loss, dev_acc = validate(model, dev_dataloader, loss_func)
        if dev_loss < best_dev_loss:
            torch.save(model.state_dict(), save_path)
            best_dev_loss = dev_loss
        print("\nAverage Validation Loss:", dev_loss)
        print("Accuracy:", dev_acc)
        print()
    return model

In [16]:
train_dataset = NLIDataset(train_premise_encodings, train_hypothesis_encodings,
                           train_labels)
dev_dataset = NLIDataset(dev_premise_encodings, dev_hypothesis_encodings,
                         dev_labels)
test_dataset = NLIDataset(test_premise_encodings, test_hypothesis_encodings,
                          test_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [17]:
model = NLIModel(3).to(device)
# model.load_state_dict(torch.load(BASE + "DIIN.pt", map_location=device))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using cache found in /home2/kawshikmanikantan/.cache/torch/hub/pytorch_vision_v0.10.0


In [18]:
# test_loss, test_acc = validate(model, test_dataloader, nn.CrossEntropyLoss())

In [19]:
# print(test_loss, test_acc)

In [20]:
NUM_EPOCHS = 2
LR = 1e-5

In [21]:
save_path = BASE + "DIIN.pt"
model = train(model, train_dataloader, dev_dataloader, NUM_EPOCHS, LR, save_path)

Epoch 1/2:


  0%|          | 0/17168 [00:00<?, ?it/s]

KeyboardInterrupt: 