In [None]:
import gc
import torch
import pandas as pd

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

In [None]:
device = torch.device('cuda:0')

In [None]:
mom = 0.9531249799613819
alpha = 0.7216872276263672
eps = 0.02241164469571523
wd = 0.04258532691804346

In [None]:
opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
loss = MosLoss(7)
metrics = [accuracy, macro_f1]

In [None]:
for fold in range(1, 6):
    model_dir = f'model_weights/comparison/aedes_only/fold_{fold}'
    data_csv_path = f'data/comparison/aedes_only/data_fold_{fold}.csv'

    df = pd.read_csv(data_csv_path)
    train_dl, val_dl = get_data_loaders(df)
    dls = DataLoaders(train_dl, val_dl)
    net = Mos_Xception(7)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
        model_dir=model_dir,
    )
    cb = SaveModelCallback(monitor='macro_f1')
    learn.fit_one_cycle(
        60,
        2e-03,
        div=25,
        pct_start=0.3,
        cbs=[cb],
    )

    del learn
    del net
    gc.collect()
    torch.cuda.empty_cache()