In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import requirements

In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 5.3 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 47.5 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 55.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.9 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 55.1 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found ex

In [None]:
import os
import pdb
import argparse
from dataclasses import dataclass, field
from typing import Optional
from collections import defaultdict

import torch
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm, trange

from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    AutoConfig,
    AdamW
)

# 1. Preprocess

In [None]:
def make_id_file(task, tokenizer):
    def make_data_strings(file_name):
        data_strings = []
        with open(os.path.join(file_name), 'r', encoding='utf-8') as f:
            id_file_data = [tokenizer.encode(line.lower()) for line in f.readlines()]
        for item in id_file_data:
            data_strings.append(' '.join([str(k) for k in item]))
        return data_strings
    
    print('it will take some times...')
    train_pos = make_data_strings('sentiment.train.1')
    train_neg = make_data_strings('sentiment.train.0')
    dev_pos = make_data_strings('sentiment.dev.1')
    dev_neg = make_data_strings('sentiment.dev.0')

    print('make id file finished!')
    return train_pos, train_neg, dev_pos, dev_neg

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

In [None]:
from google.colab import files
uploaded = files.upload()

Saving sentiment.dev.0 to sentiment.dev.0
Saving sentiment.dev.1 to sentiment.dev.1
Saving sentiment.train.0 to sentiment.train.0
Saving sentiment.train.1 to sentiment.train.1
Saving test_no_label.csv to test_no_label.csv


In [None]:
!ls

drive	     sentiment.dev.0  sentiment.train.0  test_no_label.csv
sample_data  sentiment.dev.1  sentiment.train.1


In [None]:
train_pos, train_neg, dev_pos, dev_neg = make_id_file('yelp', tokenizer)

it will take some times...
make id file finished!


In [None]:
train_pos[:10]

['101 6581 2833 1012 102',
 '101 21688 8013 2326 1012 102',
 '101 2027 2036 2031 3679 19247 1998 3256 6949 2029 2003 2428 2204 1012 102',
 '101 2009 1005 1055 1037 2204 15174 2098 7570 22974 2063 1012 102',
 '101 1996 3095 2003 5379 1012 102',
 '101 2204 3347 2833 1012 102',
 '101 2204 2326 1012 102',
 '101 11350 1997 2154 2003 25628 1998 7167 1997 19247 1012 102',
 '101 2307 2173 2005 6265 2030 3347 27962 1998 5404 1012 102',
 '101 1996 2047 2846 3504 6429 1012 102']

In [None]:
class SentimentDataset(object):
    def __init__(self, tokenizer, pos, neg):
        self.tokenizer = tokenizer
        self.data = []
        self.label = []

        for pos_sent in pos:
            self.data += [self._cast_to_int(pos_sent.strip().split())]
            self.label += [[1]]
        for neg_sent in neg:
            self.data += [self._cast_to_int(neg_sent.strip().split())]
            self.label += [[0]]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample), np.array(self.label[index])

In [None]:
train_dataset = SentimentDataset(tokenizer, train_pos, train_neg)
dev_dataset = SentimentDataset(tokenizer, dev_pos, dev_neg)

In [None]:
for i, item in enumerate(train_dataset):
    print(item)
    if i == 10:
        break

(array([ 101, 6581, 2833, 1012,  102]), array([1]))
(array([  101, 21688,  8013,  2326,  1012,   102]), array([1]))
(array([  101,  2027,  2036,  2031,  3679, 19247,  1998,  3256,  6949,
        2029,  2003,  2428,  2204,  1012,   102]), array([1]))
(array([  101,  2009,  1005,  1055,  1037,  2204, 15174,  2098,  7570,
       22974,  2063,  1012,   102]), array([1]))
(array([ 101, 1996, 3095, 2003, 5379, 1012,  102]), array([1]))
(array([ 101, 2204, 3347, 2833, 1012,  102]), array([1]))
(array([ 101, 2204, 2326, 1012,  102]), array([1]))
(array([  101, 11350,  1997,  2154,  2003, 25628,  1998,  7167,  1997,
       19247,  1012,   102]), array([1]))
(array([  101,  2307,  2173,  2005,  6265,  2030,  3347, 27962,  1998,
        5404,  1012,   102]), array([1]))
(array([ 101, 1996, 2047, 2846, 3504, 6429, 1012,  102]), array([1]))
(array([ 101, 2023, 2173, 2001, 2200, 2204, 1012,  102]), array([1]))


In [None]:
def collate_fn_style(samples):
    input_ids, labels = zip(*samples)
    max_len = max(len(input_id) for input_id in input_ids)
    sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]

    input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
                             batch_first=True)
    attention_mask = torch.tensor(
        [[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
         sorted_indices])
    token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
    position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
    labels = torch.tensor(np.stack(labels, axis=0)[sorted_indices])

    return input_ids, attention_mask, token_type_ids, position_ids, labels

In [None]:
# random seed
random_seed=42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def compute_acc(predictions, target_labels):
    return (np.array(predictions) == np.array(target_labels)).mean()

In [None]:
class Trainer():
    def __init__(self, device, output_path, lr, resume_path):
        self.output_path = output_path
        self.device = device
        self.model = BertForSequenceClassification.from_pretrained('bert-large-uncased')
        self.optimizer = AdamW(self.model.parameters(), lr=lr)
        if resume_path :
            checkpoint = torch.load(resume_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_train_epoch = checkpoint['epoch'] + 1
            self.lowest_valid_loss = checkpoint['lowest_valid_loss']
        else:
            self.start_train_epoch = 0
            self.lowest_valid_loss = 9999.
        self.model.to(self.device)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def training(self, train_loader, dev_loader, last_epoch):
        self.model.train()
        for epoch in range(self.start_train_epoch, last_epoch):
            with tqdm(train_loader, unit="batch") as tepoch:
                for iteration, (input_ids, attention_mask, token_type_ids, position_ids, labels) in enumerate(tepoch):
                    tepoch.set_description(f"Epoch {epoch}")
                    input_ids = input_ids.to(self.device)
                    attention_mask = attention_mask.to(self.device)
                    token_type_ids = token_type_ids.to(self.device)
                    position_ids = position_ids.to(self.device)
                    labels = labels.to(self.device, dtype=torch.long)

                    self.optimizer.zero_grad()

                    output = self.model(input_ids=input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               labels=labels)

                    loss = output.loss
                    loss.backward()

                    self.optimizer.step()
                
                    tepoch.set_postfix(loss=loss.item())
                    if iteration != 0 and iteration % int(len(train_loader) / 5) == 0:
                        # Evaluate the model five times per epoch
                        with torch.no_grad():
                            self.model.eval()
                            valid_losses = []
                            predictions = []
                            target_labels = []
                            for input_ids, attention_mask, token_type_ids, position_ids, labels in tqdm(dev_loader,
                                                                                                    desc='Eval',
                                                                                                    position=1,
                                                                                                    leave=None):
                                input_ids = input_ids.to(self.device)
                                attention_mask = attention_mask.to(self.device)
                                token_type_ids = token_type_ids.to(self.device)
                                position_ids = position_ids.to(self.device)
                                labels = labels.to(self.device, dtype=torch.long)

                                output = self.model(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids,
                                           position_ids=position_ids,
                                           labels=labels)

                                logits = output.logits
                                loss = output.loss
                                valid_losses.append(loss.item())

                                batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
                                batch_labels = [int(example) for example in labels]

                                predictions += batch_predictions
                                target_labels += batch_labels

                        print(epoch)
                        acc = compute_acc(predictions, target_labels)
                        valid_loss = sum(valid_losses) / len(valid_losses)
                        if self.lowest_valid_loss > valid_loss:
                            print('')
                            print('Acc for model which have lower valid loss: ', acc)
                            #torch.save(self.model.state_dict(), "/content/drive/MyDrive/Colab Notebooks/NLP/project/pytorch_model.bin")
                            torch.save({
                                'epoch': epoch,
                                'lowest_valid_loss': self.lowest_valid_loss,
                                'optimizer_state_dict': self.optimizer.state_dict(),
                                'model_state_dict': self.model.state_dict(),
                                }, f'{self.output_path}/checkpoint_epoch_{epoch}.{iteration}.pth')
                            self.lowest_valid_loss = valid_loss
                            print('--------------save checkpoint at epoch : {}--------------'.format(epoch))
                            print('--------------lowest_valid_loss : {}--------------'.format(self.lowest_valid_loss))

In [None]:
train_batch_size=32
eval_batch_size=32

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=train_batch_size,
                                           shuffle=True, collate_fn=collate_fn_style,
                                           pin_memory=True, num_workers=2)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=eval_batch_size,
                                         shuffle=False, collate_fn=collate_fn_style,
                                         num_workers=2)

# 경로설정
output_path = '/content/drive/MyDrive/Goorm_Deep_Learning/Projects/project1/checkpoints/'
#resume_path = '/content/drive/MyDrive/Colab Notebooks/NLP/project/checkpoints/checkpoint_epoch_1.13850.pth'

lr = 5e-6 #5e-5
last_epoch = 5

trainer = Trainer(device,output_path,lr,None)
trainer.training(train_loader, dev_loader, last_epoch)


Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

0

Acc for model which have lower valid loss:  0.973


Epoch 0:  20%|██        | 2771/13852 [12:56<25:00:34,  8.13s/batch, loss=0.0399]

--------------save checkpoint at epoch : 0--------------
--------------lowest_valid_loss : 0.07874967139400542--------------


Epoch 0:  40%|███▉      | 5540/13852 [25:19<38:29,  3.60batch/s, loss=0.00484]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.26it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.54it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.64it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.33it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.82it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.05it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.36it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.45it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.60it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.21it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.31it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.44it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.53it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.66it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.78it/s][A
Eval:  

0

Acc for model which have lower valid loss:  0.977


Epoch 0:  40%|████      | 5541/13852 [25:44<17:59:47,  7.80s/batch, loss=0.00484]

--------------save checkpoint at epoch : 0--------------
--------------lowest_valid_loss : 0.06581158636137843--------------


Epoch 0:  60%|█████▉    | 8310/13852 [38:08<24:09,  3.82batch/s, loss=0.00594]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.21it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.66it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.56it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.43it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.91it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.12it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.46it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.57it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.73it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:07, 13.33it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.39it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.51it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.57it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.72it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.85it/s][A
Eval:  

0


Epoch 0:  80%|███████▉  | 11080/13852 [50:41<12:24,  3.72batch/s, loss=0.00135]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.28it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.54it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.68it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.52it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.93it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.12it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.44it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.47it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.61it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.21it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.30it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.42it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.53it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.68it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.80it/s][A
Eval: 

0

Acc for model which have lower valid loss:  0.979


Epoch 0:  80%|███████▉  | 11081/13852 [51:06<6:02:16,  7.84s/batch, loss=0.00135]

--------------save checkpoint at epoch : 0--------------
--------------lowest_valid_loss : 0.05772567156329751--------------


Epoch 0: 100%|█████████▉| 13850/13852 [1:03:30<00:00,  3.59batch/s, loss=0.0047]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.09it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.56it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.47it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.40it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.90it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.12it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.42it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.51it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.65it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:07, 13.26it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.35it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.46it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.50it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.61it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.75it/s][A
Eval:

0


Epoch 0: 100%|██████████| 13852/13852 [1:03:40<00:00,  3.63batch/s, loss=0.308]
Epoch 1:  20%|█▉        | 2770/13852 [12:23<50:05,  3.69batch/s, loss=0.0197]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.35it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.65it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.66it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.44it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.87it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.10it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.41it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.47it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.62it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.23it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.33it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.44it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.48it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 

1

Acc for model which have lower valid loss:  0.979


Epoch 1:  20%|██        | 2771/13852 [12:52<27:27:55,  8.92s/batch, loss=0.0197]

--------------save checkpoint at epoch : 1--------------
--------------lowest_valid_loss : 0.05528783631511033--------------


Epoch 1:  40%|███▉      | 5540/13852 [25:17<38:29,  3.60batch/s, loss=0.00608]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.21it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.43it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.50it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.38it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.89it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.10it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.40it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.51it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.62it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.20it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.28it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.35it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.41it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.55it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.69it/s][A
Eval:  

1

Acc for model which have lower valid loss:  0.97975


Epoch 1:  40%|████      | 5541/13852 [25:44<19:39:30,  8.52s/batch, loss=0.00608]

--------------save checkpoint at epoch : 1--------------
--------------lowest_valid_loss : 0.05433912756713107--------------


Epoch 1:  60%|█████▉    | 8310/13852 [38:09<25:19,  3.65batch/s, loss=0.00573]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.01it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:13,  9.34it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.47it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.41it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.90it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.10it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.41it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.54it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.68it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.24it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.34it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.44it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.47it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.57it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.71it/s][A
Eval:  

1

Acc for model which have lower valid loss:  0.983


Epoch 1:  60%|█████▉    | 8311/13852 [38:39<14:18:07,  9.29s/batch, loss=0.00573]

--------------save checkpoint at epoch : 1--------------
--------------lowest_valid_loss : 0.04990366371162236--------------


Epoch 1:  80%|███████▉  | 11080/13852 [51:04<12:41,  3.64batch/s, loss=0.0459]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.11it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:13,  9.37it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.41it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.34it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.81it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.06it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.44it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.54it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.67it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.21it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.33it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.44it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.57it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.70it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.76it/s][A
Eval:  

1


Epoch 1: 100%|█████████▉| 13850/13852 [1:03:36<00:00,  3.81batch/s, loss=0.0115]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.11it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.58it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.51it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.40it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.82it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 12.98it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.33it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.41it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.58it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.20it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.29it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.42it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.51it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.60it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.77it/s][A
Eval:

1


Epoch 1: 100%|██████████| 13852/13852 [1:03:46<00:00,  3.62batch/s, loss=0.00154]
Epoch 2:  20%|█▉        | 2770/13852 [12:25<51:13,  3.61batch/s, loss=0.000903]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:25,  4.91it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:13,  9.27it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.43it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.35it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.85it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.09it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.38it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.51it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.66it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:07, 13.26it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.36it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.47it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.55it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:

2


Epoch 2:  40%|███▉      | 5540/13852 [25:00<36:26,  3.80batch/s, loss=0.00598]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.24it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.56it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.49it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.40it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.83it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.05it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.40it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.46it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.57it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.20it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.31it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.42it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.51it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.56it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.69it/s][A
Eval:  

2


Epoch 2:  60%|█████▉    | 8310/13852 [37:33<24:56,  3.70batch/s, loss=0.0453] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.21it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.50it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.62it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.48it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.88it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.09it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.37it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.50it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.63it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.22it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.30it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.38it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.49it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.60it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.76it/s][A
Eval:  

2


Epoch 2:  80%|███████▉  | 11080/13852 [50:08<12:19,  3.75batch/s, loss=0.046]   
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.14it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.62it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.60it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.48it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.94it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.14it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.42it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.50it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.67it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.24it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.31it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.41it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.50it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.60it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.75it/s][A
Eval:

2


Epoch 2: 100%|█████████▉| 13850/13852 [1:02:43<00:00,  3.64batch/s, loss=0.00217]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.15it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.45it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.56it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.39it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.85it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 12.97it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.32it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.43it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.54it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.16it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.25it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.36it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.42it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.52it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.67it/s][A
Eval

2


Epoch 2: 100%|██████████| 13852/13852 [1:02:53<00:00,  3.67batch/s, loss=0.00622]
Epoch 3:  20%|█▉        | 2770/13852 [12:25<48:09,  3.84batch/s, loss=0.000617]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.12it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.48it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.48it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.38it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.80it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.04it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.38it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.42it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.60it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.19it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.29it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.42it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.51it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:

3


Epoch 3:  40%|███▉      | 5540/13852 [25:00<37:41,  3.68batch/s, loss=0.0795] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.28it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.56it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.66it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.52it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 13.00it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.14it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.43it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.52it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.68it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.24it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.33it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.46it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.50it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.57it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.70it/s][A
Eval:  

3


Epoch 3:  60%|█████▉    | 8310/13852 [37:34<24:52,  3.71batch/s, loss=0.0719]  
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.14it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.65it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.57it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.29it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.76it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.02it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.33it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.42it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.58it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.20it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.30it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.42it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.47it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.65it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.78it/s][A
Eval: 

3


Epoch 3:  80%|███████▉  | 11080/13852 [50:09<12:08,  3.81batch/s, loss=8.73e-5] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.08it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:13,  9.31it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.44it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.38it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.81it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.03it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.36it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.47it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.63it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.21it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.30it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.43it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.49it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.61it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.76it/s][A
Eval:

3


Epoch 3: 100%|█████████▉| 13850/13852 [1:02:45<00:00,  3.80batch/s, loss=0.00259] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.29it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.56it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.61it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.47it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:08, 12.90it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.12it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.42it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.50it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.61it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.13it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.24it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.33it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.41it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.54it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.70it/s][A
Eva

3


Epoch 3: 100%|██████████| 13852/13852 [1:02:55<00:00,  3.67batch/s, loss=0.00206]
Epoch 4:  20%|█▉        | 2770/13852 [12:26<50:14,  3.68batch/s, loss=0.000293]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.08it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.53it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.51it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.34it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.79it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 12.99it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.30it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.41it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.51it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.11it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.22it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.33it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.36it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:

4


Epoch 4:  40%|███▉      | 5540/13852 [25:00<37:31,  3.69batch/s, loss=0.000828]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:23,  5.17it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.47it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.52it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.39it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.88it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.08it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.41it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.52it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.60it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.23it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.34it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.48it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.51it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.61it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.76it/s][A
Eval: 

4


Epoch 4:  60%|█████▉    | 8310/13852 [37:37<24:54,  3.71batch/s, loss=0.00533] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.07it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:13,  9.36it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.44it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.37it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.86it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.09it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.45it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.55it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.64it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.18it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.26it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.37it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.42it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.59it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.75it/s][A
Eval: 

4


Epoch 4:  80%|███████▉  | 11080/13852 [50:12<11:51,  3.90batch/s, loss=0.00109]
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.13it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.42it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.52it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.40it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.84it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.03it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.36it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.40it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.53it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.18it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.30it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.38it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.47it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.60it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:06, 13.76it/s][A
Eval: 

4


Epoch 4: 100%|█████████▉| 13850/13852 [1:02:46<00:00,  3.86batch/s, loss=0.00289] 
Eval:   0%|          | 0/125 [00:00<?, ?it/s][A
Eval:   1%|          | 1/125 [00:00<00:24,  5.00it/s][A
Eval:   2%|▏         | 3/125 [00:00<00:12,  9.51it/s][A
Eval:   4%|▍         | 5/125 [00:00<00:10, 11.48it/s][A
Eval:   6%|▌         | 7/125 [00:00<00:09, 12.38it/s][A
Eval:   7%|▋         | 9/125 [00:00<00:09, 12.84it/s][A
Eval:   9%|▉         | 11/125 [00:00<00:08, 13.06it/s][A
Eval:  10%|█         | 13/125 [00:01<00:08, 13.38it/s][A
Eval:  12%|█▏        | 15/125 [00:01<00:08, 13.43it/s][A
Eval:  14%|█▎        | 17/125 [00:01<00:07, 13.60it/s][A
Eval:  15%|█▌        | 19/125 [00:01<00:08, 13.16it/s][A
Eval:  17%|█▋        | 21/125 [00:01<00:07, 13.21it/s][A
Eval:  18%|█▊        | 23/125 [00:01<00:07, 13.31it/s][A
Eval:  20%|██        | 25/125 [00:01<00:07, 13.37it/s][A
Eval:  22%|██▏       | 27/125 [00:02<00:07, 13.55it/s][A
Eval:  23%|██▎       | 29/125 [00:02<00:07, 13.69it/s][A
Eva

4


Epoch 4: 100%|██████████| 13852/13852 [1:02:56<00:00,  3.67batch/s, loss=0.000368]


In [None]:
import pandas as pd
test_df = pd.read_csv('test_no_label.csv')

In [None]:
test_dataset = test_df['Id']

In [None]:
def make_id_file_test(tokenizer, test_dataset):
    data_strings = []
    id_file_data = [tokenizer.encode(sent.lower()) for sent in test_dataset]
    for item in id_file_data:
        data_strings.append(' '.join([str(k) for k in item]))
    return data_strings

In [None]:
test = make_id_file_test(tokenizer, test_dataset)

In [None]:
test[:10]

['101 2009 1005 1055 1037 2878 2047 3325 1998 2047 26389 2169 2051 2017 2175 1012 102',
 '101 2061 15640 2013 2019 2214 5440 1012 102',
 '101 2009 2003 1996 2087 14469 7273 1999 1996 3028 1012 102',
 '101 2079 2025 3696 1037 10084 2007 2122 2111 1012 102',
 '101 1045 2001 6091 1998 2016 2081 2033 2514 2061 6625 1998 6160 1012 102',
 '101 1996 2069 2518 2057 2363 2008 2001 2980 2001 1996 4157 1012 102',
 '101 2053 1010 2025 1996 3924 2012 2004 2226 1010 1996 3924 1999 3502 2152 1012 102',
 '101 2027 3288 2009 2041 2392 2005 2017 1998 2024 2200 14044 1012 102',
 '101 4606 1996 12043 2106 1050 1005 1056 2130 2113 2129 2000 2147 1996 3274 1012 102',
 '101 2027 2031 2019 6581 4989 1997 25025 2015 2000 5454 2013 1012 102']

In [None]:
class SentimentTestDataset(object):
    def __init__(self, tokenizer, test):
        self.tokenizer = tokenizer
        self.data = []

        for sent in test:
            self.data += [self._cast_to_int(sent.strip().split())]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample)

In [None]:
test_dataset = SentimentTestDataset(tokenizer, test)

In [None]:
def collate_fn_style_test(samples):
    input_ids = samples
    print(input_ids)
    max_len = max(len(input_id) for input_id in input_ids)
    #sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]
    sorted_indices = list(i for i in range(len(input_ids)))

    input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
                             batch_first=True)
    attention_mask = torch.tensor(
        [[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
         sorted_indices])
    token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
    position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])

    return input_ids, attention_mask, token_type_ids, position_ids

In [None]:
test_batch_size = 32
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                          shuffle=False, collate_fn=collate_fn_style_test,
                                          num_workers=2)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = '/content/drive/MyDrive/Goorm_Deep_Learning/Projects/project1/checkpoints/checkpoint_epoch_1.8310.pth'
model = BertForSequenceClassification.from_pretrained('bert-large-uncased')
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

with torch.no_grad():
    model.eval()
    predictions = []
    for input_ids, attention_mask, token_type_ids, position_ids in tqdm(test_loader,
                                                                        desc='Test',
                                                                        position=1,
                                                                        leave=None):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        position_ids = position_ids.to(device)

        output = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids,
                       position_ids=position_ids)

        logits = output.logits
        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
        predictions += batch_predictions

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

[array([  101,  1996,  9686, 20110,  2080,  2001,  2126,  2205,  2980,
        1012,   102]), array([ 101, 1996, 2833, 2001, 3492, 2919, 1010, 1045, 2052, 2025, 2175,
       2045, 2153, 1012,  102]), array([  101,  2034,  1010,  1996, 15812,  2001,  2019,  1037,  1011,
        1011,  4920,  1012,   102]), array([ 101, 2079, 1050, 1005, 1056, 5949, 1996, 4440, 1012,  102]), array([  101,  2017, 24185,  1050,  1005,  1056,  2424,  1037,  2488,
        4989,  1999,  3660, 15145,  1012,   102]), array([ 101, 5959, 2635, 2026, 2155, 2182, 2467, 1996, 4840, 4355, 2712,
       2833, 1012,  102]), array([  101,  1996,  5785,  2001,  3147,  1998,  2524,  1998,  1996,
        5510,  2001,  2200, 20857,  1012,   102]), array([ 101, 7078, 6659, 1010, 2079, 2025, 2344, 2013, 2023, 2173, 1012,
        102]), array([  101,  2572, 15599,  3401,  1998,  4292,  2020, 21688,  1012,
         102]), array([ 101, 2028, 3714, 2518, 2044, 2178, 2027, 2079, 1050, 1005, 1056,
       2729, 2000, 4769, 1012,  102


Test:   3%|▎         | 1/32 [00:00<00:09,  3.44it/s][A

[array([ 101, 1045, 2097, 2022, 2183, 2067, 1998, 9107, 2023, 2307, 2173,
        999,  102]), array([ 101, 2053, 2655, 1010, 2053, 2498, 1012,  102]), array([  101,  2190,  6350,  1045,  2018,  2006,  2026,  3522,  4440,
        2000,  3660, 15145,  1012,   102]), array([  101,  2074,  2253,  2067,  2000,  2131,  1996,  4373, 12824,
       10203,  1998,  3561,  1012,   102]), array([  101,  2307,  2833,  1010,  2307,  2326,  1010,  2037, 18640,
        2064,  2025,  2022,  3786,  1012,   102]), array([ 101, 1045, 3984, 2016, 2001, 1050, 1005, 1056, 3407, 2008, 2057,
       2020, 4851, 1996, 7597, 1012,  102]), array([  101,  1045,  2052,  2196, 28667,  9006, 10497,  3087,  2000,
        2444,  2182,  1012,   102]), array([ 101, 2023, 2173, 2003, 3968, 6292, 1998, 4569, 1012,  102]), array([ 101, 2023, 2173, 2003, 3458, 1037, 4485, 4920, 1012,  102]), array([  101,  1045,  2036,  2131, 18064,  2182,  2043,  1045,  2031,
        1037,  4086, 11868,  1012,   102]), array([ 101, 1045, 201


Test:   9%|▉         | 3/32 [00:00<00:03,  7.70it/s][A

[array([ 101, 2025, 2061, 2307, 2833, 1998, 2326, 1012,  102]), array([  101,  1996,  2662, 22715,  9956,  2001,  2036,  1037,  2307,
       22715,  9956,  1012,   102]), array([  101,  1045,  2699,  2000,  2053, 24608,  1012,   102]), array([ 101, 2174, 1010, 2009, 2357, 2041, 2000, 2022, 2498, 2066, 1045,
       2245, 2009, 2052, 1012,  102]), array([  101,  2026, 11687,  7273, 12595,  2066,  7273,  5785, 27130,
        2007,  3347,  4783,  4226, 12901,  1012,   102]), array([ 101, 2699, 2037, 5869, 8490, 2532, 1998, 2049, 1996, 2190, 4921,
       2063, 2412, 2018, 1012,  102]), array([ 101, 2023, 2003, 4089, 1996, 5409, 3306, 2833, 1045, 1005, 2310,
       2018, 1999, 2026, 2166, 1012,  102]), array([ 101, 2025, 2061, 2172, 2122, 2420, 1012,  102]), array([  101,  1996, 25545,  2001,  5667, 11158,  1012,   102]), array([ 101, 2003, 2008, 1037, 2204, 2518, 2030, 1037, 2919, 2518, 1029,
        102]), array([ 101, 2057, 2020, 2119, 2061, 7622, 1012,  102]), array([ 101, 3071, 2040, 25


Test:  16%|█▌        | 5/32 [00:00<00:02,  9.59it/s][A

[array([  101,  2037,  4372,  5428, 27266,  2050, 12901,  6669, 13711,
        2098,  2023,  9841,  1012,   102]), array([ 101, 1045, 1005, 1049, 2025, 1037, 5470, 1997, 4121, 4677, 7884,
       1012,  102]), array([ 101, 2021, 2009, 2003, 5791, 4276, 1996, 3524, 1012,  102]), array([  101,  1045,  2018,  1996, 15890,  2320,  1998,  2009,  2001,
       28900,  1998,  2200, 14894,  3993,   999,   102]), array([ 101, 1996, 2450, 2040, 3271, 2033, 2651, 2001, 2200, 5379, 1998,
       3716, 3085, 1012,  102]), array([ 101, 5564, 1045, 1005, 1049, 3403, 2005, 2019, 3437, 1012,  102]), array([  101,  2009,  2003,  2107,  1037,  4658,  2210,  2173,  2000,
        9483,  1998,  2131, 10677,  4801,  1012,   102]), array([  101,  2069,  2059,  2106,  2256, 13877,  2265,  2039,  2007,
        2178,  2358, 12541, 11253, 10441,  2213,  2452,  2440,  1997,
        2300,  1012,   102]), array([ 101, 2467, 1037, 2307, 3325, 2045, 2007, 1996, 3954, 1998, 1996,
       2717, 1997, 1996, 2136, 1012,  102]


Test:  22%|██▏       | 7/32 [00:00<00:02, 10.33it/s][A

[array([  101,  2079,  1050,  1005,  1056,  2175,  2182,  4983,  2017,
        2215,  2000,  3477,  2005, 10231,  1012,   102]), array([  101,  1996,  4471, 19294,  2003, 12476,  2036,  1012,   102]), array([  101,  1996,  3095,  2003, 12476,  1998,  3295,  2003,  2157,
        1999,  1996,  2540,  1997,  2214,  2237,   999,   102]), array([ 101, 2023, 3962, 2001, 2026, 5440, 2796, 4825, 1012,  102]), array([ 101, 3333, 2125, 2026, 2417, 8722, 1998, 2093, 3940, 1997, 2304,
       6007, 1012,  102]), array([ 101, 2057, 1005, 2310, 2042, 2000, 2023, 2146, 2511, 4825, 2116,
       2335, 1012,  102]), array([  101,  2028,  1997,  2026,  3180,  7516,  1999,  3660, 15145,
        1012,   102]), array([ 101, 1045, 2699, 2000, 2655, 8385, 2007, 2053, 3433, 1012,  102]), array([ 101, 2065, 1045, 2018, 2359, 2009, 8871, 1045, 2052, 2031, 8871,
       2009, 2870,  999,  102]), array([  101,  1045,  2165,  1996, 15610,  1005,  1055, 12832,  1997,
        1996,  2310,  2389,  3792,  1998,  2009,  2


Test:  28%|██▊       | 9/32 [00:00<00:02, 11.07it/s][A

[array([  101,  5717, 10821,  2007,  2010,  2147,  1012,   102]), array([  101,  1045,  5791, 16755,  2023,  2173,  2000,  2500,   999,
         102]), array([ 101, 2057, 2293, 2000, 2272, 2004, 1037, 2155, 1012,  102]), array([ 101, 2253, 2046, 2023, 3295, 2000, 4530, 2009, 2125, 1012,  102]), array([ 101, 1045, 2293, 2008, 2023, 2003, 1996, 2173, 1999, 1996, 5021,
       1011, 2155, 9661, 1012,  102]), array([  101,  4918,  2001,  4550, 15708,  1998,  2658,  1012,   102]), array([ 101, 2023, 2003, 1037, 6659, 2267, 1012,  102]), array([ 101, 1045, 1005, 2310, 8828, 2182, 2116, 2335, 1010, 2021, 3904,
       2004, 2919, 2004, 2197, 2305, 1012,  102]), array([ 101, 1045, 2293, 3059, 1998, 1045, 4521, 2182, 2411, 1012,  102]), array([  101,  2288, 14180,  2012,  3347,  2008,  2057,  2081, 17829,
        2012,  2279,  2341,  1012,   102]), array([ 101, 2009, 2003, 8335, 1998, 1996, 3095, 2003, 2467, 5379, 1012,
        102]), array([  101,  1045,  2079,  2025,  2066,  2000,  2022, 14180,


Test:  34%|███▍      | 11/32 [00:01<00:01, 11.56it/s][A

[array([  101,  1045,  2036,  2293,  2037, 14057,  3295,  2157,  2125,
        1997,  3660, 15145,  2346,  1012,   102]), array([ 101, 2036, 1010, 1996, 2833, 2003, 2307,  999,  102]), array([  101,  3071,  2003,  2467,  3565,  5379,  1998, 14044,  1012,
         102]), array([ 101, 2023, 2173, 2038, 2070, 2307, 2833, 1012,  102]), array([ 101, 2007, 2008, 2108, 2056, 1010, 1045, 5632, 1996, 2833, 1012,
        102]), array([ 101, 2870, 1998, 2155, 1013, 2814, 2031, 2042, 2183, 2000, 2068,
       2005, 2086,  999,  102]), array([  101,  1996, 26509,  2003,  2307,  1010,  1996,  3869,  2327,
       18624,  1010,  1996,  8974,  2074,  6581,  1012,   102]), array([  101,  1996,  7224,  2001,  4569,  1998,  1996,  3095, 18452,
        2017,  2092,  1012,   102]), array([  101,  1998,  1996,  9344,  2003,  2126,  2058, 21125,  1012,
         102]), array([ 101, 3953, 2240, 2027, 2058, 4872, 1998, 2104, 8116, 1012,  102]), array([  101,  1996, 10733,  2003,  3253,  2302, 22286,  2015,  1998,


Test:  41%|████      | 13/32 [00:01<00:01, 12.13it/s][A

[array([  101, 17752,  1010,  2130,  1996, 10733,  2003, 11519,  1012,
         102]), array([  101,  5632,  1996, 19958,  1037,  2843,  1012,   102]), array([ 101, 1996, 2173, 2001, 5697, 2021, 1996, 2326, 2001, 6581,  999,
        102]), array([ 101, 2011, 2521, 1996, 2190, 3325, 1045, 2031, 2412, 2018, 1999,
       1037, 8285, 4497, 1012,  102]), array([ 101, 2008, 1005, 1055, 1037, 4678, 2424, 1999, 1037, 2449, 1012,
        102]), array([ 101, 2035, 2105, 2307, 2173, 2000, 6723, 1037, 4392, 2030, 2048,
       1012,  102]), array([ 101, 2017, 2031, 2000, 4133, 2648, 2000, 2131, 1996, 2190, 3325,
       2013, 2023, 2173, 1012,  102]), array([ 101, 2256, 3325, 2007, 2023, 2449, 2001, 3893, 1012,  102]), array([  101,  1996, 22861,  4160, 11840,  2001,  2204,  1012,   102]), array([  101,  1996,  2436,  2036,  3478,  2000,  2130, 12134,  2055,
        2023,  2878, 23512,  3325,  1012,   102]), array([  101,  1045,  2064,  9826,  2360,  1045,  2572,  2061,  5580,
        2057,  2097,  


Test:  47%|████▋     | 15/32 [00:01<00:01, 12.66it/s][A

[array([ 101, 2027, 2064, 4067, 2017, 2005, 1996, 2659, 5790, 1012,  102]), array([ 101, 2833, 2003, 2467, 6429, 2053, 3043, 2054, 1045, 2344, 1012,
        102]), array([  101,  1996,  3528,  1997, 10514,  6182,  9372,  3084,  2005,
        1037,  2204,  5983,  1012,   102]), array([  101,  2017,  2131,  2061,  2172,  2005,  1996,  3976,  1998,
        1996,  2833,  2003,  4840,  1998, 12090,  1012,   102]), array([  101,  2009,  1005,  1055,  2980,  1010, 12984,  6669,  1010,
        1998, 12090,   999,   102]), array([  101,  2122,  2420,  1996,  3737,  2003,  3492, 10223,  6508,
        1012,   102]), array([ 101, 2028, 1997, 2026, 5440, 2822, 2173, 2000, 4521,  999,  102]), array([  101,  1996,  2069,  2518,  1045,  2001,  3253,  2001,  1037,
        2489, 18064,  1012,   102]), array([  101,  1045,  2293,  2037, 14894,  2098,  4157,  1012,   102]), array([ 101, 1045, 2228, 2009, 1005, 1055, 6659, 2043, 5126, 2552, 2008,
       2126, 1012,  102]), array([ 101, 2326, 2001, 2307, 19


Test:  53%|█████▎    | 17/32 [00:01<00:01, 12.92it/s][A

[array([  101,  2023,  2173,  2003,  6530,  1998,  2448,  2091,  1998,
        1996,  2326, 27136,  2015,   999,   102]), array([  101,  1045,  2079,  2025, 16755,  3087,  2000,  2023, 29500,
        1012,   102]), array([ 101, 2057, 2097, 2025, 2022, 2478, 2023, 3295, 2153, 1012,  102]), array([  101,  9202,  3105,  2006, 10063,  1998, 10393,  1012,   102]), array([ 101, 2307, 3095, 1998, 6919, 2833, 1012,  102]), array([  101,  1056,  4143,  5753, 17471, 12901,  2018,  2126,  2205,
        2172, 29454,  2140,  1999,  2009,  1012,   102]), array([ 101, 1996, 3944, 2318, 2041, 4030, 1012,  102]), array([  101,  1045,  2074,  2939,  2041,  1010,  2170,  1996,  3208,
        2000, 17612,  1012,   102]), array([  101,  2253,  2067,  1999,  2005, 19037,  1010,  3651,  2145,
       16542,  1012,   102]), array([  101, 11844, 24657,  1005,  1055,  2003,  1996,  2087, 10862,
        6429, 10733,  2173,  1999,  5334,  1012,   102]), array([  101,  2026,  2990,  1998, 14492,  2001,  5667, 11158


Test:  59%|█████▉    | 19/32 [00:01<00:00, 13.22it/s][A

[array([ 101, 5791, 1037, 2173, 2000, 2562, 1999, 2568, 1012,  102]), array([ 101, 2023, 9909, 1996, 2197, 2210, 3543, 2000, 2054, 2001, 2525,
       2019, 6429, 2173, 1012,  102]), array([ 101, 2008, 3310, 2007, 1996, 3760, 3295, 2295, 1012,  102]), array([ 101, 2035, 1999, 2035, 3492, 2204,  999,  102]), array([  101,  2036,  1996, 18081,  2180,  2669,  2015,  3627,  2247,
        2007,  1996, 25482, 22088, 12901,  1012,   102]), array([  101,  2023,  2282,  2008,  2002,  2179,  2036,  2128, 23941,
        2094,  1997,  5610,   999,   102]), array([  101,  2035,  2018,  2008, 17087, 21956, 14894,  2008,  8871,
        2091,  2092,  2007, 18007,  1012,   102]), array([  101,  2061,  1010,  2008,  2001,  2026,  2028,  1998,  2069,
        2051, 13063,  1996, 12122,  2045,  1012,   102]), array([  101,  1996,  3954,  2003,  1037,  7570,  4140,  1998,  1996,
        4322,  2003,  2200, 16222,  5358,  5302, 16616,  1012,   102]), array([  101,  1996, 20548,  7852,  2001, 20857,  1998,  31


Test:  66%|██████▌   | 21/32 [00:01<00:00, 13.43it/s][A

[array([  101,  2467,  4840,  1010,  6638,  1010,  1998, 12090,  1012,
         102]), array([  101,  1045,  3641, 20548,  7852,  1998, 10768,  4779, 16835,
        2638, 19423, 24857,  2007, 11546,  1012,   102]), array([  101,  3046,  2505,  2006,  1996, 12183,  1998,  2017,  2097,
        2022,  5580,  2017,  2234,  1999,   999,   102]), array([ 101, 1045, 8969, 2023, 1998, 2002, 2253, 2067, 2000, 2202, 2729,
       1997, 1996, 3021, 1012,  102]), array([ 101, 2968, 2515, 2025, 4025, 2000, 2729, 1012,  102]), array([  101,  5379,  1998, 18066,  2007,  1037,  4569,  7224,  1998,
       27547,  2833,  1012,   102]), array([  101,  2431,  1997,  2026,  2132,  2001,  2058, 13995,  1012,
         102]), array([ 101, 2027, 5338, 2033, 2005, 2147, 2025, 2589, 1010, 1998, 3033,
       2025, 5361, 1012,  102]), array([ 101, 2017, 1005, 2222, 2156, 2339, 2320, 2017, 2131, 2045, 1012,
        102]), array([ 101, 2074, 6659, 2003, 2035, 1045, 2064, 2360, 1012,  102]), array([ 101, 2307, 2326, 2


Test:  72%|███████▏  | 23/32 [00:01<00:00, 13.59it/s][A

[array([ 101, 1045, 1005, 2310, 2196, 2018, 1037, 4788, 3325, 2084, 2023,
        999,  102]), array([  101,  1998,  1996, 23621,  4372,  5428, 27266,  3022,  2020,
        2307,   999,   102]), array([  101,  2002,  2003,  2200, 16030,  1998, 15958, 14977,  2005,
        2010,  6304,  1012,   102]), array([ 101, 2037, 6265, 2569, 2003, 1037, 2307, 3643,  999,  102]), array([ 101, 2060, 2084, 2008, 1010, 2009, 1005, 1055, 2307,  999,  102]), array([ 101, 3565, 2204, 9144, 1998, 2200, 5379, 3095, 1012,  102]), array([ 101, 8013, 2326, 2001, 2307, 2045, 1012,  102]), array([  101,  2053, 14571,  1998,  2037,  2833,  2003, 17203,  1012,
         102]), array([ 101, 2200, 2204, 7987, 4609, 2818, 1010, 2001, 7622, 2007, 4989,
       1998, 3737, 1012,  102]), array([ 101, 1045, 2031, 2025, 2363, 2107, 2019, 7729, 2646, 1037, 8013,
       2077, 1012,  102]), array([  101,  2026,  7954,  2001,  2200, 17087,  1998, 14894,  3993,
        1012,   102]), array([ 101, 3892, 2295, 1045, 3641, 2000, 


Test:  78%|███████▊  | 25/32 [00:02<00:00, 13.63it/s][A

[array([  101,  2200, 14044,  1010,  7570, 13102,  6590,  3468,  1010,
        3716,  3085,  1010,  1998, 12367,  8082,  1012,   102]), array([ 101, 2026, 2684, 2363, 2307, 2729, 1012,  102]), array([ 101, 2057, 2409, 2068, 2000, 5293, 2009, 1010, 2057, 2106, 1050,
       1005, 1056, 2215, 2000, 3524, 1012,  102]), array([ 101, 2941, 1010, 2074, 2562, 3788, 1012,  102]), array([  101,  2021,  2009,  2001, 19424,  1998,  2980,  1999,  2045,
        1012,   102]), array([  101,  2035,  2016,  2106,  2001,  2507,  2033,  1996,  2448,
        2105,  1998,  9828,  1998, 18667,  2673,  1012,   102]), array([ 101, 2200, 5379, 3095, 1998, 3208, 1012,  102]), array([ 101, 2027, 2036, 2024, 1996, 2190, 3976, 1999, 1996, 2181, 1012,
        102]), array([ 101, 1996, 2396, 1999, 1996, 8975, 2369, 2009, 2003, 4658, 2205,
        999,  102]), array([ 101, 2498, 2008, 2569, 2055, 2023, 2173, 1012,  102]), array([  101,  2058,  3597, 23461,  2061,  6649,  2008,  2009,  2001,
        1996, 18700,  1997


Test:  84%|████████▍ | 27/32 [00:02<00:00, 13.69it/s][A

[array([ 101, 3892, 1045, 2439, 2035, 4847, 2005, 2023, 2194, 1012,  102]), array([ 101, 2027, 2699, 2613, 2524, 2000, 2131, 2033, 2000, 2272, 2067,
       2021, 1045, 4188, 1012,  102]), array([ 101, 1996, 3609, 2008, 2016, 3594, 2006, 2026, 6513, 1005, 1055,
       2606, 3504, 2307, 1012,  102]), array([  101,  1996,  2833,  2182,  2003, 12090,  1012,   102]), array([ 101, 2057, 1005, 2222, 5121, 2022, 2067,  999,  102]), array([ 101, 1996, 2311, 2993, 3504, 4704, 1012,  102]), array([  101,  2065,  1045,  2071,  2507,  5717,  3340,  1045, 13366,
        2052,  1012,   102]), array([  101,  1996,  5404,  2003,  1050,  1005,  1056,  2919,  1010,
        2021,  1996,  2833,  2001,  2625,  2084, 16166,  1012,   102])]



Test:  91%|█████████ | 29/32 [00:02<00:00, 12.99it/s][A
Test:  97%|█████████▋| 31/32 [00:02<00:00, 13.26it/s][A
                                                     [A

In [None]:
test_df['Category'] = predictions

In [None]:
test_df.to_csv('submission(bert_large)(batch_36).csv', index=False)