In [1]:
import gc
import torch
import pandas as pd
import utils.constants as const

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
params_df = pd.read_csv('results/params.csv')
avg_params_series = params_df.mean(axis=0)
avg_params_df = pd.DataFrame(
    data=[avg_params_series.values],
    columns=avg_params_series.index,
)
params_df.mean(axis=0)

mom      0.741241
alpha    0.712219
eps      0.007894
wd       0.077542
f1       0.824052
dtype: float64

In [4]:
avg_params_df

Unnamed: 0,mom,alpha,eps,wd,f1
0,0.741241,0.712219,0.007894,0.077542,0.824052


In [5]:
mom = avg_params_df['mom'].item()
alpha = avg_params_df['alpha'].item()
eps = avg_params_df['eps'].item()
wd = avg_params_df['wd'].item()

In [6]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(const.NUM_CLASSES)
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [7]:
for fold in range(1, 6):
    model_dir = f'model_weights/fold_{fold}'
    data_csv_path = f'data/splits/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(const.NUM_CLASSES)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.218,4.538208,0.108974,0.107445,01:06
1,3.849785,3.210979,0.166667,0.153232,01:00
2,3.206007,2.948654,0.471154,0.314615,00:50
3,2.733851,2.158173,0.663462,0.511302,00:48
4,2.030416,1.390083,0.730769,0.563337,00:49
5,1.440167,0.974477,0.858974,0.679199,00:49
6,1.051142,0.715534,0.855769,0.7103,00:48
7,0.760756,0.611793,0.88141,0.719715,00:46
8,0.657553,0.423815,0.929487,0.769274,00:44
9,0.576446,1.015526,0.817308,0.636543,00:46


  return F.conv2d(input, weight, bias, self.stride,


Better model found at epoch 0 with macro_f1 value: 0.10744549568078977.
Better model found at epoch 1 with macro_f1 value: 0.15323157095404455.
Better model found at epoch 2 with macro_f1 value: 0.31461526965732245.
Better model found at epoch 3 with macro_f1 value: 0.5113015902517793.
Better model found at epoch 4 with macro_f1 value: 0.5633369219390725.
Better model found at epoch 5 with macro_f1 value: 0.6791994944626524.
Better model found at epoch 6 with macro_f1 value: 0.7103000381261252.
Better model found at epoch 7 with macro_f1 value: 0.7197145621185793.
Better model found at epoch 8 with macro_f1 value: 0.7692743199237595.
Better model found at epoch 10 with macro_f1 value: 0.8330626588386372.
Better model found at epoch 19 with macro_f1 value: 0.8655815117904999.
Better model found at epoch 21 with macro_f1 value: 0.9197506513745828.
Better model found at epoch 26 with macro_f1 value: 0.9371114198700405.
Better model found at epoch 31 with macro_f1 value: 0.9373957741754819

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.264956,4.55495,0.202572,0.084449,00:44
1,3.860215,3.19259,0.347267,0.226875,00:44
2,3.183466,2.844607,0.546624,0.429438,00:44
3,2.635085,2.038682,0.697749,0.547006,00:44
4,1.91246,1.349947,0.790997,0.646482,00:43
5,1.390773,0.831477,0.884244,0.719413,00:46
6,1.029158,0.581282,0.916399,0.694008,00:45
7,0.816825,0.505411,0.926045,0.741604,00:43
8,0.684918,0.414225,0.916399,0.707785,00:43
9,0.609217,0.725482,0.826367,0.692236,00:44


Better model found at epoch 0 with macro_f1 value: 0.0844492464615216.
Better model found at epoch 1 with macro_f1 value: 0.22687469881019692.
Better model found at epoch 2 with macro_f1 value: 0.4294382818968678.
Better model found at epoch 3 with macro_f1 value: 0.5470056653147153.
Better model found at epoch 4 with macro_f1 value: 0.6464817400525825.
Better model found at epoch 5 with macro_f1 value: 0.7194127749879228.
Better model found at epoch 7 with macro_f1 value: 0.7416041641673085.
Better model found at epoch 10 with macro_f1 value: 0.8216792960906641.
Better model found at epoch 17 with macro_f1 value: 0.8778031861023219.
Better model found at epoch 28 with macro_f1 value: 0.9549786036664992.
Better model found at epoch 32 with macro_f1 value: 0.9673493221465422.
Better model found at epoch 35 with macro_f1 value: 0.9734467378902604.
Better model found at epoch 39 with macro_f1 value: 0.9822158707474893.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.330852,4.607382,0.157556,0.056198,00:57
1,3.891535,3.214992,0.344051,0.15253,00:57
2,3.204437,2.957956,0.508039,0.327165,00:57
3,2.703763,2.216852,0.623794,0.414269,00:58
4,1.969512,1.418118,0.77492,0.517611,00:53
5,1.389579,0.995804,0.768489,0.550069,00:52
6,1.005738,0.620149,0.900322,0.682332,00:55
7,0.777403,0.83814,0.797428,0.608672,00:55
8,0.689747,0.512068,0.913183,0.69729,00:53
9,0.624452,0.604392,0.88746,0.706428,00:54


Better model found at epoch 0 with macro_f1 value: 0.05619800401024788.
Better model found at epoch 1 with macro_f1 value: 0.15253007819595904.
Better model found at epoch 2 with macro_f1 value: 0.32716507939483347.
Better model found at epoch 3 with macro_f1 value: 0.4142688228445481.
Better model found at epoch 4 with macro_f1 value: 0.5176113246404971.
Better model found at epoch 5 with macro_f1 value: 0.5500693032369409.
Better model found at epoch 6 with macro_f1 value: 0.6823319894619996.
Better model found at epoch 8 with macro_f1 value: 0.6972903469073506.
Better model found at epoch 9 with macro_f1 value: 0.7064280018084957.
Better model found at epoch 10 with macro_f1 value: 0.8092482902110991.
Better model found at epoch 18 with macro_f1 value: 0.8397614105612403.
Better model found at epoch 19 with macro_f1 value: 0.9230795432572566.
Better model found at epoch 20 with macro_f1 value: 0.9271350645689408.
Better model found at epoch 39 with macro_f1 value: 0.9353501533994568

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.323751,4.580914,0.170418,0.135106,01:01
1,3.895346,3.211316,0.382637,0.240232,00:58
2,3.209976,2.915032,0.508039,0.363374,00:53
3,2.686677,2.155744,0.652733,0.49869,00:51
4,1.949023,1.326573,0.762058,0.616308,00:53
5,1.383703,1.073965,0.752412,0.645141,00:59
6,1.018063,0.703507,0.845659,0.70018,00:52
7,0.837501,0.54544,0.903537,0.732005,00:55
8,0.714138,0.590157,0.858521,0.730049,00:58
9,0.628869,0.771521,0.832797,0.706548,00:56


Better model found at epoch 0 with macro_f1 value: 0.13510587993580037.
Better model found at epoch 1 with macro_f1 value: 0.24023239056506232.
Better model found at epoch 2 with macro_f1 value: 0.3633735238902109.
Better model found at epoch 3 with macro_f1 value: 0.49868963194783894.
Better model found at epoch 4 with macro_f1 value: 0.6163084344865083.
Better model found at epoch 5 with macro_f1 value: 0.6451407417233027.
Better model found at epoch 6 with macro_f1 value: 0.7001797744912264.
Better model found at epoch 7 with macro_f1 value: 0.73200507148967.
Better model found at epoch 10 with macro_f1 value: 0.7608982359334044.
Better model found at epoch 12 with macro_f1 value: 0.7858896648808382.
Better model found at epoch 14 with macro_f1 value: 0.7979118949683601.
Better model found at epoch 18 with macro_f1 value: 0.8631406654424293.
Better model found at epoch 24 with macro_f1 value: 0.8886867466198795.
Better model found at epoch 33 with macro_f1 value: 0.9206367825722666.

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.409146,4.655257,0.176849,0.078694,00:50
1,3.928831,3.218301,0.356913,0.194193,00:51
2,3.192704,2.838488,0.530547,0.309762,00:51
3,2.634944,2.085931,0.681672,0.461723,00:52
4,1.975378,1.387347,0.797428,0.561576,00:49
5,1.404483,0.925521,0.836013,0.661933,00:50
6,1.001878,0.69998,0.874598,0.709673,00:51
7,0.804687,0.615493,0.88746,0.6954,00:49
8,0.648958,0.415318,0.92926,0.753629,00:51
9,0.626619,0.941507,0.81672,0.662767,00:52


Better model found at epoch 0 with macro_f1 value: 0.07869417847822817.
Better model found at epoch 1 with macro_f1 value: 0.19419280109532508.
Better model found at epoch 2 with macro_f1 value: 0.3097617502054242.
Better model found at epoch 3 with macro_f1 value: 0.4617225478977442.
Better model found at epoch 4 with macro_f1 value: 0.5615759963470389.
Better model found at epoch 5 with macro_f1 value: 0.66193334914984.
Better model found at epoch 6 with macro_f1 value: 0.7096727721558488.
Better model found at epoch 8 with macro_f1 value: 0.7536293091319175.
Better model found at epoch 10 with macro_f1 value: 0.7721399038635703.
Better model found at epoch 15 with macro_f1 value: 0.8552505764395204.
Better model found at epoch 22 with macro_f1 value: 0.8705607985565478.
Better model found at epoch 24 with macro_f1 value: 0.9059264637524154.
Better model found at epoch 29 with macro_f1 value: 0.9088099528484045.
Better model found at epoch 31 with macro_f1 value: 0.915372908095095.
B