In [1]:
import gc
import torch
import pandas as pd
import utils.constants as const

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.13 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
mom = 0.9531249799613819
alpha = 0.7216872276263672
eps = 0.02241164469571523
wd = 0.04258532691804346

In [4]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(const.NUM_CLASSES)
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [5]:
for fold in range(1, 6):
    model_dir = f'model_weights/fold_{fold}'
    data_csv_path = f'data/splits/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(const.NUM_CLASSES)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.602891,4.957448,0.134615,0.087594,01:01
1,4.197834,3.431082,0.25641,0.123054,00:39
2,3.369275,3.058239,0.375,0.257804,00:39
3,3.042055,2.819066,0.548077,0.422615,00:38
4,2.641143,2.2003,0.615385,0.441391,00:38
5,2.053046,1.547862,0.698718,0.550259,00:38
6,1.551101,1.050462,0.798077,0.675129,00:38
7,1.167136,0.906231,0.830128,0.660052,00:38
8,0.932619,0.654148,0.878205,0.726767,00:38
9,0.777766,0.846382,0.807692,0.676171,00:38


Better model found at epoch 0 with macro_f1 value: 0.08759427220261469.
Better model found at epoch 1 with macro_f1 value: 0.12305363275281869.
Better model found at epoch 2 with macro_f1 value: 0.25780449574818604.
Better model found at epoch 3 with macro_f1 value: 0.422615306915192.
Better model found at epoch 4 with macro_f1 value: 0.4413906047514279.
Better model found at epoch 5 with macro_f1 value: 0.5502585477070604.
Better model found at epoch 6 with macro_f1 value: 0.6751290536839707.
Better model found at epoch 8 with macro_f1 value: 0.7267666658694192.
Better model found at epoch 10 with macro_f1 value: 0.7413872537125428.
Better model found at epoch 11 with macro_f1 value: 0.7622021173042627.
Better model found at epoch 12 with macro_f1 value: 0.7980131227072272.
Better model found at epoch 16 with macro_f1 value: 0.8208172846055394.
Better model found at epoch 17 with macro_f1 value: 0.8315511091855177.
Better model found at epoch 20 with macro_f1 value: 0.8824748184048292

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.382487,4.761877,0.118971,0.04743,00:38
1,4.078896,3.44423,0.170418,0.07318,00:38
2,3.330579,3.059434,0.376206,0.284803,00:38
3,3.020276,2.812936,0.578778,0.368829,00:40
4,2.61693,2.128139,0.63344,0.454982,00:40
5,2.063987,1.559386,0.726688,0.560635,00:39
6,1.555974,1.085812,0.807074,0.612801,00:39
7,1.158152,0.840473,0.845659,0.68078,00:39
8,0.97139,0.678199,0.868167,0.716872,00:39
9,0.804987,0.583359,0.874598,0.742121,00:40


Better model found at epoch 0 with macro_f1 value: 0.047430177996105646.
Better model found at epoch 1 with macro_f1 value: 0.07318035056885888.
Better model found at epoch 2 with macro_f1 value: 0.2848026379047131.
Better model found at epoch 3 with macro_f1 value: 0.36882861670381684.
Better model found at epoch 4 with macro_f1 value: 0.45498216399215113.
Better model found at epoch 5 with macro_f1 value: 0.5606345159372673.
Better model found at epoch 6 with macro_f1 value: 0.6128009890815105.
Better model found at epoch 7 with macro_f1 value: 0.680780127326264.
Better model found at epoch 8 with macro_f1 value: 0.7168719560355283.
Better model found at epoch 9 with macro_f1 value: 0.7421213909787098.
Better model found at epoch 13 with macro_f1 value: 0.773230903657351.
Better model found at epoch 15 with macro_f1 value: 0.7829729721997039.
Better model found at epoch 17 with macro_f1 value: 0.7940495934894857.
Better model found at epoch 18 with macro_f1 value: 0.8342303025789201.

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.472619,4.82907,0.189711,0.073038,00:37
1,4.119749,3.421255,0.29582,0.109524,00:37
2,3.350256,3.058428,0.398714,0.218115,00:37
3,3.046516,2.821576,0.514469,0.310493,00:37
4,2.651398,2.215427,0.607717,0.397191,00:38
5,2.056003,1.490474,0.768489,0.48786,00:37
6,1.538081,1.168326,0.749196,0.468394,00:38
7,1.161356,0.814226,0.836013,0.579415,00:38
8,0.903536,0.694141,0.868167,0.618889,00:38
9,0.753425,0.768245,0.836013,0.660817,00:38


Better model found at epoch 0 with macro_f1 value: 0.07303793363184323.
Better model found at epoch 1 with macro_f1 value: 0.10952374550121814.
Better model found at epoch 2 with macro_f1 value: 0.21811464582075998.
Better model found at epoch 3 with macro_f1 value: 0.31049279080682324.
Better model found at epoch 4 with macro_f1 value: 0.39719084242321795.
Better model found at epoch 5 with macro_f1 value: 0.48785986655330343.
Better model found at epoch 7 with macro_f1 value: 0.5794152151552001.
Better model found at epoch 8 with macro_f1 value: 0.6188892079802346.
Better model found at epoch 9 with macro_f1 value: 0.6608167776436147.
Better model found at epoch 11 with macro_f1 value: 0.6748756107902579.
Better model found at epoch 12 with macro_f1 value: 0.8172111494398262.
Better model found at epoch 15 with macro_f1 value: 0.8225774406718704.
Better model found at epoch 18 with macro_f1 value: 0.8257956776728258.
Better model found at epoch 19 with macro_f1 value: 0.8276685056765

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.521451,4.920293,0.173633,0.072467,00:41
1,4.142995,3.425236,0.299035,0.178588,00:41
2,3.326428,3.019375,0.427653,0.303338,00:41
3,2.988019,2.729686,0.601286,0.463933,00:41
4,2.549215,2.054034,0.649518,0.522563,00:41
5,1.973051,1.436698,0.800643,0.645506,00:41
6,1.487794,0.985516,0.829582,0.679276,00:41
7,1.147141,0.811767,0.864952,0.674822,00:41
8,0.899736,0.702837,0.864952,0.778522,00:42
9,0.733875,0.71357,0.823151,0.647288,00:41


Better model found at epoch 0 with macro_f1 value: 0.07246690538259729.
Better model found at epoch 1 with macro_f1 value: 0.1785875231710116.
Better model found at epoch 2 with macro_f1 value: 0.3033382533158861.
Better model found at epoch 3 with macro_f1 value: 0.463932536659538.
Better model found at epoch 4 with macro_f1 value: 0.5225628791781217.
Better model found at epoch 5 with macro_f1 value: 0.6455057786281835.
Better model found at epoch 6 with macro_f1 value: 0.6792763506589874.
Better model found at epoch 8 with macro_f1 value: 0.7785219965319287.
Better model found at epoch 16 with macro_f1 value: 0.8172224381112427.
Better model found at epoch 20 with macro_f1 value: 0.8356935290560004.
Better model found at epoch 25 with macro_f1 value: 0.8925277283538874.
Better model found at epoch 45 with macro_f1 value: 0.9114522467701598.
Better model found at epoch 49 with macro_f1 value: 0.9344420840946097.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.495385,4.906269,0.118971,0.064725,00:42
1,4.136719,3.443349,0.144695,0.076139,00:41
2,3.359202,3.081496,0.414791,0.237892,00:40
3,3.041584,2.826746,0.498392,0.290864,00:39
4,2.659464,2.15398,0.691318,0.462978,00:40
5,2.091464,1.481249,0.742765,0.572938,00:40
6,1.538229,1.086415,0.829582,0.652126,00:40
7,1.180183,0.840393,0.829582,0.663254,00:40
8,0.919647,0.833486,0.823151,0.66427,00:40
9,0.74695,0.594523,0.871383,0.673791,00:41


Better model found at epoch 0 with macro_f1 value: 0.06472482796027462.
Better model found at epoch 1 with macro_f1 value: 0.07613868783261755.
Better model found at epoch 2 with macro_f1 value: 0.23789155438668308.
Better model found at epoch 3 with macro_f1 value: 0.29086412512566284.
Better model found at epoch 4 with macro_f1 value: 0.4629781891555092.
Better model found at epoch 5 with macro_f1 value: 0.5729382604606704.
Better model found at epoch 6 with macro_f1 value: 0.6521256190015141.
Better model found at epoch 7 with macro_f1 value: 0.6632543688354897.
Better model found at epoch 8 with macro_f1 value: 0.6642704797746662.
Better model found at epoch 9 with macro_f1 value: 0.6737912461824348.
Better model found at epoch 10 with macro_f1 value: 0.6921715728378842.
Better model found at epoch 11 with macro_f1 value: 0.7261251322491215.
Better model found at epoch 13 with macro_f1 value: 0.801691149845972.
Better model found at epoch 15 with macro_f1 value: 0.8293822592768068.