# Import needed libraries

In [1]:
!pip install sacremoses sentencepiece

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1


In [2]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

import os

2024-07-15 02:02:26.071569: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-15 02:02:26.071699: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-15 02:02:26.212374: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Device agnostic code

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set default device
torch.set_default_device(device)
print(f"Default device set to: {device}")

Default device set to: cuda


In [4]:
train_df = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/train.csv")
train_df.head(5)

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1


In [5]:
train_df.drop(columns=["id", "lang_abv", "language"], inplace=True)

In [6]:
train_df.sample(5)

Unnamed: 0,premise,hypothesis,label
6093,Update on the Democratic fund-raising scandal ...,Clinton said the agents had not told him anyth...,0
65,لم تصل روح الليبرالية السائدة في أوروبا إلى إس...,اسبانيا لم تكن ابدا ليبرالية.,2
5230,Your speeches are inflammatory.,Your speeches make people feel a lot of rage.,0
1580,Las metáforas animales originales son práctica...,Las metáforas de animales prácticamente han de...,0
11647,Charles Geveden has introduced legislation tha...,Charles Geveden initiated a law that will esse...,2


In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

In [8]:
class CustomDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: object, max_length: int):
        self.dataset = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        premise = self.dataset["premise"].iloc[idx]
        hypothesis = self.dataset["hypothesis"].iloc[idx]
        
        token_dict = self.tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt")
    
        return {
            "input_ids": token_dict["input_ids"].flatten(),
            "token_type_ids": token_dict["token_type_ids"].flatten(),
            "attention_mask": token_dict["attention_mask"].flatten(),
            "label": torch.tensor(self.dataset["label"].iloc[idx], dtype=torch.long)
        }
    
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = [batch_item["input_ids"] for batch_item in batch]
        batch_type_ids = [batch_item["token_type_ids"] for batch_item in batch]
        batch_attention_masks = [batch_item["attention_mask"] for batch_item in batch]
        
        batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
        batch_type_ids  = pad_sequence(batch_type_ids, batch_first=True, padding_value=0) # will get ignored by the attention mask when going through the model; very important
        batch_attention_masks = pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
        
        return {
            "input_ids": batch_input_ids,
            "token_type_ids": batch_type_ids,
            "attention_mask": batch_attention_masks,
            "labels": torch.stack([batch_item["label"] for batch_item in batch], dim=0)
        }
        

In [9]:
class CustomDatasetSubmission(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: object, max_length: int):
        self.dataset = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        premise = self.dataset["premise"].iloc[idx]
        hypothesis = self.dataset["hypothesis"].iloc[idx]
        
        token_dict = self.tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt")
    
        return {
            "input_ids": token_dict["input_ids"].flatten(),
            "token_type_ids": token_dict["token_type_ids"].flatten(),
            "attention_mask": token_dict["attention_mask"].flatten(),
            "id": self.dataset["id"].iloc[idx]
        }
    
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = [batch_item["input_ids"] for batch_item in batch]
        batch_type_ids = [batch_item["token_type_ids"] for batch_item in batch]
        batch_attention_masks = [batch_item["attention_mask"] for batch_item in batch]
        
        batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
        batch_type_ids  = pad_sequence(batch_type_ids, batch_first=True, padding_value=0) # will get ignored by the attention mask when going through the model; very important
        batch_attention_masks = pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
        
        return {
            "input_ids": batch_input_ids,
            "token_type_ids": batch_type_ids,
            "attention_mask": batch_attention_masks,
            "ids": [batch_item["id"] for batch_item in batch]
        }
 

In [10]:
max_length = max(max([len(premise) for premise in train_df["premise"]]), max([len(hypothesis) for hypothesis in train_df["hypothesis"]]))
train_split, val_split = train_test_split(train_df, test_size=0.25, shuffle=True)
train_dataset = CustomDataset(df=train_split, tokenizer=tokenizer, max_length=max_length)
val_dataset = CustomDataset(df=val_split, tokenizer=tokenizer, max_length=max_length)

In [11]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, generator=torch.Generator(device=device), collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=True, generator=torch.Generator(device=device), collate_fn=val_dataset.collate_fn)

In [12]:
next(iter(val_dataloader))

{'input_ids': tensor([[  101,   150, 32298, 11717, 12134, 12739, 10637,   117, 14266,   182,
          27005, 12134, 15594, 13055, 61601, 10116, 10271, 90804, 10113, 14266,
          10134, 20897, 95340, 62310, 35432,   119,   102,   150, 32298, 11717,
          12134, 17502, 91610, 15694, 12739, 31252, 14266, 15900, 77203, 13369,
          10371, 14266, 10374, 21840, 10113,   119,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0],
         [  101, 46242, 24472,   117, 10355, 11644, 10141, 44891, 10112, 10104,
          10109, 30066,   117,   169, 10211, 73099, 13844, 10333, 10109, 73337,
            117, 10355,   187,   112, 45754,   100,   100, 31013, 31604, 42932,
          11203,   112,   112,   132,   182,   112, 10176,   118, 10794, 10801,
            117, 11641,   136,   102, 13298, 73337, 10211, 73099, 13844, 11245,
          41764,   180,   112, 10231, 10139, 31164, 10152, 10563, 91

In [13]:
bert_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'google-bert/bert-base-multilingual-cased')

Downloading: "https://github.com/huggingface/pytorch-transformers/zipball/main" to /root/.cache/torch/hub/main.zip


config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

In [14]:
class Classifier(nn.Module):
    def __init__(self, n_inputs, hidden_size, n_classes, dropout):
        super(Classifier, self).__init__()
        self.linear1 = nn.Linear(in_features=n_inputs, out_features=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=n_classes)
        self.dropout = nn.Dropout(dropout)

        self.act_fn = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act_fn(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [15]:
class BertClassifier(nn.Module):
    def __init__(self, bert_model, classifier):
        super(BertClassifier, self).__init__()
        self.bert_model = bert_model
        self.classifier = classifier

    def forward(self, x: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
        x = self.bert_model(x, attention_mask=attention_mask, token_type_ids=token_type_ids)
        x = x.last_hidden_state[:, 0, :]
        x = self.classifier(x)
        return x

In [16]:
classifier = Classifier(n_inputs=768, n_classes=3, hidden_size=512, dropout=0.35)
model = BertClassifier(bert_model=bert_model, classifier=classifier)
bert_model.config.hidden_dropout_prob = 0.35
bert_model.config

BertConfig {
  "_name_or_path": "google-bert/bert-base-multilingual-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.35,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 119547
}

In [17]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

In [18]:
def train_model(model, dataloader, val_dataloader, loss_fn, optimizer, scheduler, epochs):
    writer = SummaryWriter()
    writer.add_graph(model, (torch.zeros(size=(32, 32), dtype=torch.long), torch.zeros(size=(32, 32), dtype=torch.long)))
    for epoch in range(epochs):
        model.eval()
        total_inputs = 0
        total_correct = 0

        with torch.inference_mode():
            for batch in tqdm(val_dataloader):
                logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
                percents = torch.softmax(logits, dim=1)
                preds = torch.argmax(percents, dim=1)

                total_correct += (preds==batch["labels"]).sum().item()
                total_inputs += batch["labels"].view(-1).shape[0]

            print(f"{total_correct} out of {total_inputs}")
            print(f"acc of {total_correct/total_inputs*100}%")

            writer.add_scalar("Val acc", total_correct/total_inputs*100, epoch)

        model.train()
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
            loss = loss_fn(logits, batch["labels"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx+1) % 40 == 0:
                print(f"loss for batch {batch_idx+1} --> {loss} at epoch {epoch}")

            writer.add_scalar("Loss", loss, batch_idx)

        scheduler.step()

    writer.close()

In [19]:
scheduler = StepLR(gamma=0.5, step_size=5, optimizer=optimizer)
train_model(model=model, dataloader=train_dataloader, val_dataloader=val_dataloader, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, epochs=10)

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 758/758 [00:13<00:00, 54.26it/s]


1031 out of 3030
acc of 34.026402640264024%


41it [00:05,  8.62it/s]

loss for batch 40 --> 1.0852018594741821 at epoch 0


81it [00:09,  7.91it/s]

loss for batch 80 --> 1.140058994293213 at epoch 0


121it [00:14,  9.20it/s]

loss for batch 120 --> 1.0064085721969604 at epoch 0


161it [00:19,  8.19it/s]

loss for batch 160 --> 0.9553490281105042 at epoch 0


199it [00:23,  7.64it/s]

loss for batch 200 --> 0.8456542491912842 at epoch 0


241it [00:28,  9.15it/s]

loss for batch 240 --> 0.7064106464385986 at epoch 0


281it [00:33,  9.15it/s]

loss for batch 280 --> 1.1862196922302246 at epoch 0


321it [00:38,  8.95it/s]

loss for batch 320 --> 1.1413415670394897 at epoch 0


361it [00:42,  8.86it/s]

loss for batch 360 --> 1.041303038597107 at epoch 0


401it [00:47,  8.06it/s]

loss for batch 400 --> 1.0211780071258545 at epoch 0


440it [00:52,  8.45it/s]

loss for batch 440 --> 0.6678245663642883 at epoch 0


481it [00:57,  8.23it/s]

loss for batch 480 --> 0.625379204750061 at epoch 0


521it [01:01,  8.80it/s]

loss for batch 520 --> 0.75577312707901 at epoch 0


560it [01:06,  8.07it/s]

loss for batch 560 --> 0.8749281167984009 at epoch 0


600it [01:11,  8.29it/s]

loss for batch 600 --> 0.9202721118927002 at epoch 0


641it [01:15,  8.57it/s]

loss for batch 640 --> 1.088850975036621 at epoch 0


681it [01:20,  8.80it/s]

loss for batch 680 --> 0.9698458909988403 at epoch 0


721it [01:25,  8.79it/s]

loss for batch 720 --> 1.237196683883667 at epoch 0


761it [01:29,  8.41it/s]

loss for batch 760 --> 0.5923158526420593 at epoch 0


801it [01:34,  7.48it/s]

loss for batch 800 --> 0.5473297834396362 at epoch 0


841it [01:39,  9.12it/s]

loss for batch 840 --> 0.617317259311676 at epoch 0


881it [01:44,  8.49it/s]

loss for batch 880 --> 0.7769757509231567 at epoch 0


921it [01:48,  8.74it/s]

loss for batch 920 --> 0.8965568542480469 at epoch 0


961it [01:53,  7.23it/s]

loss for batch 960 --> 0.8463897705078125 at epoch 0


1001it [01:58,  9.03it/s]

loss for batch 1000 --> 0.5976963043212891 at epoch 0


1041it [02:03,  8.61it/s]

loss for batch 1040 --> 0.7260236144065857 at epoch 0


1081it [02:08,  8.94it/s]

loss for batch 1080 --> 0.8571076393127441 at epoch 0


1121it [02:13,  8.60it/s]

loss for batch 1120 --> 0.7908651828765869 at epoch 0


1137it [02:15,  8.42it/s]
100%|██████████| 758/758 [00:13<00:00, 55.83it/s]


1942 out of 3030
acc of 64.0924092409241%


41it [00:04,  8.26it/s]

loss for batch 40 --> 0.6981120705604553 at epoch 1


80it [00:09,  9.16it/s]

loss for batch 80 --> 1.0313987731933594 at epoch 1


121it [00:14,  7.93it/s]

loss for batch 120 --> 0.7144114971160889 at epoch 1


161it [00:18,  8.68it/s]

loss for batch 160 --> 0.5338137745857239 at epoch 1


201it [00:23,  8.29it/s]

loss for batch 200 --> 0.6648814082145691 at epoch 1


242it [00:28,  8.68it/s]

loss for batch 240 --> 1.4000393152236938 at epoch 1


280it [00:33,  7.76it/s]

loss for batch 280 --> 0.5132821798324585 at epoch 1


321it [00:37,  9.70it/s]

loss for batch 320 --> 0.7651089429855347 at epoch 1


361it [00:42,  8.54it/s]

loss for batch 360 --> 0.7741470336914062 at epoch 1


401it [00:47,  8.60it/s]

loss for batch 400 --> 0.5477215647697449 at epoch 1


440it [00:52,  7.60it/s]

loss for batch 440 --> 0.7332596778869629 at epoch 1


481it [00:56,  8.80it/s]

loss for batch 480 --> 0.7947837114334106 at epoch 1


521it [01:01,  7.87it/s]

loss for batch 520 --> 0.7220468521118164 at epoch 1


561it [01:06,  8.59it/s]

loss for batch 560 --> 0.6683995723724365 at epoch 1


600it [01:10,  7.53it/s]

loss for batch 600 --> 0.8723700046539307 at epoch 1


641it [01:15,  8.65it/s]

loss for batch 640 --> 0.464885950088501 at epoch 1


680it [01:20,  9.01it/s]

loss for batch 680 --> 0.6945022940635681 at epoch 1


721it [01:25,  8.43it/s]

loss for batch 720 --> 0.47495388984680176 at epoch 1


760it [01:29,  8.80it/s]

loss for batch 760 --> 1.1895228624343872 at epoch 1


801it [01:34,  8.75it/s]

loss for batch 800 --> 0.7292827367782593 at epoch 1


841it [01:39,  8.78it/s]

loss for batch 840 --> 0.7474239468574524 at epoch 1


881it [01:43,  9.28it/s]

loss for batch 880 --> 0.4925704300403595 at epoch 1


921it [01:48,  9.07it/s]

loss for batch 920 --> 0.2892225384712219 at epoch 1


961it [01:53,  8.42it/s]

loss for batch 960 --> 1.2210941314697266 at epoch 1


1001it [01:58,  8.44it/s]

loss for batch 1000 --> 1.0280967950820923 at epoch 1


1041it [02:03,  8.38it/s]

loss for batch 1040 --> 0.7802601456642151 at epoch 1


1080it [02:07,  8.97it/s]

loss for batch 1080 --> 0.9927705526351929 at epoch 1


1121it [02:12,  9.01it/s]

loss for batch 1120 --> 1.271245002746582 at epoch 1


1137it [02:14,  8.48it/s]
100%|██████████| 758/758 [00:13<00:00, 56.10it/s]


1970 out of 3030
acc of 65.01650165016501%


41it [00:04,  8.84it/s]

loss for batch 40 --> 0.5516591668128967 at epoch 2


81it [00:09,  8.81it/s]

loss for batch 80 --> 0.31374168395996094 at epoch 2


121it [00:14,  8.29it/s]

loss for batch 120 --> 0.21485939621925354 at epoch 2


161it [00:18,  7.68it/s]

loss for batch 160 --> 0.7510437369346619 at epoch 2


201it [00:23,  7.56it/s]

loss for batch 200 --> 0.43859824538230896 at epoch 2


241it [00:28,  8.37it/s]

loss for batch 240 --> 0.6505807638168335 at epoch 2


281it [00:33,  7.67it/s]

loss for batch 280 --> 0.21256938576698303 at epoch 2


321it [00:37,  9.37it/s]

loss for batch 320 --> 0.7872347831726074 at epoch 2


361it [00:42,  7.78it/s]

loss for batch 360 --> 0.5156131982803345 at epoch 2


401it [00:47,  8.79it/s]

loss for batch 400 --> 0.3985230028629303 at epoch 2


441it [00:51,  9.48it/s]

loss for batch 440 --> 0.12024863809347153 at epoch 2


481it [00:56,  9.08it/s]

loss for batch 480 --> 0.977763295173645 at epoch 2


521it [01:01,  7.15it/s]

loss for batch 520 --> 0.3218908905982971 at epoch 2


560it [01:05,  7.96it/s]

loss for batch 560 --> 0.5377265214920044 at epoch 2


601it [01:10,  8.51it/s]

loss for batch 600 --> 0.9133287668228149 at epoch 2


640it [01:15,  7.74it/s]

loss for batch 640 --> 0.27033814787864685 at epoch 2


681it [01:20,  7.67it/s]

loss for batch 680 --> 0.31538528203964233 at epoch 2


721it [01:25,  8.50it/s]

loss for batch 720 --> 0.3884432315826416 at epoch 2


761it [01:29,  8.90it/s]

loss for batch 760 --> 0.18240118026733398 at epoch 2


801it [01:34,  8.67it/s]

loss for batch 800 --> 0.46700721979141235 at epoch 2


841it [01:39,  8.26it/s]

loss for batch 840 --> 0.6872376203536987 at epoch 2


881it [01:44,  6.59it/s]

loss for batch 880 --> 0.45076626539230347 at epoch 2


920it [01:48,  6.83it/s]

loss for batch 920 --> 0.5421676635742188 at epoch 2


961it [01:53,  8.79it/s]

loss for batch 960 --> 0.5914221405982971 at epoch 2


1001it [01:58,  7.59it/s]

loss for batch 1000 --> 0.3687291443347931 at epoch 2


1041it [02:03,  9.21it/s]

loss for batch 1040 --> 0.5521294474601746 at epoch 2


1081it [02:07,  8.60it/s]

loss for batch 1080 --> 0.4212508499622345 at epoch 2


1121it [02:12,  8.69it/s]

loss for batch 1120 --> 0.2616634964942932 at epoch 2


1137it [02:14,  8.47it/s]
100%|██████████| 758/758 [00:13<00:00, 55.49it/s]


1973 out of 3030
acc of 65.11551155115511%


41it [00:04,  8.49it/s]

loss for batch 40 --> 0.3779199719429016 at epoch 3


81it [00:09,  8.62it/s]

loss for batch 80 --> 0.6580794453620911 at epoch 3


121it [00:14,  8.71it/s]

loss for batch 120 --> 0.7800438404083252 at epoch 3


161it [00:18,  7.82it/s]

loss for batch 160 --> 0.367540568113327 at epoch 3


201it [00:23,  8.90it/s]

loss for batch 200 --> 0.22342181205749512 at epoch 3


241it [00:28,  6.75it/s]

loss for batch 240 --> 0.39982497692108154 at epoch 3


281it [00:33,  8.47it/s]

loss for batch 280 --> 0.3368050158023834 at epoch 3


321it [00:37,  9.08it/s]

loss for batch 320 --> 0.737054169178009 at epoch 3


360it [00:42,  7.75it/s]

loss for batch 360 --> 0.34776270389556885 at epoch 3


401it [00:47,  8.86it/s]

loss for batch 400 --> 0.22926513850688934 at epoch 3


441it [00:52,  7.13it/s]

loss for batch 440 --> 0.09528668969869614 at epoch 3


481it [00:57,  7.68it/s]

loss for batch 480 --> 0.17850491404533386 at epoch 3


521it [01:01,  9.85it/s]

loss for batch 520 --> 0.08626687526702881 at epoch 3


561it [01:06,  8.32it/s]

loss for batch 560 --> 0.34830737113952637 at epoch 3


601it [01:11,  8.43it/s]

loss for batch 600 --> 0.5007573962211609 at epoch 3


641it [01:15,  8.16it/s]

loss for batch 640 --> 0.08983255922794342 at epoch 3


681it [01:20,  8.17it/s]

loss for batch 680 --> 0.575508713722229 at epoch 3


721it [01:25,  8.66it/s]

loss for batch 720 --> 0.1250794231891632 at epoch 3


761it [01:30,  8.43it/s]

loss for batch 760 --> 0.3936617970466614 at epoch 3


801it [01:34,  8.60it/s]

loss for batch 800 --> 0.0746479406952858 at epoch 3


841it [01:39,  9.07it/s]

loss for batch 840 --> 0.2205515056848526 at epoch 3


881it [01:44,  8.59it/s]

loss for batch 880 --> 0.16443151235580444 at epoch 3


921it [01:49,  7.65it/s]

loss for batch 920 --> 0.20824767649173737 at epoch 3


961it [01:53,  8.31it/s]

loss for batch 960 --> 0.1855769157409668 at epoch 3


1000it [01:58,  8.05it/s]

loss for batch 1000 --> 0.7776181101799011 at epoch 3


1041it [02:03,  8.30it/s]

loss for batch 1040 --> 0.4239230751991272 at epoch 3


1081it [02:08,  8.50it/s]

loss for batch 1080 --> 0.31893035769462585 at epoch 3


1121it [02:12,  8.88it/s]

loss for batch 1120 --> 0.2509044408798218 at epoch 3


1137it [02:14,  8.44it/s]
100%|██████████| 758/758 [00:13<00:00, 54.63it/s]


1954 out of 3030
acc of 64.48844884488449%


41it [00:04,  8.37it/s]

loss for batch 40 --> 0.511584997177124 at epoch 4


81it [00:09,  8.04it/s]

loss for batch 80 --> 0.01872262731194496 at epoch 4


120it [00:14,  8.39it/s]

loss for batch 120 --> 0.036947015672922134 at epoch 4


161it [00:19,  8.71it/s]

loss for batch 160 --> 0.12861450016498566 at epoch 4


201it [00:23,  9.58it/s]

loss for batch 200 --> 0.027154816314578056 at epoch 4


241it [00:28,  8.83it/s]

loss for batch 240 --> 0.07438715547323227 at epoch 4


281it [00:32,  8.64it/s]

loss for batch 280 --> 0.05783693864941597 at epoch 4


321it [00:37,  8.47it/s]

loss for batch 320 --> 0.15422914922237396 at epoch 4


361it [00:42,  9.21it/s]

loss for batch 360 --> 0.019706938415765762 at epoch 4


401it [00:47,  8.94it/s]

loss for batch 400 --> 0.5681535601615906 at epoch 4


441it [00:52,  8.68it/s]

loss for batch 440 --> 0.26317650079727173 at epoch 4


481it [00:57,  8.22it/s]

loss for batch 480 --> 0.07183010131120682 at epoch 4


521it [01:01,  8.66it/s]

loss for batch 520 --> 0.015473832376301289 at epoch 4


561it [01:06,  8.98it/s]

loss for batch 560 --> 0.16929905116558075 at epoch 4


601it [01:11,  8.49it/s]

loss for batch 600 --> 0.38713568449020386 at epoch 4


641it [01:15,  8.87it/s]

loss for batch 640 --> 0.03442380577325821 at epoch 4


681it [01:20,  8.68it/s]

loss for batch 680 --> 0.01507704146206379 at epoch 4


721it [01:25,  8.81it/s]

loss for batch 720 --> 0.18791165947914124 at epoch 4


761it [01:30,  8.29it/s]

loss for batch 760 --> 0.03391134738922119 at epoch 4


801it [01:34,  8.20it/s]

loss for batch 800 --> 0.03256285935640335 at epoch 4


841it [01:39,  8.40it/s]

loss for batch 840 --> 0.09766845405101776 at epoch 4


881it [01:44,  8.59it/s]

loss for batch 880 --> 0.3219436705112457 at epoch 4


921it [01:48,  8.48it/s]

loss for batch 920 --> 0.10358920693397522 at epoch 4


961it [01:53,  7.89it/s]

loss for batch 960 --> 0.07800842821598053 at epoch 4


1001it [01:58,  9.06it/s]

loss for batch 1000 --> 0.52093106508255 at epoch 4


1041it [02:03,  8.67it/s]

loss for batch 1040 --> 0.5859455466270447 at epoch 4


1081it [02:08,  8.33it/s]

loss for batch 1080 --> 0.03598906844854355 at epoch 4


1121it [02:12,  8.58it/s]

loss for batch 1120 --> 0.06703060865402222 at epoch 4


1137it [02:14,  8.46it/s]
100%|██████████| 758/758 [00:13<00:00, 56.47it/s]


1919 out of 3030
acc of 63.33333333333333%


41it [00:04,  8.30it/s]

loss for batch 40 --> 0.4740826487541199 at epoch 5


81it [00:09,  7.94it/s]

loss for batch 80 --> 0.9681631326675415 at epoch 5


121it [00:14,  8.08it/s]

loss for batch 120 --> 0.20883436501026154 at epoch 5


161it [00:19,  8.38it/s]

loss for batch 160 --> 0.03178924322128296 at epoch 5


201it [00:23,  8.42it/s]

loss for batch 200 --> 0.04482564702630043 at epoch 5


241it [00:28,  8.69it/s]

loss for batch 240 --> 0.01462253276258707 at epoch 5


282it [00:33,  9.41it/s]

loss for batch 280 --> 0.013735746964812279 at epoch 5


321it [00:38,  8.84it/s]

loss for batch 320 --> 0.10529294610023499 at epoch 5


359it [00:42,  7.95it/s]

loss for batch 360 --> 0.03329514339566231 at epoch 5


401it [00:47,  8.18it/s]

loss for batch 400 --> 0.029303189367055893 at epoch 5


441it [00:52,  8.24it/s]

loss for batch 440 --> 0.03059834986925125 at epoch 5


481it [00:57,  8.79it/s]

loss for batch 480 --> 0.02037048153579235 at epoch 5


521it [01:01,  8.75it/s]

loss for batch 520 --> 0.0092175817117095 at epoch 5


561it [01:06,  8.44it/s]

loss for batch 560 --> 0.050218019634485245 at epoch 5


601it [01:11,  8.31it/s]

loss for batch 600 --> 0.22585147619247437 at epoch 5


641it [01:16,  7.85it/s]

loss for batch 640 --> 0.1286240965127945 at epoch 5


681it [01:21,  8.59it/s]

loss for batch 680 --> 0.28516191244125366 at epoch 5


721it [01:25,  8.19it/s]

loss for batch 720 --> 0.023365460336208344 at epoch 5


761it [01:30,  8.16it/s]

loss for batch 760 --> 0.11000115424394608 at epoch 5


802it [01:35,  9.56it/s]

loss for batch 800 --> 0.2407141923904419 at epoch 5


841it [01:39,  8.51it/s]

loss for batch 840 --> 0.028486866503953934 at epoch 5


881it [01:44,  8.25it/s]

loss for batch 880 --> 0.03260236233472824 at epoch 5


921it [01:49,  8.50it/s]

loss for batch 920 --> 0.011178593151271343 at epoch 5


961it [01:54,  8.61it/s]

loss for batch 960 --> 0.00787043385207653 at epoch 5


1001it [01:58,  9.68it/s]

loss for batch 1000 --> 0.016301214694976807 at epoch 5


1041it [02:03,  8.32it/s]

loss for batch 1040 --> 0.006501579191535711 at epoch 5


1081it [02:07,  8.59it/s]

loss for batch 1080 --> 0.011546158231794834 at epoch 5


1121it [02:12,  8.17it/s]

loss for batch 1120 --> 0.06549997627735138 at epoch 5


1137it [02:14,  8.47it/s]
100%|██████████| 758/758 [00:13<00:00, 56.49it/s]


1957 out of 3030
acc of 64.58745874587459%


41it [00:04,  8.47it/s]

loss for batch 40 --> 0.007413038983941078 at epoch 6


81it [00:09,  8.24it/s]

loss for batch 80 --> 0.009117111563682556 at epoch 6


121it [00:14,  7.94it/s]

loss for batch 120 --> 0.7384960055351257 at epoch 6


160it [00:18,  8.76it/s]

loss for batch 160 --> 0.02422722615301609 at epoch 6


201it [00:23,  9.12it/s]

loss for batch 200 --> 0.021398819983005524 at epoch 6


241it [00:28,  8.25it/s]

loss for batch 240 --> 0.01967243105173111 at epoch 6


281it [00:33,  8.14it/s]

loss for batch 280 --> 0.021422483026981354 at epoch 6


321it [00:38,  8.17it/s]

loss for batch 320 --> 0.06248168647289276 at epoch 6


361it [00:42,  8.51it/s]

loss for batch 360 --> 0.014074433594942093 at epoch 6


401it [00:47,  7.82it/s]

loss for batch 400 --> 0.015358042903244495 at epoch 6


441it [00:52,  8.35it/s]

loss for batch 440 --> 0.011061074212193489 at epoch 6


481it [00:56,  8.83it/s]

loss for batch 480 --> 0.0066685788333415985 at epoch 6


521it [01:01,  8.53it/s]

loss for batch 520 --> 0.014134981669485569 at epoch 6


560it [01:06,  8.75it/s]

loss for batch 560 --> 0.008707632310688496 at epoch 6


601it [01:10,  8.51it/s]

loss for batch 600 --> 0.004834149964153767 at epoch 6


641it [01:15,  8.41it/s]

loss for batch 640 --> 0.007579980418086052 at epoch 6


680it [01:20,  8.73it/s]

loss for batch 680 --> 0.005163509864360094 at epoch 6


721it [01:25,  7.96it/s]

loss for batch 720 --> 0.010053543373942375 at epoch 6


761it [01:30,  8.93it/s]

loss for batch 760 --> 0.012716444209218025 at epoch 6


801it [01:35,  8.03it/s]

loss for batch 800 --> 0.006987965200096369 at epoch 6


841it [01:39,  8.13it/s]

loss for batch 840 --> 0.09317444264888763 at epoch 6


881it [01:44,  8.78it/s]

loss for batch 880 --> 0.005880937445908785 at epoch 6


921it [01:49,  9.36it/s]

loss for batch 920 --> 0.009432002902030945 at epoch 6


960it [01:53,  8.12it/s]

loss for batch 960 --> 0.008693672716617584 at epoch 6


1001it [01:58,  8.82it/s]

loss for batch 1000 --> 0.011341508477926254 at epoch 6


1041it [02:02,  9.07it/s]

loss for batch 1040 --> 0.00962644163519144 at epoch 6


1081it [02:07,  7.73it/s]

loss for batch 1080 --> 0.1203741803765297 at epoch 6


1121it [02:12,  8.10it/s]

loss for batch 1120 --> 0.011458882130682468 at epoch 6


1137it [02:14,  8.46it/s]
100%|██████████| 758/758 [00:13<00:00, 56.75it/s]


1972 out of 3030
acc of 65.08250825082507%


40it [00:04,  9.09it/s]

loss for batch 40 --> 0.10823877155780792 at epoch 7


81it [00:09,  8.75it/s]

loss for batch 80 --> 0.01536963414400816 at epoch 7


121it [00:14,  8.81it/s]

loss for batch 120 --> 0.2331644594669342 at epoch 7


161it [00:18,  8.74it/s]

loss for batch 160 --> 0.012766997329890728 at epoch 7


201it [00:23,  8.01it/s]

loss for batch 200 --> 0.05322849750518799 at epoch 7


241it [00:28,  7.69it/s]

loss for batch 240 --> 0.004516297951340675 at epoch 7


281it [00:33,  8.29it/s]

loss for batch 280 --> 0.010508784092962742 at epoch 7


321it [00:37,  9.02it/s]

loss for batch 320 --> 0.1428084373474121 at epoch 7


361it [00:42,  8.28it/s]

loss for batch 360 --> 0.008866578340530396 at epoch 7


401it [00:47,  8.05it/s]

loss for batch 400 --> 0.006578325759619474 at epoch 7


441it [00:52,  8.46it/s]

loss for batch 440 --> 0.1722688227891922 at epoch 7


480it [00:56,  8.40it/s]

loss for batch 480 --> 0.2624821960926056 at epoch 7


521it [01:01,  7.51it/s]

loss for batch 520 --> 0.5253813862800598 at epoch 7


560it [01:06,  8.62it/s]

loss for batch 560 --> 0.004902347456663847 at epoch 7


601it [01:10,  9.12it/s]

loss for batch 600 --> 0.019324516877532005 at epoch 7


641it [01:15,  8.43it/s]

loss for batch 640 --> 0.02126394771039486 at epoch 7


681it [01:20,  8.44it/s]

loss for batch 680 --> 0.20038311183452606 at epoch 7


721it [01:24,  8.48it/s]

loss for batch 720 --> 0.007258885074406862 at epoch 7


761it [01:29,  8.55it/s]

loss for batch 760 --> 0.009151259437203407 at epoch 7


801it [01:34,  8.75it/s]

loss for batch 800 --> 0.010796073824167252 at epoch 7


840it [01:38,  8.72it/s]

loss for batch 840 --> 0.005767601542174816 at epoch 7


881it [01:43,  9.07it/s]

loss for batch 880 --> 0.003951546270400286 at epoch 7


921it [01:48,  8.62it/s]

loss for batch 920 --> 0.012194391340017319 at epoch 7


961it [01:52,  8.89it/s]

loss for batch 960 --> 0.004405812360346317 at epoch 7


1001it [01:57,  7.53it/s]

loss for batch 1000 --> 0.014041058719158173 at epoch 7


1041it [02:02,  8.09it/s]

loss for batch 1040 --> 0.008680371567606926 at epoch 7


1081it [02:07,  8.73it/s]

loss for batch 1080 --> 0.6118051409721375 at epoch 7


1122it [02:12,  9.67it/s]

loss for batch 1120 --> 0.012617075815796852 at epoch 7


1137it [02:13,  8.50it/s]
100%|██████████| 758/758 [00:13<00:00, 56.89it/s]


1926 out of 3030
acc of 63.56435643564357%


40it [00:04,  8.13it/s]

loss for batch 40 --> 0.003463811008259654 at epoch 8


81it [00:09,  8.24it/s]

loss for batch 80 --> 0.00506939273327589 at epoch 8


121it [00:14,  8.36it/s]

loss for batch 120 --> 0.007220136933028698 at epoch 8


161it [00:19,  9.36it/s]

loss for batch 160 --> 0.009263257496058941 at epoch 8


201it [00:23,  8.78it/s]

loss for batch 200 --> 0.011821419931948185 at epoch 8


241it [00:28,  8.57it/s]

loss for batch 240 --> 0.12941749393939972 at epoch 8


281it [00:33,  8.55it/s]

loss for batch 280 --> 0.01718023233115673 at epoch 8


321it [00:38,  8.95it/s]

loss for batch 320 --> 0.0073124417103827 at epoch 8


361it [00:42,  8.44it/s]

loss for batch 360 --> 0.005798301659524441 at epoch 8


401it [00:47,  7.39it/s]

loss for batch 400 --> 0.1686467081308365 at epoch 8


441it [00:52,  8.56it/s]

loss for batch 440 --> 0.4246981739997864 at epoch 8


481it [00:57,  7.63it/s]

loss for batch 480 --> 0.05392398312687874 at epoch 8


521it [01:01,  8.18it/s]

loss for batch 520 --> 0.009211927652359009 at epoch 8


561it [01:06,  8.29it/s]

loss for batch 560 --> 0.004233437590301037 at epoch 8


601it [01:11,  8.40it/s]

loss for batch 600 --> 0.005409841891378164 at epoch 8


641it [01:16,  8.85it/s]

loss for batch 640 --> 0.005953352898359299 at epoch 8


681it [01:20,  8.38it/s]

loss for batch 680 --> 0.01167010422796011 at epoch 8


721it [01:25,  8.09it/s]

loss for batch 720 --> 0.021421410143375397 at epoch 8


761it [01:30,  8.89it/s]

loss for batch 760 --> 0.004049610812216997 at epoch 8


801it [01:35,  7.59it/s]

loss for batch 800 --> 0.004446269012987614 at epoch 8


841it [01:39,  9.45it/s]

loss for batch 840 --> 0.01757001504302025 at epoch 8


881it [01:44,  8.25it/s]

loss for batch 880 --> 0.09014299511909485 at epoch 8


921it [01:49,  8.32it/s]

loss for batch 920 --> 0.004610921256244183 at epoch 8


961it [01:54,  8.57it/s]

loss for batch 960 --> 0.21350224316120148 at epoch 8


1001it [01:58,  8.56it/s]

loss for batch 1000 --> 0.0044058263301849365 at epoch 8


1041it [02:03,  8.64it/s]

loss for batch 1040 --> 0.008614752441644669 at epoch 8


1080it [02:07,  8.83it/s]

loss for batch 1080 --> 0.00627761147916317 at epoch 8


1121it [02:12,  8.43it/s]

loss for batch 1120 --> 0.1334436982870102 at epoch 8


1137it [02:14,  8.44it/s]
100%|██████████| 758/758 [00:13<00:00, 55.38it/s]


1969 out of 3030
acc of 64.98349834983497%


41it [00:05,  8.12it/s]

loss for batch 40 --> 0.011373392306268215 at epoch 9


81it [00:09,  7.80it/s]

loss for batch 80 --> 0.003820929443463683 at epoch 9


121it [00:14,  8.46it/s]

loss for batch 120 --> 0.002103176200762391 at epoch 9


161it [00:19,  8.22it/s]

loss for batch 160 --> 0.003243530634790659 at epoch 9


201it [00:23,  8.16it/s]

loss for batch 200 --> 0.0068632191978394985 at epoch 9


241it [00:28,  8.21it/s]

loss for batch 240 --> 0.0039953915402293205 at epoch 9


281it [00:33,  8.55it/s]

loss for batch 280 --> 0.2218170464038849 at epoch 9


321it [00:38,  8.76it/s]

loss for batch 320 --> 0.0034250100143253803 at epoch 9


361it [00:42,  8.63it/s]

loss for batch 360 --> 0.009761691093444824 at epoch 9


401it [00:47,  8.70it/s]

loss for batch 400 --> 0.005017514806240797 at epoch 9


441it [00:52,  8.82it/s]

loss for batch 440 --> 0.005542901810258627 at epoch 9


481it [00:57,  8.99it/s]

loss for batch 480 --> 0.34636327624320984 at epoch 9


521it [01:02,  7.64it/s]

loss for batch 520 --> 0.01352057233452797 at epoch 9


561it [01:06,  9.04it/s]

loss for batch 560 --> 0.0021660926286131144 at epoch 9


601it [01:11,  8.33it/s]

loss for batch 600 --> 0.0032780859619379044 at epoch 9


641it [01:16,  8.08it/s]

loss for batch 640 --> 0.004098927602171898 at epoch 9


681it [01:21,  9.00it/s]

loss for batch 680 --> 0.15855465829372406 at epoch 9


721it [01:25,  8.69it/s]

loss for batch 720 --> 0.0025449302047491074 at epoch 9


760it [01:30,  8.54it/s]

loss for batch 760 --> 0.013850444927811623 at epoch 9


801it [01:35,  8.21it/s]

loss for batch 800 --> 0.0033070978242903948 at epoch 9


840it [01:39,  7.97it/s]

loss for batch 840 --> 0.04127832502126694 at epoch 9


881it [01:44,  9.07it/s]

loss for batch 880 --> 0.002498496090993285 at epoch 9


921it [01:49,  8.31it/s]

loss for batch 920 --> 0.02165365405380726 at epoch 9


961it [01:54,  8.54it/s]

loss for batch 960 --> 0.06650954484939575 at epoch 9


1001it [01:59,  8.92it/s]

loss for batch 1000 --> 0.8443608283996582 at epoch 9


1041it [02:04,  8.63it/s]

loss for batch 1040 --> 0.008535202592611313 at epoch 9


1081it [02:08,  8.93it/s]

loss for batch 1080 --> 0.025257185101509094 at epoch 9


1121it [02:13,  8.29it/s]

loss for batch 1120 --> 0.022679360583424568 at epoch 9


1137it [02:15,  8.41it/s]


In [20]:
submission_csv = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/test.csv")
submission_csv.drop(columns=["lang_abv", "language"])

Unnamed: 0,id,premise,hypothesis
0,c6d58c3f69,بکس، کیسی، راہیل، یسعیاہ، کیلی، کیلی، اور کولم...,"کیسی کے لئے کوئی یادگار نہیں ہوگا, کولمین ہائی..."
1,cefcc82292,هذا هو ما تم نصحنا به.,عندما يتم إخبارهم بما يجب عليهم فعله ، فشلت ال...
2,e98005252c,et cela est en grande partie dû au fait que le...,Les mères se droguent.
3,58518c10ba,与城市及其他公民及社区组织代表就IMA的艺术发展进行对话&amp,IMA与其他组织合作，因为它们都依靠共享资金。
4,c32b0d16df,Она все еще была там.,"Мы думали, что она ушла, однако, она осталась."
...,...,...,...
5190,5f90dd59b0,نیند نے وعدہ کیا کہ موٹل نے سوال میں تحقیق کی.,نیمیتھ کو موٹل کی تفتیش کے لئے معاوضہ دیا جارہ...
5191,f357a04e86,The rock has a soft texture and can be bough...,The rock is harder than most types of rock.
5192,1f0ea92118,她目前的存在，并考虑到他与沃佛斯顿争执的本质，那是尴尬的。,她在与Wolverstone的打斗结束后才在场的事实被看作是很尴尬的。
5193,0407b48afb,isn't it i can remember i've only been here ei...,I could see downtown Dallas from where I lived...


In [21]:
submission_dataset = CustomDatasetSubmission(df=submission_csv.drop(columns=["lang_abv", "language"]), max_length=max_length, tokenizer=tokenizer)

In [22]:
submission_dataloder = DataLoader(dataset=submission_dataset, batch_size=32, shuffle=False, generator=torch.Generator(device=device), collate_fn=submission_dataset.collate_fn)

In [23]:
final_output = {
    "id": [],
    "prediction": []
}

model.eval()
with torch.inference_mode():
    for batch in tqdm(submission_dataloder):
        logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
        percents = torch.softmax(logits, dim=1)
        preds = torch.argmax(percents, dim=1)
        for idx, prediction in zip(batch["ids"], preds.view(-1).tolist()):
            final_output["id"].append(idx)
            final_output["prediction"].append(prediction)


100%|██████████| 163/163 [00:21<00:00,  7.74it/s]


In [24]:
final_output_df = pd.DataFrame(final_output)
final_output_df
final_output_df.to_csv("./submission.csv", index=False)