In [1]:
import os
import gc
import torch
import pandas as pd
import numpy as np
import utils.constants as const

from functools import partial

from utils.losses import MosLoss
from utils.xception import Mos_Xception
from utils.data_loader import get_data_loaders
from utils.metrics import accuracy, macro_f1

from fastai.vision.all import *
from fastai.distributed import setup_distrib, num_distrib

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.13 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
device = torch.device('cuda:0')

In [3]:
df = pd.read_csv('data/splits/data_fold_1.csv')

In [4]:
train_dl, val_dl = get_data_loaders(df)

In [5]:
dls = DataLoaders(train_dl, val_dl)

In [6]:
loss = MosLoss(const.NUM_CLASSES)
freeze_bn = False
save_imgs = False
train_losses = []
valid_losses = []
valid_f1s = []
lr_hist = []
metrics = [accuracy, macro_f1]

Focal Loss with gamma =  0


In [7]:
result_dict = {}

for idx in range(40):
    mom = np.random.uniform(0.5, 0.999)
    alpha = np.random.uniform(0.5, 0.999)
    eps = np.random.uniform(1e-6, 0.1)
    wd = np.random.uniform(1e-6, 0.1)

    opt_func = partial(ranger, mom=mom, alpha=alpha, eps=eps)
    net = Mos_Xception(const.NUM_CLASSES)

    learn = Learner(
        dls,
        net,
        wd=wd,
        opt_func=opt_func,
        metrics=metrics,
        loss_func=loss,
    )

    learn.fit_one_cycle(
        20,
        2e-03,
        div=25,
        pct_start=0.3,
    )

    param_dict = {}
    param_dict['mom'] = mom
    param_dict['alpha'] = alpha
    param_dict['eps'] = eps
    param_dict['wd'] = wd
    param_dict['f1'] = learn.metrics[1].value.detach().cpu().numpy()

    result_dict[idx] = param_dict

    del net
    del learn
    gc.collect()
    torch.cuda.empty_cache()

epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.388383,4.626642,0.105769,0.031825,00:48
1,3.753207,3.094482,0.371795,0.207645,00:43
2,3.074545,2.64074,0.560897,0.427188,00:41
3,2.270517,1.555839,0.746795,0.547961,00:40
4,1.527138,0.945967,0.798077,0.604068,00:40
5,1.099044,0.671763,0.894231,0.705602,00:40
6,0.851959,0.614067,0.88141,0.749498,00:39
7,0.668011,0.603627,0.871795,0.748958,00:39
8,0.543086,0.466125,0.907051,0.727374,00:41
9,0.467613,0.334209,0.945513,0.840258,00:39


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.494915,4.695302,0.134615,0.094527,00:39
1,3.763239,3.057135,0.467949,0.313587,00:39
2,2.956151,2.381108,0.634615,0.492284,00:40
3,2.058707,1.312881,0.759615,0.624878,00:39
4,1.340815,0.753065,0.88141,0.720202,00:39
5,0.998069,0.763396,0.833333,0.700978,00:40
6,0.779395,0.524658,0.891026,0.777396,00:40
7,0.636154,0.41824,0.916667,0.759758,00:39
8,0.546426,0.324701,0.932692,0.786983,00:40
9,0.455163,0.376931,0.916667,0.799521,00:41


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.57075,4.8373,0.176282,0.05457,00:40
1,3.874661,3.150304,0.285256,0.146938,00:40
2,3.151832,2.830421,0.490385,0.371906,00:42
3,2.570969,1.937144,0.612179,0.528033,00:50
4,1.907531,1.357117,0.766026,0.589991,00:59
5,1.411213,0.989358,0.830128,0.655266,00:53
6,1.09115,0.944586,0.791667,0.617325,00:53
7,0.930081,0.821382,0.810897,0.680903,00:46
8,0.746845,0.646447,0.858974,0.732689,00:47
9,0.673956,0.555879,0.891026,0.700743,00:43


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.270496,4.35019,0.217949,0.073766,00:53
1,3.625669,2.999589,0.458333,0.251232,00:50
2,2.864018,2.14129,0.698718,0.557424,00:54
3,1.936285,1.152844,0.817308,0.600003,01:04
4,1.2812,0.757415,0.849359,0.708487,01:11
5,0.939317,0.744052,0.846154,0.707452,01:03
6,0.701419,0.624985,0.86859,0.701612,00:47
7,0.583972,0.322344,0.951923,0.817118,00:47
8,0.514692,0.562695,0.887821,0.733011,00:46
9,0.41194,0.525238,0.903846,0.728181,00:46


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.326081,4.49226,0.108974,0.038997,00:46
1,3.695799,3.088612,0.266026,0.20456,00:45
2,3.03404,2.586499,0.528846,0.458117,00:48
3,2.254935,1.62245,0.679487,0.53805,00:47
4,1.54596,0.960519,0.846154,0.660224,00:46
5,1.132451,0.599189,0.913462,0.716815,00:45
6,0.825319,0.624347,0.88141,0.691333,00:45
7,0.681744,0.550427,0.897436,0.735889,00:45
8,0.60109,0.365117,0.945513,0.835752,00:40
9,0.510621,0.31935,0.932692,0.815163,00:38


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.725876,5.139077,0.115385,0.091933,00:38
1,4.091434,3.23325,0.282051,0.123437,00:40
2,3.268309,2.948835,0.525641,0.400188,00:42
3,2.81146,2.364954,0.621795,0.513138,00:41
4,2.177929,1.664542,0.698718,0.554364,00:41
5,1.647619,1.183308,0.826923,0.653397,00:41
6,1.286002,0.905324,0.875,0.715345,00:42
7,1.035061,0.746001,0.88141,0.738702,00:41
8,0.837384,0.661846,0.865385,0.701492,00:41
9,0.702955,0.521253,0.897436,0.725483,00:41


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.46937,4.794403,0.182692,0.064437,00:41
1,3.854776,3.118571,0.400641,0.200614,00:41
2,3.156337,2.79628,0.576923,0.499314,00:43
3,2.571203,1.917367,0.641026,0.549293,00:42
4,1.872628,1.260337,0.788462,0.654411,00:42
5,1.360067,0.944178,0.817308,0.646178,00:41
6,1.050812,0.697309,0.900641,0.738395,00:41
7,0.848539,0.62482,0.875,0.698251,00:42
8,0.658498,0.41944,0.926282,0.737015,00:42
9,0.542286,0.44808,0.919872,0.782731,00:42


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.51086,4.792332,0.125,0.043676,00:41
1,3.883568,3.139797,0.371795,0.197589,00:42
2,3.164015,2.820096,0.589744,0.39833,00:38
3,2.562781,1.934967,0.682692,0.532946,00:38
4,1.889541,1.36105,0.714744,0.547054,00:41
5,1.380267,0.925286,0.862179,0.658726,00:41
6,1.040292,0.657581,0.891026,0.767992,00:41
7,0.877299,0.543946,0.916667,0.730249,00:41
8,0.68822,0.585415,0.884615,0.709458,00:41
9,0.585268,0.475172,0.919872,0.796984,00:42


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.604257,5.059247,0.102564,0.04794,00:42
1,4.09977,3.275434,0.288462,0.100655,00:43
2,3.262072,2.947927,0.464744,0.289577,00:43
3,2.817384,2.370878,0.650641,0.496604,00:43
4,2.209107,1.67196,0.724359,0.565393,00:44
5,1.698146,1.253266,0.727564,0.582919,00:43
6,1.331243,0.92848,0.830128,0.718876,00:43
7,1.077989,0.76434,0.86859,0.749317,00:43
8,0.948615,0.647346,0.891026,0.740912,00:43
9,0.783851,0.599688,0.871795,0.754585,00:43


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,4.952685,3.824902,0.285256,0.109001,00:40
1,3.420834,2.912424,0.512821,0.351944,00:41
2,2.478171,1.617266,0.673077,0.555516,00:40
3,1.501733,0.840403,0.849359,0.699917,00:41
4,1.001505,0.875589,0.801282,0.708859,00:40
5,0.812283,0.701729,0.858974,0.724652,00:40
6,0.706637,0.399001,0.926282,0.736542,00:40
7,0.552191,0.488802,0.884615,0.79803,00:39
8,0.481548,0.411132,0.907051,0.76252,00:38
9,0.402501,0.267804,0.948718,0.841184,00:38


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.474741,4.708409,0.163462,0.069329,00:42
1,3.783711,3.098645,0.36859,0.233397,00:42
2,3.110077,2.74506,0.535256,0.423661,00:42
3,2.433393,1.824699,0.676282,0.519363,00:42
4,1.735528,1.199501,0.801282,0.644647,00:43
5,1.224227,0.978705,0.798077,0.681427,00:44
6,0.920206,0.615394,0.894231,0.693991,00:43
7,0.743942,0.622626,0.862179,0.726454,00:44
8,0.617832,0.466572,0.919872,0.765215,00:42
9,0.554666,0.548922,0.884615,0.7366,00:41


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.153141,4.013957,0.147436,0.059207,00:40
1,3.462772,2.838948,0.516026,0.38684,00:41
2,2.307128,1.364546,0.766026,0.601398,00:41
3,1.382425,0.778783,0.826923,0.723804,00:40
4,0.948927,1.018397,0.794872,0.689967,00:40
5,0.789161,0.877922,0.798077,0.669952,00:40
6,0.671381,0.856814,0.804487,0.627315,00:39
7,0.583019,0.363488,0.935897,0.79982,00:39
8,0.46413,0.302937,0.948718,0.809928,00:39
9,0.359888,0.193341,0.977564,0.895898,00:39


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.067321,3.945699,0.173077,0.097208,00:43
1,3.345327,2.492831,0.570513,0.422649,00:47
2,2.045078,1.08635,0.807692,0.636695,00:45
3,1.221819,0.635113,0.86859,0.729317,00:45
4,0.948761,0.814656,0.836538,0.747909,00:43
5,0.832364,0.424534,0.916667,0.755081,00:43
6,0.658318,0.320222,0.958333,0.842376,00:46
7,0.488775,0.493722,0.903846,0.7895,00:52
8,0.446339,0.569004,0.875,0.732517,00:46
9,0.382112,0.408643,0.919872,0.803267,00:49


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.247444,4.199153,0.128205,0.047598,00:59
1,3.600575,3.046728,0.349359,0.257091,01:07
2,2.919365,2.312402,0.573718,0.48757,01:06
3,2.065107,1.34285,0.717949,0.556573,01:01
4,1.382361,0.940584,0.820513,0.674379,01:01
5,1.030403,0.666869,0.875,0.701826,01:02
6,0.79951,0.531652,0.891026,0.70156,00:53
7,0.686096,0.605934,0.875,0.714793,00:51
8,0.5352,0.776924,0.846154,0.741172,00:54
9,0.51918,0.625945,0.852564,0.756883,00:49


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.017891,3.910385,0.208333,0.102228,00:49
1,3.397613,2.757359,0.586538,0.467068,00:52
2,2.215454,1.19849,0.75641,0.611385,00:51
3,1.314138,0.683961,0.871795,0.759498,00:50
4,0.957775,1.026413,0.801282,0.618067,00:52
5,0.786148,0.83764,0.801282,0.664756,00:51
6,0.67362,0.390826,0.929487,0.82977,00:49
7,0.522382,0.370598,0.935897,0.85502,00:52
8,0.430772,0.300395,0.939103,0.87604,00:50
9,0.424124,0.226514,0.964744,0.826707,00:49


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.705911,5.094882,0.112179,0.040702,00:50
1,4.112972,3.265131,0.330128,0.143432,00:53
2,3.258223,2.959659,0.516026,0.362613,00:51
3,2.798096,2.339893,0.63141,0.49465,00:49
4,2.182274,1.63904,0.778846,0.570599,00:51
5,1.659749,1.142003,0.830128,0.651109,00:51
6,1.267836,0.85667,0.852564,0.685753,00:46
7,1.038953,0.70603,0.891026,0.765789,00:49
8,0.862748,0.54921,0.916667,0.757417,00:52
9,0.720063,0.516627,0.916667,0.724111,00:48


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.12729,4.158913,0.160256,0.057164,00:50
1,3.56958,3.030526,0.416667,0.226006,00:50
2,2.890696,2.236866,0.666667,0.450243,00:51
3,2.006766,1.299677,0.766026,0.627888,00:48
4,1.32011,0.827487,0.846154,0.665503,00:50
5,0.93284,0.598531,0.887821,0.76567,00:46
6,0.777465,0.580817,0.907051,0.739534,00:51
7,0.652724,0.49213,0.910256,0.756661,00:46
8,0.577191,0.474318,0.900641,0.756781,00:50
9,0.441695,0.401372,0.916667,0.797996,00:47


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.15308,4.411551,0.121795,0.050838,00:46
1,3.670273,3.116214,0.323718,0.169057,00:44
2,3.070531,2.722988,0.589744,0.469264,00:48
3,2.384152,1.677537,0.698718,0.556527,00:43
4,1.672342,1.089719,0.833333,0.601152,00:49
5,1.189459,0.777742,0.875,0.742738,00:43
6,0.892136,0.600909,0.903846,0.737221,00:48
7,0.764802,0.669924,0.86859,0.726092,00:43
8,0.640893,0.734135,0.855769,0.747496,00:47
9,0.508473,0.365645,0.935897,0.817808,00:44


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.559108,4.851271,0.192308,0.08195,00:47
1,3.834378,3.087545,0.426282,0.262517,00:44
2,3.071492,2.576144,0.605769,0.45651,00:51
3,2.224198,1.513752,0.772436,0.57771,00:45
4,1.537694,0.951546,0.852564,0.675225,00:48
5,1.06225,0.688414,0.86859,0.693888,00:45
6,0.854334,0.789347,0.823718,0.695182,00:47
7,0.65907,0.372745,0.939103,0.751914,00:47
8,0.565407,0.35697,0.935897,0.800007,00:44
9,0.504418,0.391582,0.929487,0.755431,00:49


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.287638,4.46257,0.134615,0.049848,00:54
1,3.671399,3.056487,0.365385,0.244936,00:51
2,2.959503,2.38888,0.628205,0.47789,00:48
3,2.103225,1.324012,0.769231,0.608257,00:49
4,1.371084,0.757999,0.88141,0.746664,00:46
5,0.981699,0.751638,0.842949,0.630792,00:51
6,0.7962,0.429759,0.926282,0.812783,00:47
7,0.629124,0.479885,0.929487,0.777881,00:50
8,0.544611,0.406147,0.942308,0.841501,00:56
9,0.436904,0.407689,0.916667,0.790515,00:51


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.152041,4.102851,0.230769,0.076196,00:52
1,3.528152,2.980815,0.394231,0.315717,00:49
2,2.781595,2.055377,0.679487,0.514212,00:46
3,1.851027,1.270018,0.740385,0.566384,00:51
4,1.209324,0.838944,0.826923,0.683908,00:46
5,0.926889,0.758864,0.871795,0.725323,00:50
6,0.685703,0.55513,0.907051,0.732565,00:48
7,0.609096,0.568272,0.88141,0.763387,00:46
8,0.539895,0.327851,0.945513,0.794719,00:53
9,0.442021,0.351268,0.932692,0.79538,00:46


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.26048,4.290683,0.208333,0.084727,00:50
1,3.611539,3.027643,0.423077,0.251665,00:52
2,2.90326,2.285857,0.644231,0.541721,00:50
3,2.04705,1.331547,0.785256,0.602783,00:52
4,1.360525,0.861894,0.836538,0.659349,00:52
5,0.984293,0.869475,0.826923,0.672235,00:49
6,0.768784,0.464875,0.907051,0.78368,00:51
7,0.624324,0.421268,0.916667,0.791861,00:52
8,0.534028,0.514653,0.894231,0.76217,00:49
9,0.460283,0.547199,0.897436,0.720086,00:53


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.695277,5.199708,0.195513,0.09019,00:50
1,4.100197,3.256835,0.240385,0.170367,00:49
2,3.255194,2.931527,0.519231,0.346982,00:48
3,2.768822,2.266013,0.61859,0.483254,00:53
4,2.105052,1.584389,0.721154,0.592841,00:48
5,1.597248,1.08701,0.794872,0.610978,00:46
6,1.229026,0.840826,0.839744,0.686528,00:50
7,1.00615,0.702608,0.891026,0.718889,00:48
8,0.838206,0.577685,0.910256,0.773215,00:46
9,0.722818,0.529097,0.897436,0.754678,00:52


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.233762,4.344634,0.092949,0.032588,00:54
1,3.620209,3.056325,0.310897,0.246365,00:50
2,2.94091,2.358531,0.528846,0.44118,00:56
3,2.072086,1.375879,0.75641,0.577283,00:54
4,1.395664,0.961443,0.788462,0.623181,00:51
5,0.945439,0.910666,0.791667,0.644506,00:56
6,0.746895,0.765964,0.820513,0.702812,00:54
7,0.635513,0.508323,0.907051,0.736093,00:48
8,0.535142,0.501108,0.897436,0.764641,00:54
9,0.456871,0.396325,0.916667,0.793812,00:52


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.337611,4.536916,0.134615,0.046183,00:53
1,3.671765,3.03539,0.490385,0.273991,00:53
2,2.887913,2.186141,0.628205,0.489452,00:51
3,1.899335,1.094779,0.817308,0.637706,00:55
4,1.222484,0.932632,0.794872,0.646083,00:53
5,0.946028,0.582301,0.894231,0.70652,00:51
6,0.731677,0.419765,0.916667,0.754202,00:54
7,0.635156,0.439684,0.926282,0.765765,00:53
8,0.562552,0.365485,0.932692,0.806118,00:52
9,0.432054,0.365813,0.919872,0.783623,00:54


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.733775,5.162425,0.115385,0.045193,00:50
1,4.145663,3.282354,0.275641,0.15062,00:55
2,3.277045,2.9593,0.448718,0.323441,00:52
3,2.832532,2.395918,0.602564,0.46497,00:50
4,2.249864,1.735307,0.705128,0.569577,00:54
5,1.724947,1.264755,0.791667,0.610541,00:52
6,1.389249,0.99402,0.798077,0.65653,00:50
7,1.104144,0.792783,0.836538,0.678607,00:53
8,0.935892,0.66613,0.884615,0.716972,00:52
9,0.810388,0.674229,0.862179,0.710239,00:51


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.413216,4.705816,0.157051,0.094254,00:52
1,3.881714,3.165684,0.272436,0.158133,00:51
2,3.197944,2.882307,0.535256,0.395109,00:55
3,2.642581,2.062723,0.666667,0.499974,00:52
4,1.947958,1.421503,0.724359,0.556337,00:50
5,1.432058,0.988738,0.820513,0.633308,00:53
6,1.128645,0.749155,0.862179,0.733914,00:51
7,0.922015,0.67142,0.88141,0.737801,00:51
8,0.752288,0.514603,0.903846,0.752318,00:54
9,0.666339,0.494857,0.907051,0.718384,00:52


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.493388,4.569917,0.035256,0.015159,00:54
1,3.71015,3.054714,0.419872,0.236017,00:52
2,2.976953,2.438682,0.63141,0.455532,00:51
3,2.135349,1.413846,0.705128,0.542968,00:54
4,1.4248,0.885544,0.814103,0.650037,00:53
5,1.05742,0.935396,0.782051,0.671019,00:51
6,0.786856,0.799718,0.836538,0.695867,00:54
7,0.690268,0.509033,0.903846,0.807875,00:53
8,0.555748,0.400221,0.939103,0.786594,00:52
9,0.472716,0.387187,0.929487,0.753805,00:55


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.130975,4.247964,0.13141,0.053347,00:50
1,3.558419,3.01604,0.426282,0.318163,00:54
2,2.797322,2.049629,0.596154,0.483261,00:54
3,1.870802,1.187185,0.814103,0.627828,00:50
4,1.251925,0.750763,0.88141,0.730681,00:55
5,0.916384,1.600143,0.650641,0.551619,00:54
6,0.740174,0.57564,0.88141,0.743021,00:53
7,0.606145,0.321977,0.942308,0.746456,00:55
8,0.490706,0.431722,0.926282,0.79964,00:52
9,0.4587,0.311048,0.945513,0.825538,00:50


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.43269,4.769997,0.141026,0.067205,00:47
1,3.872469,3.150481,0.326923,0.1891,00:46
2,3.206719,2.904166,0.461538,0.334474,00:49
3,2.693701,2.13156,0.612179,0.525406,00:47
4,2.005338,1.368118,0.778846,0.65409,00:46
5,1.475187,0.984069,0.855769,0.678199,00:50
6,1.117987,0.794996,0.849359,0.728069,00:47
7,0.934096,0.670184,0.858974,0.760445,00:46
8,0.74786,0.579903,0.891026,0.803002,00:50
9,0.607489,0.600938,0.884615,0.770167,00:47


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.611759,4.96698,0.144231,0.051192,00:54
1,3.967795,3.161952,0.266026,0.146727,00:52
2,3.213377,2.90847,0.496795,0.381487,00:51
3,2.680456,2.117062,0.650641,0.482642,00:56
4,1.966483,1.38278,0.727564,0.595366,00:52
5,1.417611,0.990224,0.830128,0.693272,00:51
6,1.086979,0.698704,0.894231,0.725357,00:53
7,0.906882,0.754412,0.858974,0.741117,00:53
8,0.743549,0.670938,0.842949,0.717512,00:50
9,0.625919,0.466268,0.916667,0.797837,00:54


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.596384,4.895792,0.083333,0.035889,00:50
1,3.855265,3.107989,0.375,0.193408,00:53
2,3.0786,2.636666,0.599359,0.402931,00:52
3,2.306072,1.653246,0.730769,0.5549,00:50
4,1.610137,0.962506,0.820513,0.669769,00:53
5,1.087664,0.739907,0.855769,0.708311,00:51
6,0.850458,0.613646,0.884615,0.695626,00:50
7,0.682979,0.531895,0.891026,0.707427,00:54
8,0.57503,0.477169,0.900641,0.708159,00:51
9,0.515116,0.35243,0.929487,0.77435,00:50


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.674442,5.079943,0.173077,0.056981,00:52
1,4.031203,3.206604,0.365385,0.156917,00:49
2,3.22102,2.889342,0.567308,0.419092,00:53
3,2.70891,2.181782,0.666667,0.47595,00:51
4,2.018736,1.483085,0.727564,0.571179,00:49
5,1.453114,0.997568,0.849359,0.684812,00:53
6,1.112949,0.852041,0.836538,0.697155,00:51
7,0.891857,0.556015,0.900641,0.757007,00:49
8,0.701522,0.502459,0.913462,0.800094,00:53
9,0.65653,0.573312,0.884615,0.734763,00:51


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.581265,4.966048,0.147436,0.067117,00:52
1,3.975334,3.194782,0.339744,0.154264,00:51
2,3.212471,2.889662,0.535256,0.356543,00:50
3,2.680914,2.179821,0.628205,0.484906,00:53
4,2.025129,1.486148,0.730769,0.550498,00:52
5,1.511409,1.00197,0.820513,0.660388,00:51
6,1.156821,0.809045,0.86859,0.705145,00:55
7,0.955123,0.663,0.878205,0.721162,00:52
8,0.823726,0.633358,0.878205,0.700553,00:51
9,0.672769,0.556624,0.897436,0.747288,00:53


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.476366,4.858248,0.144231,0.103026,00:54
1,3.960264,3.21154,0.330128,0.163631,00:55
2,3.225345,2.930242,0.477564,0.356859,00:55
3,2.731917,2.233766,0.682692,0.483138,00:51
4,2.074869,1.510005,0.75,0.567638,00:52
5,1.546765,1.03043,0.833333,0.655958,00:52
6,1.212858,0.807123,0.86859,0.750607,00:49
7,0.983949,0.644108,0.903846,0.768673,00:53
8,0.816777,0.57203,0.916667,0.759891,00:53
9,0.655034,0.475985,0.926282,0.827821,00:51


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.480604,4.740539,0.105769,0.063298,00:54
1,3.800413,3.090811,0.38141,0.185399,00:56
2,3.087356,2.638223,0.580128,0.419961,00:54
3,2.289302,1.595,0.660256,0.569058,00:53
4,1.600831,1.019678,0.791667,0.642905,00:52
5,1.167267,0.717864,0.865385,0.751488,00:53
6,0.88251,0.610941,0.875,0.706219,00:51
7,0.732744,0.714207,0.839744,0.738768,00:51
8,0.602704,0.450563,0.923077,0.78612,00:54
9,0.470917,0.295759,0.948718,0.82568,00:52


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.546942,4.711398,0.144231,0.09661,00:54
1,3.827469,3.10648,0.352564,0.183789,00:56
2,3.115403,2.752779,0.576923,0.454597,00:57
3,2.461554,1.844105,0.596154,0.506556,00:57
4,1.789437,1.221277,0.766026,0.607979,00:57
5,1.291622,0.897156,0.817308,0.692792,00:58
6,0.972619,0.8086,0.842949,0.727167,00:57
7,0.751177,0.590295,0.878205,0.681364,00:59
8,0.718023,0.495045,0.919872,0.755756,00:59
9,0.555606,0.459787,0.916667,0.773794,00:59


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.11875,4.208874,0.108974,0.064229,00:58
1,3.52528,2.965911,0.458333,0.341873,00:56
2,2.624452,1.71309,0.714744,0.522537,00:54
3,1.638039,0.989586,0.807692,0.678728,00:57
4,1.095752,0.891355,0.798077,0.607472,00:56
5,0.800682,0.486683,0.913462,0.789743,00:56
6,0.702318,0.554834,0.88141,0.719515,00:57
7,0.568287,0.432162,0.900641,0.779693,00:56
8,0.435289,0.317327,0.948718,0.826632,00:55
9,0.43484,0.34785,0.935897,0.850066,00:56


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.55151,4.849012,0.096154,0.040976,00:52
1,3.905227,3.159143,0.269231,0.113293,00:51
2,3.188901,2.876618,0.592949,0.422355,00:48
3,2.618545,2.004645,0.660256,0.507553,00:52
4,1.915954,1.311949,0.733974,0.561425,00:51
5,1.380123,0.910552,0.830128,0.729676,00:48
6,1.023581,0.788143,0.823718,0.682446,00:52
7,0.84886,0.635282,0.887821,0.729995,00:50
8,0.692253,0.603904,0.891026,0.716635,00:48
9,0.562774,0.52514,0.897436,0.808524,00:51


epoch,train_loss,valid_loss,accuracy,macro_f1,time
0,5.754382,5.229441,0.121795,0.045857,00:49
1,4.140485,3.298325,0.246795,0.102432,00:51
2,3.285674,2.9716,0.464744,0.311375,00:50
3,2.817011,2.348865,0.612179,0.510016,00:48
4,2.167916,1.636243,0.727564,0.56118,00:51
5,1.644097,1.157609,0.826923,0.676408,00:50
6,1.281562,0.873011,0.839744,0.701778,00:48
7,1.014905,0.684178,0.884615,0.738608,00:53
8,0.827829,0.645408,0.871795,0.723135,00:49
9,0.681789,0.494928,0.923077,0.809784,00:49


In [8]:
f1_list = []

for v in result_dict.values():
    f1_list.append(v['f1'])

In [9]:
best_idx = np.argmax(f1_list)

In [10]:
f1_list[best_idx]

array(0.95580472)

In [11]:
result_dict[best_idx]

{'mom': 0.9531249799613819,
 'alpha': 0.7216872276263672,
 'eps': 0.02241164469571523,
 'wd': 0.04258532691804346,
 'f1': array(0.95580472)}

In [12]:
with open('param.txt', 'w') as f:
    for k, v in result_dict[best_idx].items():
        f.write(f'{k}: {v}\n')