In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')
#sys.path.append('')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


True

# Parameters

In [2]:
backbone = 'timm/resnet10t.c3_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'momdel_test'
RUN_NAME=   f'{backbone}_bowel'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 8
ACCUM_STEPS = 3
N_WORKERS  = 8
LR = 0.0002
N_EPOCHS = 200
EARLY_STOP_COUNT = N_EPOCHS
N_FOLDS  = 5
SELECT_FOLD = 1
N_PREPROCESS_CHUNKS = 12
PCT_START = 0.3
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0



DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION,
    'PCT_START': PCT_START
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [4]:
#train_injured_df = train_meta_df.loc[(train_meta_df['fold']!=SELECT_FOLD)& (train_meta_df['any_injury']==1)].copy()
#train_meta_df = pd.concat([train_meta_df, train_injured_df])
#len(train_meta_df)

In [5]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress_fast(name):
    with open(name, 'rb') as f:
        data = pickle.load(f)
    return data

# Model

In [6]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [7]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    
    
class Timm3DModelClassifierEmbed(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModelClassifierEmbed, self).__init__()
        self.model_3d = Timm3DModel(backbone, n_channels, n_labels, segtype, pretrained)
        self.model_3d = convert_3d(self.model_3d)
        self.n_channels = n_channels
        self.n_labels = n_labels                        
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model_3d(x)
        pooled_features = []
        for i in range(0, len(x)):
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, self.n_labels, 1)))
        pooled_features = torch.cat(pooled_features, dim=2)     
        labels = nn.Flatten()(pooled_features)
        #labels = torch.mean(pooled_features, dim = 2)
        return labels

In [8]:
class AbdominalClassfierOne(nn.Module):
    def __init__(self, n_output, device = DEVICE):
        super().__init__()
        self.device = device        
        self.n_output = 2
        self.model3d        = Timm3DModelClassifierEmbed(backbone, 1, 4)
        self.head           = nn.Linear(16, n_output)

class AbdominalClassifierAtt(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        
        self.model3d_bowel        = Timm3DModelClassifierEmbed(backbone, 1, 4)      
        self.model3d_extrav       = Timm3DModelClassifierEmbed(backbone, 1, 4)
        self.model3d_kidney_left  = Timm3DModelClassifierEmbed(backbone, 1, 4)
        self.model3d_kidney_right = Timm3DModelClassifierEmbed(backbone, 1, 4)
        self.model3d_liver        = Timm3DModelClassifierEmbed(backbone, 1, 4)
        self.model3d_spleen       = Timm3DModelClassifierEmbed(backbone, 1, 4)
        
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
        self.lstm = nn.LSTM(input_size = 16, hidden_size = 5, num_layers=1, batch_first=True, bidirectional=True)
        self.attentions = nn.ModuleList()
        #for i in range(0, 3):
        #    self.attentions.append(nn.MultiheadAttention(20, 1, dropout=0.1, batch_first=True))
        
        self.head = nn.Linear(10, 13)
        
    def forward(self, x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total):
        bs = x_bowel.shape[0]
        bowel_emb        = torch.reshape(self.model3d_bowel(x_bowel), (bs, 1, 16))
        extrav_emb       = torch.reshape(self.model3d_extrav(x_total), (bs, 1, 16))
        kidney_left_emb  = torch.reshape(self.model3d_kidney_left(x_kidney_left), (bs, 1, 16))
        kidney_right_emb = torch.reshape(self.model3d_kidney_right(x_kidney_right), (bs, 1, 16))
        liver_emb        = torch.reshape(self.model3d_liver(x_liver), (bs, 1, 16))
        spleen_emb       = torch.reshape(self.model3d_spleen(x_spleen), (bs, 1, 16))
        
        all_embs = torch.cat([bowel_emb, extrav_emb, kidney_left_emb, kidney_right_emb, liver_emb, spleen_emb], dim = 1)
        #for i in range(0, 3):
        #    all_embs, _ = self.attentions[i](all_embs, all_embs, all_embs)

        all_embs = self.lstm(all_embs)
        #labels = nn.Flatten()(all_embs[0])
        labels   = torch.mean(all_embs[0], dim = 1) 
        labels   = self.head(labels)

        bowel_soft = self.softmax(labels[:,:2])
        extrav_soft = self.softmax(labels[:,2:4])
        kidney_soft = self.softmax(labels[:,4:7])
        liver_soft = self.softmax(labels[:,7:10])
        spleen_soft = self.softmax(labels[:,10:13])

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        #any_in = torch.log(any_in + 1e-6)  # 1e-6은 0을 처리하기 위한 작은 값
        return labels, any_in
    

In [9]:
model = AbdominalClassifierAtt()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

86487919


0

# Metric & Loss

In [10]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [11]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    isnan_x = np.isnan(X_outs).astype(int)
    isnan_y = np.isnan(ys).astype(int)
    
    if(np.max(isnan_x)>0):
        print('xnan')
    if(np.max(isnan_y)>0):
        print('ynan')
        
    #X_outs[:, 13:15] = nn.Softmax(dim=1)(torch.from_numpy(X_outs[:, 13:15])).numpy()
    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights.astype(np.float64))
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights.astype(np.float64))
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights.astype(np.float64))
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights.astype(np.float64))
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights.astype(np.float64))
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights.astype(np.float64))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [12]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=0),    
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),
    #transforms.RandAffined(keys=chan_keys, translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [13]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['cropped_path']            
            for j in range(0, 6):
                tmp_data_3d = decompress_fast(f'{base_name}_{j}').unsqueeze(0)
                #tmp_data_3d = torch.from_numpy(tmp_data_3d)
                tmp_data_3ds[chan_dict[j]] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 6):
                    data_3d[chan_dict[i]] = self.remain_transforms_set(data_3d[chan_dict[i]])
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['bowel'], data_3d['left_kidney'], data_3d['right_kidney'], \
                data_3d['liver'], data_3d['spleen'], data_3d['total'], label        


In [14]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [15]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    optimizer.zero_grad()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in train_loader:
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        
        batch_size = X_bowel.shape[0]
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
                optimizer.zero_grad()
        counter+=1

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader, epoch):
    X_outs=[]
    ys=[]
    model.eval()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
        batch_size = y.shape[0]
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [16]:
model = AbdominalClassifierAtt()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=SELECT_FOLD], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==SELECT_FOLD], is_train = False, transform_set = None,
                                        remain_transforms_set = None)
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, pct_start= PCT_START,
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader, epoch)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model.state_dict(), f'{BASE_PATH}/weights/{RUN_NAME}_best.pt') 
            if(metric < 0.48):
                torch.save(model.state_dict(), f'{BASE_PATH}/weights/{RUN_NAME}_{metric}.pt') 
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 3764/3764 [01:21<00:00, 46.37it/s]
100%|██████████| 947/947 [00:31<00:00, 29.91it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 / trn/loss=1.1211
Epoch 1 / train/metric=0.8570
Epoch 1 / val/metric=0.8526
Best val_metric 0.8525582238853167 at epoch 1!


  0%|          | 1/200 [04:27<14:45:38, 267.03s/it]

Epoch 2 / trn/loss=1.1037
Epoch 2 / train/metric=0.8414
Epoch 2 / val/metric=0.8378
Best val_metric 0.8378091142892287 at epoch 2!


  1%|          | 2/200 [08:40<14:15:03, 259.11s/it]

Epoch 3 / trn/loss=1.0871
Epoch 3 / train/metric=0.8264
Epoch 3 / val/metric=0.8238
Best val_metric 0.8238412527865989 at epoch 3!


  2%|▏         | 3/200 [12:54<14:02:46, 256.69s/it]

Epoch 4 / trn/loss=1.0705
Epoch 4 / train/metric=0.8114
Epoch 4 / val/metric=0.8102
Best val_metric 0.8102208053500474 at epoch 4!


  2%|▏         | 4/200 [17:08<13:55:26, 255.75s/it]

Epoch 5 / trn/loss=1.0540
Epoch 5 / train/metric=0.7964
Epoch 5 / val/metric=0.7961
Best val_metric 0.796060440792108 at epoch 5!


  2%|▎         | 5/200 [21:23<13:49:55, 255.36s/it]

Epoch 6 / trn/loss=1.0373
Epoch 6 / train/metric=0.7812
Epoch 6 / val/metric=0.7818
Best val_metric 0.7817543682883193 at epoch 6!


  3%|▎         | 6/200 [25:38<13:45:07, 255.19s/it]

Epoch 7 / trn/loss=1.0199
Epoch 7 / train/metric=0.7653
Epoch 7 / val/metric=0.7674
Best val_metric 0.7674328280465524 at epoch 7!


  4%|▎         | 7/200 [29:53<13:40:51, 255.19s/it]

Epoch 8 / trn/loss=1.0013
Epoch 8 / train/metric=0.7483
Epoch 8 / val/metric=0.7507
Best val_metric 0.7507354857750252 at epoch 8!


  4%|▍         | 8/200 [34:08<13:36:30, 255.16s/it]

Epoch 9 / trn/loss=0.9814
Epoch 9 / train/metric=0.7300
Epoch 9 / val/metric=0.7338
Best val_metric 0.7338251310510827 at epoch 9!


  4%|▍         | 9/200 [38:24<13:32:41, 255.30s/it]

Epoch 10 / trn/loss=0.9610
Epoch 10 / train/metric=0.7112
Epoch 10 / val/metric=0.7175
Best val_metric 0.7174716062950419 at epoch 10!


  5%|▌         | 10/200 [42:39<13:28:29, 255.31s/it]

Epoch 11 / trn/loss=0.9417
Epoch 11 / train/metric=0.6933
Epoch 11 / val/metric=0.7012
Best val_metric 0.7012419469661372 at epoch 11!


  6%|▌         | 11/200 [46:54<13:24:16, 255.32s/it]

Epoch 12 / trn/loss=0.9238
Epoch 12 / train/metric=0.6767
Epoch 12 / val/metric=0.6871
Best val_metric 0.6870897699135862 at epoch 12!


  6%|▌         | 12/200 [51:10<13:20:22, 255.44s/it]

Epoch 13 / trn/loss=0.9079
Epoch 13 / train/metric=0.6618
Epoch 13 / val/metric=0.6738
Best val_metric 0.6737572958464385 at epoch 13!


  6%|▋         | 13/200 [55:26<13:16:15, 255.48s/it]

Epoch 14 / trn/loss=0.8935
Epoch 14 / train/metric=0.6484
Epoch 14 / val/metric=0.6623
Best val_metric 0.6623159951681767 at epoch 14!


  7%|▋         | 14/200 [59:41<13:11:46, 255.41s/it]

Epoch 15 / trn/loss=0.8809
Epoch 15 / train/metric=0.6366
Epoch 15 / val/metric=0.6523
Best val_metric 0.6522685125883952 at epoch 15!


  8%|▊         | 15/200 [1:03:56<13:07:42, 255.47s/it]

Epoch 16 / trn/loss=0.8699
Epoch 16 / train/metric=0.6264
Epoch 16 / val/metric=0.6432
Best val_metric 0.6432121327101602 at epoch 16!


  8%|▊         | 16/200 [1:08:12<13:03:14, 255.40s/it]

Epoch 17 / trn/loss=0.8603
Epoch 17 / train/metric=0.6175
Epoch 17 / val/metric=0.6357
Best val_metric 0.6357092970107717 at epoch 17!


  8%|▊         | 17/200 [1:12:28<12:59:34, 255.60s/it]

Epoch 18 / trn/loss=0.8518
Epoch 18 / train/metric=0.6097
Epoch 18 / val/metric=0.6290
Best val_metric 0.6289614526094399 at epoch 18!


  9%|▉         | 18/200 [1:16:43<12:55:25, 255.63s/it]

Epoch 19 / trn/loss=0.8433
Epoch 19 / train/metric=0.6020
Epoch 19 / val/metric=0.6224
Best val_metric 0.6223630507037679 at epoch 19!


 10%|▉         | 19/200 [1:20:59<12:51:04, 255.60s/it]

Epoch 20 / trn/loss=0.8342
Epoch 20 / train/metric=0.5940
Epoch 20 / val/metric=0.6157
Best val_metric 0.6157392525140568 at epoch 20!


 10%|█         | 20/200 [1:25:15<12:47:05, 255.70s/it]

Epoch 21 / trn/loss=0.8242
Epoch 21 / train/metric=0.5854
Epoch 21 / val/metric=0.6072
Best val_metric 0.6071598592046724 at epoch 21!


 10%|█         | 21/200 [1:29:31<12:42:52, 255.71s/it]

Epoch 22 / trn/loss=0.8151
Epoch 22 / train/metric=0.5776
Epoch 22 / val/metric=0.5996
Best val_metric 0.5995951947765402 at epoch 22!


 11%|█         | 22/200 [1:33:46<12:38:24, 255.64s/it]

Epoch 23 / trn/loss=0.8081
Epoch 23 / train/metric=0.5716
Epoch 23 / val/metric=0.5942
Best val_metric 0.5942057203451906 at epoch 23!


 12%|█▏        | 23/200 [1:38:02<12:34:25, 255.74s/it]

Epoch 24 / trn/loss=0.8007
Epoch 24 / train/metric=0.5654
Epoch 24 / val/metric=0.5901
Best val_metric 0.5900824056076225 at epoch 24!


 12%|█▏        | 24/200 [1:42:17<12:29:34, 255.54s/it]

Epoch 25 / trn/loss=0.7952
Epoch 25 / train/metric=0.5607


 12%|█▎        | 25/200 [1:46:32<12:24:58, 255.42s/it]

Epoch 25 / val/metric=0.5924
Epoch 26 / trn/loss=0.7896
Epoch 26 / train/metric=0.5562
Epoch 26 / val/metric=0.5858
Best val_metric 0.5858027399032136 at epoch 26!


 13%|█▎        | 26/200 [1:50:48<12:21:08, 255.57s/it]

Epoch 27 / trn/loss=0.7835
Epoch 27 / train/metric=0.5511


 14%|█▎        | 27/200 [1:55:03<12:16:19, 255.37s/it]

Epoch 27 / val/metric=0.5898
Epoch 28 / trn/loss=0.7812
Epoch 28 / train/metric=0.5493
Epoch 28 / val/metric=0.5766
Best val_metric 0.5766140651936785 at epoch 28!


 14%|█▍        | 28/200 [1:59:19<12:12:43, 255.60s/it]

Epoch 29 / trn/loss=0.7783
Epoch 29 / train/metric=0.5470
Epoch 29 / val/metric=0.5761
Best val_metric 0.5760663682657684 at epoch 29!


 14%|█▍        | 29/200 [2:03:35<12:08:14, 255.53s/it]

Epoch 30 / trn/loss=0.7727
Epoch 30 / train/metric=0.5424
Epoch 30 / val/metric=0.5733
Best val_metric 0.5732791642456657 at epoch 30!


 15%|█▌        | 30/200 [2:07:50<12:03:57, 255.51s/it]

Epoch 31 / trn/loss=0.7702
Epoch 31 / train/metric=0.5405


 16%|█▌        | 31/200 [2:12:05<11:59:30, 255.45s/it]

Epoch 31 / val/metric=0.5766
Epoch 32 / trn/loss=0.7677
Epoch 32 / train/metric=0.5388
Epoch 32 / val/metric=0.5655
Best val_metric 0.5655004775980504 at epoch 32!


 16%|█▌        | 32/200 [2:16:21<11:55:26, 255.52s/it]

Epoch 33 / trn/loss=0.7633
Epoch 33 / train/metric=0.5353


 16%|█▋        | 33/200 [2:20:36<11:50:59, 255.45s/it]

Epoch 33 / val/metric=0.5665
Epoch 34 / trn/loss=0.7646
Epoch 34 / train/metric=0.5365


 17%|█▋        | 34/200 [2:24:52<11:46:29, 255.36s/it]

Epoch 34 / val/metric=0.5737
Epoch 35 / trn/loss=0.7616
Epoch 35 / train/metric=0.5345
Epoch 35 / val/metric=0.5598
Best val_metric 0.5597974429679416 at epoch 35!


 18%|█▊        | 35/200 [2:29:07<11:42:42, 255.53s/it]

Epoch 36 / trn/loss=0.7590
Epoch 36 / train/metric=0.5326
Epoch 36 / val/metric=0.5596
Best val_metric 0.5595594393245845 at epoch 36!


 18%|█▊        | 36/200 [2:33:23<11:38:33, 255.57s/it]

Epoch 37 / trn/loss=0.7555
Epoch 37 / train/metric=0.5298


 18%|█▊        | 37/200 [2:37:38<11:33:56, 255.44s/it]

Epoch 37 / val/metric=0.5845
Epoch 38 / trn/loss=0.7493
Epoch 38 / train/metric=0.5249
Epoch 38 / val/metric=0.5540
Best val_metric 0.5540229820712169 at epoch 38!


 19%|█▉        | 38/200 [2:41:54<11:30:03, 255.58s/it]

Epoch 39 / trn/loss=0.7472
Epoch 39 / train/metric=0.5233


 20%|█▉        | 39/200 [2:46:10<11:25:41, 255.54s/it]

Epoch 39 / val/metric=0.5553
Epoch 40 / trn/loss=0.7449
Epoch 40 / train/metric=0.5224


 20%|██        | 40/200 [2:50:24<11:20:52, 255.33s/it]

Epoch 40 / val/metric=0.5562
Epoch 41 / trn/loss=0.7427
Epoch 41 / train/metric=0.5203
Epoch 41 / val/metric=0.5466
Best val_metric 0.5466247643753431 at epoch 41!


 20%|██        | 41/200 [2:54:40<11:16:54, 255.44s/it]

Epoch 42 / trn/loss=0.7407
Epoch 42 / train/metric=0.5194


 21%|██        | 42/200 [2:58:54<11:11:47, 255.11s/it]

Epoch 42 / val/metric=0.5540
Epoch 43 / trn/loss=0.7439
Epoch 43 / train/metric=0.5226


 22%|██▏       | 43/200 [3:03:10<11:07:44, 255.19s/it]

Epoch 43 / val/metric=0.5574
Epoch 44 / trn/loss=0.7374
Epoch 44 / train/metric=0.5171


 22%|██▏       | 44/200 [3:07:25<11:03:23, 255.15s/it]

Epoch 44 / val/metric=0.5514
Epoch 45 / trn/loss=0.7324
Epoch 45 / train/metric=0.5134


 22%|██▎       | 45/200 [3:11:41<10:59:30, 255.29s/it]

Epoch 45 / val/metric=0.5599
Epoch 46 / trn/loss=0.7324
Epoch 46 / train/metric=0.5147


 23%|██▎       | 46/200 [3:15:56<10:55:24, 255.35s/it]

Epoch 46 / val/metric=0.5686
Epoch 47 / trn/loss=0.7315
Epoch 47 / train/metric=0.5138


 24%|██▎       | 47/200 [3:20:11<10:50:57, 255.28s/it]

Epoch 47 / val/metric=0.5615
Epoch 48 / trn/loss=0.7275
Epoch 48 / train/metric=0.5104
Epoch 48 / val/metric=0.5400
Best val_metric 0.5400357538384696 at epoch 48!


 24%|██▍       | 48/200 [3:24:27<10:47:07, 255.44s/it]

Epoch 49 / trn/loss=0.7262
Epoch 49 / train/metric=0.5097


 24%|██▍       | 49/200 [3:28:43<10:42:56, 255.48s/it]

Epoch 49 / val/metric=0.5516
Epoch 50 / trn/loss=0.7222
Epoch 50 / train/metric=0.5066


 25%|██▌       | 50/200 [3:32:58<10:38:18, 255.33s/it]

Epoch 50 / val/metric=0.5495
Epoch 51 / trn/loss=0.7212
Epoch 51 / train/metric=0.5059


 26%|██▌       | 51/200 [3:37:13<10:34:15, 255.40s/it]

Epoch 51 / val/metric=0.5437
Epoch 52 / trn/loss=0.7139
Epoch 52 / train/metric=0.5007


 26%|██▌       | 52/200 [3:41:28<10:29:43, 255.29s/it]

Epoch 52 / val/metric=0.5586
Epoch 53 / trn/loss=0.7249
Epoch 53 / train/metric=0.5102


 26%|██▋       | 53/200 [3:45:43<10:25:28, 255.30s/it]

Epoch 53 / val/metric=0.5475
Epoch 54 / trn/loss=0.7180
Epoch 54 / train/metric=0.5036


 27%|██▋       | 54/200 [3:49:59<10:21:16, 255.32s/it]

Epoch 54 / val/metric=0.5438
Epoch 55 / trn/loss=0.7141
Epoch 55 / train/metric=0.5011


 28%|██▊       | 55/200 [3:54:14<10:16:56, 255.28s/it]

Epoch 55 / val/metric=0.5539
Epoch 56 / trn/loss=0.7073
Epoch 56 / train/metric=0.4951
Epoch 56 / val/metric=0.5363
Best val_metric 0.5363175804854731 at epoch 56!


 28%|██▊       | 56/200 [3:58:30<10:13:07, 255.47s/it]

Epoch 57 / trn/loss=0.7108
Epoch 57 / train/metric=0.4988


 28%|██▊       | 57/200 [4:02:45<10:08:28, 255.30s/it]

Epoch 57 / val/metric=0.5829
Epoch 58 / trn/loss=0.7054
Epoch 58 / train/metric=0.4940
Epoch 58 / val/metric=0.5343
Best val_metric 0.5343474145210495 at epoch 58!


 29%|██▉       | 58/200 [4:07:00<10:04:24, 255.38s/it]

Epoch 59 / trn/loss=0.7012
Epoch 59 / train/metric=0.4904


 30%|██▉       | 59/200 [4:11:16<10:00:12, 255.41s/it]

Epoch 59 / val/metric=0.5363
Epoch 60 / trn/loss=0.7028
Epoch 60 / train/metric=0.4927


 30%|███       | 60/200 [4:15:31<9:55:57, 255.41s/it] 

Epoch 60 / val/metric=0.5738
Epoch 61 / trn/loss=0.6988
Epoch 61 / train/metric=0.4885


 30%|███       | 61/200 [4:19:46<9:51:28, 255.31s/it]

Epoch 61 / val/metric=0.5404
Epoch 62 / trn/loss=0.7006
Epoch 62 / train/metric=0.4902


 31%|███       | 62/200 [4:24:02<9:47:22, 255.38s/it]

Epoch 62 / val/metric=0.5667
Epoch 63 / trn/loss=0.6981
Epoch 63 / train/metric=0.4888


 32%|███▏      | 63/200 [4:28:17<9:43:04, 255.36s/it]

Epoch 63 / val/metric=0.5408
Epoch 64 / trn/loss=0.6942
Epoch 64 / train/metric=0.4840


 32%|███▏      | 64/200 [4:32:32<9:38:35, 255.26s/it]

Epoch 64 / val/metric=0.5440
Epoch 65 / trn/loss=0.6961
Epoch 65 / train/metric=0.4869


 32%|███▎      | 65/200 [4:36:47<9:34:09, 255.18s/it]

Epoch 65 / val/metric=0.5367


In [None]:
#Execute this cell to fininsh the wandb run when you stopped training.
import wandb
try: 
    wandb.log(
        {'best_total_log_loss': best_metric})
    wandb.finish()
    
except: 
    print('Wandb is already finished!')

In [None]:
ind = 23561
train_meta_df[train_meta_df['patient_id']==ind]

In [None]:
valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['patient_id']==ind], is_train = False, transform_set = None,
                                    remain_transforms_set = None)

valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                        num_workers = N_WORKERS, drop_last = False)     

In [None]:
X_outs=[]
ys=[]
model.eval()
model.load_state_dict(torch.load(f'{BASE_PATH}/weights/231003_resnet10t_dicomO_std_CV0.485.pt'))
for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
    batch_size = y.shape[0]
    X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
    X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
    y = y.to(DEVICE)           
    with torch.cuda.amp.autocast(enabled=True):                
        with torch.no_grad():                 
            X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
            X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

            X_any = X_any.to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

X_outs = np.vstack(X_outs) 
ys     = np.vstack(ys)
#metric = calculate_score(X_outs, ys, 'valid')                      

del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_any
gc.collect()        
torch.cuda.empty_cache()   

In [None]:
np.average(X_outs, axis = 0)


In [None]:
len(X_outs)