In [1]:
import os
import sys

import torch
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, datasets
from torchmetrics.classification import MulticlassF1Score
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Global constants

In [2]:
HEIGHT, WIDTH = 224, 224
GLOBAL_MEAN, GLOBAL_STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
BATCH_SIZE = 150
# pretrained model downloaded from https://huggingface.co/google/vit-base-patch16-224
PRETRAINED_MODEL_PATH = "./vit_224"
NUM_CLASSES = 30

TRAIN_SIZE = 38_000
VAL_SIZE = 7_373

# Define Dataloaders

In [3]:
class SportsDataset(Dataset):
    def __init__(self, root_dir, labels, transform=None, is_train=True):
        super(SportsDataset, self).__init__()
        self.root_dir = root_dir
        self.labels = labels
        self.transform = transform
        self.is_train = is_train
        self.id2label = False
        self.label2id = False
        
        if is_train:
            unique_labels = np.unique(self.labels['label'])
            self.id2label = {
                i: label
                for i, label in enumerate(unique_labels)
            }

            self.label2id = {
                label: i
                for i, label in enumerate(unique_labels)
            }
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        image_path = f"{self.root_dir}/{self.labels.iloc[idx, 0]}"
        
        with Image.open(image_path) as img:
            img = img.convert("RGB")
            
            if self.transform:
                img = self.transform(img)
            else:
                img.copy()
                
        if self.is_train:
            return img, self.label2id[self.labels.iloc[idx, 1]]
        
        return img

In [4]:
train_augmentations = transforms.Compose([
    transforms.Resize(size=(HEIGHT, WIDTH)),
    transforms.AutoAugment(),
    transforms.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize(mean=GLOBAL_MEAN, std=GLOBAL_STD),
])

validation_augmentation = transforms.Compose([
    transforms.Resize(size=(HEIGHT, WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=GLOBAL_MEAN, std=GLOBAL_STD),
])

In [5]:
train_pd = pd.read_csv('train.csv')
train_pd, val_pd = train_test_split(train_pd, test_size=0.2, random_state=42)
test_pd = pd.read_csv('test.csv')

train_dataset = SportsDataset('./train', train_pd, transform=train_augmentations)
val_dataset = SportsDataset('./train', val_pd, transform=validation_augmentation)
test_dataset = SportsDataset('./test', test_pd, transform=validation_augmentation, is_train=False)


train_loader = DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=10, shuffle=True,
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=10, shuffle=False
)
val_loader = DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, num_workers=10, shuffle=False
)

# Vit Transformer

In [6]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.callbacks import TQDMProgressBar

In [7]:
# https://stackoverflow.com/questions/72958447/pytorch-lightning-how-can-i-output-a-summary-of-training-to-the-console
class LoggingCallback(pl.Callback):
    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        rank_zero_info("***** Test results *****")
        metrics = trainer.callback_metrics
        for key in sorted(metrics):
            if key not in ["log", "progress_bar"]:
                rank_zero_info("{} = {}\n".format(key, str(metrics[key])))

See reference tutorial on transfer learning from Hugging Face: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb

In [8]:
class ViTSportsTransferModel(pl.LightningModule):
    def __init__(self, model_path, num_classes, id2label, label2id):
        super(ViTSportsTransferModel, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained(model_path,
            num_labels=num_classes,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        )
        
        self.num_classes = num_classes
        self.metric = MulticlassF1Score(self.num_classes, average='micro')
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        pixel_values = batch[0]
        labels = batch[1]
        logits = self(pixel_values)

        
        loss = self.criterion(logits, labels)
        f1_score = self.metric(logits, labels)

        return loss, f1_score
      
    def training_step(self, batch, batch_idx):
        loss, f1_score = self.common_step(batch, batch_idx)     
        self.log("training_loss", loss)
        self.log("training_f1", f1_score)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, f1_score = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True, prog_bar=True)
        self.log("validation_f1", f1_score, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_loader

    def val_dataloader(self):
        return val_loader

# Train

In [9]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [10]:
model = ViTSportsTransferModel(
    num_classes=NUM_CLASSES,
    model_path=PRETRAINED_MODEL_PATH,
    id2label=train_dataset.id2label,
    label2id=train_dataset.label2id
)

trainer = Trainer(
    gpus=1,
    callbacks=[EarlyStopping(monitor='validation_loss'), LoggingCallback(),],
    max_epochs=12,
    log_every_n_steps=50,
#     enable_progress_bar=False,
)

trainer.fit(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at ./vit_224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([30, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([30]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type                      | Params
--------------------------------------------------------
0 | vit       | ViTForImageClassification | 85.8 M
1 | metric    | MulticlassF1Score         | 0    

Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:02<00:00,  1.18s/it]

***** Test results *****
validation_f1 = tensor(0.0300, device='cuda:0')

validation_loss = tensor(3.5296, device='cuda:0')



Epoch 0:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.541, v_num=45]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 0:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.541, v_num=45]
Epoch 0:  81%|████████  | 244/303 [05:52<01:25,  1.45s/it, loss=0.541, v_num=45]
Epoch 0:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.541, v_num=45]
Epoch 0:  81%|████████  | 246/303 [05:53<01:22,  1.44s/it, loss=0.541, v_num=45]
Epoch 0:  82%|████████▏ | 247/303 [05:54<01:20,  1.43s/it, loss=0.541, v_num=45]
Epoch 0:  82%|████████▏ | 248/303 [05:54<01:18,  1.43s/it, loss=0.541, v_num=45]
Epoch 0:  82%|████████▏ | 249/303 [05:55<01:17,  1.43s/it, loss=0.541, v_num=45]
Epoch 0:  83%|████████▎ | 250/303 [05:55<01:15,  1.42s/it, loss=0.541, v_num=45]
Epoch 0:  83%|████████▎ | 251/303 [05:56<01:13,  1.42s/it, loss=0.541, v_num=45]
Epoch 0:  83%|████████▎ | 252/303 [

***** Test results *****
training_f1 = tensor(0.8581, device='cuda:0')

training_loss = tensor(0.5790, device='cuda:0')

validation_f1 = tensor(0.9136, device='cuda:0')

validation_loss = tensor(0.3462, device='cuda:0')



Epoch 0: 100%|██████████| 303/303 [06:22<00:00,  1.26s/it, loss=0.541, v_num=45, validation_loss=0.346, validation_f1=0.914]
Epoch 1:  80%|███████▉  | 242/303 [05:49<01:28,  1.45s/it, loss=0.432, v_num=45, validation_loss=0.346, validation_f1=0.914]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 1:  80%|████████  | 243/303 [05:51<01:26,  1.45s/it, loss=0.432, v_num=45, validation_loss=0.346, validation_f1=0.914]
Epoch 1:  81%|████████  | 244/303 [05:52<01:25,  1.44s/it, loss=0.432, v_num=45, validation_loss=0.346, validation_f1=0.914]
Epoch 1:  81%|████████  | 245/303 [05:52<01:23,  1.44s/it, loss=0.432, v_num=45, validation_loss=0.346, validation_f1=0.914]
Epoch 1:  81%|████████  | 246/303 [05:53<01:21,  1.44s/it, loss=0.432, v_num=45, validation_loss=0.346, validation_f1=0.914]
Epoch 1:  82%|████████▏ | 247/303 [05:53<01:20,  1.43s/it, loss=0.432, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.8108, device='cuda:0')

training_loss = tensor(0.6475, device='cuda:0')

validation_f1 = tensor(0.9247, device='cuda:0')

validation_loss = tensor(0.2887, device='cuda:0')



Epoch 1: 100%|██████████| 303/303 [06:22<00:00,  1.26s/it, loss=0.432, v_num=45, validation_loss=0.289, validation_f1=0.925]
Epoch 2:  80%|███████▉  | 242/303 [05:51<01:28,  1.45s/it, loss=0.374, v_num=45, validation_loss=0.289, validation_f1=0.925]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 2:  80%|████████  | 243/303 [05:53<01:27,  1.45s/it, loss=0.374, v_num=45, validation_loss=0.289, validation_f1=0.925]
Epoch 2:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.374, v_num=45, validation_loss=0.289, validation_f1=0.925]
Epoch 2:  81%|████████  | 245/303 [05:54<01:23,  1.45s/it, loss=0.374, v_num=45, validation_loss=0.289, validation_f1=0.925]
Epoch 2:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.374, v_num=45, validation_loss=0.289, validation_f1=0.925]
Epoch 2:  82%|████████▏ | 247/303 [05:55<01:20,  1.44s/it, loss=0.374, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.8919, device='cuda:0')

training_loss = tensor(0.3382, device='cuda:0')

validation_f1 = tensor(0.9320, device='cuda:0')

validation_loss = tensor(0.2572, device='cuda:0')



Epoch 2: 100%|██████████| 303/303 [06:23<00:00,  1.27s/it, loss=0.374, v_num=45, validation_loss=0.257, validation_f1=0.932]
Epoch 3:  80%|███████▉  | 242/303 [05:52<01:28,  1.46s/it, loss=0.309, v_num=45, validation_loss=0.257, validation_f1=0.932]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 3:  80%|████████  | 243/303 [05:54<01:27,  1.46s/it, loss=0.309, v_num=45, validation_loss=0.257, validation_f1=0.932]
Epoch 3:  81%|████████  | 244/303 [05:54<01:25,  1.45s/it, loss=0.309, v_num=45, validation_loss=0.257, validation_f1=0.932]
Epoch 3:  81%|████████  | 245/303 [05:55<01:24,  1.45s/it, loss=0.309, v_num=45, validation_loss=0.257, validation_f1=0.932]
Epoch 3:  81%|████████  | 246/303 [05:55<01:22,  1.45s/it, loss=0.309, v_num=45, validation_loss=0.257, validation_f1=0.932]
Epoch 3:  82%|████████▏ | 247/303 [05:56<01:20,  1.44s/it, loss=0.309, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.8986, device='cuda:0')

training_loss = tensor(0.3333, device='cuda:0')

validation_f1 = tensor(0.9365, device='cuda:0')

validation_loss = tensor(0.2379, device='cuda:0')



Epoch 3: 100%|██████████| 303/303 [06:24<00:00,  1.27s/it, loss=0.309, v_num=45, validation_loss=0.238, validation_f1=0.937]
Epoch 4:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.254, v_num=45, validation_loss=0.238, validation_f1=0.937]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 4:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.254, v_num=45, validation_loss=0.238, validation_f1=0.937]
Epoch 4:  81%|████████  | 244/303 [05:52<01:25,  1.45s/it, loss=0.254, v_num=45, validation_loss=0.238, validation_f1=0.937]
Epoch 4:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.254, v_num=45, validation_loss=0.238, validation_f1=0.937]
Epoch 4:  81%|████████  | 246/303 [05:53<01:22,  1.44s/it, loss=0.254, v_num=45, validation_loss=0.238, validation_f1=0.937]
Epoch 4:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.254, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9189, device='cuda:0')

training_loss = tensor(0.3159, device='cuda:0')

validation_f1 = tensor(0.9361, device='cuda:0')

validation_loss = tensor(0.2353, device='cuda:0')



Epoch 4: 100%|██████████| 303/303 [06:23<00:00,  1.26s/it, loss=0.254, v_num=45, validation_loss=0.235, validation_f1=0.936]
Epoch 5:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.247, v_num=45, validation_loss=0.235, validation_f1=0.936]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 5:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.247, v_num=45, validation_loss=0.235, validation_f1=0.936]
Epoch 5:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.247, v_num=45, validation_loss=0.235, validation_f1=0.936]
Epoch 5:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.247, v_num=45, validation_loss=0.235, validation_f1=0.936]
Epoch 5:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.247, v_num=45, validation_loss=0.235, validation_f1=0.936]
Epoch 5:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.247, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9324, device='cuda:0')

training_loss = tensor(0.1879, device='cuda:0')

validation_f1 = tensor(0.9397, device='cuda:0')

validation_loss = tensor(0.2287, device='cuda:0')



Epoch 5: 100%|██████████| 303/303 [06:23<00:00,  1.27s/it, loss=0.247, v_num=45, validation_loss=0.229, validation_f1=0.940]
Epoch 6:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.221, v_num=45, validation_loss=0.229, validation_f1=0.940]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 6:  80%|████████  | 243/303 [05:52<01:26,  1.45s/it, loss=0.221, v_num=45, validation_loss=0.229, validation_f1=0.940]
Epoch 6:  81%|████████  | 244/303 [05:52<01:25,  1.45s/it, loss=0.221, v_num=45, validation_loss=0.229, validation_f1=0.940]
Epoch 6:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.221, v_num=45, validation_loss=0.229, validation_f1=0.940]
Epoch 6:  81%|████████  | 246/303 [05:53<01:21,  1.44s/it, loss=0.221, v_num=45, validation_loss=0.229, validation_f1=0.940]
Epoch 6:  82%|████████▏ | 247/303 [05:54<01:20,  1.43s/it, loss=0.221, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9527, device='cuda:0')

training_loss = tensor(0.2113, device='cuda:0')

validation_f1 = tensor(0.9402, device='cuda:0')

validation_loss = tensor(0.2318, device='cuda:0')



Epoch 6: 100%|██████████| 303/303 [06:22<00:00,  1.26s/it, loss=0.221, v_num=45, validation_loss=0.232, validation_f1=0.940]
Epoch 7:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.195, v_num=45, validation_loss=0.232, validation_f1=0.940]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 7:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.195, v_num=45, validation_loss=0.232, validation_f1=0.940]
Epoch 7:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.195, v_num=45, validation_loss=0.232, validation_f1=0.940]
Epoch 7:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.195, v_num=45, validation_loss=0.232, validation_f1=0.940]
Epoch 7:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.195, v_num=45, validation_loss=0.232, validation_f1=0.940]
Epoch 7:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.195, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9324, device='cuda:0')

training_loss = tensor(0.2270, device='cuda:0')

validation_f1 = tensor(0.9383, device='cuda:0')

validation_loss = tensor(0.2275, device='cuda:0')



Epoch 7: 100%|██████████| 303/303 [06:23<00:00,  1.26s/it, loss=0.195, v_num=45, validation_loss=0.227, validation_f1=0.938]
Epoch 8:  80%|███████▉  | 242/303 [05:51<01:28,  1.45s/it, loss=0.176, v_num=45, validation_loss=0.227, validation_f1=0.938]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 8:  80%|████████  | 243/303 [05:53<01:27,  1.45s/it, loss=0.176, v_num=45, validation_loss=0.227, validation_f1=0.938]
Epoch 8:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.176, v_num=45, validation_loss=0.227, validation_f1=0.938]
Epoch 8:  81%|████████  | 245/303 [05:54<01:23,  1.45s/it, loss=0.176, v_num=45, validation_loss=0.227, validation_f1=0.938]
Epoch 8:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.176, v_num=45, validation_loss=0.227, validation_f1=0.938]
Epoch 8:  82%|████████▏ | 247/303 [05:55<01:20,  1.44s/it, loss=0.176, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9392, device='cuda:0')

training_loss = tensor(0.1702, device='cuda:0')

validation_f1 = tensor(0.9424, device='cuda:0')

validation_loss = tensor(0.2217, device='cuda:0')



Epoch 8: 100%|██████████| 303/303 [06:23<00:00,  1.27s/it, loss=0.176, v_num=45, validation_loss=0.222, validation_f1=0.942]
Epoch 9:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.184, v_num=45, validation_loss=0.222, validation_f1=0.942]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 9:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.184, v_num=45, validation_loss=0.222, validation_f1=0.942]
Epoch 9:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.184, v_num=45, validation_loss=0.222, validation_f1=0.942]
Epoch 9:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.184, v_num=45, validation_loss=0.222, validation_f1=0.942]
Epoch 9:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.184, v_num=45, validation_loss=0.222, validation_f1=0.942]
Epoch 9:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.184, v_num=45, validation_los

***** Test results *****
training_f1 = tensor(0.9595, device='cuda:0')

training_loss = tensor(0.1238, device='cuda:0')

validation_f1 = tensor(0.9340, device='cuda:0')

validation_loss = tensor(0.2518, device='cuda:0')



Epoch 9: 100%|██████████| 303/303 [06:23<00:00,  1.26s/it, loss=0.184, v_num=45, validation_loss=0.252, validation_f1=0.934]
Epoch 10:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.152, v_num=45, validation_loss=0.252, validation_f1=0.934]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 10:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.152, v_num=45, validation_loss=0.252, validation_f1=0.934]
Epoch 10:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.152, v_num=45, validation_loss=0.252, validation_f1=0.934]
Epoch 10:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.152, v_num=45, validation_loss=0.252, validation_f1=0.934]
Epoch 10:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.152, v_num=45, validation_loss=0.252, validation_f1=0.934]
Epoch 10:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.152, v_num=45, validati

***** Test results *****
training_f1 = tensor(0.9324, device='cuda:0')

training_loss = tensor(0.1622, device='cuda:0')

validation_f1 = tensor(0.9398, device='cuda:0')

validation_loss = tensor(0.2306, device='cuda:0')



Epoch 10: 100%|██████████| 303/303 [06:23<00:00,  1.27s/it, loss=0.152, v_num=45, validation_loss=0.231, validation_f1=0.940]
Epoch 11:  80%|███████▉  | 242/303 [05:50<01:28,  1.45s/it, loss=0.135, v_num=45, validation_loss=0.231, validation_f1=0.940]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/61 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/61 [00:00<?, ?it/s][A
Epoch 11:  80%|████████  | 243/303 [05:52<01:27,  1.45s/it, loss=0.135, v_num=45, validation_loss=0.231, validation_f1=0.940]
Epoch 11:  81%|████████  | 244/303 [05:53<01:25,  1.45s/it, loss=0.135, v_num=45, validation_loss=0.231, validation_f1=0.940]
Epoch 11:  81%|████████  | 245/303 [05:53<01:23,  1.44s/it, loss=0.135, v_num=45, validation_loss=0.231, validation_f1=0.940]
Epoch 11:  81%|████████  | 246/303 [05:54<01:22,  1.44s/it, loss=0.135, v_num=45, validation_loss=0.231, validation_f1=0.940]
Epoch 11:  82%|████████▏ | 247/303 [05:54<01:20,  1.44s/it, loss=0.135, v_num=45, validat

***** Test results *****
training_f1 = tensor(0.9595, device='cuda:0')

training_loss = tensor(0.1485, device='cuda:0')

validation_f1 = tensor(0.9385, device='cuda:0')

validation_loss = tensor(0.2394, device='cuda:0')



Epoch 11: 100%|██████████| 303/303 [06:23<00:00,  1.26s/it, loss=0.135, v_num=45, validation_loss=0.239, validation_f1=0.939]
Epoch 11: 100%|██████████| 303/303 [06:23<00:00,  1.26s/it, loss=0.135, v_num=45, validation_loss=0.239, validation_f1=0.939]

`Trainer.fit` stopped: `max_epochs=12` reached.


Epoch 11: 100%|██████████| 303/303 [06:25<00:00,  1.27s/it, loss=0.135, v_num=45, validation_loss=0.239, validation_f1=0.939]


In [11]:
preds = []
model = model.to('cuda')
for batch in tqdm(test_loader):
    labels = [
        train_dataset.id2label[label_id]
        for label_id in 
        torch.argmax(model(batch.to('cuda')), dim=1).tolist()
    ]
    preds.extend(labels)
    
test_pd['label'] = preds
test_pd.to_csv('submission.csv', index=False)

100%|██████████| 130/130 [01:07<00:00,  1.94it/s]
