In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')
#sys.path.append('')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


True

# Parameters

In [2]:
backbone = 'timm/resnet10t.c3_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'preprocessing_test'
RUN_NAME=   f'{backbone}_2nd_std_opt-seg0.938'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 8
ACCUM_STEPS = 3
N_WORKERS  = 8
LR = 0.005
N_EPOCHS = 200
EARLY_STOP_COUNT = 30
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
PCT_START = 0.3
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0



DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION,
    'PCT_START': PCT_START
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [4]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress_fast(name):
    with open(name, 'rb') as f:
        data = pickle.load(f)
    return data

# Model

In [5]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [6]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    
    
class Timm3DModelClassifier(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModelClassifier, self).__init__()
        self.model_3d = Timm3DModel(backbone, n_channels, n_labels, segtype, pretrained)
        self.model_3d = convert_3d(self.model_3d)
        self.n_channels = n_channels
        self.n_labels = n_labels                        
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model_3d(x)
        pooled_features = []
        for i in range(0, len(x)):
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, self.n_labels, 1)))
        pooled_features = torch.cat(pooled_features, dim=2)     
        labels = torch.mean(pooled_features, dim = 2)
        return labels

In [7]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        
        self.model3d_bowel        = Timm3DModelClassifier(backbone, 1, 2)      
        self.model3d_extrav       = Timm3DModelClassifier(backbone, 1, 2)
        self.model3d_kidney_left  = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_kidney_right = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_liver        = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_spleen       = Timm3DModelClassifier(backbone, 1, 3)
        
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
    def forward(self, x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total):
        bowel_label        = self.model3d_bowel(x_bowel)
        extrav_label       = self.model3d_extrav(x_total)
        kidney_label_left  = self.model3d_kidney_left(x_kidney_left)
        kidney_label_right = self.model3d_kidney_right(x_kidney_right)
        kidney_label       = (kidney_label_left + kidney_label_right)/2
        liver_label        = self.model3d_liver(x_liver)
        spleen_label       = self.model3d_spleen(x_spleen)
        
        
        labels = torch.cat([bowel_label, extrav_label, kidney_label, liver_label, spleen_label], dim = 1)
        
        bowel_soft = self.softmax(bowel_label)
        extrav_soft = self.softmax(extrav_label)
        kidney_soft = self.softmax(kidney_label)
        liver_soft = self.softmax(liver_label)
        spleen_soft = self.softmax(spleen_label)

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        any_in = torch.log(any_in + 1e-6)  # 1e-6은 0을 처리하기 위한 작은 값
        return labels, any_in
    

In [8]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

86478624


0

# Metric & Loss

In [9]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [10]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    isnan_x = np.isnan(X_outs).astype(int)
    isnan_y = np.isnan(ys).astype(int)
    
    if(np.max(isnan_x)>0):
        print('xnan')
    if(np.max(isnan_y)>0):
        print('ynan')
        
    X_outs[:, 13:15] = nn.Softmax(dim=1)(torch.from_numpy(X_outs[:, 13:15])).numpy()
    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights.astype(np.float64))
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights.astype(np.float64))
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights.astype(np.float64))
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights.astype(np.float64))
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights.astype(np.float64))
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights.astype(np.float64))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [11]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=0),    
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),
    #transforms.RandAffined(keys=chan_keys, translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [12]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['cropped_path']            
            for j in range(0, 6):
                tmp_data_3d = decompress_fast(f'{base_name}_{j}').unsqueeze(0)
                #tmp_data_3d = torch.from_numpy(tmp_data_3d)
                tmp_data_3ds[chan_dict[j]] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 6):
                    data_3d[chan_dict[i]] = self.remain_transforms_set(data_3d[chan_dict[i]])
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['bowel'], data_3d['left_kidney'], data_3d['right_kidney'], \
                data_3d['liver'], data_3d['spleen'], data_3d['total'], label        


In [13]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [14]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in train_loader:
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        
        batch_size = X_bowel.shape[0]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
        counter+=1                   

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader, epoch):
    X_outs=[]
    ys=[]
    model.eval()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
        batch_size = y.shape[0]
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [15]:
model = AbdominalClassifier()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False, transform_set = None,
                                        remain_transforms_set = None)
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, pct_start= PCT_START,
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader, epoch)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}.pt')    
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 3782/3782 [01:27<00:00, 43.16it/s]
100%|██████████| 929/929 [00:31<00:00, 29.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 / trn/loss=0.8748
Epoch 1 / train/metric=0.6359
Epoch 1 / val/metric=0.5753
Best val_metric 0.5752747294629229 at epoch 1!


  0%|          | 1/200 [04:31<15:00:21, 271.47s/it]

Epoch 2 / trn/loss=0.8038
Epoch 2 / train/metric=0.5717
Epoch 2 / val/metric=0.5573
Best val_metric 0.5572913849431376 at epoch 2!


  1%|          | 2/200 [08:47<14:26:50, 262.68s/it]

Epoch 3 / trn/loss=0.8008
Epoch 3 / train/metric=0.5666


  2%|▏         | 3/200 [13:04<14:13:56, 260.08s/it]

Epoch 3 / val/metric=0.5581
Epoch 4 / trn/loss=0.7943
Epoch 4 / train/metric=0.5629
Epoch 4 / val/metric=0.5498
Best val_metric 0.5498073912789178 at epoch 4!


  2%|▏         | 4/200 [17:22<14:06:43, 259.20s/it]

Epoch 5 / trn/loss=0.7932
Epoch 5 / train/metric=0.5619
Epoch 5 / val/metric=0.5446
Best val_metric 0.5445510173353698 at epoch 5!


  2%|▎         | 5/200 [21:41<14:01:28, 258.92s/it]

Epoch 6 / trn/loss=0.7922
Epoch 6 / train/metric=0.5620


  3%|▎         | 6/200 [25:59<13:56:05, 258.59s/it]

Epoch 6 / val/metric=0.5510
Epoch 7 / trn/loss=0.7968
Epoch 7 / train/metric=0.5645


  4%|▎         | 7/200 [30:17<13:51:58, 258.64s/it]

Epoch 7 / val/metric=0.5453
Epoch 8 / trn/loss=0.7933
Epoch 8 / train/metric=0.5613


  4%|▍         | 8/200 [34:36<13:47:50, 258.70s/it]

Epoch 8 / val/metric=0.5515
Epoch 9 / trn/loss=0.7904
Epoch 9 / train/metric=0.5593
Epoch 9 / val/metric=0.5395
Best val_metric 0.539505760875706 at epoch 9!


  4%|▍         | 9/200 [38:55<13:43:49, 258.79s/it]

Epoch 10 / trn/loss=0.7947
Epoch 10 / train/metric=0.5595


  5%|▌         | 10/200 [43:14<13:38:58, 258.62s/it]

Epoch 10 / val/metric=0.5549
Epoch 11 / trn/loss=0.7990
Epoch 11 / train/metric=0.5664


  6%|▌         | 11/200 [47:32<13:34:58, 258.72s/it]

Epoch 11 / val/metric=0.5434
Epoch 12 / trn/loss=0.7910
Epoch 12 / train/metric=0.5554
Epoch 12 / val/metric=0.5327
Best val_metric 0.5326668023090585 at epoch 12!


  6%|▌         | 12/200 [51:52<13:31:22, 258.95s/it]

Epoch 13 / trn/loss=0.7845
Epoch 13 / train/metric=0.5517


  6%|▋         | 13/200 [56:11<13:26:55, 258.90s/it]

Epoch 13 / val/metric=0.5507
Epoch 14 / trn/loss=0.7916
Epoch 14 / train/metric=0.5573


  7%|▋         | 14/200 [1:00:30<13:23:04, 259.05s/it]

Epoch 14 / val/metric=0.5475
Epoch 15 / trn/loss=0.7865
Epoch 15 / train/metric=0.5530


  8%|▊         | 15/200 [1:04:50<13:19:08, 259.18s/it]

Epoch 15 / val/metric=0.5479
Epoch 16 / trn/loss=0.7946
Epoch 16 / train/metric=0.5584


  8%|▊         | 16/200 [1:09:08<13:14:30, 259.08s/it]

Epoch 16 / val/metric=0.5583
Epoch 17 / trn/loss=0.7957
Epoch 17 / train/metric=0.5586


  8%|▊         | 17/200 [1:13:27<13:09:38, 258.90s/it]

Epoch 17 / val/metric=0.5510
Epoch 18 / trn/loss=0.7889
Epoch 18 / train/metric=0.5549


  9%|▉         | 18/200 [1:17:46<13:05:06, 258.82s/it]

Epoch 18 / val/metric=0.5676
Epoch 19 / trn/loss=0.7909
Epoch 19 / train/metric=0.5561


 10%|▉         | 19/200 [1:22:05<13:01:43, 259.13s/it]

Epoch 19 / val/metric=0.5455
Epoch 20 / trn/loss=0.7845
Epoch 20 / train/metric=0.5521
Epoch 20 / val/metric=0.5248
Best val_metric 0.5248208714189742 at epoch 20!


 10%|█         | 20/200 [1:26:25<12:57:51, 259.28s/it]

Epoch 21 / trn/loss=0.7851
Epoch 21 / train/metric=0.5515


 10%|█         | 21/200 [1:30:44<12:53:25, 259.25s/it]

Epoch 21 / val/metric=0.5385
Epoch 22 / trn/loss=0.7916
Epoch 22 / train/metric=0.5547


 11%|█         | 22/200 [1:35:03<12:49:00, 259.22s/it]

Epoch 22 / val/metric=0.5448
Epoch 23 / trn/loss=0.7831
Epoch 23 / train/metric=0.5496


 12%|█▏        | 23/200 [1:39:23<12:44:56, 259.30s/it]

Epoch 23 / val/metric=0.5374
Epoch 24 / trn/loss=0.7857
Epoch 24 / train/metric=0.5517


 12%|█▏        | 24/200 [1:43:41<12:39:59, 259.09s/it]

Epoch 24 / val/metric=0.5366
Epoch 25 / trn/loss=0.7921
Epoch 25 / train/metric=0.5600


 12%|█▎        | 25/200 [1:48:01<12:35:46, 259.12s/it]

Epoch 25 / val/metric=0.5363
Epoch 26 / trn/loss=0.7944
Epoch 26 / train/metric=0.5591
Epoch 26 / val/metric=0.5117
Best val_metric 0.5117374426160188 at epoch 26!


 13%|█▎        | 26/200 [1:52:20<12:31:49, 259.25s/it]

Epoch 27 / trn/loss=0.7918
Epoch 27 / train/metric=0.5560


 14%|█▎        | 27/200 [1:56:39<12:27:09, 259.13s/it]

Epoch 27 / val/metric=0.5297
Epoch 28 / trn/loss=0.7870
Epoch 28 / train/metric=0.5547


 14%|█▍        | 28/200 [2:00:58<12:22:39, 259.07s/it]

Epoch 28 / val/metric=0.5145
Epoch 29 / trn/loss=0.7880
Epoch 29 / train/metric=0.5550


 14%|█▍        | 29/200 [2:05:17<12:17:56, 258.93s/it]

Epoch 29 / val/metric=0.5329
Epoch 30 / trn/loss=0.7799
Epoch 30 / train/metric=0.5511


 15%|█▌        | 30/200 [2:09:35<12:13:35, 258.92s/it]

Epoch 30 / val/metric=0.5339
Epoch 31 / trn/loss=0.7852
Epoch 31 / train/metric=0.5530


 16%|█▌        | 31/200 [2:13:55<12:09:25, 258.97s/it]

Epoch 31 / val/metric=0.5362
Epoch 32 / trn/loss=0.7865
Epoch 32 / train/metric=0.5540


 16%|█▌        | 32/200 [2:18:14<12:05:11, 259.00s/it]

Epoch 32 / val/metric=0.5334
Epoch 33 / trn/loss=0.7813
Epoch 33 / train/metric=0.5487


 16%|█▋        | 33/200 [2:22:33<12:01:11, 259.11s/it]

Epoch 33 / val/metric=0.5307
Epoch 34 / trn/loss=0.7819
Epoch 34 / train/metric=0.5510


 17%|█▋        | 34/200 [2:26:52<11:56:36, 259.01s/it]

Epoch 34 / val/metric=0.5337
Epoch 35 / trn/loss=0.7883
Epoch 35 / train/metric=0.5557


 18%|█▊        | 35/200 [2:31:11<11:52:14, 259.00s/it]

Epoch 35 / val/metric=0.5174
Epoch 36 / trn/loss=0.7772
Epoch 36 / train/metric=0.5460


 18%|█▊        | 36/200 [2:35:30<11:47:47, 258.95s/it]

Epoch 36 / val/metric=0.5385
Epoch 37 / trn/loss=0.7697
Epoch 37 / train/metric=0.5398


 18%|█▊        | 37/200 [2:39:48<11:43:23, 258.92s/it]

Epoch 37 / val/metric=0.5188
Epoch 38 / trn/loss=0.7776
Epoch 38 / train/metric=0.5474


 19%|█▉        | 38/200 [2:44:07<11:38:50, 258.83s/it]

Epoch 38 / val/metric=0.5497
Epoch 39 / trn/loss=0.7761
Epoch 39 / train/metric=0.5455


 20%|█▉        | 39/200 [2:48:26<11:34:38, 258.88s/it]

Epoch 39 / val/metric=0.5332
Epoch 40 / trn/loss=0.7774
Epoch 40 / train/metric=0.5489


 20%|██        | 40/200 [2:52:45<11:30:02, 258.77s/it]

Epoch 40 / val/metric=0.5408
Epoch 41 / trn/loss=0.7728
Epoch 41 / train/metric=0.5442


 20%|██        | 41/200 [2:57:03<11:25:47, 258.79s/it]

Epoch 41 / val/metric=0.5188
Epoch 42 / trn/loss=0.7734
Epoch 42 / train/metric=0.5425


 21%|██        | 42/200 [3:01:22<11:21:23, 258.76s/it]

Epoch 42 / val/metric=0.5555
Epoch 43 / trn/loss=0.7796
Epoch 43 / train/metric=0.5474


 22%|██▏       | 43/200 [3:05:41<11:17:03, 258.75s/it]

Epoch 43 / val/metric=0.5200
Epoch 44 / trn/loss=0.7629
Epoch 44 / train/metric=0.5330


 22%|██▏       | 44/200 [3:09:59<11:12:36, 258.69s/it]

Epoch 44 / val/metric=0.5340
Epoch 45 / trn/loss=0.7691
Epoch 45 / train/metric=0.5419


 22%|██▎       | 45/200 [3:14:18<11:08:10, 258.65s/it]

Epoch 45 / val/metric=0.5125
Epoch 46 / trn/loss=0.7673
Epoch 46 / train/metric=0.5397
Epoch 46 / val/metric=0.5105
Best val_metric 0.5104905504212512 at epoch 46!


 23%|██▎       | 46/200 [3:18:37<11:04:04, 258.73s/it]

Epoch 47 / trn/loss=0.7730
Epoch 47 / train/metric=0.5462


 24%|██▎       | 47/200 [3:22:55<10:59:10, 258.50s/it]

Epoch 47 / val/metric=0.5179
Epoch 48 / trn/loss=0.7713
Epoch 48 / train/metric=0.5417


 24%|██▍       | 48/200 [3:27:13<10:54:43, 258.44s/it]

Epoch 48 / val/metric=0.5116
Epoch 49 / trn/loss=0.7663
Epoch 49 / train/metric=0.5401


 24%|██▍       | 49/200 [3:31:31<10:50:10, 258.35s/it]

Epoch 49 / val/metric=0.5261
Epoch 50 / trn/loss=0.7682
Epoch 50 / train/metric=0.5394


 25%|██▌       | 50/200 [3:35:50<10:46:09, 258.47s/it]

Epoch 50 / val/metric=0.5409
Epoch 51 / trn/loss=0.7726
Epoch 51 / train/metric=0.5436


 26%|██▌       | 51/200 [3:40:09<10:41:58, 258.51s/it]

Epoch 51 / val/metric=0.5275
Epoch 52 / trn/loss=0.7705
Epoch 52 / train/metric=0.5426
Epoch 52 / val/metric=0.5099
Best val_metric 0.5098904788878159 at epoch 52!


 26%|██▌       | 52/200 [3:44:27<10:37:48, 258.57s/it]

Epoch 53 / trn/loss=0.7650
Epoch 53 / train/metric=0.5378


 26%|██▋       | 53/200 [3:48:45<10:33:10, 258.44s/it]

Epoch 53 / val/metric=0.5232
Epoch 54 / trn/loss=0.7674
Epoch 54 / train/metric=0.5398


 27%|██▋       | 54/200 [3:53:04<10:28:53, 258.45s/it]

Epoch 54 / val/metric=0.5219
Epoch 55 / trn/loss=0.7680
Epoch 55 / train/metric=0.5389


 28%|██▊       | 55/200 [3:57:22<10:24:24, 258.38s/it]

Epoch 55 / val/metric=0.5702
Epoch 56 / trn/loss=0.7649
Epoch 56 / train/metric=0.5386


 28%|██▊       | 56/200 [4:01:40<10:19:57, 258.32s/it]

Epoch 56 / val/metric=0.5257
Epoch 57 / trn/loss=0.7592
Epoch 57 / train/metric=0.5343


 28%|██▊       | 57/200 [4:05:59<10:15:45, 258.36s/it]

Epoch 57 / val/metric=0.5403
Epoch 58 / trn/loss=0.7710
Epoch 58 / train/metric=0.5427


 29%|██▉       | 58/200 [4:10:16<10:10:59, 258.16s/it]

Epoch 58 / val/metric=0.5193
Epoch 59 / trn/loss=0.7594
Epoch 59 / train/metric=0.5333


 30%|██▉       | 59/200 [4:14:35<10:07:07, 258.35s/it]

Epoch 59 / val/metric=0.5118
Epoch 60 / trn/loss=0.7612
Epoch 60 / train/metric=0.5347


 30%|███       | 60/200 [4:18:53<10:02:32, 258.24s/it]

Epoch 60 / val/metric=0.5251
Epoch 61 / trn/loss=0.7668
Epoch 61 / train/metric=0.5405


 30%|███       | 61/200 [4:23:11<9:58:12, 258.22s/it] 

Epoch 61 / val/metric=0.6048
Epoch 62 / trn/loss=0.7513
Epoch 62 / train/metric=0.5292


 31%|███       | 62/200 [4:27:30<9:53:50, 258.19s/it]

Epoch 62 / val/metric=0.5256
Epoch 63 / trn/loss=0.7747
Epoch 63 / train/metric=0.5435


 32%|███▏      | 63/200 [4:31:47<9:49:21, 258.11s/it]

Epoch 63 / val/metric=0.5239
Epoch 64 / trn/loss=0.7636
Epoch 64 / train/metric=0.5358


 32%|███▏      | 64/200 [4:36:06<9:45:08, 258.15s/it]

Epoch 64 / val/metric=0.5502
Epoch 65 / trn/loss=0.7591
Epoch 65 / train/metric=0.5334


 32%|███▎      | 65/200 [4:40:24<9:40:40, 258.08s/it]

Epoch 65 / val/metric=0.5157
Epoch 66 / trn/loss=0.7609
Epoch 66 / train/metric=0.5348


 33%|███▎      | 66/200 [4:44:42<9:36:28, 258.13s/it]

Epoch 66 / val/metric=0.5232
Epoch 67 / trn/loss=0.7609
Epoch 67 / train/metric=0.5355


 34%|███▎      | 67/200 [4:49:00<9:32:11, 258.13s/it]

Epoch 67 / val/metric=0.5211
Epoch 68 / trn/loss=0.7626
Epoch 68 / train/metric=0.5384


 34%|███▍      | 68/200 [4:53:18<9:27:58, 258.17s/it]

Epoch 68 / val/metric=0.5190
Epoch 69 / trn/loss=0.7578
Epoch 69 / train/metric=0.5338


 34%|███▍      | 69/200 [4:57:36<9:23:08, 257.93s/it]

Epoch 69 / val/metric=0.5374
Epoch 70 / trn/loss=0.7548
Epoch 70 / train/metric=0.5304


 35%|███▌      | 70/200 [5:01:54<9:19:23, 258.18s/it]

Epoch 70 / val/metric=0.5247
Epoch 71 / trn/loss=0.7547
Epoch 71 / train/metric=0.5293


 36%|███▌      | 71/200 [5:06:12<9:14:43, 258.01s/it]

Epoch 71 / val/metric=0.5459
Epoch 72 / trn/loss=0.7590
Epoch 72 / train/metric=0.5331


 36%|███▌      | 72/200 [5:10:30<9:10:14, 257.92s/it]

Epoch 72 / val/metric=0.5557
Epoch 73 / trn/loss=0.7531
Epoch 73 / train/metric=0.5298


 36%|███▋      | 73/200 [5:14:47<9:05:47, 257.86s/it]

Epoch 73 / val/metric=0.5144
Epoch 74 / trn/loss=0.7418
Epoch 74 / train/metric=0.5185


 37%|███▋      | 74/200 [5:19:06<9:01:46, 257.98s/it]

Epoch 74 / val/metric=0.5598
Epoch 75 / trn/loss=0.7559
Epoch 75 / train/metric=0.5304


 38%|███▊      | 75/200 [5:23:23<8:57:15, 257.88s/it]

Epoch 75 / val/metric=0.5221
Epoch 76 / trn/loss=0.7505
Epoch 76 / train/metric=0.5258


 38%|███▊      | 76/200 [5:27:41<8:53:03, 257.93s/it]

Epoch 76 / val/metric=0.5281
Epoch 77 / trn/loss=0.7503
Epoch 77 / train/metric=0.5281


 38%|███▊      | 77/200 [5:32:00<8:49:02, 258.07s/it]

Epoch 77 / val/metric=0.5140
Epoch 78 / trn/loss=0.7551
Epoch 78 / train/metric=0.5324


 39%|███▉      | 78/200 [5:36:18<8:44:41, 258.04s/it]

Epoch 78 / val/metric=0.5100
Epoch 79 / trn/loss=0.7467
Epoch 79 / train/metric=0.5229
Epoch 79 / val/metric=0.5041
Best val_metric 0.504135394512725 at epoch 79!


 40%|███▉      | 79/200 [5:40:36<8:40:28, 258.09s/it]

Epoch 80 / trn/loss=0.7453
Epoch 80 / train/metric=0.5224


 40%|████      | 80/200 [5:44:54<8:35:58, 257.98s/it]

Epoch 80 / val/metric=0.5597
Epoch 81 / trn/loss=0.7491
Epoch 81 / train/metric=0.5268


 40%|████      | 81/200 [5:49:12<8:31:34, 257.94s/it]

Epoch 81 / val/metric=0.5073
Epoch 82 / trn/loss=0.7432
Epoch 82 / train/metric=0.5217
Epoch 82 / val/metric=0.4993
Best val_metric 0.4992911459527776 at epoch 82!


 41%|████      | 82/200 [5:53:30<8:27:26, 258.03s/it]

Epoch 83 / trn/loss=0.7492
Epoch 83 / train/metric=0.5260


 42%|████▏     | 83/200 [5:57:48<8:23:00, 257.95s/it]

Epoch 83 / val/metric=0.5001
Epoch 84 / trn/loss=0.7477
Epoch 84 / train/metric=0.5236


 42%|████▏     | 84/200 [6:02:05<8:18:35, 257.90s/it]

Epoch 84 / val/metric=0.5066
Epoch 85 / trn/loss=0.7448
Epoch 85 / train/metric=0.5218


 42%|████▎     | 85/200 [6:06:23<8:14:13, 257.86s/it]

Epoch 85 / val/metric=0.5116
Epoch 86 / trn/loss=0.7447
Epoch 86 / train/metric=0.5226


 43%|████▎     | 86/200 [6:10:41<8:09:59, 257.89s/it]

Epoch 86 / val/metric=0.5096
Epoch 87 / trn/loss=0.7455
Epoch 87 / train/metric=0.5238


 44%|████▎     | 87/200 [6:14:59<8:05:36, 257.85s/it]

Epoch 87 / val/metric=0.5139
Epoch 88 / trn/loss=0.7433
Epoch 88 / train/metric=0.5213


 44%|████▍     | 88/200 [6:19:17<8:01:20, 257.86s/it]

Epoch 88 / val/metric=0.5083
Epoch 89 / trn/loss=0.7364
Epoch 89 / train/metric=0.5179


 44%|████▍     | 89/200 [6:23:35<7:57:05, 257.89s/it]

Epoch 89 / val/metric=0.5206
Epoch 90 / trn/loss=0.7424
Epoch 90 / train/metric=0.5198


 45%|████▌     | 90/200 [6:27:53<7:52:55, 257.96s/it]

Epoch 90 / val/metric=0.5545
Epoch 91 / trn/loss=0.7331
Epoch 91 / train/metric=0.5155


 46%|████▌     | 91/200 [6:32:11<7:48:37, 257.96s/it]

Epoch 91 / val/metric=0.5322
Epoch 92 / trn/loss=0.7426
Epoch 92 / train/metric=0.5236


 46%|████▌     | 92/200 [6:36:29<7:44:17, 257.94s/it]

Epoch 92 / val/metric=0.5089
Epoch 93 / trn/loss=0.7448
Epoch 93 / train/metric=0.5217


 46%|████▋     | 93/200 [6:40:47<7:40:00, 257.95s/it]

Epoch 93 / val/metric=0.5275
Epoch 94 / trn/loss=0.7437
Epoch 94 / train/metric=0.5203


 47%|████▋     | 94/200 [6:45:04<7:35:33, 257.86s/it]

Epoch 94 / val/metric=0.5082
Epoch 95 / trn/loss=0.7366
Epoch 95 / train/metric=0.5157


 48%|████▊     | 95/200 [6:49:22<7:31:22, 257.93s/it]

Epoch 95 / val/metric=0.5105
Epoch 96 / trn/loss=0.7377
Epoch 96 / train/metric=0.5186


 48%|████▊     | 96/200 [6:53:40<7:27:04, 257.93s/it]

Epoch 96 / val/metric=0.5023
Epoch 97 / trn/loss=0.7371
Epoch 97 / train/metric=0.5179


 48%|████▊     | 97/200 [6:57:58<7:22:53, 258.00s/it]

Epoch 97 / val/metric=0.5081
Epoch 98 / trn/loss=0.7340
Epoch 98 / train/metric=0.5154
Epoch 98 / val/metric=0.4985
Best val_metric 0.49847422133937336 at epoch 98!


 49%|████▉     | 98/200 [7:02:17<7:18:49, 258.13s/it]

Epoch 99 / trn/loss=0.7285
Epoch 99 / train/metric=0.5091


 50%|████▉     | 99/200 [7:06:34<7:14:10, 257.93s/it]

Epoch 99 / val/metric=0.5036
Epoch 100 / trn/loss=0.7364
Epoch 100 / train/metric=0.5168


 50%|█████     | 100/200 [7:10:53<7:10:04, 258.04s/it]

Epoch 100 / val/metric=0.5145
Epoch 101 / trn/loss=0.7356
Epoch 101 / train/metric=0.5153


 50%|█████     | 101/200 [7:15:11<7:05:43, 258.01s/it]

Epoch 101 / val/metric=0.5075
Epoch 102 / trn/loss=0.7328
Epoch 102 / train/metric=0.5140


 51%|█████     | 102/200 [7:19:28<7:01:07, 257.83s/it]

Epoch 102 / val/metric=0.4988
Epoch 103 / trn/loss=0.7404
Epoch 103 / train/metric=0.5198


 52%|█████▏    | 103/200 [7:23:46<6:56:40, 257.74s/it]

Epoch 103 / val/metric=0.5144
Epoch 104 / trn/loss=0.7302
Epoch 104 / train/metric=0.5109


 52%|█████▏    | 104/200 [7:28:04<6:52:42, 257.94s/it]

Epoch 104 / val/metric=0.5045
Epoch 105 / trn/loss=0.7334
Epoch 105 / train/metric=0.5132


 52%|█████▎    | 105/200 [7:32:22<6:48:21, 257.92s/it]

Epoch 105 / val/metric=0.5348
Epoch 106 / trn/loss=0.7367
Epoch 106 / train/metric=0.5156


 53%|█████▎    | 106/200 [7:36:40<6:44:03, 257.91s/it]

Epoch 106 / val/metric=0.5288
Epoch 107 / trn/loss=0.7298
Epoch 107 / train/metric=0.5092
Epoch 107 / val/metric=0.4854
Best val_metric 0.4854485944553019 at epoch 107!


 54%|█████▎    | 107/200 [7:40:58<6:40:06, 258.13s/it]

Epoch 108 / trn/loss=0.7248
Epoch 108 / train/metric=0.5082


 54%|█████▍    | 108/200 [7:45:16<6:35:32, 257.96s/it]

Epoch 108 / val/metric=0.5146
Epoch 109 / trn/loss=0.7274
Epoch 109 / train/metric=0.5105


 55%|█████▍    | 109/200 [7:49:34<6:31:16, 257.98s/it]

Epoch 109 / val/metric=0.4966
Epoch 110 / trn/loss=0.7246
Epoch 110 / train/metric=0.5066


 55%|█████▌    | 110/200 [7:53:51<6:26:41, 257.79s/it]

Epoch 110 / val/metric=0.5210
Epoch 111 / trn/loss=0.7236
Epoch 111 / train/metric=0.5066


 56%|█████▌    | 111/200 [7:58:09<6:22:29, 257.85s/it]

Epoch 111 / val/metric=0.5032
Epoch 112 / trn/loss=0.7221
Epoch 112 / train/metric=0.5042


 56%|█████▌    | 112/200 [8:02:27<6:18:07, 257.82s/it]

Epoch 112 / val/metric=0.4986
Epoch 113 / trn/loss=0.7262
Epoch 113 / train/metric=0.5076


 56%|█████▋    | 113/200 [8:06:45<6:14:01, 257.95s/it]

Epoch 113 / val/metric=0.4915
Epoch 114 / trn/loss=0.7230
Epoch 114 / train/metric=0.5061


 57%|█████▋    | 114/200 [8:11:03<6:09:44, 257.95s/it]

Epoch 114 / val/metric=0.5000
Epoch 115 / trn/loss=0.7192
Epoch 115 / train/metric=0.5029


 57%|█████▊    | 115/200 [8:15:21<6:05:22, 257.92s/it]

Epoch 115 / val/metric=0.5002
Epoch 116 / trn/loss=0.7198
Epoch 116 / train/metric=0.5047


 58%|█████▊    | 116/200 [8:19:39<6:01:01, 257.88s/it]

Epoch 116 / val/metric=0.4921
Epoch 117 / trn/loss=0.7246
Epoch 117 / train/metric=0.5065


 58%|█████▊    | 117/200 [8:23:57<5:56:46, 257.91s/it]

Epoch 117 / val/metric=0.5014
Epoch 118 / trn/loss=0.7126
Epoch 118 / train/metric=0.4981


 59%|█████▉    | 118/200 [8:28:14<5:52:23, 257.84s/it]

Epoch 118 / val/metric=0.4857
Epoch 119 / trn/loss=0.7091
Epoch 119 / train/metric=0.4963


 60%|█████▉    | 119/200 [8:32:32<5:48:06, 257.86s/it]

Epoch 119 / val/metric=0.4917
Epoch 120 / trn/loss=0.7157
Epoch 120 / train/metric=0.4996


 60%|██████    | 120/200 [8:36:50<5:43:50, 257.88s/it]

Epoch 120 / val/metric=0.4989
Epoch 121 / trn/loss=0.7179
Epoch 121 / train/metric=0.5038


 60%|██████    | 121/200 [8:41:08<5:39:23, 257.77s/it]

Epoch 121 / val/metric=0.5201
Epoch 122 / trn/loss=0.7114
Epoch 122 / train/metric=0.4974


 61%|██████    | 122/200 [8:45:25<5:34:59, 257.69s/it]

Epoch 122 / val/metric=0.4923
Epoch 123 / trn/loss=0.7106
Epoch 123 / train/metric=0.4957


 62%|██████▏   | 123/200 [8:49:43<5:30:43, 257.70s/it]

Epoch 123 / val/metric=0.5101
Epoch 124 / trn/loss=0.7121
Epoch 124 / train/metric=0.4975


 62%|██████▏   | 124/200 [8:54:01<5:26:25, 257.70s/it]

Epoch 124 / val/metric=0.5222
Epoch 125 / trn/loss=0.7083
Epoch 125 / train/metric=0.4957


 62%|██████▎   | 125/200 [8:58:19<5:22:14, 257.79s/it]

Epoch 125 / val/metric=0.5093
Epoch 126 / trn/loss=0.7088
Epoch 126 / train/metric=0.4969


 63%|██████▎   | 126/200 [9:02:37<5:17:57, 257.80s/it]

Epoch 126 / val/metric=0.4951
Epoch 127 / trn/loss=0.7085
Epoch 127 / train/metric=0.4955


 64%|██████▎   | 127/200 [9:06:54<5:13:37, 257.77s/it]

Epoch 127 / val/metric=0.5108
Epoch 128 / trn/loss=0.7068
Epoch 128 / train/metric=0.4943


 64%|██████▍   | 128/200 [9:11:12<5:09:26, 257.87s/it]

Epoch 128 / val/metric=0.4908
Epoch 129 / trn/loss=0.7027
Epoch 129 / train/metric=0.4923


 64%|██████▍   | 129/200 [9:15:31<5:05:16, 257.98s/it]

Epoch 129 / val/metric=0.4987
Epoch 130 / trn/loss=0.7041
Epoch 130 / train/metric=0.4930


 65%|██████▌   | 130/200 [9:19:48<5:00:55, 257.93s/it]

Epoch 130 / val/metric=0.5065
Epoch 131 / trn/loss=0.7068
Epoch 131 / train/metric=0.4956


 66%|██████▌   | 131/200 [9:24:06<4:56:37, 257.93s/it]

Epoch 131 / val/metric=0.5047
Epoch 132 / trn/loss=0.7018
Epoch 132 / train/metric=0.4913


 66%|██████▌   | 132/200 [9:28:24<4:52:19, 257.93s/it]

Epoch 132 / val/metric=0.5012
Epoch 133 / trn/loss=0.7010
Epoch 133 / train/metric=0.4889


 66%|██████▋   | 133/200 [9:32:42<4:47:58, 257.88s/it]

Epoch 133 / val/metric=0.5000
Epoch 134 / trn/loss=0.7044
Epoch 134 / train/metric=0.4933


 67%|██████▋   | 134/200 [9:37:00<4:43:42, 257.91s/it]

Epoch 134 / val/metric=0.4956
Epoch 135 / trn/loss=0.7040
Epoch 135 / train/metric=0.4942


 68%|██████▊   | 135/200 [9:41:18<4:39:20, 257.85s/it]

Epoch 135 / val/metric=0.4922
Epoch 136 / trn/loss=0.7070
Epoch 136 / train/metric=0.4955


 68%|██████▊   | 136/200 [9:45:35<4:34:52, 257.70s/it]

Epoch 136 / val/metric=0.4926
Epoch 137 / trn/loss=0.7016
Epoch 137 / train/metric=0.4904


                                                      

Epoch 137 / val/metric=0.5213
Not improved for 30 epochs, terminate the train






0,1
best_total_log_loss,▁
lr,▁▁▁▂▂▂▃▄▄▅▆▆▇▇▇██████████▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄
train_any_in_metric,▄▅▅▇▆▇█▆▅▆▆▅▆▅▅▄▄▄▃▄▄▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▁▁
train_any_loss,▆▂▃▃▆▆▂▄▄▇▇▅▄▂▃▅▃▄█▁▁▆▃▄▃▆▁▅▅█▂▆▂▃▃▆▅▃▇▄
train_avg_loss,▅▄▄▂▆▅▂▄▃▇▅█▃▂▃▄▄▃▆▂▁▄▄▃▅▃▂▅▅█▃▅▂▂▃▅▄▂▅▄
train_avg_metric,█▇▇▇▇▇▇▇▇▆▆▆▅▅▅▅▅▅▄▅▄▃▄▄▄▄▃▄▃▃▃▂▂▂▂▂▂▁▁▁
train_bowel_loss,▁▁▅▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▅█▁▁▁▁▇▁▄▁▁▁
train_bowel_metric,▇▆▅▆▆▆▆▆▆█▇▆▆▇▇▇▇▇▆▆▆▆▅▅▆▅▅▄▄▄▄▄▂▃▂▂▃▂▁▂
train_extrav_loss,▅█▃▁▂▅▃▅▃▇▂▃▃▃▃▃▂▅▅▃▂▂▅▃▃▂▂▂▂▅▂▂▅▅▂▂▂▃▄▅
train_extrav_metric,▁▁▂▂▄▃▂██▄▅▄▄▅▅▆▅▆▅▅▃▄▅▄▄▄▃▅▄▃▄▃▂▃▃▂▃▃▃▃

0,1
best_total_log_loss,0.48545
lr,0.00211
train_any_in_metric,0.59096
train_any_loss,2.1881
train_avg_loss,0.81173
train_avg_metric,0.49035
train_bowel_loss,0.02251
train_bowel_metric,0.14529
train_extrav_loss,0.54561
train_extrav_metric,0.61512


In [16]:
#Execute this cell to fininsh the wandb run when you stopped training.
import wandb
try: 
    wandb.log({'best_total_log_loss': best_metric})
    wandb.finish()
    
except:
    print('Wandb is already finished!')

Wandb is already finished!


In [17]:
ind = 23561
train_meta_df[train_meta_df['patient_id']==ind]

Unnamed: 0,patient_id,series,bowel_healthy,bowel_injury,extravasation_healthy,extravasation_injury,kidney_healthy,kidney_low,kidney_high,liver_healthy,liver_low,liver_high,spleen_healthy,spleen_low,spleen_high,any_injury,fold,path,mask_path,cropped_path
3540,23561,19317,1,0,1,0,1,0,0,1,0,0,1,0,0,0,4,/home/junseonglee/Desktop/01_codes/inputs/rsna...,/home/junseonglee/Desktop/01_codes/inputs/rsna...,/home/junseonglee/Desktop/01_codes/inputs/rsna...


In [24]:
valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['patient_id']==ind], is_train = False, transform_set = None,
                                    remain_transforms_set = None)

valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                        num_workers = N_WORKERS, drop_last = False)     

100%|██████████| 1/1 [00:00<00:00, 50.20it/s]


In [25]:
X_outs=[]
ys=[]
model.eval()
model.load_state_dict(torch.load(f'{BASE_PATH}/weights/231003_resnet10t_dicomO_std_CV0.485.pt'))
for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
    batch_size = y.shape[0]
    X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
    X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
    y = y.to(DEVICE)           
    with torch.cuda.amp.autocast(enabled=True):                
        with torch.no_grad():                 
            X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
            X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

            X_any = X_any.to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

X_outs = np.vstack(X_outs) 
ys     = np.vstack(ys)
#metric = calculate_score(X_outs, ys, 'valid')                      

del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_any
gc.collect()        
torch.cuda.empty_cache()   

In [26]:
np.average(X_outs, axis = 0)


array([ 0.8515625 ,  0.14868164,  0.6298828 ,  0.3701172 ,  0.9379883 ,
        0.03500366,  0.02693176,  0.94873047,  0.04693604,  0.00417709,
        0.9091797 ,  0.07794189,  0.01283264, -0.46212062, -0.9941018 ],
      dtype=float32)

In [27]:
len(X_outs)

1