In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import re
import string
import json
import sys
#import classificationreport
from sklearn.metrics import classification_report, confusion_matrix
from spacy.lang.en import English
eng = English()
tok = eng.tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
#NLI_LSTM with 3 class classification with dense layer
class NLI_LSTM(nn.Module):
    def __init__(self,embedding_matrix, vocab_size, embedding_dim, hidden_dim, output_dim,n_fc_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix))
        self.embedding.weight.requires_grad = False
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,batch_first=True)
        self.fc=nn.ModuleList([nn.Linear(hidden_dim * 2,hidden_dim * 2) for i in range(n_fc_layers)])
        self.linear=nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    
    # forward pass
    # NLI with 2 text inputs
    def forward(self, text1, text2):
        #text = [sent len, batch size]
        embedded1 = self.dropout(self.embedding(text1))
        embedded2 = self.dropout(self.embedding(text2))
        #embedded = [sent len, batch size, emb dim]
        output1, (hidden1, cell1) = self.lstm(embedded1)
        output2, (hidden2, cell2) = self.lstm(embedded2)

        hidden = torch.cat((hidden1,hidden2),dim=2)
        hidden = torch.squeeze(hidden)

        for i in range(len(self.fc)-1):
            hidden = self.fc[i](hidden)
            hidden = F.relu(hidden)
            hidden = self.dropout(hidden)
        
        hidden = self.fc[len(self.fc)-1](hidden)
        hidden = F.log_softmax(hidden,dim=1)
        hidden = self.dropout(hidden)

        hidden = self.linear(hidden)
        #hidden = [batch size, output dim]
        return hidden

In [3]:
#load glove vectors
embeddings_index = {}
f = open('../data/glove.6B.300d.txt')
lines = f.readlines()
for line in tqdm(lines):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

100%|████████████████████████████████| 400000/400000 [00:25<00:00, 15621.28it/s]


In [4]:
# tokenise input and remove punctuation and numbers using only regex
def tokenize(text):
    # text = re.sub(r'[^a-zA-Z ]', '', text)
    # #convert to lowercase
    # text = text.lower()
    # return text.split()
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]') # remove punctuation and numbers
    nopunctext = regex.sub(" ", text.lower())
    return [token.text for token in tok(nopunctext)]

UNK="<UNK>"
PAD="<PAD>"

#import data
def getDataset(dataset_name="mnli"):
    if dataset_name=="mnli":
        filepath_train="../data/multinli_1.0/multinli_1.0/multinli_1.0_train.jsonl"
        filepath_dev="../data/multinli_1.0/multinli_1.0/multinli_1.0_dev_matched.jsonl"
        filepath_test="../data/multinli_1.0/multinli_1.0/multinli_1.0_dev_mismatched.jsonl"
    elif dataset_name=="snli":
        filepath_train="../data/snli_1.0/snli_1.0/snli_1.0_train.jsonl"
        filepath_dev="../data/snli_1.0/snli_1.0/snli_1.0_dev.jsonl"
        filepath_test="../data/snli_1.0/snli_1.0/snli_1.0_test.jsonl"
    else:
        print("Invalid dataset name")
        return None
    
    #read train,dev and test data
    labels = ["contradiction", "entailment", "neutral"]
    f= open(filepath_train, "r")
    data = list(f)
    train_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("train data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        train_dataset["premise"].append(line['sentence1'])
        train_dataset["hypothesis"].append(line['sentence2'])
        train_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_dev, "r")
    data = list(f)
    dev_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("dev data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        dev_dataset["premise"].append(line['sentence1'])
        dev_dataset["hypothesis"].append(line['sentence2'])
        dev_dataset["label"].append(line['gold_label'])
    f.close()

    f= open(filepath_test, "r")
    data = list(f)
    test_dataset={"premise":[],"hypothesis":[],"label":[]}
    print("test data")
    for line in tqdm(data):
        line = json.loads(line)
        if line['gold_label'] not in labels:
            # print(line['gold_label'])
            continue
        test_dataset["premise"].append(line['sentence1'])
        test_dataset["hypothesis"].append(line['sentence2'])
        test_dataset["label"].append(line['gold_label'])
    f.close()
    
    return train_dataset,dev_dataset,test_dataset

def getWord2index(dataset):
    word2index = {"":0,UNK:1,PAD:2}
    for sentence in dataset["premise"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    for sentence in dataset["hypothesis"]:
        for word in tokenize(sentence):
            if word not in word2index:
                word2index[word] = len(word2index)
    return word2index

def getEmbeddingMatrix(word2index,emb_size=300):
    embedding_matrix = np.zeros((len(word2index),emb_size),dtype=np.float32)
    for word, i in word2index.items():
        if i==0:
            embedding_matrix[i] = np.zeros(emb_size)
        elif word in embeddings_index:
            embedding_matrix[i] = embeddings_index[word]
        else:
            embedding_matrix[i] = np.random.uniform(-0.25,0.25,emb_size)
    return embedding_matrix

def getLabel2index(dataset):
    label2index = {"entailment":0,"neutral":1,"contradiction":2}
    return label2index

def getSentence2vector(sentence,word2index,padLength=32):
    sentence = tokenize(sentence)
    vector = []
    for word in sentence:
        if word in word2index:
            vector.append(word2index[word])
        else:
            vector.append(word2index[UNK])
    
    if len(vector)>padLength:
        vector=vector[:padLength]
    else:
        for i in range(padLength-len(vector)):
            vector.append(word2index[PAD])
    
    if(len(vector)!=padLength):
        print("Error in vector length")
    return np.array(vector)   

In [5]:
def preprocess(dataset,word2index,label2index):
    dataset["premise"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["premise"])]
    dataset["hypothesis"] = [getSentence2vector(sentence,word2index) for sentence in tqdm(dataset["hypothesis"])]
    dataset["label"] = [label2index[label] for label in dataset["label"]]
    return dataset

def getDataloader(dataset,batch_size=32):
    # print(len(dataset["premise"]))
    premise = torch.tensor(dataset["premise"],dtype=torch.long)
    hypothesis = torch.tensor(dataset["hypothesis"],dtype=torch.long)
    labels = torch.tensor(dataset["label"],dtype=torch.long)
    dataset = torch.utils.data.TensorDataset(premise,hypothesis,labels)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def prepData(dataset_name="mnli"):
    train_dataset,dev_dataset,test_dataset = getDataset(dataset_name)
    word2index = getWord2index(train_dataset)
    label2index = getLabel2index(train_dataset)
    embedding_matrix = getEmbeddingMatrix(word2index)
    #preprocess datasets
    print("preprocess")
    train_dataset = preprocess(train_dataset,word2index,label2index)
    dev_dataset = preprocess(dev_dataset,word2index,label2index)
    test_dataset = preprocess(test_dataset,word2index,label2index)
    #dataloader
    train_dataloader = getDataloader(train_dataset)
    dev_dataloader = getDataloader(dev_dataset)
    test_dataloader = getDataloader(test_dataset)

    return train_dataloader,dev_dataloader,test_dataloader,embedding_matrix,word2index


In [6]:
def train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,epochs=5):
    
    
    f=open(f"../reports_and_results/report_lstm_nli_{datasetName}_foreachepoch.txt",'w')
    
    total_loss=[]
    total_acc=[]
    for epoch in range(1,1+epochs):
        ep_loss=0
        ep_acc=0
        train_total=0
        val_total=0
        test_total=0
        print("\n\nEpoch: ",epoch)
        #training
        print("Training")
        for batch in tqdm(train_dataloader):
            prem, hyp, label = batch
            # print(len(prem),len(hyp),len(label))
            optimizer.zero_grad()
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            # print(len(output))
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            ep_loss+=loss.item()
            total_loss.append(loss.item())
            acc = (output.argmax(1) == label).sum().item()
            train_total+=label.size(0)
            ep_acc+=acc
        total_acc.append((ep_acc/train_total))
        
        print("Validation")
        #validation
        val_y_true = []
        val_y_pred = []
        with torch.no_grad():
            ep_val_loss=0
            ep_val_acc=0
            for batch in tqdm(dev_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                loss = criterion(output, label)
                val_y_true.extend(label.cpu())
                val_y_pred.extend(output.cpu().argmax(1))
                ep_val_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                val_total+=label.size(0)
                ep_val_acc+=acc
        
        print("Test")
        y_true = []
        y_pred = []
        with torch.no_grad():
            ep_test_loss=0
            ep_test_acc=0
            for batch in tqdm(test_dataloader):
                prem, hyp, label = batch
                prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
                output = model(prem, hyp)
                y_true.extend(label.cpu())
                y_pred.extend(output.cpu().argmax(1))
                loss = criterion(output, label)
                ep_test_loss+=loss.item()
                acc = (output.argmax(1) == label).sum().item()
                test_total+=label.size(0)
                ep_test_acc+=acc
                
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)

        print()
        print("Validation Classification report")

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]))
        print("Test Classification report")
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
                
        orginal_stdout=sys.stdout
        sys.stdout=f 
        print("\n\nEpoch: ",epoch)

        
        print("Epoch: ",epoch," Train Loss: ",ep_loss/len(train_dataloader),"Train Accuracy: ",ep_acc/train_total)

        print("Epoch: ",epoch," Val Loss: ",ep_val_loss/len(dev_dataloader)," Val Accuracy: ",ep_val_acc/val_total)

        print("Epoch: ",epoch,"Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/test_total)

        print()
        print("Validation Classification report")

        print(classification_report(val_y_true, val_y_pred, target_names=["entailment","neutral","contradiction"]))
        print("Test Classification report")
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
        #save model
        torch.save(model.state_dict(), f"../models/model_lstm_nli_{datasetName}_ep_{epoch}.pt")
        sys.stdout=orginal_stdout
        
    sys.stdout=f
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc))
              
    sys.stdout=orginal_stdout
    print("Total Train Loss",sum(total_loss)/len(total_loss)," Total Train Accuracy: ",sum(total_acc)/len(total_acc))
    f.close()
    return model
    

In [7]:
def test(model,test_dataloader,criterion):
#test
    y_true = []
    y_pred = []
    with torch.no_grad():
        ep_test_loss=0
        ep_test_acc=0
        for batch in tqdm(test_dataloader):
            prem, hyp, label = batch
            prem, hyp, label= prem.to(device), hyp.to(device), label.to(device)
            output = model(prem, hyp)
            y_true.extend(label.cpu())
            y_pred.extend(output.cpu().argmax(1))
            loss = criterion(output, label)
            ep_test_loss+=loss.item()
            acc = (output.argmax(1) == label).sum().item()
            ep_test_acc+=acc
        print("Test Loss: ",ep_test_loss/len(test_dataloader)," Test Accuracy: ",ep_test_acc/len(test_dataloader))
        print(classification_report(y_true, y_pred, target_names=["entailment","neutral","contradiction"]))
        print(confusion_matrix(y_true, y_pred))
        


In [8]:
EMBEDDING_DIM = 300
HIDDEN_DIM = 100
OUTPUT_DIM = 3
N_FC_LAYERS = 2
DROPOUT = 0.3
EPOCHS = 30
lr=0.001

In [9]:
datasetName="mnli"

mtrain_dataloader,mdev_dataloader,mtest_dataloader,membedding_matrix,mword2index = prepData(datasetName)

INPUT_DIM = len(mword2index)

train data


100%|███████████████████████████████| 392702/392702 [00:02<00:00, 140816.75it/s]


dev data


100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 144518.55it/s]


test data


100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 124308.09it/s]


preprocess


100%|████████████████████████████████| 392702/392702 [00:32<00:00, 11935.87it/s]
100%|████████████████████████████████| 392702/392702 [00:20<00:00, 19210.79it/s]
100%|████████████████████████████████████| 9815/9815 [00:00<00:00, 11654.31it/s]
100%|████████████████████████████████████| 9815/9815 [00:00<00:00, 22245.40it/s]
100%|████████████████████████████████████| 9832/9832 [00:00<00:00, 11782.76it/s]
100%|████████████████████████████████████| 9832/9832 [00:00<00:00, 20228.30it/s]
  premise = torch.tensor(dataset["premise"],dtype=torch.long)


In [10]:


#initialize the model with above parameters
mnli_model = NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
#adam optimizer
mnli_model.to(device)
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
mnli_model = train(mnli_model,mtrain_dataloader,mdev_dataloader,mtest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)




Epoch:  1
Training


100%|████████████████████████████████████| 12272/12272 [01:00<00:00, 204.21it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 449.48it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 359.12it/s]


Epoch:  1  Train Loss:  1.1731529743538143 Train Accuracy:  0.33368814011642417
Epoch:  1  Val Loss:  1.1409540351128344  Val Accuracy:  0.3289862455425369
Epoch:  1 Test Loss:  1.1437471690890078  Test Accuracy:  0.32577298616761596

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.36      0.23      0.28      3479
      neutral       0.32      0.74      0.45      3123
contradiction       0.31      0.04      0.07      3213

     accuracy                           0.33      9815
    macro avg       0.33      0.34      0.27      9815
 weighted avg       0.33      0.33      0.26      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.35      0.22      0.27      3463
      neutral       0.32      0.74      0.45      3129
contradiction       0.30      0.04      0.07      3240

     accuracy                           0.33      9832
    macro avg       0.33      0.33      

100%|████████████████████████████████████| 12272/12272 [00:55<00:00, 219.76it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 489.49it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 426.01it/s]


Epoch:  2  Train Loss:  1.1618054716069508 Train Accuracy:  0.3334691445421719
Epoch:  2  Val Loss:  1.1563694543092957  Val Accuracy:  0.32949566989302087
Epoch:  2 Test Loss:  1.1560047466259498  Test Accuracy:  0.3224165988608625

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.36      0.02      0.04      3479
      neutral       0.33      0.43      0.37      3123
contradiction       0.33      0.56      0.42      3213

     accuracy                           0.33      9815
    macro avg       0.34      0.34      0.28      9815
 weighted avg       0.34      0.33      0.27      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.33      0.02      0.04      3463
      neutral       0.32      0.42      0.36      3129
contradiction       0.33      0.55      0.41      3240

     accuracy                           0.32      9832
    macro avg       0.32      0.33      0

100%|████████████████████████████████████| 12272/12272 [01:04<00:00, 189.80it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 391.71it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 455.05it/s]


Epoch:  3  Train Loss:  1.161412314760755 Train Accuracy:  0.33359646755045813
Epoch:  3  Val Loss:  1.1265527712794003  Val Accuracy:  0.3410086602139582
Epoch:  3 Test Loss:  1.120033765768076  Test Accuracy:  0.350793327908869

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.35      0.62      0.45      3479
      neutral       0.31      0.08      0.13      3123
contradiction       0.32      0.29      0.30      3213

     accuracy                           0.34      9815
    macro avg       0.33      0.33      0.29      9815
 weighted avg       0.33      0.34      0.30      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.36      0.63      0.46      3463
      neutral       0.32      0.08      0.13      3129
contradiction       0.34      0.30      0.32      3240

     accuracy                           0.35      9832
    macro avg       0.34      0.34      0.30

100%|████████████████████████████████████| 12272/12272 [01:07<00:00, 181.69it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 283.92it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 406.48it/s]


Epoch:  4  Train Loss:  1.138098651443692 Train Accuracy:  0.37396295409750907
Epoch:  4  Val Loss:  1.0596984912207539  Val Accuracy:  0.467142129393785
Epoch:  4 Test Loss:  1.0488946437835693  Test Accuracy:  0.4750813669650122

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.42      0.72      0.53      3479
      neutral       0.53      0.17      0.25      3123
contradiction       0.53      0.49      0.51      3213

     accuracy                           0.47      9815
    macro avg       0.49      0.46      0.43      9815
 weighted avg       0.49      0.47      0.44      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.43      0.71      0.53      3463
      neutral       0.50      0.17      0.26      3129
contradiction       0.56      0.51      0.53      3240

     accuracy                           0.48      9832
    macro avg       0.50      0.47      0.4

100%|████████████████████████████████████| 12272/12272 [01:14<00:00, 163.93it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 377.59it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 398.49it/s]


Epoch:  5  Train Loss:  1.0419270360728208 Train Accuracy:  0.47545721692275567
Epoch:  5  Val Loss:  1.012665235064317  Val Accuracy:  0.5006622516556292
Epoch:  5 Test Loss:  1.017155322935674  Test Accuracy:  0.4909479251423922

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.31      0.40      3479
      neutral       0.45      0.54      0.49      3123
contradiction       0.53      0.67      0.59      3213

     accuracy                           0.50      9815
    macro avg       0.51      0.51      0.49      9815
 weighted avg       0.51      0.50      0.49      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.28      0.37      3463
      neutral       0.42      0.55      0.48      3129
contradiction       0.54      0.66      0.59      3240

     accuracy                           0.49      9832
    macro avg       0.50      0.50      0.4

100%|████████████████████████████████████| 12272/12272 [01:03<00:00, 191.76it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 415.11it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 454.85it/s]


Epoch:  6  Train Loss:  1.0044734288963844 Train Accuracy:  0.5050292588273042
Epoch:  6  Val Loss:  0.9609973267545917  Val Accuracy:  0.5524197656647988
Epoch:  6 Test Loss:  0.9574659060348164  Test Accuracy:  0.5521765663140765

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.50      0.75      0.60      3479
      neutral       0.54      0.36      0.43      3123
contradiction       0.68      0.52      0.59      3213

     accuracy                           0.55      9815
    macro avg       0.57      0.55      0.54      9815
 weighted avg       0.57      0.55      0.54      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.75      0.60      3463
      neutral       0.52      0.37      0.43      3129
contradiction       0.72      0.51      0.60      3240

     accuracy                           0.55      9832
    macro avg       0.58      0.55      0.

100%|████████████████████████████████████| 12272/12272 [01:08<00:00, 179.72it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 353.94it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 363.61it/s]


Epoch:  7  Train Loss:  0.9824896822386024 Train Accuracy:  0.5222305972467672
Epoch:  7  Val Loss:  0.9368037018403168  Val Accuracy:  0.5602649006622517
Epoch:  7 Test Loss:  0.9341611931850384  Test Accuracy:  0.5592961757526445

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.51      0.72      0.60      3479
      neutral       0.56      0.35      0.43      3123
contradiction       0.65      0.59      0.62      3213

     accuracy                           0.56      9815
    macro avg       0.57      0.55      0.55      9815
 weighted avg       0.57      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.50      0.72      0.59      3463
      neutral       0.53      0.36      0.43      3129
contradiction       0.68      0.58      0.63      3240

     accuracy                           0.56      9832
    macro avg       0.57      0.55      0.

100%|████████████████████████████████████| 12272/12272 [01:07<00:00, 180.54it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 477.13it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 442.67it/s]


Epoch:  8  Train Loss:  0.9682922151323403 Train Accuracy:  0.5342015064858341
Epoch:  8  Val Loss:  0.9519692882652779  Val Accuracy:  0.5472236372898625
Epoch:  8 Test Loss:  0.947904615046142  Test Accuracy:  0.5452603742880391

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.56      0.55      3479
      neutral       0.46      0.59      0.51      3123
contradiction       0.74      0.49      0.59      3213

     accuracy                           0.55      9815
    macro avg       0.58      0.55      0.55      9815
 weighted avg       0.58      0.55      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.54      0.54      3463
      neutral       0.44      0.60      0.51      3129
contradiction       0.77      0.50      0.60      3240

     accuracy                           0.55      9832
    macro avg       0.58      0.55      0.5

100%|████████████████████████████████████| 12272/12272 [01:00<00:00, 203.45it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 330.05it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 346.79it/s]


Epoch:  9  Train Loss:  0.9589050698787401 Train Accuracy:  0.543396774144262
Epoch:  9  Val Loss:  0.9464897761904068  Val Accuracy:  0.5463066734589913
Epoch:  9 Test Loss:  0.9360492823572902  Test Accuracy:  0.5471928397070789

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.45      0.50      3479
      neutral       0.48      0.55      0.51      3123
contradiction       0.60      0.64      0.62      3213

     accuracy                           0.55      9815
    macro avg       0.55      0.55      0.55      9815
 weighted avg       0.55      0.55      0.54      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.45      0.50      3463
      neutral       0.47      0.57      0.51      3129
contradiction       0.64      0.63      0.63      3240

     accuracy                           0.55      9832
    macro avg       0.55      0.55      0.5

100%|████████████████████████████████████| 12272/12272 [01:00<00:00, 202.35it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 474.71it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 481.25it/s]


Epoch:  10  Train Loss:  0.9473146311550206 Train Accuracy:  0.5505446878294483
Epoch:  10  Val Loss:  0.9634040180557325  Val Accuracy:  0.5416199694345389
Epoch:  10 Test Loss:  0.964583180748023  Test Accuracy:  0.5378356387306753

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.57      0.50      0.53      3479
      neutral       0.44      0.69      0.54      3123
contradiction       0.78      0.44      0.56      3213

     accuracy                           0.54      9815
    macro avg       0.59      0.54      0.54      9815
 weighted avg       0.59      0.54      0.54      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.47      0.51      3463
      neutral       0.43      0.70      0.53      3129
contradiction       0.80      0.45      0.58      3240

     accuracy                           0.54      9832
    macro avg       0.60      0.54      

100%|████████████████████████████████████| 12272/12272 [01:02<00:00, 195.60it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 233.61it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:02<00:00, 137.25it/s]


Epoch:  11  Train Loss:  0.9410930720037711 Train Accuracy:  0.5551817917912311
Epoch:  11  Val Loss:  0.9164088216977322  Val Accuracy:  0.5778909831889965
Epoch:  11 Test Loss:  0.9077475403810477  Test Accuracy:  0.5781122864117169

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.65      0.59      3479
      neutral       0.55      0.45      0.49      3123
contradiction       0.65      0.63      0.64      3213

     accuracy                           0.58      9815
    macro avg       0.58      0.57      0.57      9815
 weighted avg       0.58      0.58      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.65      0.59      3463
      neutral       0.52      0.45      0.48      3129
contradiction       0.68      0.63      0.65      3240

     accuracy                           0.58      9832
    macro avg       0.58      0.57     

100%|████████████████████████████████████| 12272/12272 [01:27<00:00, 139.85it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 386.37it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 333.66it/s]


Epoch:  12  Train Loss:  0.9351510767287039 Train Accuracy:  0.5600047873451116
Epoch:  12  Val Loss:  1.1805706020286884  Val Accuracy:  0.4558329088130413
Epoch:  12 Test Loss:  1.1732664988799528  Test Accuracy:  0.45850284784377543

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.05      0.09      3479
      neutral       0.46      0.50      0.48      3123
contradiction       0.45      0.86      0.59      3213

     accuracy                           0.46      9815
    macro avg       0.54      0.47      0.38      9815
 weighted avg       0.55      0.46      0.37      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.73      0.04      0.07      3463
      neutral       0.45      0.53      0.49      3129
contradiction       0.46      0.84      0.59      3240

     accuracy                           0.46      9832
    macro avg       0.54      0.47    

100%|████████████████████████████████████| 12272/12272 [01:53<00:00, 108.14it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 479.88it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 489.31it/s]


Epoch:  13  Train Loss:  0.9302830954569221 Train Accuracy:  0.5636054820194447
Epoch:  13  Val Loss:  1.1294607953062274  Val Accuracy:  0.4684666327050433
Epoch:  13 Test Loss:  1.13567870681162  Test Accuracy:  0.47080960130187144

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.69      0.03      0.06      3479
      neutral       0.38      0.85      0.52      3123
contradiction       0.68      0.58      0.63      3213

     accuracy                           0.47      9815
    macro avg       0.58      0.48      0.40      9815
 weighted avg       0.59      0.47      0.39      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.63      0.02      0.04      3463
      neutral       0.38      0.87      0.53      3129
contradiction       0.72      0.57      0.63      3240

     accuracy                           0.47      9832
    macro avg       0.58      0.49      

100%|████████████████████████████████████| 12272/12272 [01:09<00:00, 177.73it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 390.40it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 411.69it/s]


Epoch:  14  Train Loss:  0.9272710635791857 Train Accuracy:  0.565538245285229
Epoch:  14  Val Loss:  0.9084759902099833  Val Accuracy:  0.5865511971472236
Epoch:  14 Test Loss:  0.9007213171813395  Test Accuracy:  0.5812652563059398

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.67      0.61      3479
      neutral       0.54      0.49      0.52      3123
contradiction       0.69      0.59      0.63      3213

     accuracy                           0.59      9815
    macro avg       0.59      0.58      0.58      9815
 weighted avg       0.59      0.59      0.59      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.68      0.60      3463
      neutral       0.52      0.48      0.50      3129
contradiction       0.71      0.57      0.64      3240

     accuracy                           0.58      9832
    macro avg       0.59      0.58      

100%|████████████████████████████████████| 12272/12272 [01:15<00:00, 163.35it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 380.45it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 414.21it/s]


Epoch:  15  Train Loss:  0.9238507480930046 Train Accuracy:  0.568089798371284
Epoch:  15  Val Loss:  0.9463070549483408  Val Accuracy:  0.5667855323484463
Epoch:  15 Test Loss:  0.9423055515273825  Test Accuracy:  0.5685516680227828

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.86      0.62      3479
      neutral       0.60      0.31      0.41      3123
contradiction       0.76      0.50      0.60      3213

     accuracy                           0.57      9815
    macro avg       0.62      0.56      0.55      9815
 weighted avg       0.61      0.57      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.49      0.86      0.62      3463
      neutral       0.58      0.31      0.40      3129
contradiction       0.80      0.50      0.62      3240

     accuracy                           0.57      9832
    macro avg       0.62      0.56      

100%|████████████████████████████████████| 12272/12272 [01:30<00:00, 136.01it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 363.65it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 309.81it/s]


Epoch:  16  Train Loss:  0.9179193357338418 Train Accuracy:  0.5725766611832891
Epoch:  16  Val Loss:  0.9305346608161926  Val Accuracy:  0.5640346408558329
Epoch:  16 Test Loss:  0.9208156679357801  Test Accuracy:  0.5697721724979659

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.67      0.59      3479
      neutral       0.61      0.32      0.42      3123
contradiction       0.59      0.68      0.63      3213

     accuracy                           0.56      9815
    macro avg       0.57      0.56      0.55      9815
 weighted avg       0.57      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.53      0.70      0.60      3463
      neutral       0.59      0.32      0.41      3129
contradiction       0.61      0.67      0.64      3240

     accuracy                           0.57      9832
    macro avg       0.58      0.56     

100%|████████████████████████████████████| 12272/12272 [01:25<00:00, 143.23it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 262.97it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 402.25it/s]


Epoch:  17  Train Loss:  0.9164433032371385 Train Accuracy:  0.573332959852509
Epoch:  17  Val Loss:  0.9278232327113323  Val Accuracy:  0.5772796739684157
Epoch:  17 Test Loss:  0.9262162129600326  Test Accuracy:  0.575366151342555

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.59      0.58      3479
      neutral       0.50      0.60      0.55      3123
contradiction       0.72      0.54      0.62      3213

     accuracy                           0.58      9815
    macro avg       0.59      0.58      0.58      9815
 weighted avg       0.59      0.58      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.58      0.57      3463
      neutral       0.49      0.60      0.54      3129
contradiction       0.76      0.55      0.63      3240

     accuracy                           0.58      9832
    macro avg       0.60      0.58      0

100%|████████████████████████████████████| 12272/12272 [01:44<00:00, 117.43it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 403.54it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 349.57it/s]


Epoch:  18  Train Loss:  0.9133780055486317 Train Accuracy:  0.5746571191386853
Epoch:  18  Val Loss:  0.9312265827912073  Val Accuracy:  0.5769740193581253
Epoch:  18 Test Loss:  0.9372664484884832  Test Accuracy:  0.5727217249796582

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.63      0.59      3479
      neutral       0.49      0.61      0.54      3123
contradiction       0.77      0.49      0.60      3213

     accuracy                           0.58      9815
    macro avg       0.61      0.58      0.58      9815
 weighted avg       0.61      0.58      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.63      0.59      3463
      neutral       0.48      0.61      0.54      3129
contradiction       0.80      0.48      0.60      3240

     accuracy                           0.57      9832
    macro avg       0.61      0.57     

100%|████████████████████████████████████| 12272/12272 [01:53<00:00, 107.88it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 212.31it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 160.88it/s]


Epoch:  19  Train Loss:  0.9094759987101367 Train Accuracy:  0.5779318669118059
Epoch:  19  Val Loss:  0.9360609627313645  Val Accuracy:  0.5581253183902191
Epoch:  19 Test Loss:  0.9276526343899888  Test Accuracy:  0.5585842148087876

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.40      0.49      3479
      neutral       0.47      0.64      0.54      3123
contradiction       0.64      0.64      0.64      3213

     accuracy                           0.56      9815
    macro avg       0.57      0.56      0.56      9815
 weighted avg       0.57      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.41      0.49      3463
      neutral       0.46      0.65      0.54      3129
contradiction       0.67      0.63      0.64      3240

     accuracy                           0.56      9832
    macro avg       0.58      0.56     

100%|████████████████████████████████████| 12272/12272 [01:25<00:00, 144.01it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 333.99it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 340.23it/s]


Epoch:  20  Train Loss:  0.908946837071509 Train Accuracy:  0.5786907120411915
Epoch:  20  Val Loss:  0.9344849468054135  Val Accuracy:  0.5660723382577687
Epoch:  20 Test Loss:  0.9281535674999286  Test Accuracy:  0.5600081366965012

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.45      0.52      3479
      neutral       0.48      0.62      0.54      3123
contradiction       0.64      0.64      0.64      3213

     accuracy                           0.57      9815
    macro avg       0.58      0.57      0.57      9815
 weighted avg       0.58      0.57      0.56      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.43      0.50      3463
      neutral       0.46      0.64      0.54      3129
contradiction       0.67      0.63      0.65      3240

     accuracy                           0.56      9832
    macro avg       0.58      0.56      

100%|████████████████████████████████████| 12272/12272 [01:24<00:00, 145.67it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 296.76it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 331.47it/s]


Epoch:  21  Train Loss:  0.9052156918552234 Train Accuracy:  0.5805929177849871
Epoch:  21  Val Loss:  0.9060817368644069  Val Accuracy:  0.5828833418237391
Epoch:  21 Test Loss:  0.8992260882219711  Test Accuracy:  0.5832994304312449

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.56      0.63      0.59      3479
      neutral       0.51      0.59      0.55      3123
contradiction       0.74      0.52      0.61      3213

     accuracy                           0.58      9815
    macro avg       0.60      0.58      0.58      9815
 weighted avg       0.60      0.58      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.63      0.59      3463
      neutral       0.50      0.59      0.54      3129
contradiction       0.78      0.52      0.62      3240

     accuracy                           0.58      9832
    macro avg       0.61      0.58     

100%|████████████████████████████████████| 12272/12272 [01:15<00:00, 162.97it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 319.61it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 369.80it/s]


Epoch:  22  Train Loss:  0.9052451565944271 Train Accuracy:  0.5812218934459208
Epoch:  22  Val Loss:  0.9377522862695328  Val Accuracy:  0.5558838512480897
Epoch:  22 Test Loss:  0.9342391785089071  Test Accuracy:  0.5581773799837266

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.62      0.44      0.51      3479
      neutral       0.47      0.62      0.54      3123
contradiction       0.61      0.62      0.62      3213

     accuracy                           0.56      9815
    macro avg       0.57      0.56      0.56      9815
 weighted avg       0.57      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.42      0.50      3463
      neutral       0.47      0.65      0.55      3129
contradiction       0.65      0.61      0.63      3240

     accuracy                           0.56      9832
    macro avg       0.57      0.56     

100%|████████████████████████████████████| 12272/12272 [01:12<00:00, 168.32it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 416.84it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 410.00it/s]


Epoch:  23  Train Loss:  0.9007494297616835 Train Accuracy:  0.5824467407856339
Epoch:  23  Val Loss:  0.9022248139793011  Val Accuracy:  0.5900152827305145
Epoch:  23 Test Loss:  0.8868079663484127  Test Accuracy:  0.5939788445890968

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.69      0.61      3479
      neutral       0.59      0.42      0.49      3123
contradiction       0.64      0.64      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.59      0.59      0.58      9815
 weighted avg       0.59      0.59      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.55      0.73      0.63      3463
      neutral       0.58      0.41      0.48      3129
contradiction       0.67      0.63      0.65      3240

     accuracy                           0.59      9832
    macro avg       0.60      0.59     

100%|████████████████████████████████████| 12272/12272 [01:18<00:00, 155.40it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 290.81it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 411.74it/s]


Epoch:  24  Train Loss:  0.8993394491705932 Train Accuracy:  0.5849295394472144
Epoch:  24  Val Loss:  0.9263712208511775  Val Accuracy:  0.5618950585838003
Epoch:  24 Test Loss:  0.9189215641130101  Test Accuracy:  0.5587876322213181

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.53      0.56      3479
      neutral       0.56      0.40      0.47      3123
contradiction       0.54      0.75      0.63      3213

     accuracy                           0.56      9815
    macro avg       0.57      0.56      0.55      9815
 weighted avg       0.57      0.56      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.58      0.54      0.56      3463
      neutral       0.56      0.39      0.46      3129
contradiction       0.54      0.74      0.62      3240

     accuracy                           0.56      9832
    macro avg       0.56      0.56     

100%|████████████████████████████████████| 12272/12272 [01:42<00:00, 120.02it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 319.82it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 197.96it/s]


Epoch:  25  Train Loss:  0.8970612921962099 Train Accuracy:  0.5856578270546113
Epoch:  25  Val Loss:  0.9086342966517719  Val Accuracy:  0.5778909831889965
Epoch:  25 Test Loss:  0.8957645318337849  Test Accuracy:  0.5868592351505288

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.54      0.57      3479
      neutral       0.52      0.56      0.54      3123
contradiction       0.62      0.64      0.63      3213

     accuracy                           0.58      9815
    macro avg       0.58      0.58      0.58      9815
 weighted avg       0.58      0.58      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.56      0.58      3463
      neutral       0.51      0.56      0.54      3129
contradiction       0.66      0.64      0.65      3240

     accuracy                           0.59      9832
    macro avg       0.59      0.59     

100%|████████████████████████████████████| 12272/12272 [01:16<00:00, 161.45it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 464.77it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 449.47it/s]


Epoch:  26  Train Loss:  0.8984993392652607 Train Accuracy:  0.5851663602426267
Epoch:  26  Val Loss:  0.9404449055171556  Val Accuracy:  0.5368313805399898
Epoch:  26 Test Loss:  0.9396243585007531  Test Accuracy:  0.5348860862489829

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.35      0.44      3479
      neutral       0.43      0.73      0.54      3123
contradiction       0.71      0.56      0.62      3213

     accuracy                           0.54      9815
    macro avg       0.58      0.54      0.53      9815
 weighted avg       0.59      0.54      0.53      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.34      0.43      3463
      neutral       0.42      0.74      0.54      3129
contradiction       0.74      0.55      0.63      3240

     accuracy                           0.53      9832
    macro avg       0.59      0.54     

100%|████████████████████████████████████| 12272/12272 [01:03<00:00, 194.59it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 415.73it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 426.97it/s]


Epoch:  27  Train Loss:  0.8964591353422582 Train Accuracy:  0.587043101384765
Epoch:  27  Val Loss:  0.9382103545657975  Val Accuracy:  0.5533367294956699
Epoch:  27 Test Loss:  0.9327427897747461  Test Accuracy:  0.5574654190398698

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.41      0.49      3479
      neutral       0.46      0.66      0.54      3123
contradiction       0.64      0.61      0.62      3213

     accuracy                           0.55      9815
    macro avg       0.57      0.56      0.55      9815
 weighted avg       0.57      0.55      0.55      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.60      0.42      0.49      3463
      neutral       0.46      0.68      0.55      3129
contradiction       0.68      0.59      0.63      3240

     accuracy                           0.56      9832
    macro avg       0.58      0.56      

100%|████████████████████████████████████| 12272/12272 [01:01<00:00, 199.72it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 314.54it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 251.63it/s]


Epoch:  28  Train Loss:  0.8944259446959413 Train Accuracy:  0.5884843978385647
Epoch:  28  Val Loss:  0.9232400471302119  Val Accuracy:  0.5918492103922568
Epoch:  28 Test Loss:  0.9087103193069433  Test Accuracy:  0.5907241659886087

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.76      0.63      3479
      neutral       0.60      0.38      0.46      3123
contradiction       0.68      0.62      0.64      3213

     accuracy                           0.59      9815
    macro avg       0.60      0.59      0.58      9815
 weighted avg       0.60      0.59      0.58      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.54      0.78      0.63      3463
      neutral       0.57      0.37      0.45      3129
contradiction       0.70      0.61      0.65      3240

     accuracy                           0.59      9832
    macro avg       0.60      0.58     

100%|████████████████████████████████████| 12272/12272 [01:03<00:00, 193.70it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 466.13it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 401.47it/s]


Epoch:  29  Train Loss:  0.8926838629452859 Train Accuracy:  0.5893858447372308
Epoch:  29  Val Loss:  1.0519527507527255  Val Accuracy:  0.521141110545084
Epoch:  29 Test Loss:  1.0209278003735975  Test Accuracy:  0.5252237591537836

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.67      0.26      0.37      3479
      neutral       0.50      0.51      0.51      3123
contradiction       0.49      0.81      0.62      3213

     accuracy                           0.52      9815
    macro avg       0.56      0.53      0.50      9815
 weighted avg       0.56      0.52      0.50      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.65      0.26      0.37      3463
      neutral       0.49      0.54      0.51      3129
contradiction       0.51      0.80      0.63      3240

     accuracy                           0.53      9832
    macro avg       0.55      0.53      

100%|████████████████████████████████████| 12272/12272 [01:02<00:00, 195.36it/s]


Validation


100%|████████████████████████████████████████| 307/307 [00:00<00:00, 486.47it/s]


Test


100%|████████████████████████████████████████| 308/308 [00:01<00:00, 200.47it/s]


Epoch:  30  Train Loss:  0.8908687546664686 Train Accuracy:  0.5896201190724774
Epoch:  30  Val Loss:  0.967972221902605  Val Accuracy:  0.5364238410596026
Epoch:  30 Test Loss:  0.9742052398331753  Test Accuracy:  0.5312245728234337

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.61      0.36      0.45      3479
      neutral       0.43      0.79      0.55      3123
contradiction       0.77      0.49      0.60      3213

     accuracy                           0.54      9815
    macro avg       0.60      0.54      0.53      9815
 weighted avg       0.60      0.54      0.53      9815

Test Classification report
               precision    recall  f1-score   support

   entailment       0.59      0.33      0.42      3463
      neutral       0.42      0.79      0.55      3129
contradiction       0.81      0.49      0.61      3240

     accuracy                           0.53      9832
    macro avg       0.61      0.54      

100%|████████████████████████████████████████| 308/308 [00:00<00:00, 339.57it/s]


Test Loss:  0.9769094249644836  Test Accuracy:  16.801948051948052
               precision    recall  f1-score   support

   entailment       0.60      0.34      0.44      3463
      neutral       0.41      0.79      0.54      3129
contradiction       0.79      0.47      0.59      3240

     accuracy                           0.53      9832
    macro avg       0.60      0.53      0.52      9832
 weighted avg       0.60      0.53      0.52      9832

[[1182 2120  161]
 [ 428 2457  244]
 [ 352 1352 1536]]


In [11]:
datasetName="snli"

strain_dataloader,sdev_dataloader,stest_dataloader,sembedding_matrix,sword2index = prepData(datasetName)
#intialize the model
INPUT_DIM = len(sword2index)

train data


100%|████████████████████████████████| 550152/550152 [00:05<00:00, 94881.45it/s]


dev data


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 92246.98it/s]


test data


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 89960.00it/s]


preprocess


100%|████████████████████████████████| 549367/549367 [00:29<00:00, 18442.58it/s]
100%|████████████████████████████████| 549367/549367 [00:22<00:00, 24278.01it/s]
100%|████████████████████████████████████| 9842/9842 [00:00<00:00, 18312.99it/s]
100%|████████████████████████████████████| 9842/9842 [00:00<00:00, 24727.72it/s]
100%|████████████████████████████████████| 9824/9824 [00:00<00:00, 18929.52it/s]
100%|████████████████████████████████████| 9824/9824 [00:00<00:00, 24789.07it/s]


In [None]:
#initialize the model with above parameters
snli_model = NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
#train
snli_model = train(snli_model,strain_dataloader,sdev_dataloader,stest_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)



Epoch:  1
Training


100%|████████████████████████████████████| 17168/17168 [01:21<00:00, 211.88it/s]


Validation


100%|████████████████████████████████████████| 308/308 [00:00<00:00, 343.23it/s]


Test


100%|████████████████████████████████████████| 307/307 [00:01<00:00, 206.90it/s]


Epoch:  1  Train Loss:  1.1717012585813436 Train Accuracy:  0.33333090629761164
Epoch:  1  Val Loss:  1.161587175030213  Val Accuracy:  0.3306238569396464
Epoch:  1 Test Loss:  1.1660207366322073  Test Accuracy:  0.3338762214983713

Validation Classification report
               precision    recall  f1-score   support

   entailment       0.28      0.01      0.01      3329
      neutral       0.33      0.28      0.30      3235
contradiction       0.33      0.71      0.45      3278

     accuracy                           0.33      9842
    macro avg       0.31      0.33      0.26      9842
 weighted avg       0.31      0.33      0.25      9842

Test Classification report
               precision    recall  f1-score   support

   entailment       0.35      0.01      0.02      3368
      neutral       0.34      0.29      0.31      3219
contradiction       0.33      0.72      0.45      3237

     accuracy                           0.33      9824
    macro avg       0.34      0.34      0.

 75%|██████████████████████████▉         | 12824/17168 [01:18<00:26, 161.21it/s]

In [None]:
INPUT_DIM = len(mword2index)
mnli_model=NLI_LSTM(membedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
mnli_model.to(device)
mnli_model.load_state_dict(torch.load("../models/model_lstm_nli_mnli_ep_1.pt"))


#adam optimizer
optimizer = optim.Adam(mnli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(mnli_model,mtest_dataloader,criterion)

In [None]:
INPUT_DIM = len(sword2index)
snli_model=NLI_LSTM(sembedding_matrix,INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_FC_LAYERS, DROPOUT)
snli_model.to(device)
snli_model.load_state_dict(torch.load("../models/model_lstm_nli_snli_ep_10.pt"))


#adam optimizer
optimizer = optim.Adam(snli_model.parameters())
#loss function
criterion = nn.CrossEntropyLoss()
# #train
# model = train(model,train_dataloader,dev_dataloader,test_dataloader,optimizer,criterion,datasetName,EPOCHS)

#test
test(snli_model,stest_dataloader,criterion)