In [1]:
import torch
import numpy as np
import PIL

print(torch.cuda.is_available())

True


In [2]:
!nvidia-smi

Sat Jul 27 21:55:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3080 ...    Off | 00000000:01:00.0  On |                  N/A |
| N/A   52C    P5              22W / 125W |    174MiB / 16384MiB |     41%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import pickle
import numpy as np
from skimage import io

from tqdm import tqdm, tqdm_notebook
from PIL import Image
from pathlib import Path

from torchvision import transforms, models, datasets
from multiprocessing.pool import ThreadPool
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from matplotlib import colors, pyplot as plt
%matplotlib inline

In [4]:
DATA_MODES = ['train', 'val', 'test']
RESCALE_SIZE = 224
DEVICE = torch.device("cuda")

In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((RESCALE_SIZE, RESCALE_SIZE)),
    #transforms.AutoAugment(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x / 255),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [6]:
train_dir = "./train/simpsons_dataset"

In [7]:
import math

image_datasets = datasets.ImageFolder(train_dir, train_transforms)

In [8]:
dataset_len = len(image_datasets)
p1 = (dataset_len / 100)
train_size = math.floor(p1 * 70)
test_size = dataset_len - train_size

In [9]:
train_dataset, val_dataset = torch.utils.data.random_split(image_datasets, [train_size, test_size])

In [10]:
len(train_dataset), len(val_dataset)

(14653, 6280)

In [11]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32,
    shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=128
)

loaders = {
    "train": train_dataloader,
    "valid": val_dataloader
}

In [12]:
model_resnet18 = models.resnet18(pretrained=True)
model_resnet18



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [13]:
for param in model_resnet18.parameters():
    param.requires_grad = False

In [14]:
for param in model_resnet18.avgpool.parameters():
    param.requires_grad = True
    
for param in model_resnet18.layer4.parameters():
    param.requires_grad = True

for param in model_resnet18.layer3.parameters():
    param.requires_grad = True

for param in model_resnet18.layer2.parameters():
    param.requires_grad = True

In [15]:
num_classes = 42
model_resnet18.fc = nn.Sequential(
    nn.Sequential(
        nn.Linear(512, 512),
        nn.BatchNorm1d(512),
        nn.Dropout(0.5),
        nn.ReLU()
    ),
    nn.Sequential(
        nn.Linear(512, 512),
        nn.Linear(512, num_classes)
    )
) 

In [16]:
import torch.optim as optim

optimizer = optim.Adam(
    (
        {
            "params": model_resnet18.layer2.parameters(),
            "lr": 1e-5
        },
        {
            "params": model_resnet18.layer3.parameters(),
            "lr": 1e-4
        },
        {
            "params": model_resnet18.layer4.parameters(),
            "lr": 1e-3
        },
        {
            "params": model_resnet18.avgpool.parameters()
        },
        {
            "params": model_resnet18.fc.parameters()
        }
    ), lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.5)

In [17]:
!rm -rf logs

In [18]:
from catalyst import dl
        
runner = dl.SupervisedRunner()

In [19]:
runner.train(
    model=model_resnet18,
    engine=dl.GPUEngine("cuda"),
    optimizer=optimizer,
    criterion=nn.CrossEntropyLoss(),
    scheduler=scheduler,
    callbacks=[
        dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
        dl.BackwardCallback(metric_key="loss"),
        dl.OptimizerCallback(metric_key="loss"), 
        dl.AccuracyCallback(input_key="logits", target_key="targets"),
        dl.SchedulerCallback(),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=num_classes, log_on_batch=False
        )
    ],
    loaders=loaders,
    num_epochs=30,
    verbose=True,
    logdir="logs/resnet18",
    load_best_on_end=True,
)

1/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (1/30) accuracy01: 0.4448918309836864 | accuracy01/std: 0.14677202263423728 | f1/_macro: 0.1911559226572267 | f1/_micro: 0.44488683108055643 | f1/_weighted: 0.4279054927507017 | loss: 2.6606458742236385 | loss/mean: 2.6606458742236385 | loss/std: 1.3747384786155932 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.19780175630811855 | precision/_micro: 0.4448918310243636 | precision/_weighted: 0.420551177039226 | recall/_macro: 0.19384639480145335 | recall/_micro: 0.4448918310243636 | recall/_weighted: 0.44489183102436364




1/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (1/30) accuracy01: 0.637579617834395 | accuracy01/std: 0.039323259169996655 | f1/_macro: 0.27683675651641576 | f1/_micro: 0.6375746178736054 | f1/_weighted: 0.6118302934678506 | loss: 1.4510923236798328 | loss/mean: 1.4510923236798328 | loss/std: 0.1769923214046265 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.306965687814668 | precision/_micro: 0.6375796178343949 | precision/_weighted: 0.6355023066597371 | recall/_macro: 0.2804407909988929 | recall/_micro: 0.6375796178343949 | recall/_weighted: 0.6375796178343949
* Epoch (1/30) lr: 1e-05 | momentum: 0.9


2/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (2/30) accuracy01: 0.7004026479463334 | accuracy01/std: 0.08663286755799994 | f1/_macro: 0.32180926834043166 | f1/_micro: 0.7003976479576207 | f1/_weighted: 0.6807226344165431 | loss: 1.2084862982495406 | loss/mean: 1.2084862982495406 | loss/std: 0.3704791779336468 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.3467623137157684 | precision/_micro: 0.7004026479219272 | precision/_weighted: 0.6730897417815036 | recall/_macro: 0.3225186934787151 | recall/_micro: 0.7004026479219272 | recall/_weighted: 0.7004026479219273


2/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (2/30) accuracy01: 0.7722929936305729 | accuracy01/std: 0.03570325239415732 | f1/_macro: 0.35799632565968265 | f1/_micro: 0.7722879936629442 | f1/_weighted: 0.7476571550102555 | loss: 0.9281579757192332 | loss/mean: 0.9281579757192332 | loss/std: 0.1370279412988915 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.36682927840791985 | precision/_micro: 0.7722929936305732 | precision/_weighted: 0.7437155892647882 | recall/_macro: 0.3705804482090539 | recall/_micro: 0.7722929936305732 | recall/_weighted: 0.7722929936305732
* Epoch (2/30) lr: 1e-05 | momentum: 0.9


3/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (3/30) accuracy01: 0.7614822902590278 | accuracy01/std: 0.07760568308255143 | f1/_macro: 0.3719757650452519 | f1/_micro: 0.7614772903488067 | f1/_weighted: 0.7438596319422691 | loss: 0.9667829228799634 | loss/mean: 0.9667829228799634 | loss/std: 0.3554976115399556 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.40177367136803904 | precision/_micro: 0.7614822903159763 | precision/_weighted: 0.7338963197916293 | recall/_macro: 0.36935938366204113 | recall/_micro: 0.7614822903159763 | recall/_weighted: 0.7614822903159761


3/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (3/30) accuracy01: 0.8001592356687894 | accuracy01/std: 0.030457677732822083 | f1/_macro: 0.4047926024075301 | f1/_micro: 0.8001542357000333 | f1/_weighted: 0.780242021519 | loss: 0.8367080917783605 | loss/mean: 0.8367080917783605 | loss/std: 0.1541323337325265 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.4188086510597948 | precision/_micro: 0.8001592356687898 | precision/_weighted: 0.7792693163264809 | recall/_macro: 0.4102487603222193 | recall/_micro: 0.8001592356687898 | recall/_weighted: 0.8001592356687899
* Epoch (3/30) lr: 1e-05 | momentum: 0.9


4/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (4/30) accuracy01: 0.7839350303488686 | accuracy01/std: 0.07808917432416214 | f1/_macro: 0.3914950234590963 | f1/_micro: 0.7839300304010979 | f1/_weighted: 0.7686412834666371 | loss: 0.8772319718521134 | loss/mean: 0.8772319718521134 | loss/std: 0.3657857115932357 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.4050248175626693 | precision/_micro: 0.7839350303692076 | precision/_weighted: 0.7579771654330467 | recall/_macro: 0.39152808424461116 | recall/_micro: 0.7839350303692076 | recall/_weighted: 0.7839350303692076


4/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (4/30) accuracy01: 0.7952229299363055 | accuracy01/std: 0.03252172365741848 | f1/_macro: 0.4059945380009533 | f1/_micro: 0.7952179299677432 | f1/_weighted: 0.7751712079461324 | loss: 0.8419680558951796 | loss/mean: 0.8419680558951796 | loss/std: 0.16502773920889544 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.4626774082488154 | precision/_micro: 0.7952229299363057 | precision/_weighted: 0.7832506560239892 | recall/_macro: 0.39733357089851495 | recall/_micro: 0.7952229299363057 | recall/_weighted: 0.7952229299363057
* Epoch (4/30) lr: 1e-05 | momentum: 0.9


5/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (5/30) accuracy01: 0.8110967037629443 | accuracy01/std: 0.06960265883486108 | f1/_macro: 0.4188635867821601 | f1/_micro: 0.8110917037774952 | f1/_weighted: 0.7978873709953579 | loss: 0.7732692798323124 | loss/mean: 0.7732692798323124 | loss/std: 0.32567796796643217 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.43133413349090804 | precision/_micro: 0.811096703746673 | precision/_weighted: 0.7884314185289317 | recall/_macro: 0.4187143816262151 | recall/_micro: 0.811096703746673 | recall/_weighted: 0.8110967037466731


5/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (5/30) accuracy01: 0.8277070063694266 | accuracy01/std: 0.030611094089198303 | f1/_macro: 0.4407527605185924 | f1/_micro: 0.8277020063996304 | f1/_weighted: 0.8081077697388943 | loss: 0.7938380826051069 | loss/mean: 0.7938380826051069 | loss/std: 0.1746707383438919 | lr: 1e-05 | momentum: 0.9 | precision/_macro: 0.48164908958917974 | precision/_micro: 0.8277070063694267 | precision/_weighted: 0.8136741181866326 | recall/_macro: 0.44275999820731105 | recall/_micro: 0.8277070063694267 | recall/_weighted: 0.8277070063694268
* Epoch (5/30) lr: 5e-06 | momentum: 0.9


6/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (6/30) accuracy01: 0.8788643964140893 | accuracy01/std: 0.05872454790386848 | f1/_macro: 0.49468998929387 | f1/_micro: 0.8788593963977896 | f1/_weighted: 0.8702452508511422 | loss: 0.46696502599164263 | loss/mean: 0.46696502599164263 | loss/std: 0.24005470653700822 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.5153176367600328 | precision/_micro: 0.8788643963693441 | precision/_weighted: 0.8657079621049218 | recall/_macro: 0.4930900230149663 | recall/_micro: 0.8788643963693441 | recall/_weighted: 0.8788643963693441


6/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (6/30) accuracy01: 0.8595541401273884 | accuracy01/std: 0.03024542361025914 | f1/_macro: 0.5054898344239882 | f1/_micro: 0.8595491401564731 | f1/_weighted: 0.8488836795320549 | loss: 0.5728745045175978 | loss/mean: 0.5728745045175978 | loss/std: 0.13463967324704218 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6029447107407506 | precision/_micro: 0.8595541401273885 | precision/_weighted: 0.8602684787654292 | recall/_macro: 0.5015857544557191 | recall/_micro: 0.8595541401273885 | recall/_weighted: 0.8595541401273885
* Epoch (6/30) lr: 5e-06 | momentum: 0.9


7/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (7/30) accuracy01: 0.9056165972553581 | accuracy01/std: 0.05207817957650172 | f1/_macro: 0.5529617055913012 | f1/_micro: 0.9056115973114379 | f1/_weighted: 0.9007966979272346 | loss: 0.3524534701945097 | loss/mean: 0.3524534701945097 | loss/std: 0.1907217288753389 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.5724892434191429 | precision/_micro: 0.9056165972838327 | precision/_weighted: 0.8978600208048318 | recall/_macro: 0.5510536604912617 | recall/_micro: 0.9056165972838327 | recall/_weighted: 0.9056165972838326


7/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (7/30) accuracy01: 0.8702229299363057 | accuracy01/std: 0.03201273063101741 | f1/_macro: 0.52414514327266 | f1/_micro: 0.8702179299650338 | f1/_weighted: 0.8642592293894715 | loss: 0.574849543753703 | loss/mean: 0.574849543753703 | loss/std: 0.18331863363916304 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.5865716689636156 | precision/_micro: 0.8702229299363058 | precision/_weighted: 0.8747022231386662 | recall/_macro: 0.5400994394036156 | recall/_micro: 0.8702229299363058 | recall/_weighted: 0.8702229299363058
* Epoch (7/30) lr: 5e-06 | momentum: 0.9


8/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (8/30) accuracy01: 0.9155804271877959 | accuracy01/std: 0.05056901356588731 | f1/_macro: 0.5964754389442628 | f1/_micro: 0.9155754272435747 | f1/_weighted: 0.9129783899979603 | loss: 0.31372888569399726 | loss/mean: 0.31372888569399726 | loss/std: 0.2015185412173847 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6115867885944569 | precision/_micro: 0.9155804272162698 | precision/_weighted: 0.9114074601533458 | recall/_macro: 0.5908511196035398 | recall/_micro: 0.9155804272162698 | recall/_weighted: 0.9155804272162696


8/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (8/30) accuracy01: 0.8850318471337583 | accuracy01/std: 0.025633574255119602 | f1/_macro: 0.573915500884687 | f1/_micro: 0.8850268471620053 | f1/_weighted: 0.8840757486587915 | loss: 0.5290364944820952 | loss/mean: 0.5290364944820952 | loss/std: 0.1893980352997739 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6039053665309206 | precision/_micro: 0.8850318471337579 | precision/_weighted: 0.8924828438474673 | recall/_macro: 0.5863106419146761 | recall/_micro: 0.8850318471337579 | recall/_weighted: 0.8850318471337579
* Epoch (8/30) lr: 5e-06 | momentum: 0.9


9/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (9/30) accuracy01: 0.9202893605120293 | accuracy01/std: 0.04609897873806104 | f1/_macro: 0.5949783309079967 | f1/_micro: 0.9202843605676689 | f1/_weighted: 0.9177842060972197 | loss: 0.30161761993179154 | loss/mean: 0.30161761993179154 | loss/std: 0.197111770576226 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6012697433506153 | precision/_micro: 0.9202893605405037 | precision/_weighted: 0.9160251806432901 | recall/_macro: 0.5953756442581858 | recall/_micro: 0.9202893605405037 | recall/_weighted: 0.9202893605405037


9/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (9/30) accuracy01: 0.8866242038216559 | accuracy01/std: 0.027057021950984275 | f1/_macro: 0.5818172264966124 | f1/_micro: 0.8866192038498527 | f1/_weighted: 0.8873051824962483 | loss: 0.5209451169724676 | loss/mean: 0.5209451169724676 | loss/std: 0.19780341145025299 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.5956717950057346 | precision/_micro: 0.8866242038216561 | precision/_weighted: 0.8947710977202016 | recall/_macro: 0.6010759158973215 | recall/_micro: 0.8866242038216561 | recall/_weighted: 0.886624203821656
* Epoch (9/30) lr: 5e-06 | momentum: 0.9


10/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (10/30) accuracy01: 0.9286835460396653 | accuracy01/std: 0.045039302779904616 | f1/_macro: 0.6252622944059885 | f1/_micro: 0.928678546058449 | f1/_weighted: 0.9273920606615556 | loss: 0.27612232067499887 | loss/mean: 0.27612232067499887 | loss/std: 0.2179660575199515 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6307079175524677 | precision/_micro: 0.9286835460315294 | precision/_weighted: 0.9267373839103903 | recall/_macro: 0.6247236655940496 | recall/_micro: 0.9286835460315294 | recall/_weighted: 0.9286835460315295


10/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (10/30) accuracy01: 0.8808917197452227 | accuracy01/std: 0.024055891512390983 | f1/_macro: 0.584725050595587 | f1/_micro: 0.8808867197736031 | f1/_weighted: 0.8808678339104173 | loss: 0.5546746102867611 | loss/mean: 0.5546746102867611 | loss/std: 0.20574337724748126 | lr: 5e-06 | momentum: 0.9 | precision/_macro: 0.6136407150872003 | precision/_micro: 0.880891719745223 | precision/_weighted: 0.8922249728740159 | recall/_macro: 0.6026377263572491 | recall/_micro: 0.880891719745223 | recall/_weighted: 0.880891719745223
* Epoch (10/30) lr: 2.5e-06 | momentum: 0.9


11/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (11/30) accuracy01: 0.9538661024679966 | accuracy01/std: 0.03746571645995389 | f1/_macro: 0.6815031923789538 | f1/_micro: 0.9538611025308156 | f1/_weighted: 0.9523288761808419 | loss: 0.1688031548532828 | loss/mean: 0.1688031548532828 | loss/std: 0.13418858347289453 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.6935938329009661 | precision/_micro: 0.9538661025046066 | precision/_weighted: 0.9517768375655162 | recall/_macro: 0.6829202702970076 | recall/_micro: 0.9538661025046066 | recall/_weighted: 0.9538661025046067


11/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (11/30) accuracy01: 0.9039808917197453 | accuracy01/std: 0.021481046831808355 | f1/_macro: 0.6598736684408615 | f1/_micro: 0.9039758917474005 | f1/_weighted: 0.9051197915156549 | loss: 0.4666282222529126 | loss/mean: 0.4666282222529126 | loss/std: 0.17648229759203587 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.6886496629703668 | precision/_micro: 0.9039808917197452 | precision/_weighted: 0.9130474353641488 | recall/_macro: 0.6887355476926589 | recall/_micro: 0.9039808917197452 | recall/_weighted: 0.9039808917197452
* Epoch (11/30) lr: 2.5e-06 | momentum: 0.9


12/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (12/30) accuracy01: 0.9653995769232 | accuracy01/std: 0.033240314999671644 | f1/_macro: 0.7280321120350334 | f1/_micro: 0.9653945769043508 | f1/_weighted: 0.9645941174151106 | loss: 0.12436791843361143 | loss/mean: 0.12436791843361143 | loss/std: 0.13340025514973403 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.7312849641514668 | precision/_micro: 0.9653995768784549 | precision/_weighted: 0.9641662106113358 | recall/_macro: 0.7289413377410069 | recall/_micro: 0.9653995768784549 | recall/_weighted: 0.9653995768784549


12/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (12/30) accuracy01: 0.8998407643312103 | accuracy01/std: 0.021628775095706654 | f1/_macro: 0.6493489824312219 | f1/_micro: 0.8998357643589927 | f1/_weighted: 0.903655998045739 | loss: 0.501487462156138 | loss/mean: 0.501487462156138 | loss/std: 0.2052343637862714 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.6549254994774136 | precision/_micro: 0.8998407643312102 | precision/_weighted: 0.9121779910439116 | recall/_macro: 0.6676251088080933 | recall/_micro: 0.8998407643312102 | recall/_weighted: 0.8998407643312101
* Epoch (12/30) lr: 2.5e-06 | momentum: 0.9


13/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (13/30) accuracy01: 0.965740803975681 | accuracy01/std: 0.03226409009794979 | f1/_macro: 0.731945839047412 | f1/_micro: 0.9657358039568223 | f1/_weighted: 0.9654230323057676 | loss: 0.11864323751830724 | loss/mean: 0.11864323751830724 | loss/std: 0.11440381370009967 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.7302145325448367 | precision/_micro: 0.9657408039309356 | precision/_weighted: 0.9654997899478448 | recall/_macro: 0.7380282622623933 | recall/_micro: 0.9657408039309356 | recall/_weighted: 0.9657408039309356


13/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (13/30) accuracy01: 0.9081210191082801 | accuracy01/std: 0.022755872956619138 | f1/_macro: 0.6671964134181694 | f1/_micro: 0.9081160191358094 | f1/_weighted: 0.9103640024549257 | loss: 0.5209731080349843 | loss/mean: 0.5209731080349843 | loss/std: 0.23840685333719766 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.6619249004508535 | precision/_micro: 0.9081210191082802 | precision/_weighted: 0.9159937188398988 | recall/_macro: 0.6888950714603735 | recall/_micro: 0.9081210191082802 | recall/_weighted: 0.9081210191082802
* Epoch (13/30) lr: 2.5e-06 | momentum: 0.9


14/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (14/30) accuracy01: 0.9716099092783487 | accuracy01/std: 0.030230429558831196 | f1/_macro: 0.7708109677013235 | f1/_micro: 0.9716049092593344 | f1/_weighted: 0.9711440786690121 | loss: 0.10275294122499289 | loss/mean: 0.10275294122499289 | loss/std: 0.12613669542108336 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.7733798094208085 | precision/_micro: 0.971609909233604 | precision/_weighted: 0.9710646612772927 | recall/_macro: 0.7726177340771905 | recall/_micro: 0.971609909233604 | recall/_weighted: 0.971609909233604


14/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (14/30) accuracy01: 0.9128980891719748 | accuracy01/std: 0.02103065727646622 | f1/_macro: 0.6952938261502805 | f1/_micro: 0.9128930891993596 | f1/_weighted: 0.9135259186081018 | loss: 0.5754846395200984 | loss/mean: 0.5754846395200984 | loss/std: 0.2663586674011446 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.7098560479553027 | precision/_micro: 0.9128980891719746 | precision/_weighted: 0.9174896737888891 | recall/_macro: 0.6989202211092616 | recall/_micro: 0.9128980891719746 | recall/_weighted: 0.9128980891719745
* Epoch (14/30) lr: 2.5e-06 | momentum: 0.9


15/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (15/30) accuracy01: 0.9733160445407528 | accuracy01/std: 0.027382083302904932 | f1/_macro: 0.7899474372094162 | f1/_micro: 0.9733110445216929 | f1/_weighted: 0.9732304052858447 | loss: 0.09355656344971558 | loss/mean: 0.09355656344971558 | loss/std: 0.1025081070065209 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.7880763785645912 | precision/_micro: 0.9733160444960076 | precision/_weighted: 0.97343697674876 | recall/_macro: 0.7959127056331601 | recall/_micro: 0.9733160444960076 | recall/_weighted: 0.9733160444960077


15/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (15/30) accuracy01: 0.9089171974522294 | accuracy01/std: 0.02123111826194083 | f1/_macro: 0.6933998362246638 | f1/_micro: 0.9089121974797344 | f1/_weighted: 0.9097820137882143 | loss: 0.5625003712192463 | loss/mean: 0.5625003712192463 | loss/std: 0.23579978438702956 | lr: 2.5e-06 | momentum: 0.9 | precision/_macro: 0.6993579117611362 | precision/_micro: 0.9089171974522293 | precision/_weighted: 0.9159149176791042 | recall/_macro: 0.7160192272884178 | recall/_micro: 0.9089171974522293 | recall/_weighted: 0.9089171974522293
* Epoch (15/30) lr: 1.25e-06 | momentum: 0.9


16/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (16/30) accuracy01: 0.9832798744284449 | accuracy01/std: 0.021159114123581377 | f1/_macro: 0.8249491995400207 | f1/_micro: 0.9832748744538695 | f1/_weighted: 0.9829314873906514 | loss: 0.05745394000805086 | loss/mean: 0.05745394000805086 | loss/std: 0.07331516887815195 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.8309013996412085 | precision/_micro: 0.9832798744284447 | precision/_weighted: 0.9828166606447887 | recall/_macro: 0.8245081241445954 | recall/_micro: 0.9832798744284447 | recall/_weighted: 0.9832798744284447


16/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (16/30) accuracy01: 0.9176751592356689 | accuracy01/std: 0.021631413090795273 | f1/_macro: 0.7098610566665169 | f1/_micro: 0.9176701592629114 | f1/_weighted: 0.9181059683384276 | loss: 0.5553680028885035 | loss/mean: 0.5553680028885035 | loss/std: 0.2618591419784949 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.721530178385869 | precision/_micro: 0.9176751592356688 | precision/_weighted: 0.9219260189453002 | recall/_macro: 0.7267598455355511 | recall/_micro: 0.9176751592356688 | recall/_weighted: 0.9176751592356688
* Epoch (16/30) lr: 1.25e-06 | momentum: 0.9


17/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (17/30) accuracy01: 0.9869651265952372 | accuracy01/std: 0.021726888492638256 | f1/_macro: 0.8649965660750667 | f1/_micro: 0.9869601266205665 | f1/_weighted: 0.986824407419812 | loss: 0.04653604547167374 | loss/mean: 0.04653604547167374 | loss/std: 0.07216950066161829 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.8741550145466196 | precision/_micro: 0.9869651265952365 | precision/_weighted: 0.9869141156716519 | recall/_macro: 0.8635957178387619 | recall/_micro: 0.9869651265952365 | recall/_weighted: 0.9869651265952365


17/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (17/30) accuracy01: 0.9149681528662421 | accuracy01/std: 0.023895180170560346 | f1/_macro: 0.7082874657415356 | f1/_micro: 0.9149631528935651 | f1/_weighted: 0.9161128152693504 | loss: 0.6084977486330992 | loss/mean: 0.6084977486330992 | loss/std: 0.27566038074984306 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.7070149685343167 | precision/_micro: 0.914968152866242 | precision/_weighted: 0.9198507736130461 | recall/_macro: 0.7230804424421016 | recall/_micro: 0.914968152866242 | recall/_weighted: 0.9149681528662421
* Epoch (17/30) lr: 1.25e-06 | momentum: 0.9


18/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (18/30) accuracy01: 0.9891489797311137 | accuracy01/std: 0.018344940442365735 | f1/_macro: 0.8810360816788613 | f1/_micro: 0.9891439797563871 | f1/_weighted: 0.9891454937325018 | loss: 0.03486819107412856 | loss/mean: 0.03486819107412856 | loss/std: 0.05739950210894821 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.8838623758175634 | precision/_micro: 0.989148979731113 | precision/_weighted: 0.9892995018148512 | recall/_macro: 0.8808147567226796 | recall/_micro: 0.989148979731113 | recall/_weighted: 0.9891489797311129


18/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (18/30) accuracy01: 0.9143312101910829 | accuracy01/std: 0.023910829271659318 | f1/_macro: 0.7177156986810477 | f1/_micro: 0.914326210218425 | f1/_weighted: 0.9147720339451957 | loss: 0.6641488531592544 | loss/mean: 0.6641488531592544 | loss/std: 0.3075805038569691 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.7347520055925458 | precision/_micro: 0.9143312101910828 | precision/_weighted: 0.9182135501512403 | recall/_macro: 0.7155090291103253 | recall/_micro: 0.9143312101910828 | recall/_weighted: 0.9143312101910828
* Epoch (18/30) lr: 1.25e-06 | momentum: 0.9


19/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (19/30) accuracy01: 0.989967924657067 | accuracy01/std: 0.016893491172437707 | f1/_macro: 0.8773809524861814 | f1/_micro: 0.9899629246823201 | f1/_weighted: 0.989772953141002 | loss: 0.036034183810030096 | loss/mean: 0.036034183810030096 | loss/std: 0.06778103645145163 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.8757288495233434 | precision/_micro: 0.9899679246570668 | precision/_weighted: 0.989669066729868 | recall/_macro: 0.8801768581713373 | recall/_micro: 0.9899679246570668 | recall/_weighted: 0.9899679246570668


19/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (19/30) accuracy01: 0.9143312101910825 | accuracy01/std: 0.02322450989747919 | f1/_macro: 0.7234568020076806 | f1/_micro: 0.914326210218425 | f1/_weighted: 0.9174525113650281 | loss: 0.6168872042066731 | loss/mean: 0.6168872042066731 | loss/std: 0.27659001338984496 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.7295309862019285 | precision/_micro: 0.9143312101910828 | precision/_weighted: 0.9238106224777664 | recall/_macro: 0.7481896202527519 | recall/_micro: 0.9143312101910828 | recall/_weighted: 0.9143312101910828
* Epoch (19/30) lr: 1.25e-06 | momentum: 0.9


20/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (20/30) accuracy01: 0.9897631884255784 | accuracy01/std: 0.01745568521163819 | f1/_macro: 0.9084890535011423 | f1/_micro: 0.9897581884508369 | f1/_weighted: 0.9898193998590133 | loss: 0.03805083785738291 | loss/mean: 0.03805083785738291 | loss/std: 0.08007878619619666 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.9070854822750662 | precision/_micro: 0.9897631884255784 | precision/_weighted: 0.9899293103634598 | recall/_macro: 0.9109881332126096 | recall/_micro: 0.9897631884255784 | recall/_weighted: 0.9897631884255783


20/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (20/30) accuracy01: 0.90843949044586 | accuracy01/std: 0.027257763435795834 | f1/_macro: 0.7178613225475372 | f1/_micro: 0.9084344904733795 | f1/_weighted: 0.910698244199759 | loss: 0.6688306294429075 | loss/mean: 0.6688306294429075 | loss/std: 0.27819246975712286 | lr: 1.25e-06 | momentum: 0.9 | precision/_macro: 0.7180796130295833 | precision/_micro: 0.9084394904458599 | precision/_weighted: 0.9156255960631661 | recall/_macro: 0.736668801083157 | recall/_micro: 0.9084394904458599 | recall/_weighted: 0.9084394904458599
* Epoch (20/30) lr: 6.25e-07 | momentum: 0.9


21/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (21/30) accuracy01: 0.9928342318979044 | accuracy01/std: 0.014228605142162216 | f1/_macro: 0.9255993002058821 | f1/_micro: 0.9928292319230851 | f1/_weighted: 0.9927276660841604 | loss: 0.025183909889376514 | loss/mean: 0.025183909889376514 | loss/std: 0.052337033150688834 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.9313540633067185 | precision/_micro: 0.9928342318979049 | precision/_weighted: 0.9927198464473298 | recall/_macro: 0.9229063638358727 | recall/_micro: 0.9928342318979049 | recall/_weighted: 0.9928342318979049


21/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (21/30) accuracy01: 0.9168789808917197 | accuracy01/std: 0.02364083835256391 | f1/_macro: 0.7284962506063108 | f1/_micro: 0.916873980918986 | f1/_weighted: 0.9195135000846693 | loss: 0.6012752803267948 | loss/mean: 0.6012752803267948 | loss/std: 0.28149315242106343 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.723838420532281 | precision/_micro: 0.9168789808917197 | precision/_weighted: 0.9249391931156283 | recall/_macro: 0.7522647362780871 | recall/_micro: 0.9168789808917197 | recall/_weighted: 0.9168789808917198
* Epoch (21/30) lr: 6.25e-07 | momentum: 0.9


22/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (22/30) accuracy01: 0.9939944038763395 | accuracy01/std: 0.013471356507744147 | f1/_macro: 0.9178689213723237 | f1/_micro: 0.9939894039014902 | f1/_weighted: 0.9938860772563451 | loss: 0.019686334273263693 | loss/mean: 0.019686334273263693 | loss/std: 0.04168763270041428 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.9203976083539428 | precision/_micro: 0.9939944038763393 | precision/_weighted: 0.9938322186278503 | recall/_macro: 0.9162679311156717 | recall/_micro: 0.9939944038763393 | recall/_weighted: 0.9939944038763394


22/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (22/30) accuracy01: 0.9213375796178342 | accuracy01/std: 0.021773208218990283 | f1/_macro: 0.7456789966971 | f1/_micro: 0.9213325796449687 | f1/_weighted: 0.9238441037693811 | loss: 0.6057048362531481 | loss/mean: 0.6057048362531481 | loss/std: 0.27886822797284416 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.7392578994416368 | precision/_micro: 0.9213375796178344 | precision/_weighted: 0.9285154664719364 | recall/_macro: 0.7713789376814103 | recall/_micro: 0.9213375796178344 | recall/_weighted: 0.9213375796178345
* Epoch (22/30) lr: 6.25e-07 | momentum: 0.9


23/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (23/30) accuracy01: 0.9952910666757656 | accuracy01/std: 0.01155508120969964 | f1/_macro: 0.9562598987653217 | f1/_micro: 0.9952860667008842 | f1/_weighted: 0.9952792384718353 | loss: 0.015336740710751142 | loss/mean: 0.015336740710751142 | loss/std: 0.03521051898695285 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.9645319842128172 | precision/_micro: 0.9952910666757661 | precision/_weighted: 0.9953610616012477 | recall/_macro: 0.9534014424589452 | recall/_micro: 0.9952910666757661 | recall/_weighted: 0.9952910666757662


23/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (23/30) accuracy01: 0.9135350318471339 | accuracy01/std: 0.02413958879624367 | f1/_macro: 0.722652825328196 | f1/_micro: 0.9135300318744998 | f1/_weighted: 0.9175189278516522 | loss: 0.6389931479077432 | loss/mean: 0.6389931479077432 | loss/std: 0.28184972490148014 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.7280577448932167 | precision/_micro: 0.9135350318471338 | precision/_weighted: 0.9244857646732854 | recall/_macro: 0.7438766484813172 | recall/_micro: 0.9135350318471338 | recall/_weighted: 0.9135350318471338
* Epoch (23/30) lr: 6.25e-07 | momentum: 0.9


24/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (24/30) accuracy01: 0.9954275574967579 | accuracy01/std: 0.013159691266164867 | f1/_macro: 0.9412531650212398 | f1/_micro: 0.995422557521873 | f1/_weighted: 0.9953690976955023 | loss: 0.014965849382317754 | loss/mean: 0.014965849382317754 | loss/std: 0.043956465677315606 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.9397692460659309 | precision/_micro: 0.9954275574967584 | precision/_weighted: 0.9953538902616796 | recall/_macro: 0.9432386936569528 | recall/_micro: 0.9954275574967584 | recall/_weighted: 0.9954275574967584


24/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (24/30) accuracy01: 0.9149681528662419 | accuracy01/std: 0.023208398025352996 | f1/_macro: 0.7206677675924725 | f1/_micro: 0.9149631528935651 | f1/_weighted: 0.9192967793879251 | loss: 0.6659649683411712 | loss/mean: 0.6659649683411712 | loss/std: 0.2951854972526099 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.7275343395176466 | precision/_micro: 0.914968152866242 | precision/_weighted: 0.9276520934635915 | recall/_macro: 0.7488728997179009 | recall/_micro: 0.914968152866242 | recall/_weighted: 0.9149681528662421
* Epoch (24/30) lr: 6.25e-07 | momentum: 0.9


25/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (25/30) accuracy01: 0.9959052753702314 | accuracy01/std: 0.010942222003474434 | f1/_macro: 0.9559618704721583 | f1/_micro: 0.995900275395334 | f1/_weighted: 0.995930925086737 | loss: 0.015369953613538804 | loss/mean: 0.015369953613538804 | loss/std: 0.03673922200409467 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.9537921286665724 | precision/_micro: 0.9959052753702313 | precision/_weighted: 0.9960231909686729 | recall/_macro: 0.960502492069081 | recall/_micro: 0.9959052753702313 | recall/_weighted: 0.9959052753702315


25/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (25/30) accuracy01: 0.9176751592356688 | accuracy01/std: 0.021166256251433507 | f1/_macro: 0.7240001987978694 | f1/_micro: 0.9176701592629114 | f1/_weighted: 0.9206349702803929 | loss: 0.6460245460841306 | loss/mean: 0.6460245460841306 | loss/std: 0.2912056612271297 | lr: 6.25e-07 | momentum: 0.9 | precision/_macro: 0.7292833596091538 | precision/_micro: 0.9176751592356688 | precision/_weighted: 0.9267854379371937 | recall/_macro: 0.7428641487895573 | recall/_micro: 0.9176751592356688 | recall/_weighted: 0.9176751592356688
* Epoch (25/30) lr: 3.125e-07 | momentum: 0.9


26/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (26/30) accuracy01: 0.996860711117177 | accuracy01/std: 0.010052390442224769 | f1/_macro: 0.950732920490717 | f1/_micro: 0.9968557111422559 | f1/_weighted: 0.9967942638607754 | loss: 0.012563153033255035 | loss/mean: 0.012563153033255035 | loss/std: 0.03859749359179221 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.9623104578159781 | precision/_micro: 0.9968607111171773 | precision/_weighted: 0.9968067440054431 | recall/_macro: 0.945023354903903 | recall/_micro: 0.9968607111171773 | recall/_weighted: 0.9968607111171773


26/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (26/30) accuracy01: 0.9165605095541399 | accuracy01/std: 0.02147193231704691 | f1/_macro: 0.7210371327172522 | f1/_micro: 0.9165555095814159 | f1/_weighted: 0.9201269613683355 | loss: 0.6593658854247658 | loss/mean: 0.6593658854247658 | loss/std: 0.3193752141109012 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.7198136609642377 | precision/_micro: 0.9165605095541401 | precision/_weighted: 0.9259900334058028 | recall/_macro: 0.7423075592422469 | recall/_micro: 0.9165605095541401 | recall/_weighted: 0.9165605095541401
* Epoch (26/30) lr: 3.125e-07 | momentum: 0.9


27/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (27/30) accuracy01: 0.9975431652221387 | accuracy01/std: 0.00937056000237924 | f1/_macro: 0.973419403913419 | f1/_micro: 0.9975381652472002 | f1/_weighted: 0.9975421351216092 | loss: 0.008606763044386697 | loss/mean: 0.008606763044386697 | loss/std: 0.029945919960106548 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.9720303409448258 | precision/_micro: 0.9975431652221388 | precision/_weighted: 0.9975612942952096 | recall/_macro: 0.9749514155322959 | recall/_micro: 0.9975431652221388 | recall/_weighted: 0.9975431652221387


27/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (27/30) accuracy01: 0.9187898089171976 | accuracy01/std: 0.022101625547699528 | f1/_macro: 0.7325904276453341 | f1/_micro: 0.918784808944407 | f1/_weighted: 0.9228734069714462 | loss: 0.647600011442118 | loss/mean: 0.647600011442118 | loss/std: 0.3046503867149971 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.7267324920099778 | precision/_micro: 0.9187898089171974 | precision/_weighted: 0.9293708640304627 | recall/_macro: 0.760050166239357 | recall/_micro: 0.9187898089171974 | recall/_weighted: 0.9187898089171975
* Epoch (27/30) lr: 3.125e-07 | momentum: 0.9


28/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (28/30) accuracy01: 0.9977479014536276 | accuracy01/std: 0.008081496696746267 | f1/_macro: 0.9666371621297423 | f1/_micro: 0.9977429014786835 | f1/_weighted: 0.9977282077380997 | loss: 0.00849584680376636 | loss/mean: 0.00849584680376636 | loss/std: 0.02509412943480479 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.9789454731638794 | precision/_micro: 0.9977479014536272 | precision/_weighted: 0.9978876204426004 | recall/_macro: 0.9625418910022933 | recall/_micro: 0.9977479014536272 | recall/_weighted: 0.9977479014536274


28/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (28/30) accuracy01: 0.9197452229299362 | accuracy01/std: 0.022353662629403265 | f1/_macro: 0.7362275945746495 | f1/_micro: 0.9197402229571175 | f1/_weighted: 0.9220278631227924 | loss: 0.6600656903853083 | loss/mean: 0.6600656903853083 | loss/std: 0.3119184368129308 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.74468175655914 | precision/_micro: 0.9197452229299363 | precision/_weighted: 0.9271770043307335 | recall/_macro: 0.7472272282389569 | recall/_micro: 0.9197452229299363 | recall/_weighted: 0.9197452229299363
* Epoch (28/30) lr: 3.125e-07 | momentum: 0.9


29/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (29/30) accuracy01: 0.998089128506108 | accuracy01/std: 0.007767416553007332 | f1/_macro: 0.9757656835345292 | f1/_micro: 0.9980841285311557 | f1/_weighted: 0.9980913484135701 | loss: 0.007148500647579816 | loss/mean: 0.007148500647579816 | loss/std: 0.02088125861217212 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.9727109235098491 | precision/_micro: 0.998089128506108 | precision/_weighted: 0.9981396084435564 | recall/_macro: 0.979758367143894 | recall/_micro: 0.998089128506108 | recall/_weighted: 0.9980891285061081


29/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (29/30) accuracy01: 0.9194267515923567 | accuracy01/std: 0.02388290569719119 | f1/_macro: 0.7305164334287401 | f1/_micro: 0.9194217516195473 | f1/_weighted: 0.9221028081501816 | loss: 0.6668772015601967 | loss/mean: 0.6668772015601967 | loss/std: 0.3127477589247396 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.7292908167072221 | precision/_micro: 0.9194267515923567 | precision/_weighted: 0.9271200185600691 | recall/_macro: 0.7514018947810511 | recall/_micro: 0.9194267515923567 | recall/_weighted: 0.9194267515923568
* Epoch (29/30) lr: 3.125e-07 | momentum: 0.9


30/30 * Epoch (train):   0%|          | 0/458 [00:00<?, ?it/s]

train (30/30) accuracy01: 0.997747901453627 | accuracy01/std: 0.008593127916584222 | f1/_macro: 0.9816266402281535 | f1/_micro: 0.9977429014786835 | f1/_weighted: 0.9977601710393026 | loss: 0.007360248117538803 | loss/mean: 0.007360248117538803 | loss/std: 0.023250174255772575 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.9788412855710024 | precision/_micro: 0.9977479014536272 | precision/_weighted: 0.9978329735712502 | recall/_macro: 0.9855915819848254 | recall/_micro: 0.9977479014536272 | recall/_weighted: 0.9977479014536274


30/30 * Epoch (valid):   0%|          | 0/50 [00:00<?, ?it/s]

valid (30/30) accuracy01: 0.9195859872611466 | accuracy01/std: 0.02334563948756022 | f1/_macro: 0.7371632765368668 | f1/_micro: 0.9195809872883325 | f1/_weighted: 0.9222059463966924 | loss: 0.6821966795974473 | loss/mean: 0.6821966795974473 | loss/std: 0.3267253254208207 | lr: 3.125e-07 | momentum: 0.9 | precision/_macro: 0.747491542908389 | precision/_micro: 0.9195859872611465 | precision/_weighted: 0.9272567892566954 | recall/_macro: 0.7435132081459291 | recall/_micro: 0.9195859872611465 | recall/_weighted: 0.9195859872611466
* Epoch (30/30) lr: 1.5625e-07 | momentum: 0.9
Top models:
logs/resnet18/checkpoints/model.0030.pth	30.0000


In [20]:
%load_ext tensorboard
%tensorboard --logdir logs

In [21]:
import os

class TestDataset(Dataset):
    def __init__(self, root, transforms):
        super().__init__()
        self.files = sorted(list(Path(root).rglob('*.jpg')), key=self.extract_image_number)
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        x = self.load_sample(self.files[index])
        return self.transforms(x), 0

    def load_sample(self, file):
        image = Image.open(file)
        image.load()
        return image

    def extract_image_number(self, path):
        filename = os.path.basename(path)
        return int(filename[3:-4])

In [22]:
test_transforms = transforms.Compose([
    transforms.Resize((RESCALE_SIZE, RESCALE_SIZE)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x / 255),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [23]:
test_dir = "./testset/testset"
test_dataset = TestDataset(test_dir, test_transforms)

In [24]:
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128
)

In [25]:
result = []

for prediction in runner.predict_loader(engine=dl.GPUEngine("cuda"), loader=test_dataloader):
    result += prediction['logits'].detach().cpu().numpy().argmax(1).tolist()

In [26]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
label_encoder.fit(image_datasets.classes)

In [27]:
import pandas as pd

sample_submission_path = 'sample_submission.csv'
submission = pd.read_csv(sample_submission_path)
submission.head()

Unnamed: 0,Id,Expected
0,img0.jpg,bart_simpson
1,img1.jpg,bart_simpson
2,img2.jpg,bart_simpson
3,img3.jpg,bart_simpson
4,img4.jpg,bart_simpson


In [28]:
submission['Expected'] = label_encoder.inverse_transform(result)
submission

Unnamed: 0,Id,Expected
0,img0.jpg,nelson_muntz
1,img1.jpg,bart_simpson
2,img2.jpg,mayor_quimby
3,img3.jpg,nelson_muntz
4,img4.jpg,lisa_simpson
...,...,...
986,img986.jpg,sideshow_bob
987,img987.jpg,nelson_muntz
988,img988.jpg,ned_flanders
989,img989.jpg,charles_montgomery_burns


In [29]:
submission_path = 'submission.csv'
submission.to_csv(submission_path, index=None)