In [1]:
import gc
import torch
import pandas as pd

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
params_df = pd.read_csv('results/params.csv')
avg_params_series = params_df.mean(axis=0)
avg_params_df = pd.DataFrame(
    data=[avg_params_series.values],
    columns=avg_params_series.index,
)

mom = avg_params_df['mom'].item()
alpha = avg_params_df['alpha'].item()
eps = avg_params_df['eps'].item()
wd = avg_params_df['wd'].item()

In [4]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(7)
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [5]:
for fold in range(1, 6):
    model_dir = f'model_weights/comparison/aedes_only/fold_{fold}'
    data_csv_path = f'data/comparison/aedes_only/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(7)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.370159,3.957532,0.220779,0.073264,00:51
1,3.624791,3.02009,0.354978,0.12331,00:44
2,3.070948,2.741557,0.450216,0.259687,00:31
3,2.741536,2.402056,0.5671,0.325461,00:31
4,2.337744,1.840602,0.649351,0.418156,00:31
5,1.852482,1.226323,0.744589,0.541991,00:31
6,1.407625,1.083843,0.805195,0.59756,00:31
7,1.133464,0.753535,0.848485,0.634756,00:31
8,0.919296,0.592731,0.87013,0.695073,00:32
9,0.78568,0.586145,0.865801,0.745666,00:32


  return F.conv2d(input, weight, bias, self.stride,


Better model found at epoch 0 with macro_f1 value: 0.07326416551980461.
Better model found at epoch 1 with macro_f1 value: 0.1233103431422293.
Better model found at epoch 2 with macro_f1 value: 0.2596866481798754.
Better model found at epoch 3 with macro_f1 value: 0.32546103351378314.
Better model found at epoch 4 with macro_f1 value: 0.41815564613321465.
Better model found at epoch 5 with macro_f1 value: 0.5419913932960719.
Better model found at epoch 6 with macro_f1 value: 0.5975603320381149.
Better model found at epoch 7 with macro_f1 value: 0.634755551710522.
Better model found at epoch 8 with macro_f1 value: 0.6950727695935737.
Better model found at epoch 9 with macro_f1 value: 0.745665949705949.
Better model found at epoch 12 with macro_f1 value: 0.8013967431844056.
Better model found at epoch 16 with macro_f1 value: 0.8566713678569566.
Better model found at epoch 17 with macro_f1 value: 0.8635145025820656.
Better model found at epoch 24 with macro_f1 value: 0.8752424932625156.
B

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.411305,4.017834,0.190476,0.077029,00:56
1,3.652268,3.077077,0.255411,0.130642,00:49
2,3.083306,2.769945,0.393939,0.227788,00:49
3,2.748841,2.421012,0.545455,0.297323,00:50
4,2.375962,1.866805,0.662338,0.42264,00:52
5,1.893718,1.461197,0.692641,0.447895,00:49
6,1.428232,0.975113,0.787879,0.613306,00:53
7,1.158321,0.639722,0.922078,0.731576,00:51
8,0.903219,1.212259,0.770563,0.582213,00:53
9,0.8026,0.944813,0.809524,0.553516,00:54


Better model found at epoch 0 with macro_f1 value: 0.07702911913726676.
Better model found at epoch 1 with macro_f1 value: 0.13064166907123445.
Better model found at epoch 2 with macro_f1 value: 0.22778820561810004.
Better model found at epoch 3 with macro_f1 value: 0.2973232607550208.
Better model found at epoch 4 with macro_f1 value: 0.4226396860721149.
Better model found at epoch 5 with macro_f1 value: 0.44789458363259704.
Better model found at epoch 6 with macro_f1 value: 0.6133056635236817.
Better model found at epoch 7 with macro_f1 value: 0.7315760191336229.
Better model found at epoch 14 with macro_f1 value: 0.780733706075797.
Better model found at epoch 24 with macro_f1 value: 0.7993583409251611.
Better model found at epoch 26 with macro_f1 value: 0.8391183820704838.
Better model found at epoch 27 with macro_f1 value: 0.9118486387570739.
Better model found at epoch 28 with macro_f1 value: 0.9401589527816737.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.519782,4.138027,0.151515,0.080076,00:51
1,3.720826,3.132098,0.220779,0.107579,00:51
2,3.118805,2.781239,0.467532,0.261292,00:52
3,2.773886,2.446661,0.584416,0.325409,00:51
4,2.354479,1.913447,0.640693,0.407058,00:52
5,1.893913,1.331434,0.735931,0.447881,00:50
6,1.409235,1.281044,0.731602,0.484408,00:52
7,1.148316,0.927754,0.800866,0.537301,00:52
8,0.921489,0.834261,0.805195,0.56719,00:53
9,0.805592,0.621777,0.883117,0.688432,00:53


Better model found at epoch 0 with macro_f1 value: 0.08007568339487943.
Better model found at epoch 1 with macro_f1 value: 0.1075791822933407.
Better model found at epoch 2 with macro_f1 value: 0.26129207232500223.
Better model found at epoch 3 with macro_f1 value: 0.3254090372209075.
Better model found at epoch 4 with macro_f1 value: 0.40705849768652613.
Better model found at epoch 5 with macro_f1 value: 0.447881489379957.
Better model found at epoch 6 with macro_f1 value: 0.4844084592803307.
Better model found at epoch 7 with macro_f1 value: 0.537301028957716.
Better model found at epoch 8 with macro_f1 value: 0.5671900694453768.
Better model found at epoch 9 with macro_f1 value: 0.6884315706709118.
Better model found at epoch 14 with macro_f1 value: 0.7150507065820377.
Better model found at epoch 16 with macro_f1 value: 0.8287460960272021.
Better model found at epoch 20 with macro_f1 value: 0.8410504544104411.
Better model found at epoch 25 with macro_f1 value: 0.8756456174751636.
B

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.376873,3.971998,0.151515,0.081125,00:56
1,3.626286,3.04407,0.246753,0.166533,00:56
2,3.067554,2.735628,0.393939,0.228346,00:59
3,2.732551,2.354311,0.5671,0.299131,00:57
4,2.316303,1.793382,0.714286,0.524982,00:57
5,1.816578,1.170072,0.818182,0.662235,00:55
6,1.357503,0.879212,0.822511,0.66179,00:53
7,1.082018,1.165388,0.718615,0.601717,00:36
8,0.880294,1.231384,0.744589,0.485953,00:30
9,0.749846,0.891694,0.839827,0.618425,00:29


Better model found at epoch 0 with macro_f1 value: 0.08112476367840946.
Better model found at epoch 1 with macro_f1 value: 0.16653304923981613.
Better model found at epoch 2 with macro_f1 value: 0.22834565427163622.
Better model found at epoch 3 with macro_f1 value: 0.2991312521035558.
Better model found at epoch 4 with macro_f1 value: 0.5249823927465245.
Better model found at epoch 5 with macro_f1 value: 0.6622352647454942.
Better model found at epoch 10 with macro_f1 value: 0.7514199997744789.
Better model found at epoch 11 with macro_f1 value: 0.849550857087089.
Better model found at epoch 21 with macro_f1 value: 0.8659034911265785.
Better model found at epoch 23 with macro_f1 value: 0.9266764395890156.
Better model found at epoch 46 with macro_f1 value: 0.928400590469556.
Better model found at epoch 50 with macro_f1 value: 0.9522847522847522.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.380165,3.991248,0.194805,0.078344,00:30
1,3.635571,3.0707,0.320346,0.134263,00:30
2,3.079682,2.750557,0.571429,0.325863,00:29
3,2.745876,2.387546,0.584416,0.388419,00:30
4,2.325059,1.805025,0.709957,0.455045,00:29
5,1.907454,1.365789,0.774892,0.470837,00:29
6,1.438761,1.066907,0.757576,0.560938,00:29
7,1.129927,0.784118,0.831169,0.682442,00:30
8,0.884887,0.803526,0.82684,0.614848,00:30
9,0.773901,0.719566,0.865801,0.610595,00:30


Better model found at epoch 0 with macro_f1 value: 0.07834420105096795.
Better model found at epoch 1 with macro_f1 value: 0.13426324701904024.
Better model found at epoch 2 with macro_f1 value: 0.32586313904944125.
Better model found at epoch 3 with macro_f1 value: 0.3884191895229946.
Better model found at epoch 4 with macro_f1 value: 0.45504546202628754.
Better model found at epoch 5 with macro_f1 value: 0.47083652638595974.
Better model found at epoch 6 with macro_f1 value: 0.5609382067770469.
Better model found at epoch 7 with macro_f1 value: 0.6824417353740663.
Better model found at epoch 12 with macro_f1 value: 0.8002476376669926.
Better model found at epoch 23 with macro_f1 value: 0.8580237950838364.
Better model found at epoch 34 with macro_f1 value: 0.8639874295279772.
Better model found at epoch 35 with macro_f1 value: 0.8930663965379787.
Better model found at epoch 38 with macro_f1 value: 0.9041486476202297.
Better model found at epoch 41 with macro_f1 value: 0.9398978446597