In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb
sys.path.append('./lib_models')

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


KeyboardInterrupt: 

# Parameters

In [None]:
backbone = 'seresnext50_32x4d.racm_in1k'

IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'augmentation'
RUN_NAME=   f'{backbone}_drop_patch'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 16
ACCUM_STEPS = 1
N_WORKERS  = 8
LR = 0.001
N_EPOCHS = 100
EARLY_STOP_COUNT = 20
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0

DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'RESOL': RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [None]:
# Mask related parameters
# Order 1: Bowel, 2: left kidney, 3: right kidney, 4: liver, 5: spleen

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [None]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):  
    np.save(name, data)

def decompress_fast(name):
    data = np.load(f'{name}.npy')
    return data

# Model

In [None]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [None]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm_new.create_model(
            backbone,
            in_chans=N_CHANNELS,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, N_CHANNELS, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], 13, 1))
            self.batchnorms.append(nn.BatchNorm2d(g[i].shape[1]))
            self.batchnorms13.append(nn.BatchNorm2d(13))

        del g
        gc.collect()
    def forward(self,x):
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])

        return global_features

In [None]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        self.upsample = torch.nn.Upsample(size = [UP_RESOL, UP_RESOL, UP_RESOL])
        self.resnet3d = TimmSegModel(backbone)
        self.resnet3d = convert_3d(self.resnet3d)
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.upsample(x)
        x = self.resnet3d(x)
        pooled_features = []
        for i in range(0, len(x)):        
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, 13, 1)))
            
        x = torch.cat(pooled_features, dim=2)
        labels = torch.mean(x, dim=2)
        
        bowel_soft = self.softmax(labels[:,0:2])
        extrav_soft = self.softmax(labels[:,2:4])
        kidney_soft = self.softmax(labels[:,4:7])
        liver_soft = self.softmax(labels[:,7:10])
        spleen_soft = self.softmax(labels[:,10:13])

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        return labels, any_in

In [None]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

28521267


0

# Metric & Loss

In [None]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [None]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]
    

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights)
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights)
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights)
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights)
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights)
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights)
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [None]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(keys=["image"], translate_range=[int(x*y) for x, y in zip([RESOL, RESOL, RESOL], [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=("image"), prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

remain_transforms_train = transforms.Compose([
    transforms.RandRotate(range_x = np.pi, range_y = np.pi, range_z = np.pi, prob = 0.1)
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [None]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

In [None]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3d = decompress_fast(self.meta_df.iloc[i]['cropped_path'])[None]            
            tmp_data_3d = torch.from_numpy(tmp_data_3d)
            self.data_3ds.append(tmp_data_3d)
        self.data_3ds = torch.cat(self.data_3ds, dim = 0)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = torch.clone(self.data_3ds[idx])
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set({'image':data_3d})
                data_3d = data_3d['image']        

            if self.remain_transforms_set is not None:   
                data_3d = self.remain_transforms_set(data_3d)
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                
        
        return data_3d, label        

gc.collect()

0

# Train loop

In [None]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    last_count_on = False
    for X, y in train_loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        wandb.log({'lr': current_lr})
        
        batch_size = X.shape[0]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X)
            do_mixup = False
            if np.random.random() < p_mixup:
                do_mixup = True
                X, y, labels_shuffled, lam = mixup(X, y)                
            
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
            if do_mixup:
                bowel_loss2, extrav_loss2, kidney_loss2, liver_loss2, spleen_loss2, any_in_loss2, avg_loss2 = calculate_loss(X_out, X_any, labels_shuffled)
                bowel_loss  = bowel_loss * lam  + bowel_loss2 * (1 - lam)
                extrav_loss = extrav_loss * lam  + extrav_loss2 * (1 - lam)
                kidney_loss = kidney_loss * lam  + kidney_loss2 * (1 - lam)         
                liver_loss  = liver_loss * lam  + liver_loss2 * (1 - lam) 
                spleen_loss = spleen_loss * lam  + spleen_loss2 * (1 - lam) 
                any_in_loss = any_in_loss * lam  + any_in_loss2 * (1 - lam) 
                avg_loss = avg_loss * lam  + avg_loss2 * (1 - lam)       
                
            step = 'train'
            wandb.log({f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
        counter+=1                   

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=X.size(0))     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader):
        val_meters   = {'loss': AverageMeter()}
        X_outs=[]
        ys=[]
        model.eval()
        for X, y in valid_loader:            
            batch_size = X.shape[0]        
            X, y = X.to(DEVICE), y.to(DEVICE)                 
            with torch.cuda.amp.autocast(enabled=True):                
                with torch.no_grad():                 
                    X_out, X_any = model(X)                                           
                    y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                    X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                    X_any = X_any.to('cpu').numpy()
                    X_out = np.hstack([X_out, X_any])
                    X_outs.append(X_out)

                    y     = y.to('cpu').numpy()[:,:-1]
                    y_any = y_any.to('cpu').numpy()
                    y     = np.hstack([y, y_any])
                    ys.append(y)

        X_outs = np.vstack(X_outs) 
        ys     = np.vstack(ys)
        metric = calculate_score(X_outs, ys, 'valid')                
        print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
        
        del X, X_outs, y, ys, X_any
        gc.collect()        
        torch.cuda.empty_cache()   
        return metric 

In [None]:
model = AbdominalClassifier()
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False, transform_set = None,
                                        remain_transforms_set = None)        
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = len(train_loader)
    #scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, 
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     

        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader)
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model, f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE}.pt')    
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

100%|██████████| 3782/3782 [01:22<00:00, 45.64it/s]
100%|██████████| 929/929 [00:15<00:00, 60.45it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 / trn/loss=0.9179
Epoch 1 / train/metric=0.6644
Epoch 1 / val/metric=0.6090
Best val_metric 0.6090197348795995 at epoch 1!


  1%|          | 1/100 [04:04<6:43:03, 244.28s/it]

Epoch 2 / trn/loss=0.8585
Epoch 2 / train/metric=0.6108
Epoch 2 / val/metric=0.6008
Best val_metric 0.6007511023174777 at epoch 2!


  2%|▏         | 2/100 [07:43<6:14:41, 229.40s/it]

Epoch 3 / trn/loss=0.8529
Epoch 3 / train/metric=0.6061
Epoch 3 / val/metric=0.5974
Best val_metric 0.5974259743845357 at epoch 3!


  3%|▎         | 3/100 [11:26<6:06:22, 226.63s/it]

Epoch 4 / trn/loss=0.8470
Epoch 4 / train/metric=0.6015
Epoch 4 / val/metric=0.5871
Best val_metric 0.5870917796636846 at epoch 4!


  4%|▍         | 4/100 [15:07<5:58:51, 224.28s/it]

Epoch 5 / trn/loss=0.8408
Epoch 5 / train/metric=0.5972
Epoch 5 / val/metric=0.5770
Best val_metric 0.5770244789954976 at epoch 5!


  5%|▌         | 5/100 [18:59<5:59:21, 226.96s/it]

Epoch 6 / trn/loss=0.8166
Epoch 6 / train/metric=0.5742


  6%|▌         | 6/100 [22:35<5:50:15, 223.57s/it]

Epoch 6 / val/metric=0.5846
Epoch 7 / trn/loss=0.8137
Epoch 7 / train/metric=0.5720
Epoch 7 / val/metric=0.5341
Best val_metric 0.5341452922650022 at epoch 7!


  7%|▋         | 7/100 [26:29<5:51:34, 226.83s/it]

Epoch 8 / trn/loss=0.8054
Epoch 8 / train/metric=0.5659


  8%|▊         | 8/100 [30:09<5:44:39, 224.78s/it]

Epoch 8 / val/metric=0.5371
Epoch 9 / trn/loss=0.8006
Epoch 9 / train/metric=0.5630


  9%|▉         | 9/100 [33:48<5:37:47, 222.72s/it]

Epoch 9 / val/metric=0.5554
Epoch 10 / trn/loss=0.8003
Epoch 10 / train/metric=0.5630


 10%|█         | 10/100 [37:34<5:35:59, 223.99s/it]

Epoch 10 / val/metric=0.5417
Epoch 11 / trn/loss=0.7953
Epoch 11 / train/metric=0.5593


 11%|█         | 11/100 [41:15<5:30:29, 222.80s/it]

Epoch 11 / val/metric=0.5617
Epoch 12 / trn/loss=0.7977
Epoch 12 / train/metric=0.5609


 12%|█▏        | 12/100 [44:57<5:26:27, 222.59s/it]

Epoch 12 / val/metric=0.5454
Epoch 13 / trn/loss=0.7930
Epoch 13 / train/metric=0.5575


 13%|█▎        | 13/100 [48:39<5:22:38, 222.51s/it]

Epoch 13 / val/metric=0.5700
Epoch 14 / trn/loss=0.7892
Epoch 14 / train/metric=0.5549


 14%|█▍        | 14/100 [52:16<5:16:41, 220.94s/it]

Epoch 14 / val/metric=0.5416
Epoch 15 / trn/loss=0.7887
Epoch 15 / train/metric=0.5536


 15%|█▌        | 15/100 [56:01<5:14:25, 221.95s/it]

Epoch 15 / val/metric=0.6031
Epoch 16 / trn/loss=0.7846
Epoch 16 / train/metric=0.5531


 16%|█▌        | 16/100 [59:38<5:08:43, 220.51s/it]

Epoch 16 / val/metric=0.6574
Epoch 17 / trn/loss=0.7825
Epoch 17 / train/metric=0.5500


 17%|█▋        | 17/100 [1:03:19<5:05:12, 220.63s/it]

Epoch 17 / val/metric=0.5448
Epoch 18 / trn/loss=0.7884
Epoch 18 / train/metric=0.5546


 18%|█▊        | 18/100 [1:06:59<5:01:15, 220.44s/it]

Epoch 18 / val/metric=0.5642
Epoch 19 / trn/loss=0.7783
Epoch 19 / train/metric=0.5460


 19%|█▉        | 19/100 [1:10:41<4:58:13, 220.90s/it]

Epoch 19 / val/metric=0.5752
Epoch 20 / trn/loss=0.7753
Epoch 20 / train/metric=0.5439


 20%|██        | 20/100 [1:14:23<4:54:56, 221.20s/it]

Epoch 20 / val/metric=0.5476
Epoch 21 / trn/loss=0.7803
Epoch 21 / train/metric=0.5484


 21%|██        | 21/100 [1:18:02<4:50:42, 220.79s/it]

Epoch 21 / val/metric=0.5404
Epoch 22 / trn/loss=0.7747
Epoch 22 / train/metric=0.5435


 22%|██▏       | 22/100 [1:21:44<4:47:32, 221.19s/it]

Epoch 22 / val/metric=0.5379
Epoch 23 / trn/loss=0.7749
Epoch 23 / train/metric=0.5438


 23%|██▎       | 23/100 [1:25:24<4:43:03, 220.56s/it]

Epoch 23 / val/metric=0.5540
Epoch 24 / trn/loss=0.7641
Epoch 24 / train/metric=0.5360


 24%|██▍       | 24/100 [1:29:03<4:38:49, 220.13s/it]

Epoch 24 / val/metric=0.5612
Epoch 25 / trn/loss=0.7684
Epoch 25 / train/metric=0.5389


 25%|██▌       | 25/100 [1:32:43<4:35:02, 220.04s/it]

Epoch 25 / val/metric=0.5659
Epoch 26 / trn/loss=0.7594
Epoch 26 / train/metric=0.5316


 26%|██▌       | 26/100 [1:36:27<4:33:08, 221.46s/it]

Epoch 26 / val/metric=0.6712
Epoch 27 / trn/loss=0.7565
Epoch 27 / train/metric=0.5302


                                                     

Epoch 27 / val/metric=0.5372
Not improved for 20 epochs, terminate the train






0,1
best_total_log_loss,▁
lr,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train_any_in_metric,▄██▇█▆▄▃▂▃▃▂▃▂▂▃▂▃▂▁▂▂▂▁▁▁▂
train_any_loss,▃▃▅▄▅▅▂▅▃▅▄▂█▄▄▆▄▅▄▄▄▃▃▅▅▄▂▄▃▆▄▄▆▃▄▃▃▅▁▄
train_avg_loss,▅▃▅▄▅▆▃▆▃▅▄▃█▅▅█▅▆▅▅▅▄▄▅▆▄▂▅▃▆▄▃▆▃▇▃▄▆▁▅
train_avg_metric,█▅▅▅▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
train_bowel_loss,▄▂▁▁▁▁▁▁▁▁▁▄▅▁█▅▁▁▁▁▁▁▄▁▁▂▁▁▁▅▁▁▁▄▁▁▁█▁▄
train_bowel_metric,█▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train_extrav_loss,▄▆▆▁▆▁▁█▂▂▅▃▇▄▂▆▃█▃▄▂▁▃▂▅▁▃▄▁▄▂▄▄▄▇▂▄█▁▂
train_extrav_metric,█▆▆▅▅▆▇▅▅▄▄▅▅▄▅▂▃▅▅▄▅▄▄▂▃▃▁

0,1
best_total_log_loss,0.53415
lr,0.00098
train_any_in_metric,0.69596
train_any_loss,2.32557
train_avg_loss,0.86988
train_avg_metric,0.53017
train_bowel_loss,0.02187
train_bowel_metric,0.1571
train_extrav_loss,0.98188
train_extrav_metric,0.55704


In [None]:
#Execute this cell to fininsh the wandb run when you stopped training.

import wandb
try:
    wandb.log({'best_total_log_loss': best_metric})
    wandb.finish()
    
except:
    print('Wandb is already finished!')

Wandb is already finished!


: 