In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')
#sys.path.append('')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


True

# Parameters

In [2]:
backbone = 'timm/resnet10t.c3_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'momdel_test'
RUN_NAME=   f'{backbone}_lstm_with_organ_embedding'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 8
ACCUM_STEPS = 3
N_WORKERS  = 8
LR = 0.0002
N_EPOCHS = 200
EARLY_STOP_COUNT = N_EPOCHS
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
PCT_START = 0.3
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0



DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION,
    'PCT_START': PCT_START
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [4]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress_fast(name):
    with open(name, 'rb') as f:
        data = pickle.load(f)
    return data

# Model

In [5]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [6]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    
    
class Timm3DModelClassifierEmbed(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModelClassifierEmbed, self).__init__()
        self.model_3d = Timm3DModel(backbone, n_channels, n_labels, segtype, pretrained)
        self.model_3d = convert_3d(self.model_3d)
        self.n_channels = n_channels
        self.n_labels = n_labels                        
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model_3d(x)
        pooled_features = []
        for i in range(0, len(x)):
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, self.n_labels, 1)))
        pooled_features = torch.cat(pooled_features, dim=2)     
        labels = nn.Flatten()(pooled_features)
        #labels = torch.mean(pooled_features, dim = 2)
        return labels

In [7]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        
        self.model3d_bowel        = Timm3DModelClassifierEmbed(backbone, 1, 32)      
        self.model3d_extrav       = Timm3DModelClassifierEmbed(backbone, 1, 32)
        self.model3d_kidney_left  = Timm3DModelClassifierEmbed(backbone, 1, 32)
        self.model3d_kidney_right = Timm3DModelClassifierEmbed(backbone, 1, 32)
        self.model3d_liver        = Timm3DModelClassifierEmbed(backbone, 1, 32)
        self.model3d_spleen       = Timm3DModelClassifierEmbed(backbone, 1, 32)
        
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
        self.lstm = nn.LSTM(input_size =128, hidden_size = 256, num_layers=5, batch_first=True, bidirectional=True)
        
        self.head = nn.Linear(512, 13)
        
    def forward(self, x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total):
        bs = x_bowel.shape[0]
        
        bowel_emb        = torch.reshape(self.model3d_bowel(x_bowel), (bs, 1, 128))
        extrav_emb       = torch.reshape(self.model3d_extrav(x_total), (bs, 1, 128))
        kidney_left_emb  = torch.reshape(self.model3d_kidney_left(x_kidney_left), (bs, 1, 128))
        kidney_right_emb = torch.reshape(self.model3d_kidney_right(x_kidney_right), (bs, 1, 128))
        liver_emb        = torch.reshape(self.model3d_liver(x_liver), (bs, 1, 128))
        spleen_emb       = torch.reshape(self.model3d_spleen(x_spleen), (bs, 1, 128))
        
        all_embs = torch.cat([bowel_emb, extrav_emb, kidney_left_emb, kidney_right_emb, liver_emb, spleen_emb], dim = 1)
        all_embs = self.lstm(all_embs)

        labels   = torch.mean(all_embs[0], dim = 1) 
        labels   = self.head(labels)

        bowel_soft = self.softmax(labels[:,:2])
        extrav_soft = self.softmax(labels[:,2:4])
        kidney_soft = self.softmax(labels[:,4:7])
        liver_soft = self.softmax(labels[:,7:10])
        spleen_soft = self.softmax(labels[:,10:13])

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        #any_in = torch.log(any_in + 1e-6)  # 1e-6은 0을 처리하기 위한 작은 값
        return labels, any_in
    

In [8]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

93764765


0

# Metric & Loss

In [9]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [10]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    isnan_x = np.isnan(X_outs).astype(int)
    isnan_y = np.isnan(ys).astype(int)
    
    if(np.max(isnan_x)>0):
        print('xnan')
    if(np.max(isnan_y)>0):
        print('ynan')
        
    #X_outs[:, 13:15] = nn.Softmax(dim=1)(torch.from_numpy(X_outs[:, 13:15])).numpy()
    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights.astype(np.float64))
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights.astype(np.float64))
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights.astype(np.float64))
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights.astype(np.float64))
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights.astype(np.float64))
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights.astype(np.float64))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [11]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=0),    
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),
    #transforms.RandAffined(keys=chan_keys, translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [12]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['cropped_path']            
            for j in range(0, 6):
                tmp_data_3d = decompress_fast(f'{base_name}_{j}').unsqueeze(0)
                #tmp_data_3d = torch.from_numpy(tmp_data_3d)
                tmp_data_3ds[chan_dict[j]] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 6):
                    data_3d[chan_dict[i]] = self.remain_transforms_set(data_3d[chan_dict[i]])
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['bowel'], data_3d['left_kidney'], data_3d['right_kidney'], \
                data_3d['liver'], data_3d['spleen'], data_3d['total'], label        


In [13]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [14]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in train_loader:
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        
        batch_size = X_bowel.shape[0]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
        counter+=1                   

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader, epoch):
    X_outs=[]
    ys=[]
    model.eval()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
        batch_size = y.shape[0]
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [15]:
model = AbdominalClassifier()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False, transform_set = None,
                                        remain_transforms_set = None)
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, pct_start= PCT_START,
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader, epoch)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}_best.pt') 
            if(metric < 0.48):
                torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}_{metric}.pt') 
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 3782/3782 [01:22<00:00, 45.69it/s]
100%|██████████| 929/929 [00:29<00:00, 31.25it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 / trn/loss=1.1385
Epoch 1 / train/metric=0.8649
Epoch 1 / val/metric=0.8416
Best val_metric 0.8415715837165322 at epoch 1!


  0%|          | 1/200 [04:32<15:02:21, 272.07s/it]

Epoch 2 / trn/loss=0.9711
Epoch 2 / train/metric=0.7130
Epoch 2 / val/metric=0.6118
Best val_metric 0.6117981292574327 at epoch 2!


  1%|          | 2/200 [08:50<14:31:39, 264.14s/it]

Epoch 3 / trn/loss=0.8620
Epoch 3 / train/metric=0.6149
Epoch 3 / val/metric=0.6072
Best val_metric 0.6072271286814738 at epoch 3!


  2%|▏         | 3/200 [13:10<14:20:41, 262.14s/it]

Epoch 4 / trn/loss=0.8588
Epoch 4 / train/metric=0.6114
Epoch 4 / val/metric=0.6066
Best val_metric 0.6065658971643674 at epoch 4!


  2%|▏         | 4/200 [17:30<14:13:16, 261.21s/it]

Epoch 5 / trn/loss=0.8581
Epoch 5 / train/metric=0.6110
Epoch 5 / val/metric=0.6059
Best val_metric 0.6059121823467746 at epoch 5!


  2%|▎         | 5/200 [21:51<14:08:27, 261.06s/it]

Epoch 6 / trn/loss=0.8564
Epoch 6 / train/metric=0.6096
Epoch 6 / val/metric=0.6018
Best val_metric 0.6017964658200537 at epoch 6!


  3%|▎         | 6/200 [26:12<14:04:08, 261.07s/it]

Epoch 7 / trn/loss=0.8555
Epoch 7 / train/metric=0.6094
Epoch 7 / val/metric=0.5984
Best val_metric 0.5983697269663023 at epoch 7!


  4%|▎         | 7/200 [30:33<13:59:45, 261.06s/it]

Epoch 8 / trn/loss=0.8553
Epoch 8 / train/metric=0.6081
Epoch 8 / val/metric=0.5950
Best val_metric 0.5950354793413667 at epoch 8!


  4%|▍         | 8/200 [34:53<13:54:54, 260.91s/it]

Epoch 9 / trn/loss=0.8486
Epoch 9 / train/metric=0.6021
Epoch 9 / val/metric=0.5889
Best val_metric 0.5889080376430723 at epoch 9!


  4%|▍         | 9/200 [39:14<13:50:41, 260.95s/it]

Epoch 10 / trn/loss=0.8455
Epoch 10 / train/metric=0.6024
Epoch 10 / val/metric=0.5817
Best val_metric 0.5816514646762462 at epoch 10!


  5%|▌         | 10/200 [43:36<13:46:41, 261.06s/it]

Epoch 11 / trn/loss=0.8386
Epoch 11 / train/metric=0.5953
Epoch 11 / val/metric=0.5746
Best val_metric 0.5745815270233611 at epoch 11!


  6%|▌         | 11/200 [47:57<13:42:21, 261.07s/it]

Epoch 12 / trn/loss=0.8341
Epoch 12 / train/metric=0.5949
Epoch 12 / val/metric=0.5726
Best val_metric 0.5726184964119039 at epoch 12!


  6%|▌         | 12/200 [52:21<13:41:04, 262.04s/it]

Epoch 13 / trn/loss=0.8191
Epoch 13 / train/metric=0.5778
Epoch 13 / val/metric=0.5602
Best val_metric 0.5602418556012838 at epoch 13!


  6%|▋         | 13/200 [56:42<13:36:08, 261.87s/it]

Epoch 14 / trn/loss=0.8187
Epoch 14 / train/metric=0.5754
Epoch 14 / val/metric=0.5566
Best val_metric 0.5565576972656776 at epoch 14!


  7%|▋         | 14/200 [1:01:04<13:31:37, 261.81s/it]

Epoch 15 / trn/loss=0.8113
Epoch 15 / train/metric=0.5729


  8%|▊         | 15/200 [1:05:25<13:26:47, 261.66s/it]

Epoch 15 / val/metric=0.5606
Epoch 16 / trn/loss=0.8089
Epoch 16 / train/metric=0.5670
Epoch 16 / val/metric=0.5477
Best val_metric 0.5477384647645802 at epoch 16!


  8%|▊         | 16/200 [1:09:46<13:21:55, 261.50s/it]

Epoch 17 / trn/loss=0.8015
Epoch 17 / train/metric=0.5597


  8%|▊         | 17/200 [1:14:08<13:17:19, 261.42s/it]

Epoch 17 / val/metric=0.5508
Epoch 18 / trn/loss=0.7980
Epoch 18 / train/metric=0.5576
Epoch 18 / val/metric=0.5416
Best val_metric 0.5415710988050267 at epoch 18!


  9%|▉         | 18/200 [1:18:29<13:13:01, 261.43s/it]

Epoch 19 / trn/loss=0.7923
Epoch 19 / train/metric=0.5538


 10%|▉         | 19/200 [1:22:51<13:08:37, 261.42s/it]

Epoch 19 / val/metric=0.5676
Epoch 20 / trn/loss=0.7964
Epoch 20 / train/metric=0.5595
Epoch 20 / val/metric=0.5318
Best val_metric 0.5317792556736137 at epoch 20!


 10%|█         | 20/200 [1:27:12<13:04:38, 261.55s/it]

Epoch 21 / trn/loss=0.7917
Epoch 21 / train/metric=0.5539
Epoch 21 / val/metric=0.5299
Best val_metric 0.5298975144411658 at epoch 21!


 10%|█         | 21/200 [1:31:34<13:00:29, 261.62s/it]

Epoch 22 / trn/loss=0.7933
Epoch 22 / train/metric=0.5539


 11%|█         | 22/200 [1:35:55<12:55:24, 261.37s/it]

Epoch 22 / val/metric=0.5338
Epoch 23 / trn/loss=0.7864
Epoch 23 / train/metric=0.5517
Epoch 23 / val/metric=0.5262
Best val_metric 0.5262136015208331 at epoch 23!


 12%|█▏        | 23/200 [1:40:17<12:51:14, 261.44s/it]

Epoch 24 / trn/loss=0.7786
Epoch 24 / train/metric=0.5436


 12%|█▏        | 24/200 [1:44:38<12:46:44, 261.39s/it]

Epoch 24 / val/metric=0.5275
Epoch 25 / trn/loss=0.7813
Epoch 25 / train/metric=0.5480
Epoch 25 / val/metric=0.5162
Best val_metric 0.5162136864101524 at epoch 25!


 12%|█▎        | 25/200 [1:49:00<12:43:21, 261.72s/it]

Epoch 26 / trn/loss=0.7766
Epoch 26 / train/metric=0.5421


 13%|█▎        | 26/200 [1:53:21<12:38:21, 261.50s/it]

Epoch 26 / val/metric=0.5439
Epoch 27 / trn/loss=0.7782
Epoch 27 / train/metric=0.5453


 14%|█▎        | 27/200 [1:57:43<12:33:43, 261.41s/it]

Epoch 27 / val/metric=0.5216
Epoch 28 / trn/loss=0.7715
Epoch 28 / train/metric=0.5394
Epoch 28 / val/metric=0.5120
Best val_metric 0.511976614459293 at epoch 28!


 14%|█▍        | 28/200 [2:02:05<12:30:05, 261.66s/it]

Epoch 29 / trn/loss=0.7796
Epoch 29 / train/metric=0.5484


 14%|█▍        | 29/200 [2:06:26<12:25:13, 261.48s/it]

Epoch 29 / val/metric=0.5168
Epoch 30 / trn/loss=0.7822
Epoch 30 / train/metric=0.5482


 15%|█▌        | 30/200 [2:10:47<12:20:12, 261.25s/it]

Epoch 30 / val/metric=0.5164
Epoch 31 / trn/loss=0.7762
Epoch 31 / train/metric=0.5438


 16%|█▌        | 31/200 [2:15:08<12:15:56, 261.28s/it]

Epoch 31 / val/metric=0.5165
Epoch 32 / trn/loss=0.7783
Epoch 32 / train/metric=0.5458


 16%|█▌        | 32/200 [2:19:29<12:11:44, 261.34s/it]

Epoch 32 / val/metric=0.5522
Epoch 33 / trn/loss=0.7675
Epoch 33 / train/metric=0.5381


 16%|█▋        | 33/200 [2:23:51<12:07:20, 261.32s/it]

Epoch 33 / val/metric=0.5350
Epoch 34 / trn/loss=0.7699
Epoch 34 / train/metric=0.5410


 17%|█▋        | 34/200 [2:28:11<12:02:31, 261.15s/it]

Epoch 34 / val/metric=0.5253
Epoch 35 / trn/loss=0.7742
Epoch 35 / train/metric=0.5429


 18%|█▊        | 35/200 [2:32:32<11:57:53, 261.05s/it]

Epoch 35 / val/metric=0.5135
Epoch 36 / trn/loss=0.7684
Epoch 36 / train/metric=0.5378


 18%|█▊        | 36/200 [2:36:53<11:53:23, 261.00s/it]

Epoch 36 / val/metric=0.5159
Epoch 37 / trn/loss=0.7631
Epoch 37 / train/metric=0.5314


 18%|█▊        | 37/200 [2:41:15<11:49:26, 261.14s/it]

Epoch 37 / val/metric=0.5337
Epoch 38 / trn/loss=0.7713
Epoch 38 / train/metric=0.5414


 19%|█▉        | 38/200 [2:45:36<11:45:38, 261.35s/it]

Epoch 38 / val/metric=0.5124
Epoch 39 / trn/loss=0.7707
Epoch 39 / train/metric=0.5457
Epoch 39 / val/metric=0.5080
Best val_metric 0.508041735391506 at epoch 39!


 20%|█▉        | 39/200 [2:49:59<11:41:49, 261.55s/it]

Epoch 40 / trn/loss=0.7636
Epoch 40 / train/metric=0.5344
Epoch 40 / val/metric=0.5043
Best val_metric 0.5042843741849313 at epoch 40!


 20%|██        | 40/200 [2:54:20<11:37:25, 261.53s/it]

Epoch 41 / trn/loss=0.7613
Epoch 41 / train/metric=0.5330


 20%|██        | 41/200 [2:58:41<11:32:55, 261.48s/it]

Epoch 41 / val/metric=0.5226
Epoch 42 / trn/loss=0.7625
Epoch 42 / train/metric=0.5327


 21%|██        | 42/200 [3:03:02<11:28:11, 261.34s/it]

Epoch 42 / val/metric=0.5123
Epoch 43 / trn/loss=0.7755
Epoch 43 / train/metric=0.5482


 22%|██▏       | 43/200 [3:07:23<11:23:13, 261.11s/it]

Epoch 43 / val/metric=0.5096
Epoch 44 / trn/loss=0.7669
Epoch 44 / train/metric=0.5426
Epoch 44 / val/metric=0.5009
Best val_metric 0.5009389172496869 at epoch 44!


 22%|██▏       | 44/200 [3:11:44<11:19:12, 261.23s/it]

Epoch 45 / trn/loss=0.7524
Epoch 45 / train/metric=0.5231


 22%|██▎       | 45/200 [3:16:05<11:14:24, 261.06s/it]

Epoch 45 / val/metric=0.5102
Epoch 46 / trn/loss=0.7570
Epoch 46 / train/metric=0.5306


 23%|██▎       | 46/200 [3:20:26<11:10:06, 261.08s/it]

Epoch 46 / val/metric=0.5317
Epoch 47 / trn/loss=0.7550
Epoch 47 / train/metric=0.5291


 24%|██▎       | 47/200 [3:24:47<11:05:49, 261.11s/it]

Epoch 47 / val/metric=0.5235
Epoch 48 / trn/loss=0.7612
Epoch 48 / train/metric=0.5332


 24%|██▍       | 48/200 [3:29:08<11:01:25, 261.09s/it]

Epoch 48 / val/metric=0.5088
Epoch 49 / trn/loss=0.7600
Epoch 49 / train/metric=0.5305
Epoch 49 / val/metric=0.5000
Best val_metric 0.5000163817834882 at epoch 49!


 24%|██▍       | 49/200 [3:33:30<10:57:23, 261.21s/it]

Epoch 50 / trn/loss=0.7593
Epoch 50 / train/metric=0.5322


 25%|██▌       | 50/200 [3:37:51<10:53:06, 261.24s/it]

Epoch 50 / val/metric=0.5186
Epoch 51 / trn/loss=0.7586
Epoch 51 / train/metric=0.5306


 26%|██▌       | 51/200 [3:42:12<10:48:33, 261.17s/it]

Epoch 51 / val/metric=0.5071
Epoch 52 / trn/loss=0.7594
Epoch 52 / train/metric=0.5302


 26%|██▌       | 52/200 [3:46:34<10:44:19, 261.21s/it]

Epoch 52 / val/metric=0.5630
Epoch 53 / trn/loss=0.7591
Epoch 53 / train/metric=0.5299


 26%|██▋       | 53/200 [3:50:55<10:39:47, 261.14s/it]

Epoch 53 / val/metric=0.5140
Epoch 54 / trn/loss=0.7585
Epoch 54 / train/metric=0.5302


 27%|██▋       | 54/200 [3:55:16<10:35:34, 261.20s/it]

Epoch 54 / val/metric=0.5260
Epoch 55 / trn/loss=0.7531
Epoch 55 / train/metric=0.5266


 28%|██▊       | 55/200 [3:59:37<10:31:15, 261.21s/it]

Epoch 55 / val/metric=0.5074
Epoch 56 / trn/loss=0.7596
Epoch 56 / train/metric=0.5322
Epoch 56 / val/metric=0.4996
Best val_metric 0.49957182450786525 at epoch 56!


 28%|██▊       | 56/200 [4:03:59<10:27:21, 261.40s/it]

Epoch 57 / trn/loss=0.7501
Epoch 57 / train/metric=0.5229


 28%|██▊       | 57/200 [4:08:20<10:22:28, 261.18s/it]

Epoch 57 / val/metric=0.5043
Epoch 58 / trn/loss=0.7449
Epoch 58 / train/metric=0.5184


 29%|██▉       | 58/200 [4:12:41<10:18:17, 261.25s/it]

Epoch 58 / val/metric=0.5017
Epoch 59 / trn/loss=0.7491
Epoch 59 / train/metric=0.5236


 30%|██▉       | 59/200 [4:17:02<10:13:57, 261.26s/it]

Epoch 59 / val/metric=0.5030
Epoch 60 / trn/loss=0.7451
Epoch 60 / train/metric=0.5199


 30%|███       | 60/200 [4:21:24<10:09:45, 261.33s/it]

Epoch 60 / val/metric=0.5067
Epoch 61 / trn/loss=0.7539
Epoch 61 / train/metric=0.5296


 30%|███       | 61/200 [4:25:45<10:05:18, 261.28s/it]

Epoch 61 / val/metric=0.5050
Epoch 62 / trn/loss=0.7440
Epoch 62 / train/metric=0.5195


 31%|███       | 62/200 [4:30:06<10:01:04, 261.34s/it]

Epoch 62 / val/metric=0.5033
Epoch 63 / trn/loss=0.7462
Epoch 63 / train/metric=0.5192
Epoch 63 / val/metric=0.4988
Best val_metric 0.49877882577245986 at epoch 63!


 32%|███▏      | 63/200 [4:34:28<9:56:56, 261.43s/it] 

Epoch 64 / trn/loss=0.7490
Epoch 64 / train/metric=0.5231


 32%|███▏      | 64/200 [4:38:49<9:52:25, 261.36s/it]

Epoch 64 / val/metric=0.5052
Epoch 65 / trn/loss=0.7438
Epoch 65 / train/metric=0.5174
Epoch 65 / val/metric=0.4920
Best val_metric 0.49203139553801223 at epoch 65!


 32%|███▎      | 65/200 [4:43:11<9:48:06, 261.38s/it]

Epoch 66 / trn/loss=0.7393
Epoch 66 / train/metric=0.5149


 33%|███▎      | 66/200 [4:47:32<9:43:37, 261.32s/it]

Epoch 66 / val/metric=0.4936
Epoch 67 / trn/loss=0.7419
Epoch 67 / train/metric=0.5185
Epoch 67 / val/metric=0.4882
Best val_metric 0.4881675303533289 at epoch 67!


 34%|███▎      | 67/200 [4:51:54<9:39:43, 261.53s/it]

Epoch 68 / trn/loss=0.7343
Epoch 68 / train/metric=0.5112


 34%|███▍      | 68/200 [4:56:15<9:34:57, 261.35s/it]

Epoch 68 / val/metric=0.4947
Epoch 69 / trn/loss=0.7363
Epoch 69 / train/metric=0.5113


 34%|███▍      | 69/200 [5:00:35<9:30:08, 261.13s/it]

Epoch 69 / val/metric=0.4969
Epoch 70 / trn/loss=0.7374
Epoch 70 / train/metric=0.5129
Epoch 70 / val/metric=0.4741
Best val_metric 0.47414746621552256 at epoch 70!


 35%|███▌      | 70/200 [5:04:57<9:26:04, 261.27s/it]

Epoch 71 / trn/loss=0.7284
Epoch 71 / train/metric=0.5079


 36%|███▌      | 71/200 [5:09:18<9:21:29, 261.16s/it]

Epoch 71 / val/metric=0.5350
Epoch 72 / trn/loss=0.7304
Epoch 72 / train/metric=0.5096


 36%|███▌      | 72/200 [5:13:39<9:17:12, 261.19s/it]

Epoch 72 / val/metric=0.5065
Epoch 73 / trn/loss=0.7316
Epoch 73 / train/metric=0.5109


 36%|███▋      | 73/200 [5:18:00<9:12:53, 261.21s/it]

Epoch 73 / val/metric=0.4934
Epoch 74 / trn/loss=0.7221
Epoch 74 / train/metric=0.4997


 37%|███▋      | 74/200 [5:22:22<9:08:26, 261.16s/it]

Epoch 74 / val/metric=0.4854
Epoch 75 / trn/loss=0.7260
Epoch 75 / train/metric=0.5063
Epoch 75 / val/metric=0.4718
Best val_metric 0.47180072856126937 at epoch 75!


 38%|███▊      | 75/200 [5:26:43<9:04:23, 261.30s/it]

Epoch 76 / trn/loss=0.7296
Epoch 76 / train/metric=0.5055


 38%|███▊      | 76/200 [5:31:05<9:00:04, 261.32s/it]

Epoch 76 / val/metric=0.5027
Epoch 77 / trn/loss=0.7244
Epoch 77 / train/metric=0.5057


 38%|███▊      | 77/200 [5:35:26<8:55:50, 261.38s/it]

Epoch 77 / val/metric=0.4906
Epoch 78 / trn/loss=0.7229
Epoch 78 / train/metric=0.5019


 39%|███▉      | 78/200 [5:39:48<8:51:32, 261.42s/it]

Epoch 78 / val/metric=0.5126
Epoch 79 / trn/loss=0.7173
Epoch 79 / train/metric=0.4988


 40%|███▉      | 79/200 [5:44:09<8:47:20, 261.49s/it]

Epoch 79 / val/metric=0.4972
Epoch 80 / trn/loss=0.7192
Epoch 80 / train/metric=0.5030


 40%|████      | 80/200 [5:48:30<8:42:38, 261.32s/it]

Epoch 80 / val/metric=0.4818
Epoch 81 / trn/loss=0.7160
Epoch 81 / train/metric=0.5051


 40%|████      | 81/200 [5:52:51<8:38:16, 261.32s/it]

Epoch 81 / val/metric=0.4877
Epoch 82 / trn/loss=0.7156
Epoch 82 / train/metric=0.4965


 41%|████      | 82/200 [5:57:12<8:33:40, 261.19s/it]

Epoch 82 / val/metric=0.4811
Epoch 83 / trn/loss=0.7192
Epoch 83 / train/metric=0.5024


 42%|████▏     | 83/200 [6:01:33<8:29:11, 261.13s/it]

Epoch 83 / val/metric=0.4823
Epoch 84 / trn/loss=0.7311
Epoch 84 / train/metric=0.5093


 42%|████▏     | 84/200 [6:05:55<8:25:01, 261.22s/it]

Epoch 84 / val/metric=0.4768
Epoch 85 / trn/loss=0.7083
Epoch 85 / train/metric=0.4919


 42%|████▎     | 85/200 [6:10:16<8:20:50, 261.31s/it]

Epoch 85 / val/metric=0.5319
Epoch 86 / trn/loss=0.7154
Epoch 86 / train/metric=0.4970


 43%|████▎     | 86/200 [6:14:38<8:16:27, 261.30s/it]

Epoch 86 / val/metric=0.4807
Epoch 87 / trn/loss=0.7134
Epoch 87 / train/metric=0.4976


 44%|████▎     | 87/200 [6:18:59<8:12:20, 261.42s/it]

Epoch 87 / val/metric=0.4770
Epoch 88 / trn/loss=0.7167
Epoch 88 / train/metric=0.4968


 44%|████▍     | 88/200 [6:23:21<8:08:09, 261.51s/it]

Epoch 88 / val/metric=0.4809
Epoch 89 / trn/loss=0.7054
Epoch 89 / train/metric=0.4901
Epoch 89 / val/metric=0.4648
Best val_metric 0.46484606172825443 at epoch 89!


 44%|████▍     | 89/200 [6:27:42<8:03:45, 261.49s/it]

Epoch 90 / trn/loss=0.7094
Epoch 90 / train/metric=0.4934


 45%|████▌     | 90/200 [6:32:03<7:59:06, 261.34s/it]

Epoch 90 / val/metric=0.4936
Epoch 91 / trn/loss=0.7016
Epoch 91 / train/metric=0.4847


 46%|████▌     | 91/200 [6:36:24<7:54:27, 261.17s/it]

Epoch 91 / val/metric=0.4881
Epoch 92 / trn/loss=0.7103
Epoch 92 / train/metric=0.4938


 46%|████▌     | 92/200 [6:40:45<7:49:59, 261.11s/it]

Epoch 92 / val/metric=0.4856
Epoch 93 / trn/loss=0.7033
Epoch 93 / train/metric=0.4892


 46%|████▋     | 93/200 [6:45:06<7:45:31, 261.05s/it]

Epoch 93 / val/metric=0.4792
Epoch 94 / trn/loss=0.7152
Epoch 94 / train/metric=0.4984


 47%|████▋     | 94/200 [6:49:27<7:41:04, 260.99s/it]

Epoch 94 / val/metric=0.4699
Epoch 95 / trn/loss=0.7041
Epoch 95 / train/metric=0.4914


 48%|████▊     | 95/200 [6:53:48<7:36:44, 261.00s/it]

Epoch 95 / val/metric=0.4718
Epoch 96 / trn/loss=0.7026
Epoch 96 / train/metric=0.4884


 48%|████▊     | 96/200 [6:58:09<7:32:25, 261.02s/it]

Epoch 96 / val/metric=0.4784
Epoch 97 / trn/loss=0.6918
Epoch 97 / train/metric=0.4779


 48%|████▊     | 97/200 [7:02:30<7:27:52, 260.90s/it]

Epoch 97 / val/metric=0.4727
Epoch 98 / trn/loss=0.6981
Epoch 98 / train/metric=0.4839


 49%|████▉     | 98/200 [7:06:50<7:23:31, 260.89s/it]

Epoch 98 / val/metric=0.4708
Epoch 99 / trn/loss=0.6958
Epoch 99 / train/metric=0.4810
Epoch 99 / val/metric=0.4568
Best val_metric 0.45684252793353625 at epoch 99!


 50%|████▉     | 99/200 [7:11:12<7:19:40, 261.19s/it]

Epoch 100 / trn/loss=0.6984
Epoch 100 / train/metric=0.4849


 50%|█████     | 100/200 [7:15:33<7:15:06, 261.07s/it]

Epoch 100 / val/metric=0.4775
Epoch 101 / trn/loss=0.6995
Epoch 101 / train/metric=0.4852


 50%|█████     | 101/200 [7:19:54<7:10:36, 260.98s/it]

Epoch 101 / val/metric=0.4692
Epoch 102 / trn/loss=0.7029
Epoch 102 / train/metric=0.4896


 51%|█████     | 102/200 [7:24:15<7:06:28, 261.11s/it]

Epoch 102 / val/metric=0.4708
Epoch 103 / trn/loss=0.6936
Epoch 103 / train/metric=0.4799


 52%|█████▏    | 103/200 [7:28:36<7:02:03, 261.07s/it]

Epoch 103 / val/metric=0.4698
Epoch 104 / trn/loss=0.6950
Epoch 104 / train/metric=0.4806


 52%|█████▏    | 104/200 [7:32:57<6:57:46, 261.11s/it]

Epoch 104 / val/metric=0.4665
Epoch 105 / trn/loss=0.6887
Epoch 105 / train/metric=0.4728


 52%|█████▎    | 105/200 [7:37:19<6:53:36, 261.22s/it]

Epoch 105 / val/metric=0.4754
Epoch 106 / trn/loss=0.6909
Epoch 106 / train/metric=0.4787


 53%|█████▎    | 106/200 [7:41:40<6:49:21, 261.29s/it]

Epoch 106 / val/metric=0.4728
Epoch 107 / trn/loss=0.6919
Epoch 107 / train/metric=0.4773


 54%|█████▎    | 107/200 [7:46:02<6:45:00, 261.29s/it]

Epoch 107 / val/metric=0.4690
Epoch 108 / trn/loss=0.6837
Epoch 108 / train/metric=0.4691


 54%|█████▍    | 108/200 [7:50:23<6:40:37, 261.27s/it]

Epoch 108 / val/metric=0.4781
Epoch 109 / trn/loss=0.6836
Epoch 109 / train/metric=0.4706


 55%|█████▍    | 109/200 [7:54:44<6:36:12, 261.23s/it]

Epoch 109 / val/metric=0.4776
Epoch 110 / trn/loss=0.6816
Epoch 110 / train/metric=0.4709


 55%|█████▌    | 110/200 [7:59:05<6:31:53, 261.27s/it]

Epoch 110 / val/metric=0.4987
Epoch 111 / trn/loss=0.6795
Epoch 111 / train/metric=0.4682


 56%|█████▌    | 111/200 [8:03:27<6:27:33, 261.28s/it]

Epoch 111 / val/metric=0.4625
Epoch 112 / trn/loss=0.6764
Epoch 112 / train/metric=0.4670


 56%|█████▌    | 112/200 [8:07:48<6:23:13, 261.29s/it]

Epoch 112 / val/metric=0.4595
Epoch 113 / trn/loss=0.6688
Epoch 113 / train/metric=0.4602


 56%|█████▋    | 113/200 [8:12:10<6:18:57, 261.35s/it]

Epoch 113 / val/metric=0.4631
Epoch 114 / trn/loss=0.6768
Epoch 114 / train/metric=0.4659


 57%|█████▋    | 114/200 [8:16:32<6:14:52, 261.54s/it]

Epoch 114 / val/metric=0.4686
Epoch 115 / trn/loss=0.6777
Epoch 115 / train/metric=0.4660


 57%|█████▊    | 115/200 [8:20:53<6:10:21, 261.43s/it]

Epoch 115 / val/metric=0.4669
Epoch 116 / trn/loss=0.6853
Epoch 116 / train/metric=0.4765


 58%|█████▊    | 116/200 [8:25:14<6:05:52, 261.34s/it]

Epoch 116 / val/metric=0.4701
Epoch 117 / trn/loss=0.6746
Epoch 117 / train/metric=0.4651


 58%|█████▊    | 117/200 [8:29:36<6:01:41, 261.47s/it]

Epoch 117 / val/metric=0.4571
Epoch 118 / trn/loss=0.6768
Epoch 118 / train/metric=0.4649


 59%|█████▉    | 118/200 [8:33:57<5:57:22, 261.49s/it]

Epoch 118 / val/metric=0.4581
Epoch 119 / trn/loss=0.6760
Epoch 119 / train/metric=0.4661


 60%|█████▉    | 119/200 [8:38:19<5:53:06, 261.56s/it]

Epoch 119 / val/metric=0.4652
Epoch 120 / trn/loss=0.6762
Epoch 120 / train/metric=0.4670


 60%|██████    | 120/200 [8:42:40<5:48:35, 261.44s/it]

Epoch 120 / val/metric=0.4726
Epoch 121 / trn/loss=0.6746
Epoch 121 / train/metric=0.4649


 60%|██████    | 121/200 [8:47:01<5:44:11, 261.41s/it]

Epoch 121 / val/metric=0.4665
Epoch 122 / trn/loss=0.6640
Epoch 122 / train/metric=0.4591


 61%|██████    | 122/200 [8:51:23<5:39:56, 261.50s/it]

Epoch 122 / val/metric=0.4614
Epoch 123 / trn/loss=0.6668
Epoch 123 / train/metric=0.4591


 62%|██████▏   | 123/200 [8:55:44<5:35:28, 261.41s/it]

Epoch 123 / val/metric=0.5303
Epoch 124 / trn/loss=0.6698
Epoch 124 / train/metric=0.4586


 62%|██████▏   | 124/200 [9:00:05<5:30:58, 261.29s/it]

Epoch 124 / val/metric=0.4625
Epoch 125 / trn/loss=0.6634
Epoch 125 / train/metric=0.4551
Epoch 125 / val/metric=0.4555
Best val_metric 0.45545979802556125 at epoch 125!


 62%|██████▎   | 125/200 [9:04:27<5:26:55, 261.54s/it]

Epoch 126 / trn/loss=0.6711
Epoch 126 / train/metric=0.4604


 63%|██████▎   | 126/200 [9:08:48<5:22:19, 261.34s/it]

Epoch 126 / val/metric=0.4567
Epoch 127 / trn/loss=0.6557
Epoch 127 / train/metric=0.4487


 64%|██████▎   | 127/200 [9:13:09<5:17:49, 261.22s/it]

Epoch 127 / val/metric=0.4689
Epoch 128 / trn/loss=0.6543
Epoch 128 / train/metric=0.4459
Epoch 128 / val/metric=0.4537
Best val_metric 0.4536799986375812 at epoch 128!


 64%|██████▍   | 128/200 [9:17:31<5:13:39, 261.39s/it]

Epoch 129 / trn/loss=0.6541
Epoch 129 / train/metric=0.4498


 64%|██████▍   | 129/200 [9:21:52<5:09:15, 261.35s/it]

Epoch 129 / val/metric=0.4654
Epoch 130 / trn/loss=0.6573
Epoch 130 / train/metric=0.4527


 65%|██████▌   | 130/200 [9:26:14<5:04:56, 261.38s/it]

Epoch 130 / val/metric=0.4612
Epoch 131 / trn/loss=0.6612
Epoch 131 / train/metric=0.4524


 66%|██████▌   | 131/200 [9:30:35<5:00:34, 261.37s/it]

Epoch 131 / val/metric=0.4606
Epoch 132 / trn/loss=0.6516
Epoch 132 / train/metric=0.4469


 66%|██████▌   | 132/200 [9:34:56<4:56:10, 261.33s/it]

Epoch 132 / val/metric=0.4718
Epoch 133 / trn/loss=0.6459
Epoch 133 / train/metric=0.4436


 66%|██████▋   | 133/200 [9:39:17<4:51:44, 261.27s/it]

Epoch 133 / val/metric=0.4658
Epoch 134 / trn/loss=0.6580
Epoch 134 / train/metric=0.4531


 67%|██████▋   | 134/200 [9:43:39<4:47:19, 261.21s/it]

Epoch 134 / val/metric=0.4539
Epoch 135 / trn/loss=0.6489
Epoch 135 / train/metric=0.4426
Epoch 135 / val/metric=0.4477
Best val_metric 0.44772652692332365 at epoch 135!


 68%|██████▊   | 135/200 [9:48:00<4:43:03, 261.29s/it]

Epoch 136 / trn/loss=0.6428
Epoch 136 / train/metric=0.4386


 68%|██████▊   | 136/200 [9:52:21<4:38:29, 261.08s/it]

Epoch 136 / val/metric=0.4561
Epoch 137 / trn/loss=0.6434
Epoch 137 / train/metric=0.4405


 68%|██████▊   | 137/200 [9:56:42<4:34:10, 261.12s/it]

Epoch 137 / val/metric=0.4629
Epoch 138 / trn/loss=0.6483
Epoch 138 / train/metric=0.4424


 69%|██████▉   | 138/200 [10:01:02<4:29:39, 260.97s/it]

Epoch 138 / val/metric=0.4558
Epoch 139 / trn/loss=0.6358
Epoch 139 / train/metric=0.4329


 70%|██████▉   | 139/200 [10:05:24<4:25:26, 261.09s/it]

Epoch 139 / val/metric=0.4526
Epoch 140 / trn/loss=0.6437
Epoch 140 / train/metric=0.4385


 70%|███████   | 140/200 [10:09:45<4:21:02, 261.03s/it]

Epoch 140 / val/metric=0.4484
Epoch 141 / trn/loss=0.6307
Epoch 141 / train/metric=0.4283
Epoch 141 / val/metric=0.4467
Best val_metric 0.44666794313768615 at epoch 141!


 70%|███████   | 141/200 [10:14:06<4:16:53, 261.25s/it]

Epoch 142 / trn/loss=0.6348
Epoch 142 / train/metric=0.4340


 71%|███████   | 142/200 [10:18:28<4:12:32, 261.25s/it]

Epoch 142 / val/metric=0.4628
Epoch 143 / trn/loss=0.6324
Epoch 143 / train/metric=0.4283


 72%|███████▏  | 143/200 [10:22:49<4:08:08, 261.21s/it]

Epoch 143 / val/metric=0.4561
Epoch 144 / trn/loss=0.6344
Epoch 144 / train/metric=0.4308


 72%|███████▏  | 144/200 [10:27:10<4:03:39, 261.07s/it]

Epoch 144 / val/metric=0.4569
Epoch 145 / trn/loss=0.6330
Epoch 145 / train/metric=0.4278


 72%|███████▎  | 145/200 [10:31:31<3:59:24, 261.18s/it]

Epoch 145 / val/metric=0.4553
Epoch 146 / trn/loss=0.6271
Epoch 146 / train/metric=0.4255


 73%|███████▎  | 146/200 [10:35:52<3:55:07, 261.25s/it]

Epoch 146 / val/metric=0.4715
Epoch 147 / trn/loss=0.6233
Epoch 147 / train/metric=0.4188
Epoch 147 / val/metric=0.4452
Best val_metric 0.44522028818898757 at epoch 147!


 74%|███████▎  | 147/200 [10:40:14<3:50:50, 261.34s/it]

Epoch 148 / trn/loss=0.6280
Epoch 148 / train/metric=0.4246


 74%|███████▍  | 148/200 [10:44:35<3:46:21, 261.18s/it]

Epoch 148 / val/metric=0.4479
Epoch 149 / trn/loss=0.6224
Epoch 149 / train/metric=0.4197


 74%|███████▍  | 149/200 [10:48:56<3:42:01, 261.20s/it]

Epoch 149 / val/metric=0.4703
Epoch 150 / trn/loss=0.6302
Epoch 150 / train/metric=0.4307


 75%|███████▌  | 150/200 [10:53:17<3:37:39, 261.19s/it]

Epoch 150 / val/metric=0.4735
Epoch 151 / trn/loss=0.6195
Epoch 151 / train/metric=0.4187


 76%|███████▌  | 151/200 [10:57:38<3:33:17, 261.16s/it]

Epoch 151 / val/metric=0.4557
Epoch 152 / trn/loss=0.6215
Epoch 152 / train/metric=0.4209


 76%|███████▌  | 152/200 [11:02:00<3:29:00, 261.26s/it]

Epoch 152 / val/metric=0.4574
Epoch 153 / trn/loss=0.6218
Epoch 153 / train/metric=0.4214
Epoch 153 / val/metric=0.4408
Best val_metric 0.4407957692542428 at epoch 153!


 76%|███████▋  | 153/200 [11:06:21<3:24:43, 261.36s/it]

Epoch 154 / trn/loss=0.6152
Epoch 154 / train/metric=0.4161


 77%|███████▋  | 154/200 [11:10:42<3:20:15, 261.21s/it]

Epoch 154 / val/metric=0.4544
Epoch 155 / trn/loss=0.6131
Epoch 155 / train/metric=0.4133


 78%|███████▊  | 155/200 [11:15:03<3:15:50, 261.13s/it]

Epoch 155 / val/metric=0.4545
Epoch 156 / trn/loss=0.6115
Epoch 156 / train/metric=0.4139


 78%|███████▊  | 156/200 [11:19:25<3:11:33, 261.22s/it]

Epoch 156 / val/metric=0.4619
Epoch 157 / trn/loss=0.6208
Epoch 157 / train/metric=0.4194


 78%|███████▊  | 157/200 [11:23:46<3:07:13, 261.24s/it]

Epoch 157 / val/metric=0.4424
Epoch 158 / trn/loss=0.6076
Epoch 158 / train/metric=0.4094


 79%|███████▉  | 158/200 [11:28:07<3:02:52, 261.26s/it]

Epoch 158 / val/metric=0.4633
Epoch 159 / trn/loss=0.6117
Epoch 159 / train/metric=0.4111


 80%|███████▉  | 159/200 [11:32:29<2:58:36, 261.38s/it]

Epoch 159 / val/metric=0.4485
Epoch 160 / trn/loss=0.6115
Epoch 160 / train/metric=0.4113


 80%|████████  | 160/200 [11:36:50<2:54:10, 261.27s/it]

Epoch 160 / val/metric=0.4515
Epoch 161 / trn/loss=0.6118
Epoch 161 / train/metric=0.4153


 80%|████████  | 161/200 [11:41:11<2:49:44, 261.14s/it]

Epoch 161 / val/metric=0.4589
Epoch 162 / trn/loss=0.6123
Epoch 162 / train/metric=0.4125


 81%|████████  | 162/200 [11:45:32<2:45:27, 261.24s/it]

Epoch 162 / val/metric=0.4558
Epoch 163 / trn/loss=0.6048
Epoch 163 / train/metric=0.4064


 82%|████████▏ | 163/200 [11:49:54<2:41:07, 261.29s/it]

Epoch 163 / val/metric=0.4622
Epoch 164 / trn/loss=0.6028
Epoch 164 / train/metric=0.4037


 82%|████████▏ | 164/200 [11:54:15<2:36:46, 261.28s/it]

Epoch 164 / val/metric=0.4441
Epoch 165 / trn/loss=0.6060
Epoch 165 / train/metric=0.4066


 82%|████████▎ | 165/200 [11:58:36<2:32:23, 261.25s/it]

Epoch 165 / val/metric=0.4551
Epoch 166 / trn/loss=0.6011
Epoch 166 / train/metric=0.4042


 83%|████████▎ | 166/200 [12:02:57<2:28:03, 261.29s/it]

Epoch 166 / val/metric=0.4495
Epoch 167 / trn/loss=0.6080
Epoch 167 / train/metric=0.4111


 84%|████████▎ | 167/200 [12:07:19<2:23:42, 261.28s/it]

Epoch 167 / val/metric=0.4533
Epoch 168 / trn/loss=0.6017
Epoch 168 / train/metric=0.4039


 84%|████████▍ | 168/200 [12:11:40<2:19:19, 261.25s/it]

Epoch 168 / val/metric=0.4559
Epoch 169 / trn/loss=0.5994
Epoch 169 / train/metric=0.4035


 84%|████████▍ | 169/200 [12:16:01<2:14:57, 261.21s/it]

Epoch 169 / val/metric=0.4581
Epoch 170 / trn/loss=0.5945
Epoch 170 / train/metric=0.3982


 85%|████████▌ | 170/200 [12:20:22<2:10:35, 261.19s/it]

Epoch 170 / val/metric=0.4503
Epoch 171 / trn/loss=0.6001
Epoch 171 / train/metric=0.4029


 86%|████████▌ | 171/200 [12:24:43<2:06:13, 261.14s/it]

Epoch 171 / val/metric=0.4433
Epoch 172 / trn/loss=0.5926
Epoch 172 / train/metric=0.3991


 86%|████████▌ | 172/200 [12:29:04<2:01:49, 261.07s/it]

Epoch 172 / val/metric=0.4586
Epoch 173 / trn/loss=0.5942
Epoch 173 / train/metric=0.4000


 86%|████████▋ | 173/200 [12:33:25<1:57:32, 261.19s/it]

Epoch 173 / val/metric=0.4591
Epoch 174 / trn/loss=0.5989
Epoch 174 / train/metric=0.4036


 87%|████████▋ | 174/200 [12:37:46<1:53:07, 261.06s/it]

Epoch 174 / val/metric=0.4517
Epoch 175 / trn/loss=0.5923
Epoch 175 / train/metric=0.3952


 88%|████████▊ | 175/200 [12:42:07<1:48:44, 260.99s/it]

Epoch 175 / val/metric=0.4530
Epoch 176 / trn/loss=0.5974
Epoch 176 / train/metric=0.3997


 88%|████████▊ | 176/200 [12:46:29<1:44:28, 261.17s/it]

Epoch 176 / val/metric=0.4487
Epoch 177 / trn/loss=0.5855
Epoch 177 / train/metric=0.3893


 88%|████████▊ | 177/200 [12:50:50<1:40:08, 261.24s/it]

Epoch 177 / val/metric=0.4538
Epoch 178 / trn/loss=0.5921
Epoch 178 / train/metric=0.3975


 89%|████████▉ | 178/200 [12:55:11<1:35:47, 261.24s/it]

Epoch 178 / val/metric=0.4612
Epoch 179 / trn/loss=0.5953
Epoch 179 / train/metric=0.3978


 90%|████████▉ | 179/200 [12:59:33<1:31:28, 261.35s/it]

Epoch 179 / val/metric=0.4594
Epoch 180 / trn/loss=0.5917
Epoch 180 / train/metric=0.3958


 90%|█████████ | 180/200 [13:03:54<1:27:06, 261.34s/it]

Epoch 180 / val/metric=0.4505
Epoch 181 / trn/loss=0.5917
Epoch 181 / train/metric=0.3964


 90%|█████████ | 181/200 [13:08:15<1:22:44, 261.29s/it]

Epoch 181 / val/metric=0.4675
Epoch 182 / trn/loss=0.5861
Epoch 182 / train/metric=0.3925


 91%|█████████ | 182/200 [13:12:36<1:18:21, 261.20s/it]

Epoch 182 / val/metric=0.4647
Epoch 183 / trn/loss=0.5933
Epoch 183 / train/metric=0.3979


 92%|█████████▏| 183/200 [13:16:57<1:13:59, 261.16s/it]

Epoch 183 / val/metric=0.4582
Epoch 184 / trn/loss=0.5850
Epoch 184 / train/metric=0.3899


 92%|█████████▏| 184/200 [13:21:19<1:09:39, 261.21s/it]

Epoch 184 / val/metric=0.4666
Epoch 185 / trn/loss=0.5889
Epoch 185 / train/metric=0.3944


 92%|█████████▎| 185/200 [13:25:40<1:05:18, 261.25s/it]

Epoch 185 / val/metric=0.4704
Epoch 186 / trn/loss=0.5882
Epoch 186 / train/metric=0.3938


 93%|█████████▎| 186/200 [13:30:01<1:00:57, 261.22s/it]

Epoch 186 / val/metric=0.4682
Epoch 187 / trn/loss=0.5843
Epoch 187 / train/metric=0.3902


 94%|█████████▎| 187/200 [13:34:22<56:35, 261.21s/it]  

Epoch 187 / val/metric=0.4561
Epoch 188 / trn/loss=0.5828
Epoch 188 / train/metric=0.3891


 94%|█████████▍| 188/200 [13:38:44<52:14, 261.20s/it]

Epoch 188 / val/metric=0.4530
Epoch 189 / trn/loss=0.5839
Epoch 189 / train/metric=0.3899


 94%|█████████▍| 189/200 [13:43:05<47:52, 261.17s/it]

Epoch 189 / val/metric=0.4559
Epoch 190 / trn/loss=0.5951
Epoch 190 / train/metric=0.3999


 95%|█████████▌| 190/200 [13:47:26<43:31, 261.16s/it]

Epoch 190 / val/metric=0.4529
Epoch 191 / trn/loss=0.5885
Epoch 191 / train/metric=0.3943


 96%|█████████▌| 191/200 [13:51:47<39:11, 261.29s/it]

Epoch 191 / val/metric=0.4637
Epoch 192 / trn/loss=0.5850
Epoch 192 / train/metric=0.3911


 96%|█████████▌| 192/200 [13:56:09<34:50, 261.34s/it]

Epoch 192 / val/metric=0.4591
Epoch 193 / trn/loss=0.5878
Epoch 193 / train/metric=0.3935


 96%|█████████▋| 193/200 [14:00:30<30:28, 261.26s/it]

Epoch 193 / val/metric=0.4547
Epoch 194 / trn/loss=0.5885
Epoch 194 / train/metric=0.3941


 97%|█████████▋| 194/200 [14:04:51<26:07, 261.32s/it]

Epoch 194 / val/metric=0.4547
Epoch 195 / trn/loss=0.5899
Epoch 195 / train/metric=0.3947


 98%|█████████▊| 195/200 [14:09:12<21:45, 261.06s/it]

Epoch 195 / val/metric=0.4535
Epoch 196 / trn/loss=0.5825
Epoch 196 / train/metric=0.3886


 98%|█████████▊| 196/200 [14:13:33<17:24, 261.10s/it]

Epoch 196 / val/metric=0.4542
Epoch 197 / trn/loss=0.5800
Epoch 197 / train/metric=0.3862


 98%|█████████▊| 197/200 [14:17:54<13:03, 261.08s/it]

Epoch 197 / val/metric=0.4575
Epoch 198 / trn/loss=0.5896
Epoch 198 / train/metric=0.3949


 99%|█████████▉| 198/200 [14:22:15<08:42, 261.08s/it]

Epoch 198 / val/metric=0.4489
Epoch 199 / trn/loss=0.5876
Epoch 199 / train/metric=0.3927


100%|█████████▉| 199/200 [14:26:36<04:21, 261.13s/it]

Epoch 199 / val/metric=0.4564
Epoch 200 / trn/loss=0.5815
Epoch 200 / train/metric=0.3876


                                                     

Epoch 200 / val/metric=0.4562






0,1
best_total_log_loss,▁
lr,▁▁▂▃▃▄▅▆▆▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train_any_in_metric,▅██▆▅▄▅▄▄▅▄▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▁▂▂▂▁▂▂▁
train_any_loss,▄▁▄▆▆▃▃▂▅▃▆▇▄▇▃▄▄▃▄▄▃▄▆▃▂▃▄▁▅▂▆▃█▃▃▁▅▃▂▄
train_avg_loss,▆▂█▅▆▆▄▃▆▃▅▆▄▆▃▆▃▃▅▃▄▄▇▃▄▄▆▁▄▂▆▃▇▃▃▁▅▄▃▅
train_avg_metric,█▆▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train_bowel_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁
train_bowel_metric,█▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_extrav_loss,█▁█▅▅▄▂▂▄▂▁▅▅▂▂▆▁▂▅▂▆▂▂▂▂█▃▁▃▁▂▂▁▂▄▁▂▇▃▂
train_extrav_metric,▇▆▆▇█▇▆▆▆▆▆▅▆▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
best_total_log_loss,0.4408
lr,0.0
train_any_in_metric,0.50977
train_any_loss,1.62094
train_avg_loss,0.7384
train_avg_metric,0.38755
train_bowel_loss,0.95638
train_bowel_metric,0.12835
train_extrav_loss,0.39302
train_extrav_metric,0.47196


In [16]:
#Execute this cell to fininsh the wandb run when you stopped training.
import wandb
try: 
    wandb.log({'best_total_log_loss': best_metric})
    wandb.finish()
    
except: 
    print('Wandb is already finished!')

Wandb is already finished!


In [17]:
ind = 23561
train_meta_df[train_meta_df['patient_id']==ind]

Unnamed: 0,patient_id,series,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury,fold,path,mask_path,cropped_path
3540,23561,19317,1,0,1,0,1,0,0,1,0,0,1,0,0,0,4,/home/junseonglee/Desktop/01_codes/inputs/rsna...,/home/junseonglee/Desktop/01_codes/inputs/rsna...,/home/junseonglee/Desktop/01_codes/inputs/rsna...


In [18]:
valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['patient_id']==ind], is_train = False, transform_set = None,
                                    remain_transforms_set = None)

valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                        num_workers = N_WORKERS, drop_last = False)     

100%|██████████| 1/1 [00:00<00:00,  4.38it/s]


In [19]:
X_outs=[]
ys=[]
model.eval()
model.load_state_dict(torch.load(f'{BASE_PATH}/weights/231003_resnet10t_dicomO_std_CV0.485.pt'))
for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
    batch_size = y.shape[0]
    X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
    X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
    y = y.to(DEVICE)           
    with torch.cuda.amp.autocast(enabled=True):                
        with torch.no_grad():                 
            X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
            X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

            X_any = X_any.to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

X_outs = np.vstack(X_outs) 
ys     = np.vstack(ys)
#metric = calculate_score(X_outs, ys, 'valid')                      

del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_any
gc.collect()        
torch.cuda.empty_cache()   

FileNotFoundError: [Errno 2] No such file or directory: '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection/weights/231003_resnet10t_dicomO_std_CV0.485.pt'

In [None]:
np.average(X_outs, axis = 0)


In [None]:
len(X_outs)