In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb
sys.path.append('./lib_models')

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


True

# Parameters

In [2]:
backbone = 'timm/resnet10t.c3_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'model_test'
RUN_NAME=   f'{backbone}_separated_ratio'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 8
ACCUM_STEPS = 3
N_WORKERS  = 8
LR = 0.001
N_EPOCHS = 200
EARLY_STOP_COUNT = 30
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0

DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [4]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):  
    np.save(name, data)

def decompress_fast(name):
    data = np.load(f'{name}.npy')
    return data

# Model

In [5]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [6]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    
    
class Timm3DModelClassifier(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModelClassifier, self).__init__()
        self.model_3d = Timm3DModel(backbone, n_channels, n_labels, segtype, pretrained)
        self.model_3d = convert_3d(self.model_3d)
        self.n_channels = n_channels
        self.n_labels = n_labels    
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model_3d(x)
        pooled_features = []
        for i in range(0, len(x)):
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, self.n_labels, 1)))
        pooled_features = torch.cat(pooled_features, dim=2)     
        labels = torch.mean(pooled_features, dim = 2)
        return labels

In [7]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        self.upsample = torch.nn.Upsample(size = [UP_RESOL, UP_RESOL, UP_RESOL])
        
        self.model3d_bowel        = Timm3DModelClassifier(backbone, 1, 2)      
        self.model3d_extrav       = Timm3DModelClassifier(backbone, 1, 2)
        self.model3d_kidney_left  = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_kidney_right = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_liver        = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_spleen       = Timm3DModelClassifier(backbone, 1, 3)
        
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
    def forward(self, x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total):
        bowel_label        = self.model3d_bowel(x_bowel)
        extrav_label       = self.model3d_extrav(x_total)
        kidney_label_left  = self.model3d_kidney_left(x_kidney_left)
        kidney_label_right = self.model3d_kidney_right(x_kidney_right)
        kidney_label       = (kidney_label_left + kidney_label_right)/2
        liver_label        = self.model3d_liver(x_liver)
        spleen_label       = self.model3d_spleen(x_spleen)
        
        
        labels = torch.cat([bowel_label, extrav_label, kidney_label, liver_label, spleen_label], dim = 1)
        
        bowel_soft = self.softmax(bowel_label)
        extrav_soft = self.softmax(extrav_label)
        kidney_soft = self.softmax(kidney_label)
        liver_soft = self.softmax(liver_label)
        spleen_soft = self.softmax(spleen_label)

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        return labels, any_in
    

In [8]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

86478624


0

# Metric & Loss

In [9]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [10]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]
    

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights)
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights)
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights)
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights)
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights)
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights)
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [11]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),
    #transforms.RandAffined(keys=chan_keys, translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [12]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['cropped_path']            
            for j in range(0, 6):
                tmp_data_3d = decompress_fast(f'{base_name}_{j}')[None]
                tmp_data_3d = torch.from_numpy(tmp_data_3d)
                tmp_data_3ds[chan_dict[j]] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 6):
                    data_3d[chan_dict[i]] = self.remain_transforms_set(data_3d[chan_dict[i]])
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['bowel'], data_3d['left_kidney'], data_3d['right_kidney'], \
                data_3d['liver'], data_3d['spleen'], data_3d['total'], label        


In [13]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [14]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in train_loader:
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        wandb.log({})
        
        batch_size = X_bowel.shape[0]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)
            do_mixup = False
            if np.random.random() < p_mixup:
                do_mixup = True
                X, y, labels_shuffled, lam = mixup(X, y)                
            
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
            if do_mixup:
                bowel_loss2, extrav_loss2, kidney_loss2, liver_loss2, spleen_loss2, any_in_loss2, avg_loss2 = calculate_loss(X_out, X_any, labels_shuffled)
                bowel_loss  = bowel_loss * lam  + bowel_loss2 * (1 - lam)
                extrav_loss = extrav_loss * lam  + extrav_loss2 * (1 - lam)
                kidney_loss = kidney_loss * lam  + kidney_loss2 * (1 - lam)         
                liver_loss  = liver_loss * lam  + liver_loss2 * (1 - lam) 
                spleen_loss = spleen_loss * lam  + spleen_loss2 * (1 - lam) 
                any_in_loss = any_in_loss * lam  + any_in_loss2 * (1 - lam) 
                avg_loss = avg_loss * lam  + avg_loss2 * (1 - lam)       
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
        counter+=1                   

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader):
    X_outs=[]
    ys=[]
    model.eval()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
        batch_size = y.shape[0]
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [15]:
model = AbdominalClassifier()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False, transform_set = None,
                                        remain_transforms_set = None)        
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, 
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model, f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE}.pt')    
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 3782/3782 [01:38<00:00, 38.25it/s]
100%|██████████| 929/929 [00:35<00:00, 26.48it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 / trn/loss=1.0044
Epoch 1 / train/metric=0.7420
Epoch 1 / val/metric=0.6674
Best val_metric 0.6673919152737042 at epoch 1!


  0%|          | 1/200 [05:01<16:39:25, 301.34s/it]

Epoch 2 / trn/loss=0.8968
Epoch 2 / train/metric=0.6445
Epoch 2 / val/metric=0.6235
Best val_metric 0.6234755591239058 at epoch 2!


  1%|          | 2/200 [09:52<16:15:00, 295.46s/it]

Epoch 3 / trn/loss=0.8667
Epoch 3 / train/metric=0.6177
Epoch 3 / val/metric=0.6087
Best val_metric 0.6086653604035522 at epoch 3!


  2%|▏         | 3/200 [14:45<16:06:18, 294.31s/it]

Epoch 4 / trn/loss=0.8567
Epoch 4 / train/metric=0.6089
Epoch 4 / val/metric=0.6014
Best val_metric 0.6013808430036833 at epoch 4!


  2%|▏         | 4/200 [19:38<16:00:02, 293.89s/it]

Epoch 5 / trn/loss=0.8490
Epoch 5 / train/metric=0.6026
Epoch 5 / val/metric=0.5956
Best val_metric 0.5956333367306278 at epoch 5!


  2%|▎         | 5/200 [24:32<15:55:18, 293.94s/it]

Epoch 6 / trn/loss=0.8437
Epoch 6 / train/metric=0.5987
Epoch 6 / val/metric=0.5905
Best val_metric 0.5905096442998758 at epoch 6!


  3%|▎         | 6/200 [29:26<15:50:08, 293.86s/it]

Epoch 7 / trn/loss=0.8379
Epoch 7 / train/metric=0.5942
Epoch 7 / val/metric=0.5773
Best val_metric 0.5773230003639047 at epoch 7!


  4%|▎         | 7/200 [34:18<15:43:05, 293.19s/it]

Epoch 8 / trn/loss=0.8211
Epoch 8 / train/metric=0.5790
Epoch 8 / val/metric=0.5595
Best val_metric 0.5594779906844848 at epoch 8!


  4%|▍         | 8/200 [39:12<15:38:47, 293.37s/it]

Epoch 9 / trn/loss=0.8105
Epoch 9 / train/metric=0.5690
Epoch 9 / val/metric=0.5555
Best val_metric 0.5555250599633342 at epoch 9!


  4%|▍         | 9/200 [44:06<15:34:36, 293.59s/it]

Epoch 10 / trn/loss=0.8076
Epoch 10 / train/metric=0.5669
Epoch 10 / val/metric=0.5510
Best val_metric 0.5509811448537765 at epoch 10!


  5%|▌         | 10/200 [48:59<15:29:18, 293.47s/it]

Epoch 11 / trn/loss=0.8054
Epoch 11 / train/metric=0.5661
Epoch 11 / val/metric=0.5502
Best val_metric 0.5501604041681606 at epoch 11!


  6%|▌         | 11/200 [53:54<15:25:56, 293.95s/it]

Epoch 12 / trn/loss=0.8033
Epoch 12 / train/metric=0.5633
Epoch 12 / val/metric=0.5456
Best val_metric 0.545578822268815 at epoch 12!


  6%|▌         | 12/200 [58:48<15:21:28, 294.09s/it]

Epoch 13 / trn/loss=0.8034
Epoch 13 / train/metric=0.5643


  6%|▋         | 13/200 [1:03:41<15:14:53, 293.55s/it]

Epoch 13 / val/metric=0.5488
Epoch 14 / trn/loss=0.8047
Epoch 14 / train/metric=0.5656


  7%|▋         | 14/200 [1:08:35<15:10:26, 293.69s/it]

Epoch 14 / val/metric=0.5462
Epoch 15 / trn/loss=0.7987
Epoch 15 / train/metric=0.5603
Epoch 15 / val/metric=0.5444
Best val_metric 0.5443859571738096 at epoch 15!


  8%|▊         | 15/200 [1:13:30<15:06:50, 294.11s/it]

Epoch 16 / trn/loss=0.8004
Epoch 16 / train/metric=0.5609


  8%|▊         | 16/200 [1:18:24<15:02:00, 294.13s/it]

Epoch 16 / val/metric=0.5447
Epoch 17 / trn/loss=0.7988
Epoch 17 / train/metric=0.5598
Epoch 17 / val/metric=0.5407
Best val_metric 0.5407216195673618 at epoch 17!


  8%|▊         | 17/200 [1:23:19<14:57:53, 294.39s/it]

Epoch 18 / trn/loss=0.7954
Epoch 18 / train/metric=0.5581


  9%|▉         | 18/200 [1:28:13<14:52:31, 294.24s/it]

Epoch 18 / val/metric=0.5505
Epoch 19 / trn/loss=0.7967
Epoch 19 / train/metric=0.5589
Epoch 19 / val/metric=0.5394
Best val_metric 0.5393605096136288 at epoch 19!


 10%|▉         | 19/200 [1:33:07<14:47:43, 294.27s/it]

Epoch 20 / trn/loss=0.7968
Epoch 20 / train/metric=0.5589
Epoch 20 / val/metric=0.5381
Best val_metric 0.5380551845802015 at epoch 20!


 10%|█         | 20/200 [1:38:01<14:42:47, 294.27s/it]

Epoch 21 / trn/loss=0.7953
Epoch 21 / train/metric=0.5593
Epoch 21 / val/metric=0.5365
Best val_metric 0.5364786954691629 at epoch 21!


 10%|█         | 21/200 [1:42:54<14:36:24, 293.77s/it]

Epoch 22 / trn/loss=0.7967
Epoch 22 / train/metric=0.5594


 11%|█         | 22/200 [1:47:48<14:32:01, 293.94s/it]

Epoch 22 / val/metric=0.5533
Epoch 23 / trn/loss=0.7990
Epoch 23 / train/metric=0.5612


 12%|█▏        | 23/200 [1:52:43<14:27:32, 294.08s/it]

Epoch 23 / val/metric=0.5410
Epoch 24 / trn/loss=0.7960
Epoch 24 / train/metric=0.5579
Epoch 24 / val/metric=0.5357
Best val_metric 0.5356680835732083 at epoch 24!


 12%|█▏        | 24/200 [1:57:37<14:23:00, 294.21s/it]

Epoch 25 / trn/loss=0.7934
Epoch 25 / train/metric=0.5552


 12%|█▎        | 25/200 [2:02:31<14:17:55, 294.15s/it]

Epoch 25 / val/metric=0.5443
Epoch 26 / trn/loss=0.7897
Epoch 26 / train/metric=0.5531


 13%|█▎        | 26/200 [2:07:26<14:13:54, 294.45s/it]

Epoch 26 / val/metric=0.5417
Epoch 27 / trn/loss=0.7926
Epoch 27 / train/metric=0.5563


 14%|█▎        | 27/200 [2:12:21<14:09:14, 294.53s/it]

Epoch 27 / val/metric=0.5511
Epoch 28 / trn/loss=0.7917
Epoch 28 / train/metric=0.5550


 14%|█▍        | 28/200 [2:17:16<14:04:25, 294.57s/it]

Epoch 28 / val/metric=0.5376
Epoch 29 / trn/loss=0.7889
Epoch 29 / train/metric=0.5517


 14%|█▍        | 29/200 [2:22:09<13:58:40, 294.27s/it]

Epoch 29 / val/metric=0.5447
Epoch 30 / trn/loss=0.7881
Epoch 30 / train/metric=0.5522


 15%|█▌        | 30/200 [2:27:04<13:53:58, 294.35s/it]

Epoch 30 / val/metric=0.5495
Epoch 31 / trn/loss=0.7887
Epoch 31 / train/metric=0.5516


 16%|█▌        | 31/200 [2:31:58<13:48:51, 294.27s/it]

Epoch 31 / val/metric=0.5517
Epoch 32 / trn/loss=0.7968
Epoch 32 / train/metric=0.5604


 16%|█▌        | 32/200 [2:36:52<13:43:29, 294.10s/it]

Epoch 32 / val/metric=0.5428
Epoch 33 / trn/loss=0.7928
Epoch 33 / train/metric=0.5563


 16%|█▋        | 33/200 [2:41:46<13:38:30, 294.07s/it]

Epoch 33 / val/metric=0.5402
Epoch 34 / trn/loss=0.7910
Epoch 34 / train/metric=0.5542


 17%|█▋        | 34/200 [2:46:40<13:33:41, 294.10s/it]

Epoch 34 / val/metric=0.5378
Epoch 35 / trn/loss=0.7866
Epoch 35 / train/metric=0.5494
Epoch 35 / val/metric=0.5347
Best val_metric 0.5347487492166493 at epoch 35!


 18%|█▊        | 35/200 [2:51:35<13:29:24, 294.33s/it]

Epoch 36 / trn/loss=0.7910
Epoch 36 / train/metric=0.5529
Epoch 36 / val/metric=0.5268
Best val_metric 0.5267589615024719 at epoch 36!


 18%|█▊        | 36/200 [2:56:29<13:24:07, 294.19s/it]

Epoch 37 / trn/loss=0.7865
Epoch 37 / train/metric=0.5497


 18%|█▊        | 37/200 [3:01:23<13:19:05, 294.14s/it]

Epoch 37 / val/metric=0.5438
Epoch 38 / trn/loss=0.7883
Epoch 38 / train/metric=0.5521


 19%|█▉        | 38/200 [3:06:18<13:14:49, 294.38s/it]

Epoch 38 / val/metric=0.5457
Epoch 39 / trn/loss=0.7849
Epoch 39 / train/metric=0.5490


 20%|█▉        | 39/200 [3:11:12<13:09:40, 294.29s/it]

Epoch 39 / val/metric=0.5357
Epoch 40 / trn/loss=0.7857
Epoch 40 / train/metric=0.5512


 20%|██        | 40/200 [3:16:06<13:04:37, 294.23s/it]

Epoch 40 / val/metric=0.5394
Epoch 41 / trn/loss=0.7817
Epoch 41 / train/metric=0.5475


 20%|██        | 41/200 [3:21:01<13:00:08, 294.40s/it]

Epoch 41 / val/metric=0.5573
Epoch 42 / trn/loss=0.7875
Epoch 42 / train/metric=0.5530


 21%|██        | 42/200 [3:25:54<12:54:23, 294.07s/it]

Epoch 42 / val/metric=0.5504
Epoch 43 / trn/loss=0.7850
Epoch 43 / train/metric=0.5499


 22%|██▏       | 43/200 [3:30:48<12:49:18, 294.00s/it]

Epoch 43 / val/metric=0.5288
Epoch 44 / trn/loss=0.7835
Epoch 44 / train/metric=0.5486


 22%|██▏       | 44/200 [3:35:42<12:44:29, 294.03s/it]

Epoch 44 / val/metric=0.5366
Epoch 45 / trn/loss=0.7831
Epoch 45 / train/metric=0.5476


 22%|██▎       | 45/200 [3:40:36<12:39:36, 294.04s/it]

Epoch 45 / val/metric=0.5391
Epoch 46 / trn/loss=0.7858
Epoch 46 / train/metric=0.5504


 23%|██▎       | 46/200 [3:45:30<12:34:38, 294.01s/it]

Epoch 46 / val/metric=0.5318
Epoch 47 / trn/loss=0.7791
Epoch 47 / train/metric=0.5461
Epoch 47 / val/metric=0.5193
Best val_metric 0.5193423416192084 at epoch 47!


 24%|██▎       | 47/200 [3:50:25<12:30:31, 294.32s/it]

Epoch 48 / trn/loss=0.7797
Epoch 48 / train/metric=0.5463


 24%|██▍       | 48/200 [3:55:18<12:24:35, 293.92s/it]

Epoch 48 / val/metric=0.5248
Epoch 49 / trn/loss=0.7742
Epoch 49 / train/metric=0.5406


 24%|██▍       | 49/200 [4:00:12<12:19:36, 293.88s/it]

Epoch 49 / val/metric=0.5369
Epoch 50 / trn/loss=0.7796
Epoch 50 / train/metric=0.5456


 25%|██▌       | 50/200 [4:05:06<12:15:10, 294.07s/it]

Epoch 50 / val/metric=0.5325
Epoch 51 / trn/loss=0.7748
Epoch 51 / train/metric=0.5407


 26%|██▌       | 51/200 [4:10:01<12:10:59, 294.36s/it]

Epoch 51 / val/metric=0.5252
Epoch 52 / trn/loss=0.7705
Epoch 52 / train/metric=0.5389


 26%|██▌       | 52/200 [4:14:54<12:05:15, 294.02s/it]

Epoch 52 / val/metric=0.5511
Epoch 53 / trn/loss=0.7759
Epoch 53 / train/metric=0.5434


 26%|██▋       | 53/200 [4:19:48<12:00:16, 293.99s/it]

Epoch 53 / val/metric=0.5258
Epoch 54 / trn/loss=0.7710
Epoch 54 / train/metric=0.5381


 27%|██▋       | 54/200 [4:24:42<11:54:58, 293.83s/it]

Epoch 54 / val/metric=0.5230
Epoch 55 / trn/loss=0.7701
Epoch 55 / train/metric=0.5398


 28%|██▊       | 55/200 [4:29:37<11:50:44, 294.10s/it]

Epoch 55 / val/metric=0.5221
Epoch 56 / trn/loss=0.7709
Epoch 56 / train/metric=0.5383


 28%|██▊       | 56/200 [4:34:32<11:46:30, 294.38s/it]

Epoch 56 / val/metric=0.5319
Epoch 57 / trn/loss=0.7661
Epoch 57 / train/metric=0.5350
Epoch 57 / val/metric=0.5173
Best val_metric 0.5173366571998507 at epoch 57!


 28%|██▊       | 57/200 [4:39:26<11:41:19, 294.26s/it]

Epoch 58 / trn/loss=0.7617
Epoch 58 / train/metric=0.5293


 29%|██▉       | 58/200 [4:44:20<11:36:51, 294.44s/it]

Epoch 58 / val/metric=0.5310
Epoch 59 / trn/loss=0.7577
Epoch 59 / train/metric=0.5276


 30%|██▉       | 59/200 [4:49:14<11:31:28, 294.25s/it]

Epoch 59 / val/metric=0.5177
Epoch 60 / trn/loss=0.7647
Epoch 60 / train/metric=0.5317


 30%|███       | 60/200 [4:54:08<11:26:20, 294.14s/it]

Epoch 60 / val/metric=0.5303
Epoch 61 / trn/loss=0.7596
Epoch 61 / train/metric=0.5281


 30%|███       | 61/200 [4:59:02<11:21:20, 294.10s/it]

Epoch 61 / val/metric=0.5204
Epoch 62 / trn/loss=0.7590
Epoch 62 / train/metric=0.5282


 31%|███       | 62/200 [5:03:56<11:16:32, 294.15s/it]

Epoch 62 / val/metric=0.5371
Epoch 63 / trn/loss=0.7592
Epoch 63 / train/metric=0.5297
Epoch 63 / val/metric=0.5099
Best val_metric 0.5098514817845469 at epoch 63!


 32%|███▏      | 63/200 [5:08:51<11:11:49, 294.23s/it]

Epoch 64 / trn/loss=0.7474
Epoch 64 / train/metric=0.5193


 32%|███▏      | 64/200 [5:13:45<11:06:54, 294.23s/it]

Epoch 64 / val/metric=0.5414
Epoch 65 / trn/loss=0.7554
Epoch 65 / train/metric=0.5258
Epoch 65 / val/metric=0.5076
Best val_metric 0.5075754292341164 at epoch 65!


 32%|███▎      | 65/200 [5:18:40<11:02:29, 294.44s/it]

Epoch 66 / trn/loss=0.7503
Epoch 66 / train/metric=0.5223
Epoch 66 / val/metric=0.5001
Best val_metric 0.5000953818585667 at epoch 66!


 33%|███▎      | 66/200 [5:23:34<10:57:31, 294.42s/it]

Epoch 67 / trn/loss=0.7506
Epoch 67 / train/metric=0.5224


 34%|███▎      | 67/200 [5:28:28<10:52:16, 294.26s/it]

Epoch 67 / val/metric=0.5205
Epoch 68 / trn/loss=0.7525
Epoch 68 / train/metric=0.5231


 34%|███▍      | 68/200 [5:33:21<10:46:43, 293.96s/it]

Epoch 68 / val/metric=0.5056
Epoch 69 / trn/loss=0.7430
Epoch 69 / train/metric=0.5152


 34%|███▍      | 69/200 [5:38:16<10:42:04, 294.08s/it]

Epoch 69 / val/metric=0.5306
Epoch 70 / trn/loss=0.7350
Epoch 70 / train/metric=0.5093


 35%|███▌      | 70/200 [5:43:11<10:37:55, 294.42s/it]

Epoch 70 / val/metric=0.5531
Epoch 71 / trn/loss=0.7433
Epoch 71 / train/metric=0.5156


 36%|███▌      | 71/200 [5:48:05<10:32:49, 294.34s/it]

Epoch 71 / val/metric=0.5021
Epoch 72 / trn/loss=0.7371
Epoch 72 / train/metric=0.5112


 36%|███▌      | 72/200 [5:53:00<10:28:02, 294.39s/it]

Epoch 72 / val/metric=0.5181
Epoch 73 / trn/loss=0.7391
Epoch 73 / train/metric=0.5130


 36%|███▋      | 73/200 [5:57:54<10:23:02, 294.35s/it]

Epoch 73 / val/metric=0.5023
Epoch 74 / trn/loss=0.7417
Epoch 74 / train/metric=0.5139
Epoch 74 / val/metric=0.4946
Best val_metric 0.4946221208407218 at epoch 74!


 37%|███▋      | 74/200 [6:02:49<10:18:32, 294.55s/it]

Epoch 75 / trn/loss=0.7384
Epoch 75 / train/metric=0.5120


 38%|███▊      | 75/200 [6:07:43<10:13:29, 294.47s/it]

Epoch 75 / val/metric=0.5022
Epoch 76 / trn/loss=0.7321
Epoch 76 / train/metric=0.5068


 38%|███▊      | 76/200 [6:12:37<10:08:18, 294.34s/it]

Epoch 76 / val/metric=0.5037
Epoch 77 / trn/loss=0.7321
Epoch 77 / train/metric=0.5064


 38%|███▊      | 77/200 [6:17:31<10:03:12, 294.25s/it]

Epoch 77 / val/metric=0.5053
Epoch 78 / trn/loss=0.7353
Epoch 78 / train/metric=0.5100


 39%|███▉      | 78/200 [6:22:25<9:57:52, 294.03s/it] 

Epoch 78 / val/metric=0.5220
Epoch 79 / trn/loss=0.7266
Epoch 79 / train/metric=0.5028


 40%|███▉      | 79/200 [6:27:18<9:52:41, 293.89s/it]

Epoch 79 / val/metric=0.5072
Epoch 80 / trn/loss=0.7280
Epoch 80 / train/metric=0.5032
Epoch 80 / val/metric=0.4874
Best val_metric 0.48744185415697777 at epoch 80!


 40%|████      | 80/200 [6:32:13<9:48:16, 294.13s/it]

Epoch 81 / trn/loss=0.7240
Epoch 81 / train/metric=0.5002


 40%|████      | 81/200 [6:37:06<9:42:42, 293.81s/it]

Epoch 81 / val/metric=0.4978
Epoch 82 / trn/loss=0.7280
Epoch 82 / train/metric=0.5046


 41%|████      | 82/200 [6:41:58<9:36:22, 293.07s/it]

Epoch 82 / val/metric=0.5142
Epoch 83 / trn/loss=0.7249
Epoch 83 / train/metric=0.5015


 42%|████▏     | 83/200 [6:46:52<9:31:59, 293.33s/it]

Epoch 83 / val/metric=0.5058
Epoch 84 / trn/loss=0.7205
Epoch 84 / train/metric=0.4980


 42%|████▏     | 84/200 [6:51:45<9:27:19, 293.44s/it]

Epoch 84 / val/metric=0.5034
Epoch 85 / trn/loss=0.7260
Epoch 85 / train/metric=0.5017


 42%|████▎     | 85/200 [6:56:39<9:22:39, 293.56s/it]

Epoch 85 / val/metric=0.5023
Epoch 86 / trn/loss=0.7163
Epoch 86 / train/metric=0.4935


 43%|████▎     | 86/200 [7:01:34<9:18:21, 293.88s/it]

Epoch 86 / val/metric=0.5184
Epoch 87 / trn/loss=0.7214
Epoch 87 / train/metric=0.5003


 44%|████▎     | 87/200 [7:06:28<9:13:28, 293.88s/it]

Epoch 87 / val/metric=0.5122
Epoch 88 / trn/loss=0.7173
Epoch 88 / train/metric=0.4946


 44%|████▍     | 88/200 [7:11:21<9:08:34, 293.88s/it]

Epoch 88 / val/metric=0.5065
Epoch 89 / trn/loss=0.7142
Epoch 89 / train/metric=0.4915


 44%|████▍     | 89/200 [7:16:15<9:03:39, 293.87s/it]

Epoch 89 / val/metric=0.5011
Epoch 90 / trn/loss=0.7118
Epoch 90 / train/metric=0.4898


 45%|████▌     | 90/200 [7:21:10<8:59:02, 294.03s/it]

Epoch 90 / val/metric=0.5132
Epoch 91 / trn/loss=0.7156
Epoch 91 / train/metric=0.4944


 46%|████▌     | 91/200 [7:26:03<8:54:00, 293.95s/it]

Epoch 91 / val/metric=0.5000
Epoch 92 / trn/loss=0.7056
Epoch 92 / train/metric=0.4853


 46%|████▌     | 92/200 [7:30:57<8:48:47, 293.78s/it]

Epoch 92 / val/metric=0.5104
Epoch 93 / trn/loss=0.7043
Epoch 93 / train/metric=0.4853


 46%|████▋     | 93/200 [7:35:49<8:42:57, 293.25s/it]

Epoch 93 / val/metric=0.5117
Epoch 94 / trn/loss=0.7141
Epoch 94 / train/metric=0.4935


 47%|████▋     | 94/200 [7:40:41<8:37:21, 292.85s/it]

Epoch 94 / val/metric=0.5015
Epoch 95 / trn/loss=0.7123
Epoch 95 / train/metric=0.4912


 48%|████▊     | 95/200 [7:45:33<8:32:07, 292.64s/it]

Epoch 95 / val/metric=0.4960
Epoch 96 / trn/loss=0.7071
Epoch 96 / train/metric=0.4867


 48%|████▊     | 96/200 [7:50:25<8:26:55, 292.45s/it]

Epoch 96 / val/metric=0.4961
Epoch 97 / trn/loss=0.6996
Epoch 97 / train/metric=0.4810


 48%|████▊     | 97/200 [7:55:19<8:22:57, 292.98s/it]

Epoch 97 / val/metric=0.5011
Epoch 98 / trn/loss=0.7052
Epoch 98 / train/metric=0.4856


 49%|████▉     | 98/200 [8:00:13<8:18:39, 293.32s/it]

Epoch 98 / val/metric=0.5127
Epoch 99 / trn/loss=0.7048
Epoch 99 / train/metric=0.4866


 50%|████▉     | 99/200 [8:05:05<8:12:58, 292.85s/it]

Epoch 99 / val/metric=0.5009
Epoch 100 / trn/loss=0.6921
Epoch 100 / train/metric=0.4743


wandb: Network error (ReadTimeout), entering retry loop.


Epoch 100 / val/metric=0.4834
Best val_metric 0.48337520229861686 at epoch 100!


 50%|█████     | 100/200 [8:09:59<8:08:41, 293.21s/it]

Epoch 101 / trn/loss=0.6993
Epoch 101 / train/metric=0.4813


 50%|█████     | 101/200 [8:14:54<8:04:35, 293.69s/it]

Epoch 101 / val/metric=0.5159
Epoch 102 / trn/loss=0.6981
Epoch 102 / train/metric=0.4786


 51%|█████     | 102/200 [8:19:48<7:59:44, 293.72s/it]

Epoch 102 / val/metric=0.4935
Epoch 103 / trn/loss=0.7042
Epoch 103 / train/metric=0.4848


 52%|█████▏    | 103/200 [8:24:42<7:55:09, 293.91s/it]

Epoch 103 / val/metric=0.4992
Epoch 104 / trn/loss=0.6902
Epoch 104 / train/metric=0.4725


 52%|█████▏    | 104/200 [8:29:37<7:50:49, 294.26s/it]

Epoch 104 / val/metric=0.4945
Epoch 105 / trn/loss=0.6842
Epoch 105 / train/metric=0.4664


 52%|█████▎    | 105/200 [8:34:32<7:46:11, 294.43s/it]

Epoch 105 / val/metric=0.5012
Epoch 106 / trn/loss=0.6878
Epoch 106 / train/metric=0.4698


 53%|█████▎    | 106/200 [8:39:26<7:41:11, 294.38s/it]

Epoch 106 / val/metric=0.4867
Epoch 107 / trn/loss=0.6845
Epoch 107 / train/metric=0.4671


 54%|█████▎    | 107/200 [8:44:18<7:34:54, 293.49s/it]

Epoch 107 / val/metric=0.4992
Epoch 108 / trn/loss=0.6844
Epoch 108 / train/metric=0.4679


 54%|█████▍    | 108/200 [8:49:12<7:30:17, 293.67s/it]

Epoch 108 / val/metric=0.5027
Epoch 109 / trn/loss=0.6796
Epoch 109 / train/metric=0.4638


 55%|█████▍    | 109/200 [8:54:06<7:25:49, 293.95s/it]

Epoch 109 / val/metric=0.5102
Epoch 110 / trn/loss=0.6843
Epoch 110 / train/metric=0.4692


 55%|█████▌    | 110/200 [8:59:02<7:21:33, 294.37s/it]

Epoch 110 / val/metric=0.4994
Epoch 111 / trn/loss=0.6832
Epoch 111 / train/metric=0.4681


 56%|█████▌    | 111/200 [9:03:56<7:16:42, 294.41s/it]

Epoch 111 / val/metric=0.5050
Epoch 112 / trn/loss=0.6774
Epoch 112 / train/metric=0.4616


 56%|█████▌    | 112/200 [9:08:50<7:11:45, 294.38s/it]

Epoch 112 / val/metric=0.4997
Epoch 113 / trn/loss=0.6803
Epoch 113 / train/metric=0.4646


 56%|█████▋    | 113/200 [9:13:45<7:06:43, 294.29s/it]

Epoch 113 / val/metric=0.4943
Epoch 114 / trn/loss=0.6722
Epoch 114 / train/metric=0.4572
Epoch 114 / val/metric=0.4760
Best val_metric 0.4760328437074442 at epoch 114!


 57%|█████▋    | 114/200 [9:18:40<7:02:12, 294.57s/it]

Epoch 115 / trn/loss=0.6812
Epoch 115 / train/metric=0.4653


 57%|█████▊    | 115/200 [9:23:34<6:57:08, 294.45s/it]

Epoch 115 / val/metric=0.4862
Epoch 116 / trn/loss=0.6676
Epoch 116 / train/metric=0.4539


 58%|█████▊    | 116/200 [9:28:29<6:52:18, 294.51s/it]

Epoch 116 / val/metric=0.4940
Epoch 117 / trn/loss=0.6721
Epoch 117 / train/metric=0.4587


 58%|█████▊    | 117/200 [9:33:23<6:47:11, 294.35s/it]

Epoch 117 / val/metric=0.4866
Epoch 118 / trn/loss=0.6626
Epoch 118 / train/metric=0.4490


 59%|█████▉    | 118/200 [9:38:16<6:41:50, 294.03s/it]

Epoch 118 / val/metric=0.5063
Epoch 119 / trn/loss=0.6643
Epoch 119 / train/metric=0.4522


 60%|█████▉    | 119/200 [9:43:10<6:36:49, 293.94s/it]

Epoch 119 / val/metric=0.4922
Epoch 120 / trn/loss=0.6665
Epoch 120 / train/metric=0.4523


 60%|██████    | 120/200 [9:48:04<6:32:11, 294.14s/it]

Epoch 120 / val/metric=0.4864
Epoch 121 / trn/loss=0.6606
Epoch 121 / train/metric=0.4477


 60%|██████    | 121/200 [9:52:57<6:26:45, 293.74s/it]

Epoch 121 / val/metric=0.4993
Epoch 122 / trn/loss=0.6598
Epoch 122 / train/metric=0.4473


 61%|██████    | 122/200 [9:57:51<6:22:01, 293.87s/it]

Epoch 122 / val/metric=0.4788
Epoch 123 / trn/loss=0.6587
Epoch 123 / train/metric=0.4471


 62%|██████▏   | 123/200 [10:02:46<6:17:32, 294.19s/it]

Epoch 123 / val/metric=0.4894
Epoch 124 / trn/loss=0.6548
Epoch 124 / train/metric=0.4433


 62%|██████▏   | 124/200 [10:07:41<6:13:01, 294.50s/it]

Epoch 124 / val/metric=0.4831
Epoch 125 / trn/loss=0.6552
Epoch 125 / train/metric=0.4438


 62%|██████▎   | 125/200 [10:12:35<6:07:52, 294.30s/it]

Epoch 125 / val/metric=0.4894
Epoch 126 / trn/loss=0.6528
Epoch 126 / train/metric=0.4405


 63%|██████▎   | 126/200 [10:17:29<6:02:58, 294.30s/it]

Epoch 126 / val/metric=0.4936
Epoch 127 / trn/loss=0.6431
Epoch 127 / train/metric=0.4330


 64%|██████▎   | 127/200 [10:22:24<5:58:07, 294.35s/it]

Epoch 127 / val/metric=0.5009
Epoch 128 / trn/loss=0.6467
Epoch 128 / train/metric=0.4361


 64%|██████▍   | 128/200 [10:27:18<5:53:00, 294.18s/it]

Epoch 128 / val/metric=0.4871
Epoch 129 / trn/loss=0.6387
Epoch 129 / train/metric=0.4294


 64%|██████▍   | 129/200 [10:32:13<5:48:26, 294.46s/it]

Epoch 129 / val/metric=0.4894
Epoch 130 / trn/loss=0.6414
Epoch 130 / train/metric=0.4319


 65%|██████▌   | 130/200 [10:37:06<5:43:07, 294.11s/it]

Epoch 130 / val/metric=0.4900
Epoch 131 / trn/loss=0.6369
Epoch 131 / train/metric=0.4298


 66%|██████▌   | 131/200 [10:41:59<5:37:58, 293.89s/it]

Epoch 131 / val/metric=0.5096
Epoch 132 / trn/loss=0.6353
Epoch 132 / train/metric=0.4266


 66%|██████▌   | 132/200 [10:46:54<5:33:23, 294.16s/it]

Epoch 132 / val/metric=0.4884
Epoch 133 / trn/loss=0.6337
Epoch 133 / train/metric=0.4257


 66%|██████▋   | 133/200 [10:51:47<5:28:07, 293.84s/it]

Epoch 133 / val/metric=0.4926
Epoch 134 / trn/loss=0.6339
Epoch 134 / train/metric=0.4256


                                                       

Epoch 134 / val/metric=0.5033
Not improved for 20 epochs, terminate the train






0,1
best_total_log_loss,▁
lr,▁▁▁▂▂▂▃▃▄▅▅▆▆▇▇██████████▇▇▇▇▇▇▆▆▆▅▅▅▅▄▄
train_any_in_metric,▇██▆▇▆▆▆▆▆▆▆▆▇▆▆▆▅▅▅▅▄▄▆▅▄▄▄▄▄▃▃▄▃▃▂▂▂▂▁
train_any_loss,▂▇▆▃▄▁▆▃▁█▂▂▁▃▃▃▃▃▆▁▃▁▄▅▂▃▇▄▃▂▄▁▄▆▅▂▄▂▂▅
train_avg_loss,▃█▄▃▄▁▆▂▁▇▃▃▁▃▃▄▃▃▇▁▃▁▅▅▂▃▇▄▅▂▄▁▄▅▆▁▄▃▃▅
train_avg_metric,█▇▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train_bowel_loss,▂▃▃▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▃▁▃▁▁▁▁▂▁▁▁▁▁▁▄▃▁
train_bowel_metric,█▅▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▃▂▂▂▃▂▂▂▂▁▁▁
train_extrav_loss,▅█▂▅▅▂▅▂▂▄▃▂▂▃▂▂▂▅▅▂▁▂▃▂▂▄▁▅▅▁▂▂▁▃▁▂▂▃▂▃
train_extrav_metric,▇▇▇▇▆▆▆█▇▆█▇▆▇▆▅▇█▆▅▆▅▆▄▄▄▅▃▄▃▅▃▃▃▁▄▂▂▂▁

0,1
best_total_log_loss,0.47603
lr,0.00046
train_any_in_metric,0.52078
train_any_loss,0.88861
train_avg_loss,0.35342
train_avg_metric,0.4256
train_bowel_loss,0.01205
train_bowel_metric,0.10919
train_extrav_loss,0.32776
train_extrav_metric,0.53728


In [16]:
#Execute this cell to fininsh the wandb run when you stopped training.

import wandb
try: 
    wandb.log({'best_total_log_loss': best_metric})
    wandb.finish()
    
except:
    print('Wandb is already finished!')

Wandb is already finished!
