In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')


#sys.path.append('')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame
import wandb

from FixRes.imnet_extract.pnasnet import pnasnet5large

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


True

# Parameters

In [2]:
backbone = 'timm/resnet10t.c3_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'momdel_test'
RUN_NAME=   f'{backbone}_Seg3d(Load)XEffNetb0-2d_1/4AvgPoolshrink'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 160
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 4
ACCUM_STEPS = 2
N_WORKERS  = 8
LR = 0.0002
N_EPOCHS = 200
EARLY_STOP_COUNT = N_EPOCHS
N_FOLDS  = 5
SELECT_FOLD = 1
N_PREPROCESS_CHUNKS = 12
PCT_START = 0.3
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0

SEG_OUT_DIM = 5

DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION,
    'PCT_START': PCT_START
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
#def load_pnasnet5large():
    #model= pnasnet5large(pretrained = False)
    #pretrained_dict=torch.load(f'{BASE_PATH}/PNASNet.pth', map_location='cpu')['model']

    #model_dict = model.state_dict()
    ##for k in model_dict.keys():
    #    if(('module.'+k) in pretrained_dict.keys()):
    #        model_dict[k]=pretrained_dict.get(('module.'+k))
    
#    return model

#model = load_pnasnet5large()
#inputs = torch.zeros(1, 3, 128, 128)
#outputs = model(inputs)
#outputs.shape


In [4]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([474, 480, 461, 471, 469, 466, 476, 480, 463, 471]))

In [5]:
#train_injured_df = train_meta_df.loc[(train_meta_df['fold']!=SELECT_FOLD)& (train_meta_df['any_injury']==1)].copy()
#train_meta_df = pd.concat([train_meta_df, train_injured_df])
#len(train_meta_df)

In [6]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress_fast(name):
    with open(name, 'rb') as f:
        data = pickle.load(f)
    return data

# Model

In [7]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [8]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            #print(global_features[i].shape)
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    

class TimmSegModel(nn.Module):
    def __init__(self, backbone, in_chans = 1, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm_new.create_model(
            backbone,
            in_chans=in_chans,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, in_chans, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], SEG_OUT_DIM, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

class CNNSegXClassifier(nn.Module):
    def __init__(self, backbone = backbone, device = DEVICE):
        super().__init__()
        self.device = DEVICE
        self.seg_model   = TimmSegModel(backbone)
        self.seg_model = convert_3d(self.seg_model)
        self.seg_model.load_state_dict(torch.load(f'{BASE_PATH}/seg_models_backup/231001_timm3d_res10tc_CV0.938.pt'))
        
        #self.pointwise_conv3d = nn.Conv3d(SEG_OUT_DIM, 1, kernel_size = (1, 1, 1))
        #self.get_yz_conv3d = nn.Conv3d(SEG_OUT_DIM, 1, kernel_size = (1, 1, 1))
        #self.get_zx_conv3d = nn.Conv3d(SEG_OUT_DIM, 1, kernel_size = (1, 1, 1))
        #self.class_model = load_pnasnet5large() 
        self.class_model = timm_new.create_model('timm/tf_efficientnet_b0.ns_jft_in1k', pretrained = True,
                                                drop_rate=0.5, drop_path_rate=0.5)
        self.class_model.conv_stem = nn.Conv2d((SEG_OUT_DIM+1)*3, 32, kernel_size=(3, 3), stride = (2, 2), bias = False)
        self.class_model.classifier = nn.Linear(in_features=1280, out_features=13, bias=True)
        
    def forward(self, x):
        #x = torch.cat([x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total], dim = 1)
        x_origin = torch.clone(x)
        x = self.seg_model(x)
        x = torch.cat([x, x_origin], dim = 1)
        #x = self.pointwise_conv3d(x)
        #x = nn.AvgPool3d((4, 4, 4))(x)
        xy = torch.mean(x, dim = 2)
        yz = torch.mean(x, dim = 4)
        zx = torch.mean(x, dim = 3)
        
        slices_cat = torch.cat([xy, yz, zx], dim = 1)
        output = self.class_model(slices_cat)
        return output

#inputs = torch.zeros(1, 1, 128, 128, 128)
#model = CNNSegXClassifier()
#outputs = model(inputs, inputs, inputs, inputs, inputs, inputs)    

In [9]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        
        self.model    = CNNSegXClassifier()
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
        #self.lstm = nn.LSTM(input_size =32, hidden_size = 8, num_layers=5, batch_first=True, bidirectional=True, dropout = drop_rate)        
        #self.head = nn.Linear(16, 13)
        
    def forward(self, x):
        bs = x.shape[0]
        
        labels    = self.model(x)

        bowel_soft = self.softmax(labels[:,:2])
        extrav_soft = self.softmax(labels[:,2:4])
        kidney_soft = self.softmax(labels[:,4:7])
        liver_soft = self.softmax(labels[:,7:10])
        spleen_soft = self.softmax(labels[:,10:13])

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        #any_in = torch.log(any_in + 1e-6)  # 1e-6은 0을 처리하기 위한 작은 값
        return labels, any_in

In [10]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

24942406


52

# Metric & Loss

In [11]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [12]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    isnan_x = np.isnan(X_outs).astype(int)
    isnan_y = np.isnan(ys).astype(int)
    
    if(np.max(isnan_x)>0):
        print('xnan')
    if(np.max(isnan_y)>0):
        print('ynan')
        
    X_outs = np.nan_to_num(X_outs, nan = 0.0)
        
    #X_outs[:, 13:15] = nn.Softmax(dim=1)(torch.from_numpy(X_outs[:, 13:15])).numpy()
    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights.astype(np.float64))
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights.astype(np.float64))
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights.astype(np.float64))
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights.astype(np.float64))
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights.astype(np.float64))
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights.astype(np.float64))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [13]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=['all'], prob=0.5, spatial_axis=0),    
    transforms.RandFlipd(keys=['all'], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=['all'], prob=0.5, spatial_axis=2),
    #transforms.RandAffined(keys=chan_keys, translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=['all'], prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
])

'''
transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
                        prob = DROP_REGION['PROB'], 
                        fill_value = DROP_REGION['FILL']),
transforms.RandGaussianNoise(prob=0.2),
transforms.RandBiasField(prob=0.2),
transforms.RandAdjustContrast(prob=0.2),
transforms.RandGaussianSmooth(prob=0.2),
transforms.RandGaussianSharpen(prob=0.2),
transforms.RandHistogramShift(prob=0.2),
transforms.RandGibbsNoise(prob=0.2),
transforms.RandKSpaceSpikeNoise(prob=0.2),
transforms.RandRicianNoise(prob=0.2),    
'''



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [14]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['path']                        
            tmp_data_3d = decompress(f'{base_name}').unsqueeze(0)
            #tmp_data_3d = torch.from_numpy(tmp_data_3d)
            tmp_data_3ds['all'] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 1):
                    data_3d['all'] = self.remain_transforms_set(data_3d['all'].type(torch.float32)).type(torch.float16)
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['all'], label        


In [15]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [16]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    optimizer.zero_grad()
    for X, y in train_loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        
        batch_size = X.shape[0]
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X)
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
                optimizer.zero_grad()
        counter+=1

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader, epoch):
    X_outs=[]
    ys=[]
    model.eval()
    for X, y in valid_loader:
        batch_size = y.shape[0]
        X, y = X.to(DEVICE), y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [17]:
model = AbdominalClassifier()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=SELECT_FOLD], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = None)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==SELECT_FOLD], is_train = False, transform_set = None,
                                        remain_transforms_set = None)
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = True, 
                            num_workers = N_WORKERS*2, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = True, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR, weight_decay = 0.001)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, pct_start= PCT_START,
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader, epoch)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}_best.pt') 
            if(metric < 0.48):
                torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}_{metric}.pt') 
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 4231/4231 [01:30<00:00, 46.96it/s]
100%|██████████| 480/480 [00:10<00:00, 47.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1 / trn/loss=1.0306
Epoch 1 / train/metric=0.7703


  0%|          | 1/200 [10:09<33:40:04, 609.07s/it]

Epoch 1 / val/metric=0.6843
Best val_metric 0.684292526795967 at epoch 1!
Epoch 2 / trn/loss=0.8882
Epoch 2 / train/metric=0.6399


  1%|          | 2/200 [20:16<33:27:42, 608.40s/it]

Epoch 2 / val/metric=0.6314
Best val_metric 0.6314384368990048 at epoch 2!
Epoch 3 / trn/loss=0.8455
Epoch 3 / train/metric=0.6013


  2%|▏         | 3/200 [30:25<33:17:48, 608.47s/it]

Epoch 3 / val/metric=0.6198
Best val_metric 0.6197818946287917 at epoch 3!
Epoch 4 / trn/loss=0.8321
Epoch 4 / train/metric=0.5896


  2%|▏         | 4/200 [40:33<33:07:28, 608.41s/it]

Epoch 4 / val/metric=0.6127
Best val_metric 0.6126641797367512 at epoch 4!
Epoch 5 / trn/loss=0.8337
Epoch 5 / train/metric=0.5907


  2%|▎         | 5/200 [50:41<32:56:55, 608.29s/it]

Epoch 5 / val/metric=0.6228
Epoch 6 / trn/loss=0.8300
Epoch 6 / train/metric=0.5878


  3%|▎         | 6/200 [1:00:50<32:47:12, 608.42s/it]

Epoch 6 / val/metric=0.6147
Epoch 7 / trn/loss=0.8294
Epoch 7 / train/metric=0.5870


  4%|▎         | 7/200 [1:10:58<32:36:37, 608.28s/it]

Epoch 7 / val/metric=0.6121
Best val_metric 0.6120897740112475 at epoch 7!
Epoch 8 / trn/loss=0.8263
Epoch 8 / train/metric=0.5846


  4%|▍         | 8/200 [1:21:06<32:26:04, 608.15s/it]

Epoch 8 / val/metric=0.6037
Best val_metric 0.6036770761945652 at epoch 8!
Epoch 9 / trn/loss=0.8256
Epoch 9 / train/metric=0.5841


  4%|▍         | 9/200 [1:31:14<32:15:34, 608.04s/it]

Epoch 9 / val/metric=0.6108
Epoch 10 / trn/loss=0.8239
Epoch 10 / train/metric=0.5827


  5%|▌         | 10/200 [1:41:22<32:05:32, 608.06s/it]

Epoch 10 / val/metric=0.6154
Epoch 11 / trn/loss=0.8193
Epoch 11 / train/metric=0.5789
Epoch 11 / val/metric=0.6022
Best val_metric 0.6022414197959641 at epoch 11!


  6%|▌         | 11/200 [1:51:30<31:55:42, 608.16s/it]

Epoch 12 / trn/loss=0.8220
Epoch 12 / train/metric=0.5817


  6%|▌         | 12/200 [2:01:39<31:45:49, 608.24s/it]

Epoch 12 / val/metric=0.6039
Epoch 13 / trn/loss=0.8200
Epoch 13 / train/metric=0.5794


  6%|▋         | 13/200 [2:11:47<31:35:56, 608.33s/it]

Epoch 13 / val/metric=0.6035
Epoch 14 / trn/loss=0.8166
Epoch 14 / train/metric=0.5777


  7%|▋         | 14/200 [2:21:55<31:25:31, 608.23s/it]

Epoch 14 / val/metric=0.6006
Best val_metric 0.6005825757437941 at epoch 14!
Epoch 15 / trn/loss=0.8173
Epoch 15 / train/metric=0.5784


  8%|▊         | 15/200 [2:32:03<31:15:12, 608.18s/it]

Epoch 15 / val/metric=0.6043
Epoch 16 / trn/loss=0.8170
Epoch 16 / train/metric=0.5782


  8%|▊         | 16/200 [2:42:11<31:04:55, 608.13s/it]

Epoch 16 / val/metric=0.6058
Epoch 17 / trn/loss=0.8160
Epoch 17 / train/metric=0.5779
Epoch 17 / val/metric=0.5991
Best val_metric 0.5990852835852948 at epoch 17!


  8%|▊         | 17/200 [2:52:19<30:54:44, 608.11s/it]

Epoch 18 / trn/loss=0.8071
Epoch 18 / train/metric=0.5707


  9%|▉         | 18/200 [3:02:27<30:44:30, 608.08s/it]

Epoch 18 / val/metric=0.5970
Best val_metric 0.5969616775869079 at epoch 18!
Epoch 19 / trn/loss=0.8100
Epoch 19 / train/metric=0.5738
Epoch 19 / val/metric=0.5919
Best val_metric 0.5919044117864658 at epoch 19!


 10%|▉         | 19/200 [3:12:36<30:34:52, 608.24s/it]

Epoch 20 / trn/loss=0.8042
Epoch 20 / train/metric=0.5689


 10%|█         | 20/200 [3:22:44<30:24:30, 608.17s/it]

Epoch 20 / val/metric=0.6052
Epoch 21 / trn/loss=0.8014
Epoch 21 / train/metric=0.5669


 10%|█         | 21/200 [3:32:52<30:14:17, 608.14s/it]

Epoch 21 / val/metric=0.5870
Best val_metric 0.5870213451793445 at epoch 21!
Epoch 22 / trn/loss=0.8002
Epoch 22 / train/metric=0.5664


 11%|█         | 22/200 [3:43:00<30:03:45, 608.01s/it]

Epoch 22 / val/metric=0.5933
Epoch 23 / trn/loss=0.7985
Epoch 23 / train/metric=0.5645
Epoch 23 / val/metric=0.5754
Best val_metric 0.575362821645878 at epoch 23!


 12%|█▏        | 23/200 [3:53:08<29:53:47, 608.06s/it]

Epoch 24 / trn/loss=0.7939
Epoch 24 / train/metric=0.5615


 12%|█▏        | 24/200 [4:03:16<29:43:24, 607.98s/it]

Epoch 24 / val/metric=0.5903
Epoch 25 / trn/loss=0.7893
Epoch 25 / train/metric=0.5582


 12%|█▎        | 25/200 [4:13:24<29:33:28, 608.05s/it]

Epoch 25 / val/metric=0.5837
Epoch 26 / trn/loss=0.7874
Epoch 26 / train/metric=0.5564


 13%|█▎        | 26/200 [4:23:32<29:23:09, 607.99s/it]

Epoch 26 / val/metric=0.5733
Best val_metric 0.5733390585309759 at epoch 26!
Epoch 27 / trn/loss=0.7923
Epoch 27 / train/metric=0.5606


 14%|█▎        | 27/200 [4:33:40<29:12:59, 607.98s/it]

Epoch 27 / val/metric=0.5775
Epoch 28 / trn/loss=0.7837
Epoch 28 / train/metric=0.5536


 14%|█▍        | 28/200 [4:43:48<29:03:13, 608.10s/it]

Epoch 28 / val/metric=0.5902
Epoch 29 / trn/loss=0.7800
Epoch 29 / train/metric=0.5516


 14%|█▍        | 29/200 [4:53:56<28:53:02, 608.08s/it]

Epoch 29 / val/metric=0.5772
Epoch 30 / trn/loss=0.7715
Epoch 30 / train/metric=0.5448


 15%|█▌        | 30/200 [5:04:04<28:42:45, 608.03s/it]

Epoch 30 / val/metric=0.5896
Epoch 31 / trn/loss=0.7688
Epoch 31 / train/metric=0.5439


 16%|█▌        | 31/200 [5:14:13<28:33:08, 608.21s/it]

Epoch 31 / val/metric=0.5594
Best val_metric 0.5593589104378524 at epoch 31!
Epoch 32 / trn/loss=0.7634
Epoch 32 / train/metric=0.5383


 16%|█▌        | 32/200 [5:24:21<28:22:46, 608.13s/it]

Epoch 32 / val/metric=0.5445
Best val_metric 0.5445229603486809 at epoch 32!
Epoch 33 / trn/loss=0.7648
Epoch 33 / train/metric=0.5398


 16%|█▋        | 33/200 [5:34:29<28:12:28, 608.08s/it]

Epoch 33 / val/metric=0.5663
Epoch 34 / trn/loss=0.7529
Epoch 34 / train/metric=0.5311


 17%|█▋        | 34/200 [5:44:37<28:02:25, 608.11s/it]

Epoch 34 / val/metric=0.5623
Epoch 35 / trn/loss=0.7484
Epoch 35 / train/metric=0.5279


 18%|█▊        | 35/200 [5:54:45<27:52:07, 608.05s/it]

Epoch 35 / val/metric=0.5515
Epoch 36 / trn/loss=0.7384
Epoch 36 / train/metric=0.5186


 18%|█▊        | 36/200 [6:04:53<27:41:50, 607.99s/it]

Epoch 36 / val/metric=0.5647
Epoch 37 / trn/loss=0.7204
Epoch 37 / train/metric=0.5049


 18%|█▊        | 37/200 [6:15:01<27:32:06, 608.14s/it]

Epoch 37 / val/metric=0.5853
Epoch 38 / trn/loss=0.7204
Epoch 38 / train/metric=0.5059


 19%|█▉        | 38/200 [6:25:09<27:21:49, 608.08s/it]

Epoch 38 / val/metric=0.5551
Epoch 39 / trn/loss=0.7115
Epoch 39 / train/metric=0.4990


 20%|█▉        | 39/200 [6:35:17<27:11:31, 608.02s/it]

Epoch 39 / val/metric=0.5476
Epoch 40 / trn/loss=0.7146
Epoch 40 / train/metric=0.5000


 20%|██        | 40/200 [6:45:25<27:01:32, 608.08s/it]

Epoch 40 / val/metric=0.5278
Best val_metric 0.5278089410977943 at epoch 40!
Epoch 41 / trn/loss=0.7037
Epoch 41 / train/metric=0.4914


 20%|██        | 41/200 [6:55:33<26:51:28, 608.10s/it]

Epoch 41 / val/metric=0.5504
Epoch 42 / trn/loss=0.6957
Epoch 42 / train/metric=0.4881


 21%|██        | 42/200 [7:05:42<26:41:36, 608.21s/it]

Epoch 42 / val/metric=0.5397
Epoch 43 / trn/loss=0.6889
Epoch 43 / train/metric=0.4798


 22%|██▏       | 43/200 [7:15:49<26:30:59, 608.02s/it]

Epoch 43 / val/metric=0.5304
Epoch 44 / trn/loss=0.6884
Epoch 44 / train/metric=0.4803


 22%|██▏       | 44/200 [7:25:58<26:21:06, 608.12s/it]

Epoch 44 / val/metric=0.5877
Epoch 45 / trn/loss=0.6733
Epoch 45 / train/metric=0.4681


 22%|██▎       | 45/200 [7:36:05<26:10:36, 607.98s/it]

Epoch 45 / val/metric=0.5708
Epoch 46 / trn/loss=0.6736
Epoch 46 / train/metric=0.4681


 23%|██▎       | 46/200 [7:46:13<26:00:16, 607.90s/it]

Epoch 46 / val/metric=0.5413
Epoch 47 / trn/loss=0.6735
Epoch 47 / train/metric=0.4679


 24%|██▎       | 47/200 [7:56:21<25:50:12, 607.92s/it]

Epoch 47 / val/metric=0.5463
Epoch 48 / trn/loss=0.6652
Epoch 48 / train/metric=0.4613


 24%|██▍       | 48/200 [8:06:29<25:39:48, 607.82s/it]

Epoch 48 / val/metric=0.5779
Epoch 49 / trn/loss=0.6507
Epoch 49 / train/metric=0.4499


 24%|██▍       | 49/200 [8:16:36<25:29:43, 607.83s/it]

Epoch 49 / val/metric=0.5654
Epoch 50 / trn/loss=0.6436
Epoch 50 / train/metric=0.4433


 25%|██▌       | 50/200 [8:26:44<25:19:33, 607.82s/it]

Epoch 50 / val/metric=0.5176
Best val_metric 0.5176499287767846 at epoch 50!
Epoch 51 / trn/loss=0.6423
Epoch 51 / train/metric=0.4443


 26%|██▌       | 51/200 [8:36:52<25:09:43, 607.94s/it]

Epoch 51 / val/metric=0.5320
Epoch 52 / trn/loss=0.6286
Epoch 52 / train/metric=0.4329
Epoch 52 / val/metric=0.5077
Best val_metric 0.5076705434563363 at epoch 52!


 26%|██▌       | 52/200 [8:47:01<25:00:00, 608.11s/it]

Epoch 53 / trn/loss=0.6262
Epoch 53 / train/metric=0.4307


 26%|██▋       | 53/200 [8:57:09<24:49:55, 608.13s/it]

Epoch 53 / val/metric=0.5693
Epoch 54 / trn/loss=0.6138
Epoch 54 / train/metric=0.4201


 27%|██▋       | 54/200 [9:07:18<24:40:17, 608.34s/it]

Epoch 54 / val/metric=0.5408
Epoch 55 / trn/loss=0.6098
Epoch 55 / train/metric=0.4172


 28%|██▊       | 55/200 [9:17:26<24:29:51, 608.22s/it]

Epoch 55 / val/metric=0.5158
Epoch 56 / trn/loss=0.5953
Epoch 56 / train/metric=0.4030


 28%|██▊       | 56/200 [9:27:34<24:19:34, 608.16s/it]

Epoch 56 / val/metric=0.5216
Epoch 57 / trn/loss=0.5855
Epoch 57 / train/metric=0.3955


 28%|██▊       | 57/200 [9:37:42<24:09:41, 608.26s/it]

Epoch 57 / val/metric=0.5728
Epoch 58 / trn/loss=0.5809
Epoch 58 / train/metric=0.3921


 29%|██▉       | 58/200 [9:47:50<23:59:20, 608.17s/it]

Epoch 58 / val/metric=0.6308
Epoch 59 / trn/loss=0.5733
Epoch 59 / train/metric=0.3870


 30%|██▉       | 59/200 [9:57:58<23:49:00, 608.09s/it]

Epoch 59 / val/metric=0.5568
Epoch 60 / trn/loss=0.5680
Epoch 60 / train/metric=0.3817


 30%|███       | 60/200 [10:08:06<23:38:55, 608.11s/it]

Epoch 60 / val/metric=0.5507
Epoch 61 / trn/loss=0.5615
Epoch 61 / train/metric=0.3764


 30%|███       | 61/200 [10:18:14<23:28:34, 608.02s/it]

Epoch 61 / val/metric=0.5957
Epoch 62 / trn/loss=0.5530
Epoch 62 / train/metric=0.3701


 31%|███       | 62/200 [10:28:22<23:18:17, 607.96s/it]

Epoch 62 / val/metric=0.6253
Epoch 63 / trn/loss=0.5343
Epoch 63 / train/metric=0.3530


 32%|███▏      | 63/200 [10:38:30<23:08:11, 607.97s/it]

Epoch 63 / val/metric=0.5794
Epoch 64 / trn/loss=0.5309
Epoch 64 / train/metric=0.3507


 32%|███▏      | 64/200 [10:48:38<22:58:21, 608.10s/it]

Epoch 64 / val/metric=0.6244
Epoch 65 / trn/loss=0.5320
Epoch 65 / train/metric=0.3541


 32%|███▎      | 65/200 [10:58:46<22:48:08, 608.06s/it]

Epoch 65 / val/metric=0.5923
Epoch 66 / trn/loss=0.5158
Epoch 66 / train/metric=0.3408


 33%|███▎      | 66/200 [11:08:54<22:37:55, 608.03s/it]

Epoch 66 / val/metric=0.5819
Epoch 67 / trn/loss=0.5105
Epoch 67 / train/metric=0.3340


 34%|███▎      | 67/200 [11:19:02<22:27:45, 608.01s/it]

Epoch 67 / val/metric=0.6615
Epoch 68 / trn/loss=0.5064
Epoch 68 / train/metric=0.3304


 34%|███▍      | 68/200 [11:29:11<22:18:07, 608.24s/it]

Epoch 68 / val/metric=0.5884
Epoch 69 / trn/loss=0.4909
Epoch 69 / train/metric=0.3167


 34%|███▍      | 69/200 [11:39:20<22:08:16, 608.37s/it]

Epoch 69 / val/metric=0.7044
Epoch 70 / trn/loss=0.4859
Epoch 70 / train/metric=0.3128


 35%|███▌      | 70/200 [11:49:28<21:57:53, 608.26s/it]

Epoch 70 / val/metric=0.6857
Epoch 71 / trn/loss=0.4857
Epoch 71 / train/metric=0.3147


 36%|███▌      | 71/200 [11:59:36<21:47:34, 608.17s/it]

Epoch 71 / val/metric=0.5923
Epoch 72 / trn/loss=0.4869
Epoch 72 / train/metric=0.3132


 36%|███▌      | 72/200 [12:09:45<21:37:56, 608.41s/it]

Epoch 72 / val/metric=0.6215
Epoch 73 / trn/loss=0.4609
Epoch 73 / train/metric=0.2918


 36%|███▋      | 73/200 [12:19:53<21:27:35, 608.31s/it]

Epoch 73 / val/metric=0.6916
Epoch 74 / trn/loss=0.4571
Epoch 74 / train/metric=0.2896


 37%|███▋      | 74/200 [12:30:01<21:17:22, 608.28s/it]

Epoch 74 / val/metric=0.6626
Epoch 75 / trn/loss=0.4671
Epoch 75 / train/metric=0.3004


 38%|███▊      | 75/200 [12:40:09<21:07:11, 608.26s/it]

Epoch 75 / val/metric=0.6300
Epoch 76 / trn/loss=0.4497
Epoch 76 / train/metric=0.2823


 38%|███▊      | 76/200 [12:50:17<20:56:47, 608.13s/it]

Epoch 76 / val/metric=0.6557
Epoch 77 / trn/loss=0.4397
Epoch 77 / train/metric=0.2735


 38%|███▊      | 77/200 [13:00:25<20:46:50, 608.21s/it]

Epoch 77 / val/metric=0.7488
Epoch 78 / trn/loss=0.4237
Epoch 78 / train/metric=0.2612


 39%|███▉      | 78/200 [13:10:33<20:36:35, 608.16s/it]

Epoch 78 / val/metric=0.7297
Epoch 79 / trn/loss=0.4295
Epoch 79 / train/metric=0.2665


 40%|███▉      | 79/200 [13:20:42<20:26:34, 608.22s/it]

Epoch 79 / val/metric=0.7116
Epoch 80 / trn/loss=0.4187
Epoch 80 / train/metric=0.2573


 40%|████      | 80/200 [13:30:50<20:16:25, 608.21s/it]

Epoch 80 / val/metric=0.6320
Epoch 81 / trn/loss=0.4119
Epoch 81 / train/metric=0.2511


 40%|████      | 81/200 [13:40:59<20:06:30, 608.33s/it]

Epoch 81 / val/metric=0.6689
Epoch 82 / trn/loss=0.4076
Epoch 82 / train/metric=0.2459


 41%|████      | 82/200 [13:51:07<19:56:26, 608.36s/it]

Epoch 82 / val/metric=0.7004
Epoch 83 / trn/loss=0.4097
Epoch 83 / train/metric=0.2504


 42%|████▏     | 83/200 [14:01:15<19:46:02, 608.23s/it]

Epoch 83 / val/metric=0.6929
Epoch 84 / trn/loss=0.4036
Epoch 84 / train/metric=0.2441


 42%|████▏     | 84/200 [14:11:23<19:35:46, 608.16s/it]

Epoch 84 / val/metric=0.7678
Epoch 85 / trn/loss=0.3854
Epoch 85 / train/metric=0.2290


 42%|████▎     | 85/200 [14:21:31<19:25:25, 608.05s/it]

Epoch 85 / val/metric=0.6852
Epoch 86 / trn/loss=0.3927
Epoch 86 / train/metric=0.2358


 43%|████▎     | 86/200 [14:31:39<19:15:20, 608.08s/it]

Epoch 86 / val/metric=0.7324
Epoch 87 / trn/loss=0.3767
Epoch 87 / train/metric=0.2195


 44%|████▎     | 87/200 [14:41:47<19:05:07, 608.03s/it]

Epoch 87 / val/metric=0.7453
Epoch 88 / trn/loss=0.3714
Epoch 88 / train/metric=0.2161


 44%|████▍     | 88/200 [14:51:56<18:55:25, 608.27s/it]

Epoch 88 / val/metric=0.8163
Epoch 89 / trn/loss=0.3671
Epoch 89 / train/metric=0.2133


 44%|████▍     | 89/200 [15:02:04<18:45:35, 608.42s/it]

Epoch 89 / val/metric=0.7971
Epoch 90 / trn/loss=0.3629
Epoch 90 / train/metric=0.2087


 45%|████▌     | 90/200 [15:12:13<18:35:16, 608.33s/it]

Epoch 90 / val/metric=0.7774
Epoch 91 / trn/loss=0.3627
Epoch 91 / train/metric=0.2092


 46%|████▌     | 91/200 [15:22:21<18:24:57, 608.24s/it]

Epoch 91 / val/metric=0.7543
Epoch 92 / trn/loss=0.3399
Epoch 92 / train/metric=0.1889


 46%|████▌     | 92/200 [15:32:29<18:14:59, 608.33s/it]

Epoch 92 / val/metric=0.7477
Epoch 93 / trn/loss=0.3474
Epoch 93 / train/metric=0.1952


 46%|████▋     | 93/200 [15:42:37<18:04:42, 608.24s/it]

Epoch 93 / val/metric=0.8789
Epoch 94 / trn/loss=0.3324
Epoch 94 / train/metric=0.1809


 47%|████▋     | 94/200 [15:52:45<17:54:20, 608.12s/it]

Epoch 94 / val/metric=0.7984
Epoch 95 / trn/loss=0.3374
Epoch 95 / train/metric=0.1863


 48%|████▊     | 95/200 [16:02:53<17:44:16, 608.15s/it]

Epoch 95 / val/metric=0.8560
Epoch 96 / trn/loss=0.3273
Epoch 96 / train/metric=0.1775


 48%|████▊     | 96/200 [16:13:01<17:34:02, 608.10s/it]

Epoch 96 / val/metric=0.8389
Epoch 97 / trn/loss=0.3195
Epoch 97 / train/metric=0.1725


 48%|████▊     | 97/200 [16:23:09<17:23:54, 608.10s/it]

Epoch 97 / val/metric=0.8642
Epoch 98 / trn/loss=0.3179
Epoch 98 / train/metric=0.1697


 49%|████▉     | 98/200 [16:33:18<17:13:54, 608.18s/it]

Epoch 98 / val/metric=0.8984
Epoch 99 / trn/loss=0.3170
Epoch 99 / train/metric=0.1704


 50%|████▉     | 99/200 [16:43:26<17:03:47, 608.19s/it]

Epoch 99 / val/metric=0.8899
Epoch 100 / trn/loss=0.2991
Epoch 100 / train/metric=0.1536


 50%|█████     | 100/200 [16:53:34<16:53:32, 608.12s/it]

Epoch 100 / val/metric=0.9285
Epoch 101 / trn/loss=0.3074
Epoch 101 / train/metric=0.1626


 50%|█████     | 101/200 [17:03:42<16:43:32, 608.21s/it]

Epoch 101 / val/metric=0.8154
Epoch 102 / trn/loss=0.3061
Epoch 102 / train/metric=0.1585


 51%|█████     | 102/200 [17:13:50<16:33:07, 608.04s/it]

Epoch 102 / val/metric=0.9970
Epoch 103 / trn/loss=0.3036
Epoch 103 / train/metric=0.1588


 52%|█████▏    | 103/200 [17:23:58<16:22:56, 608.00s/it]

Epoch 103 / val/metric=0.8862
Epoch 104 / trn/loss=0.2936
Epoch 104 / train/metric=0.1486


 52%|█████▏    | 104/200 [17:34:06<16:12:40, 607.92s/it]

Epoch 104 / val/metric=0.9802
Epoch 105 / trn/loss=0.2888
Epoch 105 / train/metric=0.1438


 52%|█████▎    | 105/200 [17:44:13<16:02:31, 607.91s/it]

Epoch 105 / val/metric=0.8652
Epoch 106 / trn/loss=0.2842
Epoch 106 / train/metric=0.1410


 53%|█████▎    | 106/200 [17:54:21<15:52:19, 607.87s/it]

Epoch 106 / val/metric=1.0182
Epoch 107 / trn/loss=0.2725
Epoch 107 / train/metric=0.1303


 54%|█████▎    | 107/200 [18:04:29<15:42:12, 607.87s/it]

Epoch 107 / val/metric=0.9878
Epoch 108 / trn/loss=0.2734
Epoch 108 / train/metric=0.1307


 54%|█████▍    | 108/200 [18:14:37<15:32:04, 607.87s/it]

Epoch 108 / val/metric=0.9758
Epoch 109 / trn/loss=0.2683
Epoch 109 / train/metric=0.1285


 55%|█████▍    | 109/200 [18:24:46<15:22:25, 608.19s/it]

Epoch 109 / val/metric=0.9741
Epoch 110 / trn/loss=0.2647
Epoch 110 / train/metric=0.1242


 55%|█████▌    | 110/200 [18:34:54<15:12:18, 608.20s/it]

Epoch 110 / val/metric=1.0009
Epoch 111 / trn/loss=0.2578
Epoch 111 / train/metric=0.1186


 56%|█████▌    | 111/200 [18:45:03<15:02:16, 608.28s/it]

Epoch 111 / val/metric=1.1161
Epoch 112 / trn/loss=0.2475
Epoch 112 / train/metric=0.1101


 56%|█████▌    | 112/200 [18:55:10<14:51:58, 608.16s/it]

Epoch 112 / val/metric=1.0938


In [None]:
#Execute this cell to fininsh the wandb run when you stopped training.+
import wandb
try: 
    wandb.log(
        {'best_total_log_loss': best_metric})
    wandb.finish()
    
except: 
    print('Wandb is already finished!')

In [None]:
ind = 23561
train_meta_df[train_meta_df['patient_id']==ind]

In [None]:
valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['patient_id']==ind], is_train = False, transform_set = None,
                                    remain_transforms_set = None)

valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                        num_workers = N_WORKERS, drop_last = False)     

In [None]:
X_outs=[]
ys=[]
model.eval()
model.load_state_dict(torch.load(f'{BASE_PATH}/weights/231003_resnet10t_dicomO_std_CV0.485.pt'))
for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
    batch_size = y.shape[0]
    X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
    X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
    y = y.to(DEVICE)           
    with torch.cuda.amp.autocast(enabled=True):                
        with torch.no_grad():                 
            X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
            X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

            X_any = X_any.to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

X_outs = np.vstack(X_outs) 
ys     = np.vstack(ys)
#metric = calculate_score(X_outs, ys, 'valid')                      

del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_any
gc.collect()        
torch.cuda.empty_cache()   

In [None]:
np.average(X_outs, axis = 0)


In [None]:
len(X_outs)