In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
from torch.autograd import Variable
import sys  

sys.path.insert(0, '../src')

In [2]:
def load_data(t, th, debug):
    '''
    t : train or test
    th : 특정 label의 갯수 임계치
    '''
    
    if debug:
        data_list = glob(f'../data/{t}/*npy')[:100]
    else:
        data_list = glob(f'../data/{t}/*npy')
    
    if t == 'train':
        X_list = list()
        y_list = list()
        for data_path in data_list:
            data_arr = np.load(data_path)
            y = data_arr[:,:,-1]
            cut_off = np.where(y < 0, 0, y)
            if (cut_off > 0).sum() < th:
                continue
            
            X = data_arr[:,:,:9]
            
            X_list.append(X)
            y_list.append(y)
        return np.array(X_list), np.array(y_list)
    else:
        X_list = list()
        
        for data_path in data_list:
            data_arr = np.load(data_path)
            X = data_arr[:,:,:9]
            X_list.append(X)
        return np.array(X_list)

In [3]:
X_train, y_train = load_data('train', 0, False)

In [4]:
X_train.shape

(76345, 40, 40, 9)

In [5]:
from torch.utils.data import DataLoader, Dataset
import torch

class u_net_dataset(Dataset):
    def __init__(self, X_arr, y_arr, trans):
        self.X = X_arr.transpose(0, 3, 1, 2) # B H W C - > B C H W
        self.y = y_arr
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        
        data = self.X[index]
        target = self.y[index]
        
        data = torch.from_numpy(data).float() 
        target = torch.from_numpy(target).long()
        
        return data, target

In [6]:
from sklearn.model_selection import train_test_split

In [7]:
SEED = 42
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=SEED)

train_dataset = u_net_dataset(X_train, y_train, None)
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=6)

val_dataset = u_net_dataset(X_val, y_val, None)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=6)

In [8]:
import torch
from torch import nn
from tqdm.notebook import tqdm
from unet_model import UNet
# import segmentation_models_pytorch as smp

In [9]:
# class build_model(nn.Module):
#     def __init__(self, model_name='resnet18'):
    
#         super().__init__()

#         self.conv1 = nn.Conv2d(9, 3, 1)
#         self.model = smp.Unet('efficientnet-b0', encoder_weights='imagenet')
        
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.model(x)
#         return x

In [10]:
from sklearn.metrics import f1_score

def mae(y_true, y_pred) :
    
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    
    y_true = y_true.reshape(1, -1)[0]
    
    y_pred = y_pred.reshape(1, -1)[0]
    
    over_threshold = y_true >= 0.1
    
    return np.mean(np.abs(y_true[over_threshold] - y_pred[over_threshold]))

def fscore(y_true, y_pred):
    
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    
    y_true = y_true.reshape(1, -1)[0]
    
    y_pred = y_pred.reshape(1, -1)[0]
    
    remove_NAs = y_true >= 0
    
    y_true = np.where(y_true[remove_NAs] >= 0.1, 1, 0)
    
    y_pred = np.where(y_pred[remove_NAs] >= 0.1, 1, 0)
    
    return(f1_score(y_true, y_pred))

def maeOverFscore(y_true, y_pred):
    
    return mae(y_true, y_pred) / (fscore(y_true, y_pred) + 1e-07)

In [11]:
from loss import dice_loss
import torch.nn.functional as F

def calc_loss(pred, target, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

In [17]:
# model = ResNetUNet(n_class=1)
model = UNet(n_channels=9, n_classes=1)
# model = build_model()
model = model.cuda()

In [15]:
N_epoch = 100
lr = 0.001

In [18]:
loss_fn = nn.L1Loss()
opt = torch.optim.Adam(model.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=10)


best_score = 99999
for epoch in range(N_epoch):
    
    model.train()
    train_loss = 0
    pred_list = list()
    label_list = list()
    for data, label in tqdm(train_dataloader):
        data = data.cuda()
        label = label.cuda()
        
        opt.zero_grad()
        
        pred = model(data)
#         pred = F.relu(pred)
        pred = pred.squeeze(1)
        loss = loss_fn(pred, label)# + calc_loss(pred.long(), (label > 0).long())
        loss.backward()
        opt.step()
        
        pred_list.append(pred.cpu().detach().numpy())
        label_list.append(label.cpu().detach().numpy())
        train_loss += loss.item() / len(train_dataloader)
    
    trn_score = maeOverFscore(np.concatenate(label_list), np.concatenate(pred_list))
    model.eval()
    val_loss = 0
    pred_list = list()
    label_list = list()
    for data, label in tqdm(val_dataloader):
        data = data.cuda()
        label = label.cuda()
        
        pred = model(data)
        loss = loss_fn(pred.squeeze(1), label)
        
        pred_list.append(pred.cpu().detach().numpy())
        label_list.append(label.cpu().detach().numpy())
        val_loss += loss.item() / len(val_dataloader)
    val_score = maeOverFscore(np.concatenate(label_list), np.concatenate(pred_list))
    
    scheduler.step(val_score)
    
    if best_score > val_score:
        best_score = val_score
        
    print(f'train loss : {train_loss}')
    print(f'train score : {trn_score}')
    print(f'val loss : {val_loss}')
    print(f'val score : {val_score}')

HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.36658532340079
train score : 55.98802838315464
val loss : 48.42852594032885
val score : 8.888214284736987


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.33987005439897
train score : 6.608607072259034
val loss : 48.558324311425295
val score : 5.68699448968851


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.543622788911044
train score : 5.496294002512335
val loss : 48.78283096750578
val score : 5.003361328169913


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.3262605467811
train score : 4.983710608842951
val loss : 48.832499465843036
val score : 5.476523961655967


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.72460315966359
train score : 4.780707798374048
val loss : 48.79943147599697
val score : 5.378353083729722


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.32129685754577
train score : 4.488835425644576
val loss : 49.085793022314704
val score : 20.003306261186168


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.319294883000374
train score : 4.3455971880021504
val loss : 48.61099177549282
val score : 11.974837906632539


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.772336909050715
train score : 4.187966246216137
val loss : 49.86903400421142
val score : 19.90362951458968


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.71660782881082
train score : 4.07453594602819
val loss : 48.898201958338404
val score : 17.82695248772987


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.71580686730642
train score : 4.022407221153109
val loss : 48.8317334425946
val score : 5.491708426181774


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.31421952955423
train score : 3.9218288357646025
val loss : 49.07314424564442
val score : 5.115325575264233


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.1141941993187
train score : 3.8938058230086154
val loss : 48.869985021402435
val score : 11.46257960011016


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.71369865325588
train score : 3.860563707815969
val loss : 48.51326562886436
val score : 5.025199581946621


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.71330483201891
train score : 3.8112267609119557
val loss : 48.683806691070394
val score : 4.622776131388924


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.31269412438076
train score : 3.80745517219721
val loss : 48.612604243059955
val score : 10.700223848082375


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.71173368580638
train score : 3.741716246931868
val loss : 48.61765677134196
val score : 4.680731129654193


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.11169564953692
train score : 3.7189669866337667
val loss : 48.55442867577076
val score : 4.491060270329788


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.711185401429745
train score : 3.6925830252508267
val loss : 49.32630798195799
val score : 4.3630399109305245


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.313467396485315
train score : 3.8912133297444815
val loss : 48.47724827801188
val score : 4.656032206494269


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.58507723299166
train score : 3.7411777214155912
val loss : 48.39702646260461
val score : 4.166333724237911


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.31027691488466
train score : 3.6666605309623703
val loss : 48.67544659301638
val score : 3.986528753506731


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.3099692704156
train score : 3.6279331866626703
val loss : 48.54097755476833
val score : 4.787153003879107


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.310012039045496
train score : 3.620830156187211
val loss : 48.66535413016876
val score : 4.45453330758188


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.309930586380275
train score : 3.613077373196518
val loss : 48.513251292457184
val score : 5.714529909673718


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.709304646899305
train score : 3.581496364976533
val loss : 48.40950009773175
val score : 5.119262921056416


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70922335358337
train score : 3.55757073107939
val loss : 48.395720827827844
val score : 4.101724616866026


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.709193899234144
train score : 3.561586663495407
val loss : 48.54413713614146
val score : 5.252907416792332


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30856498411549
train score : 3.5415010242382086
val loss : 48.92328750838836
val score : 3.806234972408927


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70853202845903
train score : 3.5235091482827325
val loss : 49.074270566304534
val score : 9.672450418989053


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30780470768609
train score : 3.485950462495004
val loss : 48.5402763384084
val score : 4.5967539675372935


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.04197993138801
train score : 3.5042070689526827
val loss : 48.91310342152914
val score : 4.030261696014415


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30753540455675
train score : 3.467956158305863
val loss : 48.778687266260384
val score : 4.077434280192075


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30787148202459
train score : 3.483861247400561
val loss : 48.450974003970614
val score : 5.7995713173297725


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.0699468517676
train score : 3.497539893941026
val loss : 48.682593568166084
val score : 4.2836194292655705


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30833401344719
train score : 3.498787030660281
val loss : 48.39579455504815
val score : 4.004208977645937


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30797682361055
train score : 3.4743721377815677
val loss : 48.9610313055416
val score : 4.6106618903983465


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.10779004575064
train score : 3.456354005176077
val loss : 48.81871821706494
val score : 5.331797999792465


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.52750953845677
train score : 3.467948202168664
val loss : 48.620736367007105
val score : 8.736610601641008


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30723582531016
train score : 3.4462577293427783
val loss : 48.78001430953542
val score : 4.002421294075261


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.305958708065255
train score : 3.362351784337725
val loss : 48.63935001989205
val score : 3.74855335176912


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.82057205208889
train score : 3.338106181261419
val loss : 48.67145128970344
val score : 3.6626764986411406


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30528547919045
train score : 3.3283479809546996
val loss : 48.67213247319062
val score : 3.619627613249382


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70547549352051
train score : 3.3267030924226604
val loss : 48.67127637589971
val score : 3.6355147235034746


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.92538409953313
train score : 3.3106317475735185
val loss : 48.5628308745722
val score : 5.995055507529848


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.52223059280464
train score : 3.306424256606733
val loss : 48.392544079075265
val score : 3.541611750261019


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30471450136971
train score : 3.2957966368418057
val loss : 48.911308288574205
val score : 3.7626281249905906


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.704730488453045
train score : 3.284685510479833
val loss : 48.78141408960025
val score : 4.01409960264742


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30486929255227
train score : 3.3009606866875685
val loss : 49.09095246742169
val score : 5.884629846003364


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.3043308010325
train score : 3.2736268449199892
val loss : 48.67149448394776
val score : 3.5774992042797678


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70446270847071
train score : 3.2604378534776877
val loss : 48.449639402826634
val score : 7.4050655499549265


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70431078442685
train score : 3.248500716163597
val loss : 48.455225435892736
val score : 4.154199028114734


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.303828453272594
train score : 3.2437197384491774
val loss : 48.84089868689577
val score : 3.5182803021416924


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30385370406632
train score : 3.235658779211438
val loss : 48.99455054973563
val score : 3.8529517705843666


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30385063985983
train score : 3.229235567296735
val loss : 48.532167003800474
val score : 3.823707029696635


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30312027748053
train score : 3.204638858149737
val loss : 48.550829566270096
val score : 3.8205229641878904


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30320284375919
train score : 3.215259965112834
val loss : 48.53553032552202
val score : 4.020144622510667


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.50369732823847
train score : 3.1882891846739168
val loss : 48.67195294151703
val score : 3.6098801452608478


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30268425022561
train score : 3.180088167479218
val loss : 48.71391605672738
val score : 3.3651362989103375


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.302641019814956
train score : 3.1767211204936525
val loss : 48.63732264637947
val score : 3.36835777781196


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.7027005527479
train score : 3.163671509177391
val loss : 48.817252149681245
val score : 4.060915488973363


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.78814609491575
train score : 3.1520802294078902
val loss : 48.39270812471708
val score : 3.5600531567441327


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70229987772183
train score : 3.1417795355646114
val loss : 48.62446768606702
val score : 4.198375684562353


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.56115765652309
train score : 3.1327828069820898
val loss : 48.673767863710715
val score : 3.647915959089617


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.70169500093908
train score : 3.1155432477189504
val loss : 48.808354618648686
val score : 3.814137866521812


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.4591607760017
train score : 3.1070565717353498
val loss : 48.8117772102356
val score : 3.4802073392946915


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.30127529073506
train score : 3.1019760571516977
val loss : 48.70740242004395
val score : 3.4577336092611746


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.7015098814542
train score : 3.0994398453589533
val loss : 53.73604420026143
val score : 24.316237452213425


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.7029184428975
train score : 3.1858187800094013
val loss : 48.81477409054836
val score : 3.538531981728984


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.300908985889215
train score : 3.0891644764392283
val loss : 48.64049986004829
val score : 3.7152824783401335


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.299495731076846
train score : 3.0176827319910147
val loss : 48.806675299505386
val score : 3.539542764767967


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.299476926618546
train score : 3.0050976208430997
val loss : 48.87643108367919
val score : 3.3619873972686167


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.69938115573799
train score : 2.9942402206044383
val loss : 48.52953760673602
val score : 3.3565524244687293


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29862501484653
train score : 2.9761211079306937
val loss : 48.644667593638104
val score : 3.421686710299667


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29838026845828
train score : 2.9664150350044154
val loss : 48.783141979575156
val score : 3.5713407141382656


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29825711740802
train score : 2.9592452236867177
val loss : 48.531800784418984
val score : 3.804446267340487


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.298185770493006
train score : 2.950757621890601
val loss : 48.66992184494933
val score : 3.284704367050609


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.69787104247758
train score : 2.939625247682599
val loss : 48.82603296736876
val score : 3.6779894317478794


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.297531748004246
train score : 2.960478221508262
val loss : 48.746736437579
val score : 6.852711360903849


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.298914592682095
train score : 3.025088255580935
val loss : 77.86283156077067
val score : 262.1729762470328


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.63387934630737
train score : 2.9987587502821302
val loss : 48.389964492122324
val score : 3.44637950501381


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29621965251864
train score : 2.950085208477835
val loss : 48.756282946219045
val score : 3.480941939645551


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29476745150362
train score : 2.965054104509781
val loss : 48.48428768788774
val score : 3.516769187091597


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.293261709809315
train score : 2.9633754572851543
val loss : 48.71090365573764
val score : 6.40034228673816


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.694186917568224
train score : 3.025025930599239
val loss : 48.47618657772739
val score : 3.928790366897222


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.69877920340125
train score : 2.9890490811994876
val loss : 48.53176905314127
val score : 3.5062217379694127


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.297310991554205
train score : 2.929789729782114
val loss : 48.88958613822859
val score : 6.982411843118622


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.69710405388226
train score : 2.9329334889856415
val loss : 48.60706598162652
val score : 3.314105771264946


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 44.49618122062335
train score : 2.8811575275139387
val loss : 48.95163535773753
val score : 3.3912544251852865


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.47748680698375
train score : 2.869036686446871
val loss : 48.52949762195348
val score : 3.3091110184706083


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.69454505136235
train score : 2.930303421891292
val loss : 48.71320888549091
val score : 7.227674442796764


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.29301046840847
train score : 3.0280614179244423
val loss : 48.64538316726684
val score : 3.3839306919633354


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.28795664279412
train score : 2.943816822514183
val loss : 48.773082853853694
val score : 3.627166616378716


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.68633716333038
train score : 2.887434005263486
val loss : 48.88301839927833
val score : 3.3136016496164915


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.68391783336798
train score : 2.871880171412206
val loss : 48.658590385814506
val score : 3.3248319391586563


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train loss : 43.681762621862156
train score : 2.860145574176966
val loss : 48.801944291094934
val score : 3.35984383187818


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))

KeyboardInterrupt: 