In [1]:
import gc
import torch
import pandas as pd

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
mom = 0.9531249799613819
alpha = 0.7216872276263672
eps = 0.02241164469571523
wd = 0.04258532691804346

In [4]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(2)
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [5]:
for fold in range(1, 6):
    model_dir = f'model_weights/comparison/aedes_vs_non_aedes/fold_{fold}'
    data_csv_path = f'data/comparison/aedes_vs_non_aedes/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(2)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.317906,1.293087,0.675439,0.669209,01:11
1,1.230637,1.173206,0.675439,0.669209,01:17
2,1.114098,0.970289,0.675439,0.669209,01:12
3,0.918375,0.623732,0.944444,0.853533,01:16
4,0.630771,0.339774,0.991228,0.948772,01:15
5,0.371952,0.135288,0.994152,0.972918,01:16
6,0.251992,0.053589,0.997076,0.997065,01:17
7,0.212317,0.102846,0.991228,0.974188,01:17
8,0.129751,0.044545,0.997076,0.975854,01:16
9,0.134929,0.026157,0.994152,0.972918,01:17


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 3 with macro_f1 value: 0.8535330078096894.
Better model found at epoch 4 with macro_f1 value: 0.9487717818687559.
Better model found at epoch 5 with macro_f1 value: 0.9729181693428174.
Better model found at epoch 6 with macro_f1 value: 0.9970645568168788.
Better model found at epoch 15 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.337345,1.294085,0.675439,0.669209,01:09
1,1.236559,1.198462,0.675439,0.669209,01:08
2,1.117119,1.044421,0.675439,0.669209,01:07
3,0.882942,0.72775,0.795322,0.764322,01:07
4,0.580438,0.437551,0.964912,0.916227,01:13
5,0.379445,0.270538,0.97076,0.943308,01:08
6,0.240799,0.25806,0.967836,0.942504,01:08
7,0.149395,0.211764,0.964912,0.918357,01:11
8,0.112908,0.275332,0.967836,0.919162,01:13
9,0.122587,0.208627,0.973684,0.945107,01:10


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 3 with macro_f1 value: 0.7643224867476673.
Better model found at epoch 4 with macro_f1 value: 0.9162266509254556.
Better model found at epoch 5 with macro_f1 value: 0.9433084815826382.
Better model found at epoch 9 with macro_f1 value: 0.9451072124756335.
Better model found at epoch 12 with macro_f1 value: 0.9688109161793371.
Better model found at epoch 13 with macro_f1 value: 0.9722763699371887.
Better model found at epoch 16 with macro_f1 value: 0.9758536125259385.
Better model found at epoch 19 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.309657,1.309633,0.675439,0.669209,01:17
1,1.232185,1.223068,0.675439,0.669209,01:18
2,1.121137,1.076952,0.675439,0.669209,01:18
3,0.908325,0.804596,0.906433,0.838617,01:18
4,0.592407,0.477121,0.947368,0.811797,01:19
5,0.362449,0.279979,0.964912,0.840909,01:19
6,0.211272,0.112844,0.994152,0.951707,01:19
7,0.176249,0.114506,0.991228,0.950902,01:16
8,0.132029,0.070778,0.997076,0.975854,01:15
9,0.101573,0.077218,0.997076,0.975854,01:16


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 3 with macro_f1 value: 0.8386167356638158.
Better model found at epoch 5 with macro_f1 value: 0.8409087611124962.
Better model found at epoch 6 with macro_f1 value: 0.951707225051877.
Better model found at epoch 8 with macro_f1 value: 0.9758536125259385.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.309555,1.274917,0.675439,0.669209,01:02
1,1.241909,1.192495,0.675439,0.669209,01:06
2,1.117187,1.004869,0.675439,0.669209,01:10
3,0.89122,0.67287,0.845029,0.784588,01:09
4,0.613238,0.388783,0.935673,0.899067,01:12
5,0.368319,0.152137,0.997076,0.975854,00:56
6,0.198441,0.136451,0.979532,0.915521,00:44
7,0.164577,0.537192,0.915205,0.886102,00:43
8,0.136175,0.023151,1.0,1.0,00:43
9,0.132364,0.062001,0.991228,0.942138,00:43


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 3 with macro_f1 value: 0.784587740229295.
Better model found at epoch 4 with macro_f1 value: 0.8990674442065895.
Better model found at epoch 5 with macro_f1 value: 0.9758536125259385.
Better model found at epoch 8 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.324909,1.26666,0.677419,0.671172,00:44
1,1.240196,1.169759,0.677419,0.671172,00:44
2,1.108539,1.048025,0.677419,0.671172,00:44
3,0.906007,0.866952,0.780059,0.770385,00:44
4,0.620819,0.556007,0.847507,0.816919,00:44
5,0.373381,0.409793,0.906158,0.844614,00:44
6,0.231403,0.20388,0.967742,0.843149,00:44
7,0.154761,0.372893,0.917889,0.693788,00:44
8,0.139878,0.070841,0.997067,0.975783,00:44
9,0.136642,0.123973,0.985337,0.918395,00:44


Better model found at epoch 0 with macro_f1 value: 0.6711717455055464.
Better model found at epoch 3 with macro_f1 value: 0.7703846817319303.
Better model found at epoch 4 with macro_f1 value: 0.8169186080221735.
Better model found at epoch 5 with macro_f1 value: 0.8446135216849052.
Better model found at epoch 8 with macro_f1 value: 0.9757828020054867.
Better model found at epoch 25 with macro_f1 value: 1.0.
