In [1]:
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from resnet3d import generate_model
from timm.utils import AverageMeter
import wandb


wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['CUDA_LAUNCH_BLOCKING']='1'

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


# Parameters

In [2]:
BASE_PATH  = '/home/junseonglee/Desktop/01_codes/inputs/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 256
BATCH_SIZE = 8
LR = 0.001
N_EPOCHS = 30
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 24
N_WORKERS = 12
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
train_df = train_df.sort_values(by=['patient_id'])


wandb_config = {
    'RESOL': RESOL,
    'BATCH_SIZE': BATCH_SIZE,
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    

}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#DEVICE = 'cpu'

# Data split

In [3]:
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
train_meta = pd.read_csv(f'{BASE_PATH}/train_series_meta.csv')
train_df = train_df.sort_values(by=['patient_id'])
train_df

TRAIN_PATH = BASE_PATH + "/train_images/"
n_chunk = 8
patients = os.listdir(TRAIN_PATH)
n_patients = len(patients)
rng_patients = np.linspace(0, n_patients+1, n_chunk+1, dtype = int)
patients_cts = glob.glob(f'{TRAIN_PATH}/*/*')
n_cts = len(patients_cts)
patients_cts_arr = np.zeros((n_cts, 2), int)
data_paths=[]
for i in range(0, n_cts):
    patient, ct = patients_cts[i].split('/')[-2:]
    patients_cts_arr[i] = patient, ct
    data_paths.append(f'{BASE_PATH}/3d_preprocessed/{patients_cts_arr[i,0]}_{patients_cts_arr[i,1]}.pkl')
TRAIN_IMG_PATH = BASE_PATH + '/processed' 

#Generate tables for training
train_meta_df = pd.DataFrame(patients_cts_arr, columns = ['patient_id', 'series'])

#5-fold splitting
train_df['fold'] = 0
labels = train_df[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high',
                    'any_injury']].to_numpy()

mskf = MultilabelStratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=0)
counter = 0
for train_index, test_index in mskf.split(np.ones(len(train_df)), labels):
    for i in range(0, len(test_index)):
        train_df['fold'][test_index[i]] = counter
    counter+=1

train_meta_df = train_meta_df.join(train_df.set_index('patient_id'), on='patient_id')
train_meta_df['path']=data_paths
train_meta_df.to_csv(f'{BASE_PATH}/train_meta.csv', index = False)
np.unique(train_df['fold'].to_numpy(), return_counts = True)


(array([0, 1, 2, 3, 4]), array([630, 629, 630, 629, 629]))

# Dataset

In [4]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data

In [5]:
'''
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        new_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
        pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)
    return pixel_array
'''
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    pixel_rep = dcm.PixelRepresentation
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
        
        #pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

        return pixel_array, pixel_rep, bit_shift, dtype
    else:
        return 0

In [6]:
# Read each slice and stack them to make 3d data
def process_3d(save_path, data_path = TRAIN_PATH):
    tmp = save_path.split('/')[-1][:-4]
    tmp = tmp.split('_')
    patient, study = int(tmp[0]), int(tmp[1])
    imgs = {}    
    
    for f in sorted(glob.glob(data_path + f'/{patient}/{study}/*.dcm')):      
        pixel_rep = 0
        bit_shift = 0
        dtype = 0
        try:
            dicom = pydicom.dcmread(f)        
            img, pixel_rep, bit_shift, dtype = standardize_pixel_array(dicom)
            img = img.astype(float)
            break
        except:
            continue
            
    for f in sorted(glob.glob(data_path + f'/{patient}/{study}/*.dcm')):
        #For the case that some of the image can't be read -> error without this though don't know why  
        img = dicomsdl.open(f).pixelData(storedvalue=True).astype(float)
        #dicom = pydicom.dcmread(f)
        #img = standardize_pixel_array(dicom).astype(float)
        #ind = int((f.split('/')[-1])[:-4])
        pos_z = -int((f.split('/')[-1])[:-4])
        imgs[pos_z] = img


    sample_z = np.linspace(0, len(imgs)-1, RESOL, dtype=int)

    imgs_3d = []
    for i, k in enumerate(sorted(imgs.keys())):
        if i in sample_z:
            img = imgs[k]
            imgs_3d.append(cv2.resize(img, (RESOL, RESOL))[None])
    
    imgs_3d = np.vstack(imgs_3d)
    
    
    nu = np.zeros((RESOL, RESOL, RESOL))

    for i in range(0, len(imgs_3d[0,0])):
        nu[:,:,i] = cv2.resize(imgs_3d[:,:,i], (RESOL, RESOL))
    imgs_3d  = nu            

    # To deal with random image edge    

    imgs_3d = ((imgs_3d - imgs_3d.min()) / (imgs_3d.max() - imgs_3d.min()))

    if dicom.PhotometricInterpretation == "MONOCHROME1":
        imgs_3d = 1.0 - imgs_3d
    
    #Samplewise standardization to deal with the variety of the test datset.
    std = np.std(imgs_3d)
    avg = np.average(imgs_3d)
    imgs_3d = (imgs_3d-avg)/std
    imgs_3d = imgs_3d.astype(np.float32)

    #here to
    compress(save_path, imgs_3d)                      

    del imgs, img, nu
    gc.collect()

    return imgs_3d

In [7]:
# Preprocess dataset
rng_samples = np.linspace(0, len(train_meta_df), N_PREPROCESS_CHUNKS+1, dtype = int)
def process_3d_wrapper(process_ind, rng_samples = rng_samples, train_meta_df = train_meta_df):
    for i in tqdm(range(rng_samples[process_ind], rng_samples[process_ind+1])):
        if not os.path.isfile(train_meta_df.iloc[i]['path']):
            process_3d(train_meta_df.iloc[i]['path'])

In [8]:
%%time
#if __name__ == '__main__':
Parallel(n_jobs = N_PREPROCESS_CHUNKS)(delayed(process_3d_wrapper)(i) for i in range(N_PREPROCESS_CHUNKS))
    #with Pool(N_PREPROCESS_CHUNKS) as p:
    #    p.map(process_3d_wrapper, range(0, N_PREPROCESS_CHUNKS))

100%|██████████| 196/196 [00:00<00:00, 22263.60it/s]
100%|██████████| 196/196 [00:00<00:00, 30610.80it/s]
100%|██████████| 196/196 [00:00<00:00, 22318.00it/s]
100%|██████████| 197/197 [00:00<00:00, 22893.02it/s]
100%|██████████| 196/196 [00:00<00:00, 37052.49it/s]
100%|██████████| 196/196 [00:00<00:00, 22747.82it/s]
100%|██████████| 196/196 [00:00<00:00, 18958.62it/s]
100%|██████████| 196/196 [00:00<00:00, 27403.70it/s]
100%|██████████| 197/197 [00:00<00:00, 19213.53it/s]
100%|██████████| 197/197 [00:00<00:00, 37535.91it/s]
100%|██████████| 196/196 [00:00<00:00, 18969.12it/s]
100%|██████████| 196/196 [00:00<00:00, 20773.85it/s]
100%|██████████| 196/196 [00:00<00:00, 18790.48it/s]
100%|██████████| 196/196 [00:00<00:00, 37677.42it/s]
100%|██████████| 196/196 [00:00<00:00, 36328.76it/s]
100%|██████████| 196/196 [00:00<00:00, 37348.76it/s]
100%|██████████| 197/197 [00:00<00:00, 19380.73it/s]
100%|██████████| 196/196 [00:00<00:00, 36785.56it/s]
100%|██████████| 197/197 [00:00<00:00, 37371.2

CPU times: user 95.2 ms, sys: 228 ms, total: 324 ms
Wall time: 637 ms


100%|██████████| 197/197 [00:00<00:00, 37789.98it/s]


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [9]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True):
        self.meta_df = meta_df
        self.is_train = is_train
        
    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]

        #To avoid loading issue when applying multiprocessing to the unzip module
        try:
            data_3d = decompress(row['path'])
        except:
            data_3d = process_3d(row['path'])           

        data_3d = data_3d.reshape(1, RESOL, RESOL, RESOL).astype(np.float32)  # channel, 3D 
        data_3d = torch.from_numpy(data_3d)
        
        #augmentation  
        #if self.is_train:            
        #    random_angle = np.random.rand(1)[0]*360.0-180.0
        #    data_3d = transforms.functional.rotate(data_3d, random_angle, transforms.InterpolationMode.BILINEAR)
            

        label = label.to_numpy().astype(np.float32)
                
        label = torch.from_numpy(label)
        return data_3d, label        

train_dataset = AbdominalCTDataset(train_meta_df)
data_3d, label = train_dataset[0]
print(label)

del train_dataset, data_3d, label
gc.collect()

tensor([1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.])


46

In [10]:
'''
#normalization parameter
train_dataset = AbdominalCTDataset(train_meta_df)
data_3d, label = train_dataset[0]

avgs = np.zeros(len(train_dataset))
stds = np.zeros(len(train_dataset))
for i in tqdm(range(0, len(train_dataset))):
    data_3d, label = train_dataset[i]
    data_3d = data_3d.numpy()
    avgs[i] = np.average(data_3d)
    stds[i] = np.std(data_3d)
print(np.average(avgs))
print(np.average(stds))    

del train_dataset, data_3d, label, avgs, stds
gc.collect()
'''

'\n#normalization parameter\ntrain_dataset = AbdominalCTDataset(train_meta_df)\ndata_3d, label = train_dataset[0]\n\navgs = np.zeros(len(train_dataset))\nstds = np.zeros(len(train_dataset))\nfor i in tqdm(range(0, len(train_dataset))):\n    data_3d, label = train_dataset[i]\n    data_3d = data_3d.numpy()\n    avgs[i] = np.average(data_3d)\n    stds[i] = np.std(data_3d)\nprint(np.average(avgs))\nprint(np.average(stds))    \n\ndel train_dataset, data_3d, label, avgs, stds\ngc.collect()\n'

# Model

In [11]:
class AbdominalClassifier(nn.Module):
    def __init__(self, model_depth, device = DEVICE):
        super().__init__()
        self.device = device
        self.resnet3d = generate_model(model_depth = model_depth, n_input_channels = 1)
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)
        size_res_out  = 56832
        self.fc_bowel = nn.Linear(size_res_out, 2)
        self.fc_extrav= nn.Linear(size_res_out, 2)
        self.fc_kidney= nn.Linear(size_res_out, 3)
        self.fc_liver = nn.Linear(size_res_out, 3)
        self.fc_spleen= nn.Linear(size_res_out, 3)
        
        self.maxpool  = nn.MaxPool1d(5, 1)

    def forward(self, x):
        x = self.resnet3d(x)
        for i in range(0, 4):
            x[i] = self.flatten(x[i])
        x = torch.cat(x, axis = 1)
        x     = self.dropout(x)
        bowel = self.fc_bowel(x)
        extrav= self.fc_extrav(x)
        kidney= self.fc_kidney(x)
        liver = self.fc_liver(x)
        spleen= self.fc_spleen(x)

        labels = torch.cat([bowel, extrav, kidney, liver, spleen], dim = 1)

        bowel_soft = self.softmax(bowel)
        extrav_soft = self.softmax(extrav)
        kidney_soft = self.softmax(kidney)
        liver_soft = self.softmax(liver)
        spleen_soft = self.softmax(spleen)

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        return labels, any_in

In [12]:
model = AbdominalClassifier(10)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

15094349


0

# Train

In [13]:
model = AbdominalClassifier(10)
model.to(DEVICE)


#scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)


weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))


In [14]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def normalize_arr_to_one(arr):
    norm = np.sum(arr, axis = 1)
    for i in range(0, len(arr[0])):
        arr[:,i]/=norm
    return arr

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def apply_normalization_to_labels(X_out):
    X_out[:,:2]    = normalize_arr_to_one(X_out[:,:2])
    X_out[:,2:4]   = normalize_arr_to_one(X_out[:,2:4])
    X_out[:,4:7]   = normalize_arr_to_one(X_out[:,4:7])
    X_out[:,7:10]  = normalize_arr_to_one(X_out[:,7:10])
    X_out[:,10:13] = normalize_arr_to_one(X_out[:,10:13])

    return X_out

def calc_log_loss(ys, X_outs, bowel_weights, extrav_weights, kidney_weights, \
                  liver_weights, spleen_weights, any_in_weights):
    loss = (
              sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights)
            + sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights)
            + sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights)
            + sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights)
            + sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights)
            + sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights)
        ) /6
    return loss
    

def calculate_score(X_outs, ys):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    bowel_weights  = ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]        

    loss = calc_log_loss(ys, X_outs, bowel_weights, extrav_weights, kidney_weights, 
                        liver_weights, spleen_weights, any_in_weights)  

    return loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    cpu_y = y.clone().detach().cpu().numpy()
    bowel_weights  =  cpu_y[:,0] + 2*cpu_y[:,1]
    extrav_weights = cpu_y[:,2] + 6*cpu_y[:,3]
    kidney_weights = cpu_y[:,4] + 2*cpu_y[:,5] + 4*cpu_y[:,6]
    liver_weights  = cpu_y[:,7] + 2*cpu_y[:,8] + 4*cpu_y[:,9]
    spleen_weights = cpu_y[:,10] + 2*cpu_y[:,11] + 4*cpu_y[:,12]
    any_in_weights = (np.ones(np.shape(cpu_y[:,13])) - cpu_y[:,13]) + 6*cpu_y[:,13]
    
    sum_bowel_weights  = np.sum(bowel_weights)
    sum_extrav_weights = np.sum(extrav_weights)
    sum_kidney_weights = np.sum(kidney_weights)
    sum_liver_weights  = np.sum(liver_weights)
    sum_spleen_weights = np.sum(spleen_weights)
    sum_any_in_weights = np.sum(any_in_weights)
    sum_total_weights  = sum_bowel_weights  + sum_extrav_weights + \
                         sum_kidney_weights + sum_liver_weights  + \
                         sum_spleen_weights + sum_any_in_weights
    
    bowel_ratio  = sum_bowel_weights / sum_total_weights
    extrav_ratio = sum_extrav_weights / sum_total_weights
    kidney_ratio = sum_kidney_weights / sum_total_weights
    liver_ratio  = sum_liver_weights / sum_total_weights
    spleen_ratio = sum_spleen_weights / sum_total_weights
    any_in_ratio = sum_any_in_weights / sum_total_weights

    loss  = crit_bowel(X_out[:,:2], y[:,:2]) * bowel_ratio
    loss += crit_extrav(X_out[:,2:4], y[:,2:4]) * extrav_ratio
    loss += crit_kidney(X_out[:,4:7], y[:,4:7]) * kidney_ratio
    loss += crit_liver(X_out[:,7:10], y[:,7:10])* liver_ratio
    loss += crit_spleen(X_out[:,10:13], y[:,10:13]) * spleen_ratio
    loss += crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)) * any_in_ratio
    
    return loss

In [15]:
if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False)

    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)          
    

    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    ttl_iters = N_EPOCHS * len(train_loader)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, 
                                                    steps_per_epoch=len(train_loader), epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    for epoch in range(0, N_EPOCHS):
        train_meters = {'loss': AverageMeter()}
        val_meters   = {'loss': AverageMeter()}
        
        model.train()
        pbar = tqdm(train_loader, leave=False)  

        X_outs=[]
        ys=[]

        for X, y in pbar:
            batch_size = X.shape[0]
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):  
                X_out, X_any  = model(X)
                loss = calculate_loss(X_out, X_any, y)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()          

            #Metric calculation
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
            X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
            X_any = X_any.detach().to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

            trn_loss = loss.item()      
            train_meters['loss'].update(trn_loss, n=X.size(0))     
            pbar.set_description(f'Train loss: {trn_loss}')   
        print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

        X_outs = np.vstack(X_outs) 
        ys     = np.vstack(ys)
        metric = calculate_score(X_outs, ys)                 
        print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

        del X, X_outs, y, ys, X_any
        gc.collect()

        X_outs=[]
        ys=[]
        model.eval()
        for X, y in tqdm(valid_loader, leave=False):        
            batch_size = X.shape[0]        
            X, y = X.to(DEVICE), y.to(DEVICE)
                 
            with torch.cuda.amp.autocast(enabled=True):                
                with torch.no_grad():                 
                    X_out, X_any = model(X)                                           
                    y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                              
                    X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                    X_any = X_any.to('cpu').numpy()
                    X_out = np.hstack([X_out, X_any])
                    X_outs.append(X_out)

                    y     = y.to('cpu').numpy()[:,:-1]
                    y_any = y_any.to('cpu').numpy()
                    y     = np.hstack([y, y_any])
                    ys.append(y)

        X_outs = np.vstack(X_outs) 
        ys     = np.vstack(ys)
        metric = calculate_score(X_outs, ys)                
        print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))   
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model, f'{BASE_PATH}/weights/best.pt')    
        val_metrics[epoch] = metric
        
        del X, X_outs, y, ys, X_any
        gc.collect()        

                                                                                  

Epoch 1 / trn/loss=1.1237
Best scales: [0.92219788 0.44487828 0.77525975 0.67475441 0.67475441 0.69858797
 0.69858797 0.51114335 0.        ]
Epoch 1 / train/metric=0.6871


                                                 

Best scales: [ 8.21434358  4.24757155  0.1         0.13200884  0.1         0.22219469
  2.52353917 19.56398344  0.        ]
Epoch 1 / val/metric=0.5686
Best val_metric 0.5685991180104767 at epoch 1!


                                                                                  

Epoch 2 / trn/loss=1.1402
Best scales: [0.95477161 0.42970047 0.92219788 0.67475441 0.67475441 0.49370479
 0.69858797 0.51114335 0.        ]
Epoch 2 / train/metric=0.7086


                                                 

Best scales: [ 5.60716994  0.33700643 12.03377841  1.49926843  0.92219788 68.26071834
  0.54789012  0.83099419  0.        ]
Epoch 2 / val/metric=0.5644
Best val_metric 0.5643671367117412 at epoch 2!


                                                                                  

Epoch 3 / trn/loss=1.2029
Best scales: [0.65173396 0.38720388 0.89073546 0.65173396 0.6294989  0.58727866
 0.83099419 0.54789012 0.        ]
Epoch 3 / train/metric=0.7705


                                                 

Best scales: [  0.89073546   0.273644     0.1          2.70495973 100.
  20.25501939   0.23004301   0.56724261   0.        ]
Epoch 3 / val/metric=0.5503
Best val_metric 0.5503281731331932 at epoch 3!


                                                                                  

Epoch 4 / trn/loss=1.2187
Best scales: [0.6294989  0.42970047 0.77525975 0.60802243 0.65173396 0.52919787
 0.72326339 0.54789012 0.        ]
Epoch 4 / train/metric=0.7996


                                                 

Best scales: [100.           0.26430815   0.20729218   2.19638537   0.25529081
   0.29331663   3.57078596   1.0234114    0.        ]
Epoch 4 / val/metric=0.6046


                                                                                  

Epoch 5 / trn/loss=1.1708
Best scales: [0.74881039 0.41504048 0.69858797 0.54789012 0.69858797 0.49370479
 0.67475441 0.6294989  0.        ]
Epoch 5 / train/metric=0.7439


                                                 

Best scales: [2.70495973 0.67475441 0.18679136 0.1        0.67475441 1.44811823
 1.60705282 1.60705282 0.        ]
Epoch 5 / val/metric=0.5964


                                                                                  

Epoch 6 / trn/loss=1.0704
Best scales: [0.77525975 0.42970047 0.92219788 0.74881039 0.74881039 0.6294989
 0.65173396 0.51114335 0.        ]
Epoch 6 / train/metric=0.6523


                                                 

Best scales: [7.66341087 1.35099352 0.4605922  0.56724261 0.69858797 1.3987131
 1.0969858  1.05956018 0.        ]
Epoch 6 / val/metric=0.5790


                                                                                  

Epoch 7 / trn/loss=1.0188
Best scales: [0.95477161 0.41504048 0.77525975 0.67475441 0.74881039 0.6294989
 0.67475441 0.49370479 0.        ]
Epoch 7 / train/metric=0.5945


                                                 

Best scales: [2.61267523 3.21764175 0.67475441 0.83099419 1.84642494 2.19638537
 0.6294989  0.41504048 0.        ]
Epoch 7 / val/metric=0.5540


                                                                                  

Epoch 8 / trn/loss=1.0057
Best scales: [0.92219788 0.44487828 0.92219788 0.74881039 0.89073546 0.67475441
 0.60802243 0.41504048 0.        ]
Epoch 8 / train/metric=0.5787


                                                 

Best scales: [0.32550886 0.20022004 0.51114335 0.4605922  0.56724261 0.92219788
 0.34891012 0.31440355 0.        ]
Epoch 8 / val/metric=0.5363
Best val_metric 0.5363354210430628 at epoch 8!


                                                                                  

Epoch 9 / trn/loss=0.9914
Best scales: [0.92219788 0.38720388 0.92219788 0.80264335 0.83099419 0.72326339
 0.6294989  0.49370479 0.        ]
Epoch 9 / train/metric=0.5653


                                                 

Best scales: [0.69858797 0.28330961 0.74881039 0.49370479 0.37399373 0.36123427
 0.37399373 0.273644   0.        ]
Epoch 9 / val/metric=0.5434


                                                                                  

Epoch 10 / trn/loss=0.9756
Best scales: [0.89073546 0.38720388 0.95477161 0.83099419 0.95477161 0.77525975
 0.6294989  0.4605922  0.        ]
Epoch 10 / train/metric=0.5544


                                                 

Best scales: [0.80264335 0.47686117 0.58727866 0.47686117 1.17584955 0.83099419
 0.20729218 0.25529081 0.        ]
Epoch 10 / val/metric=0.5418


                                                                                  

Epoch 11 / trn/loss=0.9806
Best scales: [0.95477161 0.42970047 0.9884959  0.74881039 0.92219788 0.77525975
 0.56724261 0.42970047 0.        ]
Epoch 11 / train/metric=0.5574


                                                 

Best scales: [0.83099419 0.14649714 0.67475441 0.77525975 0.83099419 0.67475441
 0.36123427 0.60802243 0.        ]
Epoch 11 / val/metric=0.5359
Best val_metric 0.5358745968276217 at epoch 11!


                                                                                  

Epoch 12 / trn/loss=0.9733
Best scales: [0.89073546 0.38720388 0.92219788 0.80264335 0.89073546 0.72326339
 0.60802243 0.42970047 0.        ]
Epoch 12 / train/metric=0.5519


                                                 

Best scales: [1.05956018 0.44487828 0.47686117 0.58727866 0.74881039 1.26038293
 0.6294989  0.58727866 0.        ]
Epoch 12 / val/metric=0.5465


                                                                                  

Epoch 13 / trn/loss=0.9580
Best scales: [0.9884959  0.40088063 0.95477161 0.83099419 0.89073546 0.80264335
 0.60802243 0.4605922  0.        ]
Epoch 13 / train/metric=0.5443


                                                 

Best scales: [1.21738273 0.18041864 0.95477161 0.49370479 1.30490198 2.27396575
 0.69858797 0.37399373 0.        ]
Epoch 13 / val/metric=0.5452


                                                                                  

Epoch 14 / trn/loss=0.9526
Best scales: [0.9884959  0.4605922  0.95477161 0.77525975 0.80264335 0.69858797
 0.60802243 0.44487828 0.        ]
Epoch 14 / train/metric=0.5398


                                                 

Best scales: [0.54789012 0.37399373 1.49926843 0.6294989  1.17584955 1.13573336
 0.44487828 0.37399373 0.        ]
Epoch 14 / val/metric=0.5451


                                                                                  

Epoch 15 / trn/loss=0.9331
Best scales: [0.95477161 0.44487828 0.92219788 0.83099419 0.83099419 0.67475441
 0.60802243 0.4605922  0.        ]
Epoch 15 / train/metric=0.5275


                                                 

Best scales: [0.74881039 0.24658111 1.17584955 1.05956018 0.6294989  0.80264335
 0.34891012 0.4605922  0.        ]
Epoch 15 / val/metric=0.5386


                                                                                  

Epoch 16 / trn/loss=0.9207
Best scales: [0.9884959  0.40088063 0.95477161 0.86034644 0.92219788 0.83099419
 0.60802243 0.4605922  0.        ]
Epoch 16 / train/metric=0.5183


                                                 

Best scales: [0.80264335 0.21461412 0.40088063 0.42970047 0.9884959  0.60802243
 0.95477161 0.52919787 0.        ]
Epoch 16 / val/metric=0.5431


                                                                                  

Epoch 17 / trn/loss=0.9050
Best scales: [0.89073546 0.38720388 0.83099419 0.77525975 0.80264335 0.74881039
 0.67475441 0.49370479 0.        ]
Epoch 17 / train/metric=0.5062


                                                 

Best scales: [0.80264335 0.32550886 0.89073546 0.58727866 0.77525975 0.95477161
 0.58727866 0.4605922  0.        ]
Epoch 17 / val/metric=0.5558


                                                                                  

Epoch 18 / trn/loss=0.8817
Best scales: [0.89073546 0.4605922  0.86034644 0.77525975 0.83099419 0.74881039
 0.6294989  0.4605922  0.        ]
Epoch 18 / train/metric=0.4887


                                                 

Best scales: [1.44811823 0.19338918 1.05956018 0.49370479 0.23816856 0.52919787
 0.86034644 0.6294989  0.        ]
Epoch 18 / val/metric=0.5541


                                                                                  

Epoch 19 / trn/loss=0.8552
Best scales: [0.86034644 0.42970047 0.89073546 0.83099419 0.86034644 0.67475441
 0.6294989  0.4605922  0.        ]
Epoch 19 / train/metric=0.4674


                                                 

Best scales: [3.00183581 1.0969858  2.27396575 0.52919787 1.26038293 5.80522552
 0.28330961 0.38720388 0.        ]
Epoch 19 / val/metric=0.5811


                                                                                  

Epoch 20 / trn/loss=0.8268
Best scales: [0.86034644 0.42970047 0.92219788 0.69858797 0.77525975 0.67475441
 0.6294989  0.47686117 0.        ]
Epoch 20 / train/metric=0.4464


                                                 

Best scales: [0.33700643 1.0969858  1.3987131  1.30490198 0.72326339 1.60705282
 0.32550886 1.21738273 0.        ]
Epoch 20 / val/metric=0.6040


                                                                                  

Epoch 21 / trn/loss=0.7999
Best scales: [0.67475441 0.42970047 0.92219788 0.83099419 0.72326339 0.65173396
 0.67475441 0.52919787 0.        ]
Epoch 21 / train/metric=0.4280


                                                 

Best scales: [1.49926843 0.32550886 1.3987131  0.273644   0.4605922  0.69858797
 0.72326339 2.19638537 0.        ]
Epoch 21 / val/metric=0.6244


                                                                                  

Epoch 22 / trn/loss=0.7577
Best scales: [0.80264335 0.41504048 0.92219788 0.77525975 0.69858797 0.65173396
 0.67475441 0.56724261 0.        ]
Epoch 22 / train/metric=0.3983


                                                 

Best scales: [1.60705282 0.52919787 0.40088063 1.0234114  0.67475441 1.05956018
 0.6294989  1.0969858  0.        ]
Epoch 22 / val/metric=0.6491


                                                                                  

Epoch 23 / trn/loss=0.7086
Best scales: [0.77525975 0.40088063 0.92219788 0.72326339 0.69858797 0.6294989
 0.67475441 0.52919787 0.        ]
Epoch 23 / train/metric=0.3622


                                                 

Best scales: [2.70495973 0.9884959  0.47686117 1.21738273 2.43744415 2.43744415
 0.40088063 1.0969858  0.        ]
Epoch 23 / val/metric=0.6947


                                                                                  

Epoch 24 / trn/loss=0.6589
Best scales: [0.89073546 0.40088063 0.92219788 0.80264335 0.65173396 0.6294989
 0.74881039 0.60802243 0.        ]
Epoch 24 / train/metric=0.3275


                                                 

Best scales: [1.72258597 1.72258597 0.42970047 0.83099419 0.6294989  0.92219788
 0.42970047 1.0234114  0.        ]
Epoch 24 / val/metric=0.7223


                                                                                  

Epoch 25 / trn/loss=0.6208
Best scales: [0.77525975 0.41504048 0.83099419 0.77525975 0.65173396 0.6294989
 0.6294989  0.60802243 0.        ]
Epoch 25 / train/metric=0.2975


                                                 

Best scales: [2.70495973 0.92219788 0.28330961 0.44487828 0.69858797 1.35099352
 0.51114335 0.56724261 0.        ]
Epoch 25 / val/metric=0.7190


                                                                                  

Epoch 26 / trn/loss=0.5836
Best scales: [0.69858797 0.38720388 0.86034644 0.77525975 0.72326339 0.56724261
 0.69858797 0.58727866 0.        ]
Epoch 26 / train/metric=0.2715


                                                 

Best scales: [1.55222536 0.83099419 0.80264335 0.86034644 0.67475441 1.84642494
 0.36123427 0.56724261 0.        ]
Epoch 26 / val/metric=0.7722


                                                                                  

Epoch 27 / trn/loss=0.5538
Best scales: [0.72326339 0.38720388 0.89073546 0.77525975 0.69858797 0.72326339
 0.69858797 0.58727866 0.        ]
Epoch 27 / train/metric=0.2513


                                                 

Best scales: [1.84642494 0.9884959  0.69858797 0.9884959  0.60802243 1.13573336
 0.74881039 0.89073546 0.        ]
Epoch 27 / val/metric=0.8353


                                                                                  

Epoch 28 / trn/loss=0.5278
Best scales: [0.69858797 0.37399373 0.89073546 0.72326339 0.6294989  0.60802243
 0.74881039 0.58727866 0.        ]
Epoch 28 / train/metric=0.2317


                                                 

Best scales: [2.61267523 1.30490198 0.77525975 1.30490198 1.05956018 1.44811823
 0.69858797 1.17584955 0.        ]
Epoch 28 / val/metric=0.8153


                                                                                  

Epoch 29 / trn/loss=0.5196
Best scales: [0.65173396 0.41504048 0.80264335 0.69858797 0.67475441 0.67475441
 0.74881039 0.56724261 0.        ]
Epoch 29 / train/metric=0.2252


                                                

KeyboardInterrupt: 

: 