# Import needed libraries

In [1]:
!pip install sacremoses

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1


In [2]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

2024-07-15 01:06:49.589081: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-15 01:06:49.589188: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-15 01:06:49.713510: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Device agnostic code

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"default device set to {device}")

default device set to cuda


In [4]:
train_df = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/train.csv")
train_df.head(5)

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1


In [5]:
train_df.drop(columns=["id", "lang_abv", "language"], inplace=True)

In [6]:
train_df.sample(5)

Unnamed: 0,premise,hypothesis,label
1196,В центре площади расположен гранитный Weltkuge...,Weltkugelbrunnen сделан из алюминия.,2
9156,"Possibly three months.""",Maybe three months.,0
8902,PROGRAM ACCOUNT -The budget account into which...,Funds should never be transferred between prog...,2
5792,"Vous, avec d'autres membres bienveillants, aid...",L'aide est nécessaire pour l'état.,0
12055,and then i got into it and then back out of it...,I plan to get into it on a regular basis.,1


In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

In [8]:
class CustomDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: object, max_length: int):
        self.dataset = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        premise = self.dataset["premise"].iloc[idx]
        hypothesis = self.dataset["hypothesis"].iloc[idx]
        
        token_dict = self.tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt")
    
        return {
            "input_ids": token_dict["input_ids"].flatten(),
            "token_type_ids": token_dict["token_type_ids"].flatten(),
            "attention_mask": token_dict["attention_mask"].flatten(),
            "label": torch.tensor(self.dataset["label"].iloc[idx], dtype=torch.long)
        }
    
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = [batch_item["input_ids"] for batch_item in batch]
        batch_type_ids = [batch_item["token_type_ids"] for batch_item in batch]
        batch_attention_masks = [batch_item["attention_mask"] for batch_item in batch]
        
        batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
        batch_type_ids  = pad_sequence(batch_type_ids, batch_first=True, padding_value=0) # will get ignored by the attention mask when going through the model; very important
        batch_attention_masks = pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
        
        return {
            "input_ids": batch_input_ids,
            "token_type_ids": batch_type_ids,
            "attention_mask": batch_attention_masks,
            "labels": torch.stack([batch_item["label"] for batch_item in batch], dim=0)
        }
        

In [9]:
class CustomDatasetSubmission(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: object, max_length: int):
        self.dataset = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        premise = self.dataset["premise"].iloc[idx]
        hypothesis = self.dataset["hypothesis"].iloc[idx]
        
        token_dict = self.tokenizer.encode_plus(premise, hypothesis, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt")
    
        return {
            "input_ids": token_dict["input_ids"].flatten(),
            "token_type_ids": token_dict["token_type_ids"].flatten(),
            "attention_mask": token_dict["attention_mask"].flatten(),
            "id": self.dataset["id"].iloc[idx]
        }
    
    @staticmethod
    def collate_fn(batch):
        batch_input_ids = [batch_item["input_ids"] for batch_item in batch]
        batch_type_ids = [batch_item["token_type_ids"] for batch_item in batch]
        batch_attention_masks = [batch_item["attention_mask"] for batch_item in batch]
        
        batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
        batch_type_ids  = pad_sequence(batch_type_ids, batch_first=True, padding_value=0) # will get ignored by the attention mask when going through the model; very important
        batch_attention_masks = pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
        
        return {
            "input_ids": batch_input_ids,
            "token_type_ids": batch_type_ids,
            "attention_mask": batch_attention_masks,
            "ids": [batch_item["id"] for batch_item in batch]
        }
 

In [10]:
max_length = max(max([len(premise) for premise in train_df["premise"]]), max([len(hypothesis) for hypothesis in train_df["hypothesis"]]))
train_split, val_split = train_test_split(train_df, test_size=0.25, shuffle=True)
train_dataset = CustomDataset(df=train_split, tokenizer=tokenizer, max_length=max_length)
val_dataset = CustomDataset(df=val_split, tokenizer=tokenizer, max_length=max_length)

In [11]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, generator=torch.Generator(device=device), collate_fn=train_dataset.collate_fn)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=True, generator=torch.Generator(device=device), collate_fn=val_dataset.collate_fn)

In [12]:
next(iter(val_dataloader))

{'input_ids': tensor([[  101, 13646, 13617, 19602, 10114, 10347, 21937,   119,   102, 10117,
          13617, 10393, 10472, 10590, 21937, 21833,   119,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0],
         [  101, 27746, 76560, 10486,   124, 41175, 11339, 11844, 88126, 10108,
          49307, 10686, 10111, 71020, 10729, 10188, 10435, 69714,   117, 11198,
          13183,   119,   102, 10167, 27746, 76560, 10486,   124, 11155, 10301,
          10105, 11339, 16454, 88126, 10311, 49307, 10686, 10111, 71020, 10729,
          10188, 10435, 69714,   119,   102],
         [  101, 12048,   107, 11065, 10149, 10472, 27874,   146, 10431, 20687,
          10114, 11783, 10114, 10105, 15034,   136,   107,   102, 12882, 13028,
          21528, 10911, 10114, 20575, 10105, 35691, 10107,   136,   102,     0,
              0

In [13]:
bert_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'google-bert/bert-base-multilingual-cased')

Downloading: "https://github.com/huggingface/pytorch-transformers/zipball/main" to /root/.cache/torch/hub/main.zip


config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

In [14]:
class Classifier(nn.Module):
    def __init__(self, n_inputs, hidden_size, n_classes, dropout):
        super(Classifier, self).__init__()
        self.linear1 = nn.Linear(in_features=n_inputs, out_features=hidden_size)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=n_classes)
        self.dropout = nn.Dropout(dropout)

        self.act_fn = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act_fn(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [15]:
class BertClassifier(nn.Module):
    def __init__(self, bert_model, classifier):
        super(BertClassifier, self).__init__()
        self.bert_model = bert_model
        self.classifier = classifier

    def forward(self, x: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
        x = self.bert_model(x, attention_mask=attention_mask, token_type_ids=token_type_ids)
        x = x.last_hidden_state[:, 0, :]
        x = self.classifier(x)
        return x

In [16]:
classifier = Classifier(n_inputs=768, n_classes=3, hidden_size=512, dropout=0.1)
model = BertClassifier(bert_model=bert_model, classifier=classifier)
bert_model.config_class()

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [17]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

In [18]:
def train_model(model, dataloader, val_dataloader, loss_fn, optimizer, scheduler, epochs):
    writer = SummaryWriter()
    writer.add_graph(model, (torch.zeros(size=(32, 32), dtype=torch.long), torch.zeros(size=(32, 32), dtype=torch.long)))
    for epoch in range(epochs):
        model.eval()
        total_inputs = 0
        total_correct = 0

        with torch.inference_mode():
            for batch in tqdm(val_dataloader):
                logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
                percents = torch.softmax(logits, dim=1)
                preds = torch.argmax(percents, dim=1)

                total_correct += (preds==batch["labels"]).sum().item()
                total_inputs += batch["labels"].view(-1).shape[0]

            print(f"{total_correct} out of {total_inputs}")
            print(f"acc of {total_correct/total_inputs*100}%")

            writer.add_scalar("Val acc", total_correct/total_inputs*100, epoch)

        model.train()
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
            loss = loss_fn(logits, batch["labels"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx+1) % 40 == 0:
                print(f"loss for batch {batch_idx+1} --> {loss} at epoch {epoch}")

            writer.add_scalar("Loss", loss, batch_idx)

        scheduler.step()

    writer.close()

In [19]:
scheduler = StepLR(gamma=0.5, step_size=5, optimizer=optimizer)
train_model(model=model, dataloader=train_dataloader, val_dataloader=val_dataloader, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, epochs=20)

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 758/758 [00:13<00:00, 56.86it/s]


971 out of 3030
acc of 32.04620462046205%


41it [00:05,  8.10it/s]

loss for batch 40 --> 1.1501154899597168 at epoch 0


81it [00:09,  9.50it/s]

loss for batch 80 --> 1.126226544380188 at epoch 0


122it [00:14,  9.78it/s]

loss for batch 120 --> 1.0427906513214111 at epoch 0


161it [00:18,  8.82it/s]

loss for batch 160 --> 1.1026710271835327 at epoch 0


201it [00:23,  8.72it/s]

loss for batch 200 --> 1.1102131605148315 at epoch 0


241it [00:28,  8.59it/s]

loss for batch 240 --> 1.1491565704345703 at epoch 0


281it [00:32,  7.79it/s]

loss for batch 280 --> 0.9664055705070496 at epoch 0


320it [00:37,  8.19it/s]

loss for batch 320 --> 0.7923179864883423 at epoch 0


361it [00:42,  7.67it/s]

loss for batch 360 --> 0.9472635984420776 at epoch 0


401it [00:47,  8.43it/s]

loss for batch 400 --> 0.7787283658981323 at epoch 0


441it [00:52,  8.96it/s]

loss for batch 440 --> 0.7251570224761963 at epoch 0


481it [00:56,  8.38it/s]

loss for batch 480 --> 0.4817902743816376 at epoch 0


521it [01:01,  8.77it/s]

loss for batch 520 --> 0.7617510557174683 at epoch 0


560it [01:06,  8.71it/s]

loss for batch 560 --> 0.8439951539039612 at epoch 0


601it [01:10,  8.68it/s]

loss for batch 600 --> 0.758194088935852 at epoch 0


641it [01:15,  8.92it/s]

loss for batch 640 --> 0.7687597274780273 at epoch 0


681it [01:20,  7.94it/s]

loss for batch 680 --> 0.9364330768585205 at epoch 0


721it [01:25,  8.58it/s]

loss for batch 720 --> 0.7922025918960571 at epoch 0


761it [01:30,  8.34it/s]

loss for batch 760 --> 0.6754845380783081 at epoch 0


801it [01:34,  8.00it/s]

loss for batch 800 --> 0.8054429292678833 at epoch 0


841it [01:39,  9.01it/s]

loss for batch 840 --> 0.7837660908699036 at epoch 0


881it [01:44,  8.53it/s]

loss for batch 880 --> 0.9504621624946594 at epoch 0


921it [01:48,  8.72it/s]

loss for batch 920 --> 0.9783841967582703 at epoch 0


961it [01:53,  8.72it/s]

loss for batch 960 --> 0.5821483731269836 at epoch 0


1001it [01:58,  8.10it/s]

loss for batch 1000 --> 1.2283540964126587 at epoch 0


1041it [02:02,  8.41it/s]

loss for batch 1040 --> 0.9453403949737549 at epoch 0


1081it [02:07,  9.20it/s]

loss for batch 1080 --> 0.7317903637886047 at epoch 0


1121it [02:11,  8.95it/s]

loss for batch 1120 --> 0.7250386476516724 at epoch 0


1137it [02:13,  8.52it/s]
100%|██████████| 758/758 [00:12<00:00, 59.37it/s]


1966 out of 3030
acc of 64.88448844884488%


42it [00:05,  8.93it/s]

loss for batch 40 --> 0.6309148073196411 at epoch 1


81it [00:09,  9.00it/s]

loss for batch 80 --> 0.6371688842773438 at epoch 1


121it [00:14,  7.97it/s]

loss for batch 120 --> 0.8864825963973999 at epoch 1


161it [00:18,  9.71it/s]

loss for batch 160 --> 0.7393105626106262 at epoch 1


201it [00:23,  8.40it/s]

loss for batch 200 --> 0.6693230867385864 at epoch 1


240it [00:28,  9.08it/s]

loss for batch 240 --> 0.5589696764945984 at epoch 1


281it [00:32,  8.13it/s]

loss for batch 280 --> 0.6546330451965332 at epoch 1


321it [00:37,  9.80it/s]

loss for batch 320 --> 0.8667758703231812 at epoch 1


361it [00:41,  8.95it/s]

loss for batch 360 --> 0.9749202132225037 at epoch 1


401it [00:46,  8.75it/s]

loss for batch 400 --> 0.40435871481895447 at epoch 1


442it [00:51,  9.26it/s]

loss for batch 440 --> 0.5936074256896973 at epoch 1


481it [00:55,  8.04it/s]

loss for batch 480 --> 0.27243077754974365 at epoch 1


521it [01:00,  8.55it/s]

loss for batch 520 --> 0.7400888204574585 at epoch 1


560it [01:05,  8.30it/s]

loss for batch 560 --> 0.3676746189594269 at epoch 1


602it [01:10,  9.60it/s]

loss for batch 600 --> 0.48296743631362915 at epoch 1


641it [01:14,  9.03it/s]

loss for batch 640 --> 0.6104297637939453 at epoch 1


681it [01:19,  8.60it/s]

loss for batch 680 --> 0.23358626663684845 at epoch 1


721it [01:23,  8.83it/s]

loss for batch 720 --> 0.8166109919548035 at epoch 1


761it [01:28,  9.23it/s]

loss for batch 760 --> 0.5709101557731628 at epoch 1


801it [01:32,  9.24it/s]

loss for batch 800 --> 0.5030835866928101 at epoch 1


841it [01:37,  9.39it/s]

loss for batch 840 --> 0.7665430903434753 at epoch 1


881it [01:42,  7.79it/s]

loss for batch 880 --> 0.6230142116546631 at epoch 1


920it [01:46,  8.43it/s]

loss for batch 920 --> 0.5252381563186646 at epoch 1


962it [01:51,  9.19it/s]

loss for batch 960 --> 0.7782996296882629 at epoch 1


1001it [01:56,  8.34it/s]

loss for batch 1000 --> 0.59918212890625 at epoch 1


1041it [02:00,  8.60it/s]

loss for batch 1040 --> 0.5143101811408997 at epoch 1


1081it [02:05,  8.84it/s]

loss for batch 1080 --> 0.37046515941619873 at epoch 1


1121it [02:09,  8.41it/s]

loss for batch 1120 --> 0.789171576499939 at epoch 1


1137it [02:11,  8.64it/s]
100%|██████████| 758/758 [00:12<00:00, 59.83it/s]


1979 out of 3030
acc of 65.31353135313532%


41it [00:04,  8.50it/s]

loss for batch 40 --> 0.2656974792480469 at epoch 2


81it [00:09,  9.47it/s]

loss for batch 80 --> 0.3038346469402313 at epoch 2


121it [00:14,  8.55it/s]

loss for batch 120 --> 0.17330706119537354 at epoch 2


161it [00:18,  8.77it/s]

loss for batch 160 --> 0.4833327531814575 at epoch 2


201it [00:23,  8.50it/s]

loss for batch 200 --> 0.3113914430141449 at epoch 2


241it [00:27,  7.95it/s]

loss for batch 240 --> 0.1515313982963562 at epoch 2


281it [00:32,  8.41it/s]

loss for batch 280 --> 0.08894151449203491 at epoch 2


321it [00:37,  8.33it/s]

loss for batch 320 --> 0.5783309936523438 at epoch 2


361it [00:41,  8.50it/s]

loss for batch 360 --> 0.8451818227767944 at epoch 2


402it [00:46,  9.38it/s]

loss for batch 400 --> 0.22656932473182678 at epoch 2


441it [00:51,  8.62it/s]

loss for batch 440 --> 0.4592113196849823 at epoch 2


481it [00:55,  8.44it/s]

loss for batch 480 --> 0.47216230630874634 at epoch 2


521it [01:00,  8.82it/s]

loss for batch 520 --> 0.4327653646469116 at epoch 2


561it [01:05,  8.70it/s]

loss for batch 560 --> 0.05833893641829491 at epoch 2


601it [01:09,  9.03it/s]

loss for batch 600 --> 0.23183226585388184 at epoch 2


641it [01:14,  9.30it/s]

loss for batch 640 --> 0.5927093625068665 at epoch 2


681it [01:19,  7.68it/s]

loss for batch 680 --> 0.15786334872245789 at epoch 2


721it [01:23,  8.89it/s]

loss for batch 720 --> 0.2447202056646347 at epoch 2


761it [01:28,  8.34it/s]

loss for batch 760 --> 0.38942256569862366 at epoch 2


802it [01:33,  9.34it/s]

loss for batch 800 --> 0.36317384243011475 at epoch 2


841it [01:37,  8.47it/s]

loss for batch 840 --> 0.4787311255931854 at epoch 2


881it [01:42,  8.32it/s]

loss for batch 880 --> 0.7293056845664978 at epoch 2


920it [01:47,  8.03it/s]

loss for batch 920 --> 0.6577077507972717 at epoch 2


961it [01:51,  8.15it/s]

loss for batch 960 --> 0.6806465983390808 at epoch 2


1001it [01:56,  8.19it/s]

loss for batch 1000 --> 0.5955641865730286 at epoch 2


1041it [02:01,  8.43it/s]

loss for batch 1040 --> 0.1538512259721756 at epoch 2


1081it [02:05,  8.58it/s]

loss for batch 1080 --> 0.3203984797000885 at epoch 2


1121it [02:10,  8.94it/s]

loss for batch 1120 --> 0.5771311521530151 at epoch 2


1137it [02:12,  8.59it/s]
100%|██████████| 758/758 [00:12<00:00, 59.65it/s]


1966 out of 3030
acc of 64.88448844884488%


41it [00:04,  9.51it/s]

loss for batch 40 --> 0.4649199843406677 at epoch 3


81it [00:09,  8.46it/s]

loss for batch 80 --> 0.12986329197883606 at epoch 3


121it [00:14,  8.87it/s]

loss for batch 120 --> 0.03860537335276604 at epoch 3


161it [00:18,  9.83it/s]

loss for batch 160 --> 0.08572875708341599 at epoch 3


201it [00:23,  8.33it/s]

loss for batch 200 --> 0.8359504342079163 at epoch 3


242it [00:27,  9.39it/s]

loss for batch 240 --> 0.19303952157497406 at epoch 3


281it [00:32,  7.94it/s]

loss for batch 280 --> 0.16687902808189392 at epoch 3


321it [00:37,  8.49it/s]

loss for batch 320 --> 0.9251693487167358 at epoch 3


360it [00:41,  8.98it/s]

loss for batch 360 --> 0.044886376708745956 at epoch 3


401it [00:46,  8.80it/s]

loss for batch 400 --> 0.33228859305381775 at epoch 3


440it [00:51,  7.82it/s]

loss for batch 440 --> 0.30740800499916077 at epoch 3


482it [00:56,  9.16it/s]

loss for batch 480 --> 0.0818483829498291 at epoch 3


521it [01:00,  8.09it/s]

loss for batch 520 --> 0.29148441553115845 at epoch 3


561it [01:05,  8.86it/s]

loss for batch 560 --> 0.16359899938106537 at epoch 3


601it [01:09,  8.89it/s]

loss for batch 600 --> 0.07317394018173218 at epoch 3


640it [01:14,  8.51it/s]

loss for batch 640 --> 0.13477590680122375 at epoch 3


681it [01:19,  8.55it/s]

loss for batch 680 --> 0.3549934923648834 at epoch 3


721it [01:24,  8.73it/s]

loss for batch 720 --> 0.24558107554912567 at epoch 3


761it [01:28,  8.54it/s]

loss for batch 760 --> 0.19745425879955292 at epoch 3


801it [01:33,  8.50it/s]

loss for batch 800 --> 0.3055571913719177 at epoch 3


841it [01:38,  9.52it/s]

loss for batch 840 --> 1.4750041961669922 at epoch 3


881it [01:42,  8.73it/s]

loss for batch 880 --> 0.04193186014890671 at epoch 3


922it [01:47,  9.15it/s]

loss for batch 920 --> 0.47739818692207336 at epoch 3


961it [01:52,  8.96it/s]

loss for batch 960 --> 1.3430335521697998 at epoch 3


1001it [01:56,  9.53it/s]

loss for batch 1000 --> 0.0983482301235199 at epoch 3


1041it [02:01,  8.80it/s]

loss for batch 1040 --> 0.728138267993927 at epoch 3


1081it [02:05,  8.49it/s]

loss for batch 1080 --> 0.24847789108753204 at epoch 3


1121it [02:10,  8.66it/s]

loss for batch 1120 --> 0.1896628886461258 at epoch 3


1137it [02:12,  8.58it/s]
100%|██████████| 758/758 [00:12<00:00, 59.51it/s]


1961 out of 3030
acc of 64.71947194719472%


41it [00:04,  7.88it/s]

loss for batch 40 --> 0.14445795118808746 at epoch 4


81it [00:09,  7.88it/s]

loss for batch 80 --> 0.030369628220796585 at epoch 4


121it [00:14,  8.39it/s]

loss for batch 120 --> 0.06224292516708374 at epoch 4


161it [00:18,  8.65it/s]

loss for batch 160 --> 0.06790733337402344 at epoch 4


201it [00:23,  8.86it/s]

loss for batch 200 --> 0.040806788951158524 at epoch 4


242it [00:28,  9.40it/s]

loss for batch 240 --> 0.19596123695373535 at epoch 4


281it [00:32,  8.15it/s]

loss for batch 280 --> 0.12149827182292938 at epoch 4


321it [00:37,  8.80it/s]

loss for batch 320 --> 0.02378513105213642 at epoch 4


361it [00:41,  8.72it/s]

loss for batch 360 --> 0.13724136352539062 at epoch 4


401it [00:46,  7.42it/s]

loss for batch 400 --> 0.07617414742708206 at epoch 4


441it [00:51,  8.75it/s]

loss for batch 440 --> 0.023380232974886894 at epoch 4


481it [00:55,  8.25it/s]

loss for batch 480 --> 0.05030489340424538 at epoch 4


521it [01:00,  8.46it/s]

loss for batch 520 --> 0.1603887677192688 at epoch 4


561it [01:05,  8.50it/s]

loss for batch 560 --> 0.028774425387382507 at epoch 4


601it [01:09,  9.20it/s]

loss for batch 600 --> 0.1629362553358078 at epoch 4


640it [01:14,  8.15it/s]

loss for batch 640 --> 0.029991397634148598 at epoch 4


680it [01:18,  8.20it/s]

loss for batch 680 --> 0.016829468309879303 at epoch 4


721it [01:23,  8.74it/s]

loss for batch 720 --> 0.40492314100265503 at epoch 4


761it [01:28,  9.07it/s]

loss for batch 760 --> 0.03364481031894684 at epoch 4


801it [01:33,  8.15it/s]

loss for batch 800 --> 0.7845022678375244 at epoch 4


841it [01:38,  8.51it/s]

loss for batch 840 --> 0.3373435437679291 at epoch 4


881it [01:42,  7.93it/s]

loss for batch 880 --> 0.13921238481998444 at epoch 4


921it [01:47,  8.35it/s]

loss for batch 920 --> 0.06573785841464996 at epoch 4


961it [01:51,  7.28it/s]

loss for batch 960 --> 0.018334994092583656 at epoch 4


1000it [01:56,  7.86it/s]

loss for batch 1000 --> 0.0791178047657013 at epoch 4


1041it [02:01,  7.16it/s]

loss for batch 1040 --> 0.11337114125490189 at epoch 4


1081it [02:05,  8.83it/s]

loss for batch 1080 --> 0.02530214749276638 at epoch 4


1121it [02:10,  9.00it/s]

loss for batch 1120 --> 0.24310564994812012 at epoch 4


1137it [02:12,  8.59it/s]
100%|██████████| 758/758 [00:12<00:00, 59.71it/s]


1973 out of 3030
acc of 65.11551155115511%


41it [00:04,  9.25it/s]

loss for batch 40 --> 0.015799354761838913 at epoch 5


81it [00:09,  7.06it/s]

loss for batch 80 --> 0.12412893772125244 at epoch 5


121it [00:14,  8.36it/s]

loss for batch 120 --> 0.020305844023823738 at epoch 5


161it [00:18,  8.13it/s]

loss for batch 160 --> 0.028798379004001617 at epoch 5


201it [00:23,  8.57it/s]

loss for batch 200 --> 0.024753810837864876 at epoch 5


241it [00:28,  8.35it/s]

loss for batch 240 --> 0.016832107678055763 at epoch 5


281it [00:32,  8.36it/s]

loss for batch 280 --> 0.06155817210674286 at epoch 5


322it [00:37,  9.57it/s]

loss for batch 320 --> 0.0355990044772625 at epoch 5


361it [00:42,  8.33it/s]

loss for batch 360 --> 0.03129570186138153 at epoch 5


402it [00:46, 10.23it/s]

loss for batch 400 --> 0.01956687681376934 at epoch 5


441it [00:51,  8.61it/s]

loss for batch 440 --> 0.6967185139656067 at epoch 5


481it [00:55,  8.50it/s]

loss for batch 480 --> 0.00875311903655529 at epoch 5


521it [01:00,  7.87it/s]

loss for batch 520 --> 0.16996879875659943 at epoch 5


560it [01:04,  9.16it/s]

loss for batch 560 --> 0.010506800375878811 at epoch 5


601it [01:09,  8.49it/s]

loss for batch 600 --> 0.0098185446113348 at epoch 5


641it [01:14,  8.54it/s]

loss for batch 640 --> 0.2134910225868225 at epoch 5


681it [01:19,  9.98it/s]

loss for batch 680 --> 0.019771268591284752 at epoch 5


720it [01:23,  8.44it/s]

loss for batch 720 --> 0.17033976316452026 at epoch 5


761it [01:28,  9.26it/s]

loss for batch 760 --> 0.021232545375823975 at epoch 5


801it [01:32,  9.31it/s]

loss for batch 800 --> 0.028334271162748337 at epoch 5


841it [01:37,  8.54it/s]

loss for batch 840 --> 0.008810768835246563 at epoch 5


881it [01:42,  7.98it/s]

loss for batch 880 --> 0.0096656559035182 at epoch 5


921it [01:46,  9.37it/s]

loss for batch 920 --> 0.194622203707695 at epoch 5


961it [01:51,  8.28it/s]

loss for batch 960 --> 0.1317121684551239 at epoch 5


1001it [01:56,  8.53it/s]

loss for batch 1000 --> 0.01664189249277115 at epoch 5


1041it [02:01,  8.84it/s]

loss for batch 1040 --> 0.018202971667051315 at epoch 5


1081it [02:05,  8.07it/s]

loss for batch 1080 --> 0.11800070106983185 at epoch 5


1121it [02:10,  7.97it/s]

loss for batch 1120 --> 0.09026001393795013 at epoch 5


1137it [02:12,  8.58it/s]
100%|██████████| 758/758 [00:12<00:00, 59.35it/s]


1947 out of 3030
acc of 64.25742574257426%


41it [00:04,  8.30it/s]

loss for batch 40 --> 0.008528315462172031 at epoch 6


81it [00:09,  9.04it/s]

loss for batch 80 --> 0.007644103839993477 at epoch 6


121it [00:13,  9.13it/s]

loss for batch 120 --> 0.02890547178685665 at epoch 6


161it [00:18,  9.70it/s]

loss for batch 160 --> 0.01417337916791439 at epoch 6


201it [00:23,  7.83it/s]

loss for batch 200 --> 0.0067001935094594955 at epoch 6


241it [00:27,  8.88it/s]

loss for batch 240 --> 0.0149789834395051 at epoch 6


281it [00:32,  8.51it/s]

loss for batch 280 --> 0.004705562721937895 at epoch 6


321it [00:36,  9.51it/s]

loss for batch 320 --> 0.006760577671229839 at epoch 6


361it [00:41,  9.04it/s]

loss for batch 360 --> 0.016203830018639565 at epoch 6


401it [00:46,  7.77it/s]

loss for batch 400 --> 0.02180011011660099 at epoch 6


441it [00:50,  9.49it/s]

loss for batch 440 --> 0.05326168239116669 at epoch 6


481it [00:55,  9.11it/s]

loss for batch 480 --> 0.013312644325196743 at epoch 6


521it [00:59,  8.42it/s]

loss for batch 520 --> 0.4893675148487091 at epoch 6


560it [01:04,  8.45it/s]

loss for batch 560 --> 0.0109058553352952 at epoch 6


602it [01:09,  8.89it/s]

loss for batch 600 --> 0.008693031966686249 at epoch 6


642it [01:13,  9.12it/s]

loss for batch 640 --> 0.004777975846081972 at epoch 6


680it [01:18,  9.02it/s]

loss for batch 680 --> 0.012902404181659222 at epoch 6


721it [01:23,  8.32it/s]

loss for batch 720 --> 0.2140735238790512 at epoch 6


761it [01:28,  7.94it/s]

loss for batch 760 --> 0.02467178925871849 at epoch 6


801it [01:32,  8.76it/s]

loss for batch 800 --> 0.02123279869556427 at epoch 6


840it [01:37,  9.14it/s]

loss for batch 840 --> 0.03320329263806343 at epoch 6


881it [01:42,  7.88it/s]

loss for batch 880 --> 0.10939706861972809 at epoch 6


921it [01:46,  8.76it/s]

loss for batch 920 --> 0.004902810789644718 at epoch 6


961it [01:51,  6.69it/s]

loss for batch 960 --> 0.006825763266533613 at epoch 6


1000it [01:56,  8.59it/s]

loss for batch 1000 --> 0.007040743716061115 at epoch 6


1041it [02:01,  8.20it/s]

loss for batch 1040 --> 0.004140758421272039 at epoch 6


1081it [02:06,  8.44it/s]

loss for batch 1080 --> 0.004794037900865078 at epoch 6


1121it [02:11,  9.09it/s]

loss for batch 1120 --> 0.0172229316085577 at epoch 6


1137it [02:12,  8.55it/s]
100%|██████████| 758/758 [00:12<00:00, 58.64it/s]


1924 out of 3030
acc of 63.4983498349835%


41it [00:04,  9.39it/s]

loss for batch 40 --> 0.006546080112457275 at epoch 7


81it [00:09,  9.34it/s]

loss for batch 80 --> 0.003449147567152977 at epoch 7


121it [00:14,  7.90it/s]

loss for batch 120 --> 0.007828056812286377 at epoch 7


161it [00:18,  8.65it/s]

loss for batch 160 --> 0.02375223860144615 at epoch 7


201it [00:23,  7.56it/s]

loss for batch 200 --> 0.0152181601151824 at epoch 7


240it [00:28,  8.60it/s]

loss for batch 240 --> 0.0025590872392058372 at epoch 7


281it [00:33,  8.67it/s]

loss for batch 280 --> 0.00805311556905508 at epoch 7


321it [00:37,  9.11it/s]

loss for batch 320 --> 0.11541996896266937 at epoch 7


361it [00:42,  8.85it/s]

loss for batch 360 --> 0.035505183041095734 at epoch 7


401it [00:47,  8.72it/s]

loss for batch 400 --> 0.10928693413734436 at epoch 7


440it [00:51,  9.01it/s]

loss for batch 440 --> 0.4618619978427887 at epoch 7


481it [00:56,  7.31it/s]

loss for batch 480 --> 0.007192883640527725 at epoch 7


520it [01:01,  8.39it/s]

loss for batch 520 --> 0.0087449811398983 at epoch 7


561it [01:06,  8.74it/s]

loss for batch 560 --> 0.004732938017696142 at epoch 7


601it [01:11,  8.54it/s]

loss for batch 600 --> 0.0033287699334323406 at epoch 7


641it [01:15,  9.05it/s]

loss for batch 640 --> 0.06738598644733429 at epoch 7


681it [01:20,  8.69it/s]

loss for batch 680 --> 0.008776645176112652 at epoch 7


721it [01:24,  8.19it/s]

loss for batch 720 --> 0.20245854556560516 at epoch 7


760it [01:29,  9.04it/s]

loss for batch 760 --> 0.36567768454551697 at epoch 7


801it [01:33,  8.73it/s]

loss for batch 800 --> 0.003435802413150668 at epoch 7


841it [01:38,  8.93it/s]

loss for batch 840 --> 0.004635831341147423 at epoch 7


881it [01:43,  7.47it/s]

loss for batch 880 --> 0.003417754312977195 at epoch 7


921it [01:47,  8.17it/s]

loss for batch 920 --> 0.0036806613206863403 at epoch 7


960it [01:52,  8.75it/s]

loss for batch 960 --> 0.01944359950721264 at epoch 7


1001it [01:56,  8.22it/s]

loss for batch 1000 --> 0.44514355063438416 at epoch 7


1041it [02:01,  9.34it/s]

loss for batch 1040 --> 0.41649603843688965 at epoch 7


1081it [02:06,  7.35it/s]

loss for batch 1080 --> 0.0052139186300337315 at epoch 7


1121it [02:11,  7.41it/s]

loss for batch 1120 --> 0.20044559240341187 at epoch 7


1137it [02:13,  8.55it/s]
100%|██████████| 758/758 [00:12<00:00, 59.76it/s]


1966 out of 3030
acc of 64.88448844884488%


41it [00:04,  8.66it/s]

loss for batch 40 --> 0.002498457208275795 at epoch 8


81it [00:09,  9.06it/s]

loss for batch 80 --> 0.0031555630266666412 at epoch 8


121it [00:13,  8.92it/s]

loss for batch 120 --> 0.06305021047592163 at epoch 8


161it [00:18,  8.73it/s]

loss for batch 160 --> 0.23732802271842957 at epoch 8


201it [00:23,  7.88it/s]

loss for batch 200 --> 0.003695286577567458 at epoch 8


241it [00:27,  8.99it/s]

loss for batch 240 --> 0.0038120183162391186 at epoch 8


281it [00:32,  8.72it/s]

loss for batch 280 --> 0.2377752959728241 at epoch 8


321it [00:37,  8.78it/s]

loss for batch 320 --> 0.002039667684584856 at epoch 8


361it [00:41,  8.63it/s]

loss for batch 360 --> 0.005433405749499798 at epoch 8


401it [00:46,  8.87it/s]

loss for batch 400 --> 0.09654790163040161 at epoch 8


441it [00:51,  8.44it/s]

loss for batch 440 --> 0.0062782191671431065 at epoch 8


482it [00:55,  9.04it/s]

loss for batch 480 --> 0.0029851088766008615 at epoch 8


521it [01:00,  8.38it/s]

loss for batch 520 --> 0.009426838718354702 at epoch 8


560it [01:04,  8.92it/s]

loss for batch 560 --> 0.002846064977347851 at epoch 8


601it [01:09,  8.81it/s]

loss for batch 600 --> 0.0361853763461113 at epoch 8


641it [01:13,  9.32it/s]

loss for batch 640 --> 0.014470878057181835 at epoch 8


681it [01:18,  9.22it/s]

loss for batch 680 --> 0.0663101002573967 at epoch 8


721it [01:23,  9.09it/s]

loss for batch 720 --> 0.005815157201141119 at epoch 8


761it [01:28,  8.45it/s]

loss for batch 760 --> 0.0022215736098587513 at epoch 8


801it [01:32,  9.26it/s]

loss for batch 800 --> 0.43421903252601624 at epoch 8


841it [01:37,  9.62it/s]

loss for batch 840 --> 0.13143108785152435 at epoch 8


881it [01:41,  9.70it/s]

loss for batch 880 --> 0.021795136854052544 at epoch 8


921it [01:46,  9.60it/s]

loss for batch 920 --> 0.0049626268446445465 at epoch 8


961it [01:51,  8.19it/s]

loss for batch 960 --> 0.0027355526108294725 at epoch 8


1001it [01:56,  8.12it/s]

loss for batch 1000 --> 0.05327444151043892 at epoch 8


1041it [02:00,  9.83it/s]

loss for batch 1040 --> 0.003770679933950305 at epoch 8


1081it [02:05,  8.54it/s]

loss for batch 1080 --> 0.00569460354745388 at epoch 8


1121it [02:10,  8.14it/s]

loss for batch 1120 --> 0.0017594732344150543 at epoch 8


1137it [02:12,  8.61it/s]
100%|██████████| 758/758 [00:12<00:00, 59.44it/s]


1948 out of 3030
acc of 64.2904290429043%


41it [00:04,  8.35it/s]

loss for batch 40 --> 0.015876343473792076 at epoch 9


81it [00:09,  8.43it/s]

loss for batch 80 --> 0.003126689698547125 at epoch 9


121it [00:14,  8.86it/s]

loss for batch 120 --> 0.02546699345111847 at epoch 9


161it [00:18,  8.77it/s]

loss for batch 160 --> 0.0039055603556334972 at epoch 9


200it [00:23,  8.48it/s]

loss for batch 200 --> 0.0036778352223336697 at epoch 9


241it [00:27,  9.18it/s]

loss for batch 240 --> 0.04018567502498627 at epoch 9


281it [00:32,  8.66it/s]

loss for batch 280 --> 0.0018876392859965563 at epoch 9


321it [00:37,  8.66it/s]

loss for batch 320 --> 0.01204446330666542 at epoch 9


361it [00:41,  8.22it/s]

loss for batch 360 --> 0.0864429846405983 at epoch 9


401it [00:46,  8.51it/s]

loss for batch 400 --> 0.00221373257227242 at epoch 9


441it [00:51,  8.94it/s]

loss for batch 440 --> 0.021521754562854767 at epoch 9


481it [00:55,  7.98it/s]

loss for batch 480 --> 0.0022716454695910215 at epoch 9


521it [01:00,  9.24it/s]

loss for batch 520 --> 0.0019156563794240355 at epoch 9


561it [01:05,  8.28it/s]

loss for batch 560 --> 0.0036894618533551693 at epoch 9


601it [01:09,  8.62it/s]

loss for batch 600 --> 0.004686293192207813 at epoch 9


641it [01:14,  8.38it/s]

loss for batch 640 --> 0.0071375733241438866 at epoch 9


681it [01:19,  8.80it/s]

loss for batch 680 --> 0.0031389035284519196 at epoch 9


721it [01:23,  8.87it/s]

loss for batch 720 --> 0.0010937282349914312 at epoch 9


761it [01:28,  8.23it/s]

loss for batch 760 --> 0.004881728440523148 at epoch 9


801it [01:33,  8.39it/s]

loss for batch 800 --> 0.002847058931365609 at epoch 9


840it [01:37,  8.37it/s]

loss for batch 840 --> 0.002333548851311207 at epoch 9


881it [01:42,  8.56it/s]

loss for batch 880 --> 0.003263995982706547 at epoch 9


921it [01:47,  9.12it/s]

loss for batch 920 --> 0.001901631592772901 at epoch 9


961it [01:51,  8.16it/s]

loss for batch 960 --> 0.010831069201231003 at epoch 9


1001it [01:56,  8.53it/s]

loss for batch 1000 --> 0.001964692724868655 at epoch 9


1041it [02:01,  9.07it/s]

loss for batch 1040 --> 0.005140398163348436 at epoch 9


1081it [02:06,  8.25it/s]

loss for batch 1080 --> 0.0024073575623333454 at epoch 9


1121it [02:10,  8.98it/s]

loss for batch 1120 --> 0.004792107734829187 at epoch 9


1137it [02:12,  8.58it/s]
100%|██████████| 758/758 [00:12<00:00, 59.32it/s]


1962 out of 3030
acc of 64.75247524752476%


41it [00:04,  8.61it/s]

loss for batch 40 --> 0.0029386249370872974 at epoch 10


81it [00:09,  8.99it/s]

loss for batch 80 --> 0.002025445457547903 at epoch 10


121it [00:14,  8.27it/s]

loss for batch 120 --> 0.017506463453173637 at epoch 10


161it [00:18,  7.77it/s]

loss for batch 160 --> 0.006521042436361313 at epoch 10


201it [00:23,  8.30it/s]

loss for batch 200 --> 0.013375397771596909 at epoch 10


241it [00:27,  9.14it/s]

loss for batch 240 --> 0.004742221906781197 at epoch 10


281it [00:32,  8.22it/s]

loss for batch 280 --> 0.001544566242955625 at epoch 10


321it [00:37,  8.63it/s]

loss for batch 320 --> 0.0021231973078101873 at epoch 10


361it [00:41,  8.35it/s]

loss for batch 360 --> 0.0036277181934565306 at epoch 10


401it [00:46,  8.93it/s]

loss for batch 400 --> 0.012505502440035343 at epoch 10


441it [00:51,  8.58it/s]

loss for batch 440 --> 0.27432870864868164 at epoch 10


481it [00:56,  9.68it/s]

loss for batch 480 --> 0.001705855829641223 at epoch 10


521it [01:00,  7.65it/s]

loss for batch 520 --> 0.00374380755238235 at epoch 10


562it [01:05,  9.50it/s]

loss for batch 560 --> 0.0018798368982970715 at epoch 10


601it [01:09,  8.16it/s]

loss for batch 600 --> 0.00781939085572958 at epoch 10


641it [01:14,  7.79it/s]

loss for batch 640 --> 0.0034179408103227615 at epoch 10


681it [01:19,  8.00it/s]

loss for batch 680 --> 0.0029959380626678467 at epoch 10


721it [01:24,  8.50it/s]

loss for batch 720 --> 0.001807142049074173 at epoch 10


761it [01:29,  8.11it/s]

loss for batch 760 --> 0.03777023404836655 at epoch 10


801it [01:33,  8.30it/s]

loss for batch 800 --> 0.0036805891431868076 at epoch 10


841it [01:38,  9.24it/s]

loss for batch 840 --> 0.006703658029437065 at epoch 10


881it [01:43,  8.05it/s]

loss for batch 880 --> 0.0044624595902860165 at epoch 10


921it [01:48,  9.20it/s]

loss for batch 920 --> 0.0017581339925527573 at epoch 10


961it [01:52,  8.46it/s]

loss for batch 960 --> 0.0011196928098797798 at epoch 10


1001it [01:57,  9.02it/s]

loss for batch 1000 --> 0.0018733444157987833 at epoch 10


1041it [02:02,  7.84it/s]

loss for batch 1040 --> 0.1263100653886795 at epoch 10


1081it [02:06,  8.65it/s]

loss for batch 1080 --> 0.0020801834762096405 at epoch 10


1121it [02:11,  8.03it/s]

loss for batch 1120 --> 0.06938458234071732 at epoch 10


1137it [02:13,  8.54it/s]
100%|██████████| 758/758 [00:12<00:00, 59.33it/s]


1983 out of 3030
acc of 65.44554455445545%


41it [00:04,  8.47it/s]

loss for batch 40 --> 0.011887740343809128 at epoch 11


81it [00:09,  8.64it/s]

loss for batch 80 --> 0.0012142079649493098 at epoch 11


121it [00:14,  7.79it/s]

loss for batch 120 --> 0.0017884314293041825 at epoch 11


161it [00:18,  9.07it/s]

loss for batch 160 --> 0.0012973013799637556 at epoch 11


201it [00:23,  8.29it/s]

loss for batch 200 --> 0.0010625761933624744 at epoch 11


241it [00:28,  8.87it/s]

loss for batch 240 --> 0.002035885350778699 at epoch 11


281it [00:32,  9.17it/s]

loss for batch 280 --> 0.004736174829304218 at epoch 11


321it [00:37,  9.11it/s]

loss for batch 320 --> 0.00233988999389112 at epoch 11


361it [00:41,  8.94it/s]

loss for batch 360 --> 0.001599498325958848 at epoch 11


401it [00:46,  8.92it/s]

loss for batch 400 --> 0.06826625019311905 at epoch 11


441it [00:51,  8.92it/s]

loss for batch 440 --> 0.001633285777643323 at epoch 11


481it [00:56,  8.82it/s]

loss for batch 480 --> 0.002362000523135066 at epoch 11


521it [01:00,  8.34it/s]

loss for batch 520 --> 0.01978975161910057 at epoch 11


561it [01:05,  9.09it/s]

loss for batch 560 --> 0.0034593576565384865 at epoch 11


601it [01:09,  8.99it/s]

loss for batch 600 --> 0.0016691533382982016 at epoch 11


642it [01:14,  9.83it/s]

loss for batch 640 --> 0.002184765413403511 at epoch 11


681it [01:19,  8.28it/s]

loss for batch 680 --> 0.0011828896822407842 at epoch 11


721it [01:23,  9.14it/s]

loss for batch 720 --> 0.0018549620872363448 at epoch 11


762it [01:28,  9.46it/s]

loss for batch 760 --> 0.036993999034166336 at epoch 11


800it [01:33,  8.58it/s]

loss for batch 800 --> 0.001495651318691671 at epoch 11


841it [01:37,  9.35it/s]

loss for batch 840 --> 0.004742665681988001 at epoch 11


881it [01:42,  8.79it/s]

loss for batch 880 --> 0.001185126486234367 at epoch 11


921it [01:46,  8.65it/s]

loss for batch 920 --> 0.0054095033556222916 at epoch 11


961it [01:51,  8.01it/s]

loss for batch 960 --> 0.10121646523475647 at epoch 11


1001it [01:56,  7.64it/s]

loss for batch 1000 --> 0.006291148718446493 at epoch 11


1041it [02:01,  8.27it/s]

loss for batch 1040 --> 0.0022137262858450413 at epoch 11


1081it [02:05,  9.00it/s]

loss for batch 1080 --> 0.0014917498920112848 at epoch 11


1122it [02:10,  9.46it/s]

loss for batch 1120 --> 0.06949389725923538 at epoch 11


1137it [02:11,  8.61it/s]
100%|██████████| 758/758 [00:12<00:00, 60.05it/s]


1945 out of 3030
acc of 64.1914191419142%


41it [00:04,  9.52it/s]

loss for batch 40 --> 0.0014075078070163727 at epoch 12


81it [00:09,  8.50it/s]

loss for batch 80 --> 0.0028978472109884024 at epoch 12


121it [00:13,  8.18it/s]

loss for batch 120 --> 0.0008178796852007508 at epoch 12


161it [00:18,  8.24it/s]

loss for batch 160 --> 0.003637047717347741 at epoch 12


200it [00:23,  8.73it/s]

loss for batch 200 --> 0.0037197647616267204 at epoch 12


241it [00:27,  9.00it/s]

loss for batch 240 --> 0.0009487107745371759 at epoch 12


281it [00:32,  8.86it/s]

loss for batch 280 --> 0.0010632667690515518 at epoch 12


321it [00:36,  8.71it/s]

loss for batch 320 --> 0.0019335186807438731 at epoch 12


361it [00:41,  8.63it/s]

loss for batch 360 --> 0.017956173047423363 at epoch 12


401it [00:46,  8.47it/s]

loss for batch 400 --> 0.0010979161597788334 at epoch 12


441it [00:50,  7.99it/s]

loss for batch 440 --> 0.0031561958603560925 at epoch 12


481it [00:55,  8.89it/s]

loss for batch 480 --> 0.0016226896550506353 at epoch 12


521it [00:59,  7.71it/s]

loss for batch 520 --> 0.0006479882868006825 at epoch 12


561it [01:04,  8.37it/s]

loss for batch 560 --> 0.0008647097856737673 at epoch 12


601it [01:09,  8.75it/s]

loss for batch 600 --> 0.01296309195458889 at epoch 12


642it [01:14,  9.31it/s]

loss for batch 640 --> 0.0007242109859362245 at epoch 12


681it [01:18,  8.16it/s]

loss for batch 680 --> 0.0012616192689165473 at epoch 12


721it [01:23,  8.35it/s]

loss for batch 720 --> 0.0015945046907290816 at epoch 12


762it [01:27,  9.23it/s]

loss for batch 760 --> 0.001146566472016275 at epoch 12


800it [01:32,  8.71it/s]

loss for batch 800 --> 0.01844130828976631 at epoch 12


841it [01:36,  8.94it/s]

loss for batch 840 --> 0.0008870207821018994 at epoch 12


881it [01:41,  7.61it/s]

loss for batch 880 --> 0.001039510709233582 at epoch 12


921it [01:46,  8.60it/s]

loss for batch 920 --> 0.0014363490045070648 at epoch 12


961it [01:50,  9.12it/s]

loss for batch 960 --> 0.005086914170533419 at epoch 12


1001it [01:55,  9.45it/s]

loss for batch 1000 --> 0.001956172054633498 at epoch 12


1041it [02:00,  8.58it/s]

loss for batch 1040 --> 0.001001726952381432 at epoch 12


1081it [02:04,  8.23it/s]

loss for batch 1080 --> 0.0015362859703600407 at epoch 12


1121it [02:09,  8.47it/s]

loss for batch 1120 --> 0.003056662855669856 at epoch 12


1137it [02:11,  8.66it/s]
100%|██████████| 758/758 [00:12<00:00, 59.81it/s]


1969 out of 3030
acc of 64.98349834983497%


41it [00:04,  8.88it/s]

loss for batch 40 --> 0.0005944504518993199 at epoch 13


81it [00:09,  8.49it/s]

loss for batch 80 --> 0.0009004442836157978 at epoch 13


121it [00:13,  9.72it/s]

loss for batch 120 --> 0.0010634991340339184 at epoch 13


160it [00:18,  9.32it/s]

loss for batch 160 --> 0.0034562491346150637 at epoch 13


201it [00:22,  7.90it/s]

loss for batch 200 --> 0.000636727549135685 at epoch 13


241it [00:27,  8.53it/s]

loss for batch 240 --> 0.0011524184374138713 at epoch 13


282it [00:32,  9.16it/s]

loss for batch 280 --> 0.0007664105505682528 at epoch 13


321it [00:37,  7.81it/s]

loss for batch 320 --> 0.00588468462228775 at epoch 13


361it [00:41, 10.00it/s]

loss for batch 360 --> 0.0007960899383760989 at epoch 13


401it [00:46,  9.00it/s]

loss for batch 400 --> 0.0005118623375892639 at epoch 13


440it [00:51,  8.14it/s]

loss for batch 440 --> 0.004074893891811371 at epoch 13


481it [00:55,  8.93it/s]

loss for batch 480 --> 0.0006770138861611485 at epoch 13


521it [01:00,  8.78it/s]

loss for batch 520 --> 0.0016373927937820554 at epoch 13


561it [01:05,  9.14it/s]

loss for batch 560 --> 0.0008685848442837596 at epoch 13


601it [01:09,  8.29it/s]

loss for batch 600 --> 0.0005784601671621203 at epoch 13


641it [01:14,  8.82it/s]

loss for batch 640 --> 0.0007028655963949859 at epoch 13


681it [01:18,  9.22it/s]

loss for batch 680 --> 0.003628558712080121 at epoch 13


721it [01:23,  9.69it/s]

loss for batch 720 --> 0.005012712441384792 at epoch 13


761it [01:28,  9.37it/s]

loss for batch 760 --> 0.004785130266100168 at epoch 13


801it [01:32,  8.87it/s]

loss for batch 800 --> 0.0006703356048092246 at epoch 13


841it [01:37,  8.58it/s]

loss for batch 840 --> 0.3280859589576721 at epoch 13


881it [01:41,  8.38it/s]

loss for batch 880 --> 0.00205796817317605 at epoch 13


921it [01:46,  8.98it/s]

loss for batch 920 --> 0.15377148985862732 at epoch 13


961it [01:50,  8.60it/s]

loss for batch 960 --> 0.0005697878659702837 at epoch 13


1001it [01:55,  8.94it/s]

loss for batch 1000 --> 0.0007845696527510881 at epoch 13


1041it [02:00,  8.52it/s]

loss for batch 1040 --> 0.008302981033921242 at epoch 13


1081it [02:04,  8.69it/s]

loss for batch 1080 --> 0.0006074495031498373 at epoch 13


1121it [02:09,  9.10it/s]

loss for batch 1120 --> 0.0011312421411275864 at epoch 13


1137it [02:11,  8.66it/s]
100%|██████████| 758/758 [00:12<00:00, 60.03it/s]


1983 out of 3030
acc of 65.44554455445545%


41it [00:04,  8.44it/s]

loss for batch 40 --> 0.002122347243130207 at epoch 14


81it [00:09,  8.75it/s]

loss for batch 80 --> 0.0035544733982533216 at epoch 14


121it [00:14,  9.04it/s]

loss for batch 120 --> 0.0007669109618291259 at epoch 14


161it [00:18,  9.02it/s]

loss for batch 160 --> 0.004710571374744177 at epoch 14


200it [00:23,  8.63it/s]

loss for batch 200 --> 0.03138168156147003 at epoch 14


241it [00:27,  8.61it/s]

loss for batch 240 --> 0.005907194223254919 at epoch 14


281it [00:32,  8.04it/s]

loss for batch 280 --> 0.0011451481841504574 at epoch 14


321it [00:37,  9.13it/s]

loss for batch 320 --> 0.0006171062705107033 at epoch 14


361it [00:41,  8.35it/s]

loss for batch 360 --> 0.0009864148451015353 at epoch 14


401it [00:46,  9.36it/s]

loss for batch 400 --> 0.006032390985637903 at epoch 14


441it [00:50,  8.53it/s]

loss for batch 440 --> 0.0021585545036941767 at epoch 14


481it [00:55,  8.63it/s]

loss for batch 480 --> 0.00048663365305401385 at epoch 14


521it [00:59,  9.30it/s]

loss for batch 520 --> 0.004663440864533186 at epoch 14


560it [01:04,  9.23it/s]

loss for batch 560 --> 0.0005686540389433503 at epoch 14


601it [01:08,  9.28it/s]

loss for batch 600 --> 0.003998907282948494 at epoch 14


641it [01:13,  8.62it/s]

loss for batch 640 --> 0.0009406493045389652 at epoch 14


681it [01:18,  9.44it/s]

loss for batch 680 --> 0.0008229318191297352 at epoch 14


721it [01:23,  8.95it/s]

loss for batch 720 --> 0.0016301716677844524 at epoch 14


761it [01:27,  8.64it/s]

loss for batch 760 --> 0.010673065669834614 at epoch 14


801it [01:32,  9.03it/s]

loss for batch 800 --> 0.14759865403175354 at epoch 14


841it [01:36,  7.84it/s]

loss for batch 840 --> 0.0005109073244966567 at epoch 14


880it [01:41,  7.68it/s]

loss for batch 880 --> 0.004193055909126997 at epoch 14


921it [01:46,  8.66it/s]

loss for batch 920 --> 0.000847428513225168 at epoch 14


961it [01:50,  8.12it/s]

loss for batch 960 --> 0.0015564118511974812 at epoch 14


1001it [01:55,  8.91it/s]

loss for batch 1000 --> 0.0024213665165007114 at epoch 14


1041it [02:00,  7.83it/s]

loss for batch 1040 --> 0.00070631701964885 at epoch 14


1081it [02:04,  9.25it/s]

loss for batch 1080 --> 0.0011272726114839315 at epoch 14


1121it [02:09,  8.54it/s]

loss for batch 1120 --> 0.0005452457116916776 at epoch 14


1137it [02:11,  8.66it/s]
100%|██████████| 758/758 [00:12<00:00, 60.06it/s]


1968 out of 3030
acc of 64.95049504950495%


41it [00:04,  7.76it/s]

loss for batch 40 --> 0.0006883195601403713 at epoch 15


81it [00:09,  9.00it/s]

loss for batch 80 --> 0.005819388665258884 at epoch 15


121it [00:14,  8.31it/s]

loss for batch 120 --> 0.0008114312076941133 at epoch 15


161it [00:18,  9.35it/s]

loss for batch 160 --> 0.0006230713333934546 at epoch 15


201it [00:23,  8.65it/s]

loss for batch 200 --> 0.01404714398086071 at epoch 15


240it [00:27,  9.36it/s]

loss for batch 240 --> 0.0006814587395638227 at epoch 15


281it [00:32,  8.68it/s]

loss for batch 280 --> 0.0007058102637529373 at epoch 15


321it [00:36,  8.88it/s]

loss for batch 320 --> 0.0007705960888415575 at epoch 15


360it [00:41,  8.66it/s]

loss for batch 360 --> 0.0005627887439914048 at epoch 15


401it [00:46,  8.88it/s]

loss for batch 400 --> 0.0005102521390654147 at epoch 15


441it [00:50,  8.81it/s]

loss for batch 440 --> 0.0007559289806522429 at epoch 15


481it [00:55,  8.83it/s]

loss for batch 480 --> 0.014674815349280834 at epoch 15


520it [00:59,  9.10it/s]

loss for batch 520 --> 0.004618144128471613 at epoch 15


561it [01:04,  8.94it/s]

loss for batch 560 --> 0.0021546361967921257 at epoch 15


601it [01:09,  7.73it/s]

loss for batch 600 --> 0.006688895635306835 at epoch 15


640it [01:13,  8.10it/s]

loss for batch 640 --> 0.002052064286544919 at epoch 15


681it [01:18,  9.00it/s]

loss for batch 680 --> 0.01530569326132536 at epoch 15


721it [01:23,  8.38it/s]

loss for batch 720 --> 0.0010266575263813138 at epoch 15


761it [01:27,  9.38it/s]

loss for batch 760 --> 0.008815539069473743 at epoch 15


801it [01:32,  9.51it/s]

loss for batch 800 --> 0.0013892677379772067 at epoch 15


841it [01:37,  8.66it/s]

loss for batch 840 --> 0.0010545996483415365 at epoch 15


881it [01:41,  8.92it/s]

loss for batch 880 --> 0.0007982169045135379 at epoch 15


921it [01:46,  8.90it/s]

loss for batch 920 --> 0.0008782536606304348 at epoch 15


961it [01:51,  9.13it/s]

loss for batch 960 --> 0.002332629868760705 at epoch 15


1001it [01:55,  8.41it/s]

loss for batch 1000 --> 0.000645731168333441 at epoch 15


1041it [02:00,  8.58it/s]

loss for batch 1040 --> 0.00047220976557582617 at epoch 15


1081it [02:04,  9.04it/s]

loss for batch 1080 --> 0.0005720184999518096 at epoch 15


1121it [02:09,  8.00it/s]

loss for batch 1120 --> 0.0015178258763626218 at epoch 15


1137it [02:11,  8.67it/s]
100%|██████████| 758/758 [00:12<00:00, 59.95it/s]


1979 out of 3030
acc of 65.31353135313532%


41it [00:04,  8.70it/s]

loss for batch 40 --> 0.0008139853016473353 at epoch 16


81it [00:09,  8.55it/s]

loss for batch 80 --> 0.001344175892882049 at epoch 16


121it [00:13,  9.19it/s]

loss for batch 120 --> 0.0006958428421057761 at epoch 16


161it [00:18,  8.82it/s]

loss for batch 160 --> 0.0006954226410016418 at epoch 16


201it [00:22,  8.09it/s]

loss for batch 200 --> 0.0005083263968117535 at epoch 16


241it [00:27,  8.56it/s]

loss for batch 240 --> 0.0010832021944224834 at epoch 16


281it [00:32,  8.46it/s]

loss for batch 280 --> 0.0006267158314585686 at epoch 16


321it [00:36,  8.63it/s]

loss for batch 320 --> 0.0039346469566226006 at epoch 16


361it [00:41,  8.31it/s]

loss for batch 360 --> 0.0014742017956450582 at epoch 16


401it [00:46,  8.96it/s]

loss for batch 400 --> 0.000584860157687217 at epoch 16


441it [00:50,  8.37it/s]

loss for batch 440 --> 0.0004268467309884727 at epoch 16


481it [00:55,  8.50it/s]

loss for batch 480 --> 0.0006998738972470164 at epoch 16


521it [01:00,  7.99it/s]

loss for batch 520 --> 0.004465583246201277 at epoch 16


560it [01:04,  8.85it/s]

loss for batch 560 --> 0.0007904688245616853 at epoch 16


602it [01:09,  9.31it/s]

loss for batch 600 --> 0.000518566055689007 at epoch 16


641it [01:14,  8.71it/s]

loss for batch 640 --> 0.00042699958430603147 at epoch 16


681it [01:18,  8.61it/s]

loss for batch 680 --> 0.0005002882098779082 at epoch 16


721it [01:23,  8.78it/s]

loss for batch 720 --> 0.0005617223214358091 at epoch 16


761it [01:27,  8.28it/s]

loss for batch 760 --> 0.0006254873005673289 at epoch 16


801it [01:32,  8.28it/s]

loss for batch 800 --> 0.0020098176319152117 at epoch 16


842it [01:37,  9.24it/s]

loss for batch 840 --> 0.0006935563287697732 at epoch 16


881it [01:41,  9.19it/s]

loss for batch 880 --> 0.001620511175133288 at epoch 16


921it [01:46,  8.76it/s]

loss for batch 920 --> 0.000529325392562896 at epoch 16


961it [01:51,  8.50it/s]

loss for batch 960 --> 0.0004397868469823152 at epoch 16


1000it [01:55,  8.18it/s]

loss for batch 1000 --> 0.0053585902787745 at epoch 16


1041it [02:00,  8.90it/s]

loss for batch 1040 --> 0.00043160459608770907 at epoch 16


1081it [02:05,  8.58it/s]

loss for batch 1080 --> 0.0005016910727135837 at epoch 16


1121it [02:10,  8.25it/s]

loss for batch 1120 --> 0.001249013701453805 at epoch 16


1137it [02:12,  8.59it/s]
100%|██████████| 758/758 [00:12<00:00, 58.58it/s]


1982 out of 3030
acc of 65.41254125412541%


41it [00:04,  8.64it/s]

loss for batch 40 --> 0.0014526094309985638 at epoch 17


81it [00:09,  9.28it/s]

loss for batch 80 --> 0.0007847875822335482 at epoch 17


121it [00:14,  8.34it/s]

loss for batch 120 --> 0.010465865954756737 at epoch 17


161it [00:18,  8.11it/s]

loss for batch 160 --> 0.0003832585643976927 at epoch 17


201it [00:23,  8.74it/s]

loss for batch 200 --> 0.0004204546567052603 at epoch 17


241it [00:27,  8.25it/s]

loss for batch 240 --> 0.0003902368771377951 at epoch 17


281it [00:32,  8.71it/s]

loss for batch 280 --> 0.0009415054810233414 at epoch 17


321it [00:37,  8.47it/s]

loss for batch 320 --> 0.0006061600288376212 at epoch 17


361it [00:42,  8.64it/s]

loss for batch 360 --> 0.0004923664382658899 at epoch 17


401it [00:46,  7.74it/s]

loss for batch 400 --> 0.0011144361924380064 at epoch 17


441it [00:51,  8.62it/s]

loss for batch 440 --> 0.00048003956908360124 at epoch 17


481it [00:56,  8.89it/s]

loss for batch 480 --> 0.0010636920342221856 at epoch 17


521it [01:00,  9.39it/s]

loss for batch 520 --> 0.0005187589558772743 at epoch 17


561it [01:05,  9.18it/s]

loss for batch 560 --> 0.0014754353323951364 at epoch 17


601it [01:09,  8.59it/s]

loss for batch 600 --> 0.00041058528586290777 at epoch 17


641it [01:14,  7.96it/s]

loss for batch 640 --> 0.9491755366325378 at epoch 17


681it [01:19,  9.19it/s]

loss for batch 680 --> 0.0007334384135901928 at epoch 17


721it [01:23,  9.16it/s]

loss for batch 720 --> 0.0005100970156490803 at epoch 17


761it [01:28,  8.04it/s]

loss for batch 760 --> 0.0005867126164957881 at epoch 17


801it [01:33,  8.12it/s]

loss for batch 800 --> 0.0356883779168129 at epoch 17


841it [01:38,  8.90it/s]

loss for batch 840 --> 0.00044986954890191555 at epoch 17


881it [01:43,  8.92it/s]

loss for batch 880 --> 0.0005473238416016102 at epoch 17


921it [01:47,  8.48it/s]

loss for batch 920 --> 0.00043075025314465165 at epoch 17


961it [01:52,  8.02it/s]

loss for batch 960 --> 0.0016799097647890449 at epoch 17


1002it [01:56,  9.58it/s]

loss for batch 1000 --> 0.004231997299939394 at epoch 17


1041it [02:01,  9.19it/s]

loss for batch 1040 --> 0.0007840505568310618 at epoch 17


1081it [02:06,  8.86it/s]

loss for batch 1080 --> 0.00047501479275524616 at epoch 17


1122it [02:10,  9.39it/s]

loss for batch 1120 --> 0.32074037194252014 at epoch 17


1137it [02:12,  8.58it/s]
100%|██████████| 758/758 [00:12<00:00, 59.18it/s]


1983 out of 3030
acc of 65.44554455445545%


41it [00:04,  8.51it/s]

loss for batch 40 --> 0.0005071123596280813 at epoch 18


81it [00:09,  8.83it/s]

loss for batch 80 --> 0.0008429634617641568 at epoch 18


121it [00:13,  8.61it/s]

loss for batch 120 --> 0.0012710009468719363 at epoch 18


161it [00:18,  8.03it/s]

loss for batch 160 --> 0.00041395772132091224 at epoch 18


201it [00:23,  8.35it/s]

loss for batch 200 --> 0.00045372964814305305 at epoch 18


242it [00:27,  9.89it/s]

loss for batch 240 --> 0.0007410310790874064 at epoch 18


280it [00:32,  8.73it/s]

loss for batch 280 --> 0.0005683918134309351 at epoch 18


321it [00:36,  8.87it/s]

loss for batch 320 --> 0.01331379171460867 at epoch 18


362it [00:41,  9.54it/s]

loss for batch 360 --> 0.004905040841549635 at epoch 18


401it [00:45,  9.28it/s]

loss for batch 400 --> 0.0008086961461231112 at epoch 18


441it [00:50,  8.19it/s]

loss for batch 440 --> 0.00035376480082049966 at epoch 18


481it [00:55,  8.19it/s]

loss for batch 480 --> 0.000678027980029583 at epoch 18


521it [01:00,  9.02it/s]

loss for batch 520 --> 0.00035380246117711067 at epoch 18


562it [01:05,  9.66it/s]

loss for batch 560 --> 0.00036114995600655675 at epoch 18


601it [01:09,  7.86it/s]

loss for batch 600 --> 0.0002774633758235723 at epoch 18


641it [01:14,  9.65it/s]

loss for batch 640 --> 0.014878695830702782 at epoch 18


680it [01:18,  8.57it/s]

loss for batch 680 --> 0.00039446610026061535 at epoch 18


721it [01:23,  8.21it/s]

loss for batch 720 --> 0.00041687997872941196 at epoch 18


760it [01:28,  8.55it/s]

loss for batch 760 --> 0.000769724661950022 at epoch 18


801it [01:32,  9.03it/s]

loss for batch 800 --> 0.0006128587410785258 at epoch 18


840it [01:37,  8.48it/s]

loss for batch 840 --> 0.0005758365150541067 at epoch 18


881it [01:42,  8.47it/s]

loss for batch 880 --> 0.0044779097661376 at epoch 18


921it [01:46,  9.02it/s]

loss for batch 920 --> 0.0004793405532836914 at epoch 18


961it [01:51,  7.71it/s]

loss for batch 960 --> 0.0005883215926587582 at epoch 18


1001it [01:56,  8.38it/s]

loss for batch 1000 --> 0.0004199242393951863 at epoch 18


1041it [02:00,  9.01it/s]

loss for batch 1040 --> 0.00037043605698272586 at epoch 18


1081it [02:05,  9.35it/s]

loss for batch 1080 --> 0.0010307373013347387 at epoch 18


1121it [02:09,  8.91it/s]

loss for batch 1120 --> 0.00042578496504575014 at epoch 18


1137it [02:11,  8.63it/s]
100%|██████████| 758/758 [00:12<00:00, 59.71it/s]


1976 out of 3030
acc of 65.21452145214522%


41it [00:04,  8.89it/s]

loss for batch 40 --> 0.0007126162527129054 at epoch 19


80it [00:09,  8.98it/s]

loss for batch 80 --> 0.0004442887147888541 at epoch 19


121it [00:13,  8.88it/s]

loss for batch 120 --> 0.00041662342846393585 at epoch 19


161it [00:18,  7.53it/s]

loss for batch 160 --> 0.0005033697234466672 at epoch 19


201it [00:22,  8.31it/s]

loss for batch 200 --> 0.0007037183386273682 at epoch 19


241it [00:27,  9.07it/s]

loss for batch 240 --> 0.0050188577733933926 at epoch 19


281it [00:32,  8.92it/s]

loss for batch 280 --> 0.0022363776806741953 at epoch 19


321it [00:36,  7.59it/s]

loss for batch 320 --> 0.00048230797983706 at epoch 19


361it [00:41,  9.19it/s]

loss for batch 360 --> 0.001133037731051445 at epoch 19


402it [00:46,  9.95it/s]

loss for batch 400 --> 0.0004296000406611711 at epoch 19


441it [00:50,  8.24it/s]

loss for batch 440 --> 0.0004963457467965782 at epoch 19


480it [00:55,  7.81it/s]

loss for batch 480 --> 0.0014776620082557201 at epoch 19


521it [01:00,  9.16it/s]

loss for batch 520 --> 0.0004153990012127906 at epoch 19


560it [01:04,  9.09it/s]

loss for batch 560 --> 0.00036690867273136973 at epoch 19


601it [01:09,  8.48it/s]

loss for batch 600 --> 0.003661929862573743 at epoch 19


641it [01:13,  8.55it/s]

loss for batch 640 --> 0.0003208428679499775 at epoch 19


681it [01:18,  9.40it/s]

loss for batch 680 --> 0.0006113083218224347 at epoch 19


721it [01:23,  8.85it/s]

loss for batch 720 --> 0.0006850688951089978 at epoch 19


761it [01:28,  9.19it/s]

loss for batch 760 --> 0.0004485441604629159 at epoch 19


801it [01:32,  8.38it/s]

loss for batch 800 --> 0.00033103616442531347 at epoch 19


841it [01:37,  8.59it/s]

loss for batch 840 --> 0.00035141262924298644 at epoch 19


881it [01:42,  9.08it/s]

loss for batch 880 --> 0.0008225842611864209 at epoch 19


921it [01:46,  9.70it/s]

loss for batch 920 --> 0.002279805252328515 at epoch 19


961it [01:51,  8.91it/s]

loss for batch 960 --> 0.001203439780510962 at epoch 19


1001it [01:56,  9.04it/s]

loss for batch 1000 --> 0.0016484435182064772 at epoch 19


1041it [02:00,  8.33it/s]

loss for batch 1040 --> 0.00037241994868963957 at epoch 19


1080it [02:05,  8.36it/s]

loss for batch 1080 --> 0.00044343434274196625 at epoch 19


1121it [02:09,  9.03it/s]

loss for batch 1120 --> 0.00077403913019225 at epoch 19


1137it [02:11,  8.64it/s]


In [20]:
submission_csv = pd.read_csv("/kaggle/input/contradictory-my-dear-watson/test.csv")
submission_csv.drop(columns=["lang_abv", "language"])

Unnamed: 0,id,premise,hypothesis
0,c6d58c3f69,بکس، کیسی، راہیل، یسعیاہ، کیلی، کیلی، اور کولم...,"کیسی کے لئے کوئی یادگار نہیں ہوگا, کولمین ہائی..."
1,cefcc82292,هذا هو ما تم نصحنا به.,عندما يتم إخبارهم بما يجب عليهم فعله ، فشلت ال...
2,e98005252c,et cela est en grande partie dû au fait que le...,Les mères se droguent.
3,58518c10ba,与城市及其他公民及社区组织代表就IMA的艺术发展进行对话&amp,IMA与其他组织合作，因为它们都依靠共享资金。
4,c32b0d16df,Она все еще была там.,"Мы думали, что она ушла, однако, она осталась."
...,...,...,...
5190,5f90dd59b0,نیند نے وعدہ کیا کہ موٹل نے سوال میں تحقیق کی.,نیمیتھ کو موٹل کی تفتیش کے لئے معاوضہ دیا جارہ...
5191,f357a04e86,The rock has a soft texture and can be bough...,The rock is harder than most types of rock.
5192,1f0ea92118,她目前的存在，并考虑到他与沃佛斯顿争执的本质，那是尴尬的。,她在与Wolverstone的打斗结束后才在场的事实被看作是很尴尬的。
5193,0407b48afb,isn't it i can remember i've only been here ei...,I could see downtown Dallas from where I lived...


In [21]:
submission_dataset = CustomDatasetSubmission(df=submission_csv.drop(columns=["lang_abv", "language"]), max_length=max_length, tokenizer=tokenizer)

In [22]:
submission_dataloder = DataLoader(dataset=submission_dataset, batch_size=32, shuffle=False, generator=torch.Generator(device=device), collate_fn=submission_dataset.collate_fn)

In [23]:
final_output = {
    "id": [],
    "prediction": []
}

model.eval()
with torch.inference_mode():
    for batch in tqdm(submission_dataloder):
        logits = model(batch["input_ids"], attention_mask=batch["attention_mask"], token_type_ids=batch["token_type_ids"])
        percents = torch.softmax(logits, dim=1)
        preds = torch.argmax(percents, dim=1)
        for idx, prediction in zip(batch["ids"], preds.view(-1).tolist()):
            final_output["id"].append(idx)
            final_output["prediction"].append(prediction)


100%|██████████| 163/163 [00:20<00:00,  7.95it/s]


In [24]:
final_output_df = pd.DataFrame(final_output)
final_output_df
final_output_df.to_csv("./submission.csv")