In [1]:
import gc
import torch
import pandas as pd

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.21 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
params_df = pd.read_csv('results/params.csv')
avg_params_series = params_df.mean(axis=0)
avg_params_df = pd.DataFrame(
    data=[avg_params_series.values],
    columns=avg_params_series.index,
)

mom = avg_params_df['mom'].item()
alpha = avg_params_df['alpha'].item()
eps = avg_params_df['eps'].item()
wd = avg_params_df['wd'].item()

In [4]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(2)
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [5]:
for fold in range(1, 6):
    model_dir = f'model_weights/comparison/aedes_vs_non_aedes/fold_{fold}'
    data_csv_path = f'data/comparison/aedes_vs_non_aedes/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(2)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.297172,1.233671,0.675439,0.669209,00:57
1,1.170684,0.988914,0.675439,0.669209,00:56
2,0.916764,0.59913,0.935673,0.828012,00:54
3,0.603079,0.283187,0.991228,0.948772,00:57
4,0.328963,0.110243,0.991228,0.950902,01:00
5,0.204975,0.105118,0.979532,0.898869,01:02
6,0.167067,0.449804,0.906433,0.669805,01:17
7,0.16399,0.027297,0.997076,0.975854,01:28
8,0.123404,0.008284,1.0,1.0,01:21
9,0.112257,0.732649,0.833333,0.620854,01:17


  return F.conv2d(input, weight, bias, self.stride,


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 2 with macro_f1 value: 0.8280119777877519.
Better model found at epoch 3 with macro_f1 value: 0.9487717818687559.
Better model found at epoch 4 with macro_f1 value: 0.9509023454694084.
Better model found at epoch 7 with macro_f1 value: 0.9758536125259385.
Better model found at epoch 8 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.305321,1.271091,0.675439,0.669209,01:14
1,1.173772,1.087776,0.675439,0.669209,01:15
2,0.916753,0.72503,0.795322,0.765518,01:16
3,0.610373,0.425313,0.959064,0.914617,01:20
4,0.339193,0.251207,0.964912,0.918357,01:20
5,0.204148,0.272265,0.97076,0.943308,01:20
6,0.174451,0.16758,0.976608,0.922917,01:13
7,0.10973,0.168836,0.97076,0.943308,01:06
8,0.140778,0.128096,0.982456,0.97121,01:02
9,0.123712,0.155225,0.973684,0.944665,01:03


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 2 with macro_f1 value: 0.7655180136913882.
Better model found at epoch 3 with macro_f1 value: 0.914616891760518.
Better model found at epoch 4 with macro_f1 value: 0.918357214526108.
Better model found at epoch 5 with macro_f1 value: 0.9433084815826382.
Better model found at epoch 8 with macro_f1 value: 0.9712100764732343.
Better model found at epoch 17 with macro_f1 value: 0.9722763699371887.
Better model found at epoch 18 with macro_f1 value: 0.974188344424279.
Better model found at epoch 19 with macro_f1 value: 0.9750487329434697.
Better model found at epoch 23 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.326381,1.264544,0.675439,0.669209,00:56
1,1.167452,1.095535,0.675439,0.669209,00:57
2,0.908479,0.784635,0.897661,0.833508,00:57
3,0.600073,0.400205,0.982456,0.925091,01:06
4,0.327363,0.145006,0.997076,0.975854,01:11
5,0.214921,0.120749,0.991228,0.950902,01:17
6,0.159888,0.076547,0.994152,0.951707,01:17
7,0.125911,0.061229,0.997076,0.975854,01:17
8,0.070896,0.148197,0.976608,0.923179,01:16
9,0.130931,0.106873,0.982456,0.939668,01:17


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 2 with macro_f1 value: 0.8335079041121186.
Better model found at epoch 3 with macro_f1 value: 0.9250906898936875.
Better model found at epoch 4 with macro_f1 value: 0.9758536125259385.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.332821,1.257193,0.675439,0.669209,00:58
1,1.177279,1.035882,0.675439,0.669209,00:58
2,0.897392,0.643032,0.888889,0.829045,00:59
3,0.577398,0.408324,0.935673,0.92147,00:58
4,0.305486,0.146224,0.988304,0.926756,01:00
5,0.161511,0.154287,0.967836,0.935017,00:59
6,0.121405,0.033628,1.0,1.0,00:57
7,0.179066,0.072356,0.991228,0.942138,00:58
8,0.109581,0.015849,1.0,1.0,01:00
9,0.146783,0.012025,1.0,1.0,00:59


Better model found at epoch 0 with macro_f1 value: 0.669209255021612.
Better model found at epoch 2 with macro_f1 value: 0.8290445534144516.
Better model found at epoch 3 with macro_f1 value: 0.921470342522974.
Better model found at epoch 4 with macro_f1 value: 0.9267559579953469.
Better model found at epoch 5 with macro_f1 value: 0.9350165537300039.
Better model found at epoch 6 with macro_f1 value: 1.0.


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,1.295533,1.23632,0.677419,0.671172,00:58
1,1.155248,1.055883,0.677419,0.671172,00:58
2,0.893284,0.817367,0.780059,0.770385,00:57
3,0.585915,0.538807,0.818182,0.797832,00:57
4,0.341437,0.353796,0.923754,0.850921,00:58
5,0.18534,0.155216,0.973607,0.868173,00:58
6,0.164212,0.281294,0.950147,0.859656,00:56
7,0.134654,0.175067,0.961877,0.887851,00:57
8,0.15111,0.152106,0.985337,0.902324,00:58
9,0.08744,0.063742,0.985337,0.918395,00:58


Better model found at epoch 0 with macro_f1 value: 0.6711717455055464.
Better model found at epoch 2 with macro_f1 value: 0.7703846817319303.
Better model found at epoch 3 with macro_f1 value: 0.7978319862869948.
Better model found at epoch 4 with macro_f1 value: 0.8509207467788849.
Better model found at epoch 5 with macro_f1 value: 0.8681733343543958.
Better model found at epoch 7 with macro_f1 value: 0.887850687154573.
Better model found at epoch 8 with macro_f1 value: 0.9023239680887963.
Better model found at epoch 9 with macro_f1 value: 0.9183951901953983.
Better model found at epoch 11 with macro_f1 value: 0.9757828020054867.
Better model found at epoch 30 with macro_f1 value: 1.0.
