In [1]:
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import h5py
import bz2
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn

import segmentation_models_pytorch as smp
from resnet3d import generate_model
import timm
from timm.utils import AverageMeter
import wandb

sys.path.append('./lib_models')

wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['CUDA_LAUNCH_BLOCKING']='1'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjunseonglee[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/junseonglee/.netrc


# Parameters

In [2]:
BASE_PATH  = '/home/junseonglee/01_codes/input/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

RESOL = 128
BATCH_SIZE = 8
LR = 0.001
N_EPOCHS = 30
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
train_df = train_df.sort_values(by=['patient_id'])
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2

backbone = 'resnet18d'


wandb_config = {
    'RESOL': RESOL,
    'BATCH_SIZE': BATCH_SIZE,
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    

}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#DEVICE = 'cpu'

# Data split

In [3]:
train_df = pd.read_csv(f'{BASE_PATH}/train.csv')
train_meta = pd.read_csv(f'{BASE_PATH}/train_series_meta.csv')
train_df = train_df.sort_values(by=['patient_id'])
train_df

TRAIN_PATH = BASE_PATH + "/train_images/"
n_chunk = 8
patients = os.listdir(TRAIN_PATH)
n_patients = len(patients)
rng_patients = np.linspace(0, n_patients+1, n_chunk+1, dtype = int)
patients_cts = glob.glob(f'{TRAIN_PATH}/*/*')
n_cts = len(patients_cts)
patients_cts_arr = np.zeros((n_cts, 2), int)
data_paths=[]
for i in range(0, n_cts):
    patient, ct = patients_cts[i].split('/')[-2:]
    patients_cts_arr[i] = patient, ct
    data_paths.append(f'{BASE_PATH}/3d_preprocessed/{patients_cts_arr[i,0]}_{patients_cts_arr[i,1]}.pkl')
TRAIN_IMG_PATH = BASE_PATH + '/processed' 

#Generate tables for training
train_meta_df = pd.DataFrame(patients_cts_arr, columns = ['patient_id', 'series'])

#5-fold splitting
train_df['fold'] = 0
labels = train_df[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high',
                    'any_injury']].to_numpy()

mskf = MultilabelStratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=0)
counter = 0
for train_index, test_index in mskf.split(np.ones(len(train_df)), labels):
    for i in range(0, len(test_index)):
        train_df['fold'][test_index[i]] = counter
    counter+=1

train_meta_df = train_meta_df.join(train_df.set_index('patient_id'), on='patient_id')
train_meta_df['path']=data_paths
train_meta_df.to_csv(f'{BASE_PATH}/train_meta.csv', index = False)
np.unique(train_df['fold'].to_numpy(), return_counts = True)


(array([0, 1, 2, 3, 4]), array([630, 629, 630, 629, 629]))

# Dataset

In [4]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data

In [5]:
'''
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        new_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
        pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)
    return pixel_array
'''
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
    return pixel_array


In [6]:
# Read each slice and stack them to make 3d data
def process_3d(save_path, data_path = TRAIN_PATH):
    tmp = save_path.split('/')[-1][:-4]
    tmp = tmp.split('_')
    patient, study = int(tmp[0]), int(tmp[1])
    imgs = {}    
    
    for f in sorted(glob.glob(data_path + f'/{patient}/{study}/*.dcm')):      
        pixel_rep = 0
        bit_shift = 0
        dtype = 0
        try:
            dicom = pydicom.dcmread(f)        
            img, pixel_rep, bit_shift, dtype = standardize_pixel_array(dicom)
            img = img.astype(float)
            break
        except:
            continue
            
    for f in sorted(glob.glob(data_path + f'/{patient}/{study}/*.dcm')):
        #For the case that some of the image can't be read -> error without this though don't know why  
        img = dicomsdl.open(f).pixelData(storedvalue=True).astype(float)
        #dicom = pydicom.dcmread(f)
        #img = standardize_pixel_array(dicom).astype(float)
        #ind = int((f.split('/')[-1])[:-4])
        pos_z = -int((f.split('/')[-1])[:-4])
        imgs[pos_z] = img


    sample_z = np.linspace(0, len(imgs)-1, RESOL, dtype=int)

    imgs_3d = []
    for i, k in enumerate(sorted(imgs.keys())):
        if i in sample_z:
            img = imgs[k]
            imgs_3d.append(cv2.resize(img, (RESOL, RESOL))[None])
    
    imgs_3d = np.vstack(imgs_3d)
    
    
    nu = np.zeros((RESOL, RESOL, RESOL))

    for i in range(0, len(imgs_3d[0,0])):
        nu[:,:,i] = cv2.resize(imgs_3d[:,:,i], (RESOL, RESOL))
    imgs_3d  = nu            

    # To deal with random image edge    

    imgs_3d = ((imgs_3d - imgs_3d.min()) / (imgs_3d.max() - imgs_3d.min()))

    if dicom.PhotometricInterpretation == "MONOCHROME1":
        imgs_3d = 1.0 - imgs_3d
    
    #Samplewise standardization to deal with the variety of the test datset.
    std = np.std(imgs_3d)
    avg = np.average(imgs_3d)
    imgs_3d = (imgs_3d-avg)/std
    imgs_3d = imgs_3d.astype(np.float32)

    #here to
    compress(save_path, imgs_3d)                      

    del imgs, img, nu
    gc.collect()

    return imgs_3d

In [7]:
# Preprocess dataset
rng_samples = np.linspace(0, len(train_meta_df), N_PREPROCESS_CHUNKS+1, dtype = int)
def process_3d_wrapper(process_ind, rng_samples = rng_samples, train_meta_df = train_meta_df):
    for i in tqdm(range(rng_samples[process_ind], rng_samples[process_ind+1])):
        if not os.path.isfile(train_meta_df.iloc[i]['path']):
            process_3d(train_meta_df.iloc[i]['path'])

In [8]:
%%time
#if __name__ == '__main__':
Parallel(n_jobs = N_PREPROCESS_CHUNKS)(delayed(process_3d_wrapper)(i) for i in range(N_PREPROCESS_CHUNKS))
    #with Pool(N_PREPROCESS_CHUNKS) as p:
    #    p.map(process_3d_wrapper, range(0, N_PREPROCESS_CHUNKS))

CPU times: user 35.3 ms, sys: 145 ms, total: 181 ms
Wall time: 539 ms


100%|██████████| 392/392 [00:00<00:00, 29438.46it/s]
100%|██████████| 393/393 [00:00<00:00, 29438.71it/s]
100%|██████████| 393/393 [00:00<00:00, 29790.38it/s]
100%|██████████| 393/393 [00:00<00:00, 29786.08it/s]
100%|██████████| 392/392 [00:00<00:00, 26415.29it/s]
100%|██████████| 393/393 [00:00<00:00, 16791.57it/s]
100%|██████████| 393/393 [00:00<00:00, 29532.59it/s]
100%|██████████| 393/393 [00:00<00:00, 30154.42it/s]
100%|██████████| 392/392 [00:00<00:00, 29592.64it/s]
100%|██████████| 392/392 [00:00<00:00, 29634.24it/s]
100%|██████████| 393/393 [00:00<00:00, 29382.03it/s]
100%|██████████| 392/392 [00:00<00:00, 29533.11it/s]


[None, None, None, None, None, None, None, None, None, None, None, None]

In [9]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True):
        self.meta_df = meta_df
        self.is_train = is_train
        
    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]

        #To avoid loading issue when applying multiprocessing to the unzip module
        try:
            data_3d = decompress(row['path'])
        except:
            data_3d = process_3d(row['path'])           

        data_3d = data_3d.reshape(1, RESOL, RESOL, RESOL).astype(np.float32)  # channel, 3D 
        data_3d = torch.from_numpy(data_3d)
        
        #augmentation  
        #if self.is_train:            
        #    random_angle = np.random.rand(1)[0]*360.0-180.0
        #    data_3d = transforms.functional.rotate(data_3d, random_angle, transforms.InterpolationMode.BILINEAR)
            

        label = label.to_numpy().astype(np.float32)
                
        label = torch.from_numpy(label)
        return data_3d, label        

train_dataset = AbdominalCTDataset(train_meta_df)
data_3d, label = train_dataset[0]
print(label)

del train_dataset, data_3d, label
gc.collect()

tensor([1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1.])


71

In [10]:
'''
#normalization parameter
train_dataset = AbdominalCTDataset(train_meta_df)
data_3d, label = train_dataset[0]

avgs = np.zeros(len(train_dataset))
stds = np.zeros(len(train_dataset))
for i in tqdm(range(0, len(train_dataset))):
    data_3d, label = train_dataset[i]
    data_3d = data_3d.numpy()
    avgs[i] = np.average(data_3d)
    stds[i] = np.std(data_3d)
print(np.average(avgs))
print(np.average(stds))    

del train_dataset, data_3d, label, avgs, stds
gc.collect()
'''

'\n#normalization parameter\ntrain_dataset = AbdominalCTDataset(train_meta_df)\ndata_3d, label = train_dataset[0]\n\navgs = np.zeros(len(train_dataset))\nstds = np.zeros(len(train_dataset))\nfor i in tqdm(range(0, len(train_dataset))):\n    data_3d, label = train_dataset[i]\n    data_3d = data_3d.numpy()\n    avgs[i] = np.average(data_3d)\n    stds[i] = np.std(data_3d)\nprint(np.average(avgs))\nprint(np.average(stds))    \n\ndel train_dataset, data_3d, label, avgs, stds\ngc.collect()\n'

# Model

In [11]:
from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output


#m = TimmSegModel(backbone)
#m = convert_3d(m)
#out = m(torch.rand(1, 1, 128,128,128))
#for i in range(0, len(out)):
#    print(out[i].shape)

In [12]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=1,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 1, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )
        self.avgpool = nn.AvgPool2d(3, 2, 1)
    
    def forward(self,x):
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.avgpool(global_features[i])
        return global_features
        #seg_features = self.decoder(*global_features)
        #seg_features = self.segmentation_head(seg_features)


In [13]:
class AbdominalClassifier(nn.Module):
    def __init__(self, model_depth, device = DEVICE):
        super().__init__()
        self.device = device
        #self.resnet3d = generate_model(model_depth = model_depth, n_input_channels = 1)
        self.resnet3d = TimmSegModel(backbone)
        self.resnet3d = convert_3d(self.resnet3d)
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)
        size_res_out  = 37440
        self.fc_bowel = nn.Linear(size_res_out, 2)
        self.fc_extrav= nn.Linear(size_res_out, 2)
        self.fc_kidney= nn.Linear(size_res_out, 3)
        self.fc_liver = nn.Linear(size_res_out, 3)
        self.fc_spleen= nn.Linear(size_res_out, 3)
        
        self.maxpool  = nn.MaxPool1d(5, 1)

    def forward(self, x):
        x = self.resnet3d(x)
        pooled_features = []
        for i in range(0, len(x)):        
            pooled_features.append(self.flatten(torch.sum(x[i], dim = 1)))

        for i in range(0, 4):
            x[i] = self.flatten(pooled_features[i])
        x = torch.cat(x, axis = 1)
        x     = self.dropout(x)
        bowel = self.fc_bowel(x)
        extrav= self.fc_extrav(x)
        kidney= self.fc_kidney(x)
        liver = self.fc_liver(x)
        spleen= self.fc_spleen(x)

        labels = torch.cat([bowel, extrav, kidney, liver, spleen], dim = 1)

        bowel_soft = self.softmax(bowel)
        extrav_soft = self.softmax(extrav)
        kidney_soft = self.softmax(kidney)
        liver_soft = self.softmax(liver)
        spleen_soft = self.softmax(spleen)

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        return labels, any_in

In [14]:
model = AbdominalClassifier(10)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

43614957


0

# Train

In [15]:
model = AbdominalClassifier(18)
model.to(DEVICE)


#scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)


weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))


In [16]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys):
    bowel_weights =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    loss = (
             sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights)
           + sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights)
           + sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights)
           + sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights)
           + sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights)
           + sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights)
           ) / 6
    return loss

In [17]:
if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=0], is_train = True)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==0], is_train = False)

    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = True, 
                            num_workers = 8, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = True, 
                            num_workers = 8, drop_last = False)          
    

    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    ttl_iters = N_EPOCHS * len(train_loader)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, 
                                                    steps_per_epoch=len(train_loader), epochs = N_EPOCHS)

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    for epoch in range(0, N_EPOCHS):
        train_meters = {'loss': AverageMeter()}
        val_meters   = {'loss': AverageMeter()}
        
        model.train()
        pbar = tqdm(train_loader, leave=False)  

        X_outs=[]
        ys=[]

        for X, y in pbar:
            batch_size = X.shape[0]
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=True):  
                X_out, X_any  = model(X)
                loss  = crit_bowel(X_out[:,:2], y[:,:2])
                loss += crit_extrav(X_out[:,2:4], y[:,2:4])
                loss += crit_kidney(X_out[:,4:7], y[:,4:7])
                loss += crit_liver(X_out[:,7:10], y[:,7:10])
                loss += crit_spleen(X_out[:,10:13], y[:,10:13])

                loss += crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))  
                loss /= 6
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()          

            #Metric calculation
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
            X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
            X_any = X_any.detach().to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

            trn_loss = loss.item()      
            train_meters['loss'].update(trn_loss, n=X.size(0))     
            pbar.set_description(f'Train loss: {trn_loss}')   
        print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

        X_outs = np.vstack(X_outs) 
        ys     = np.vstack(ys)
        metric = calculate_score(X_outs, ys)                 
        print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

        del X, X_outs, y, ys, X_any
        gc.collect()

        X_outs=[]
        ys=[]
        model.eval()
        for X, y in tqdm(valid_loader, leave=False):        
            batch_size = X.shape[0]        
            X, y = X.to(DEVICE), y.to(DEVICE)
                 
            with torch.cuda.amp.autocast(enabled=True):                
                with torch.no_grad():                 
                    X_out, X_any = model(X)                                           
                    y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                              
                    X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                    X_any = X_any.to('cpu').numpy()
                    X_out = np.hstack([X_out, X_any])
                    X_outs.append(X_out)

                    y     = y.to('cpu').numpy()[:,:-1]
                    y_any = y_any.to('cpu').numpy()
                    y     = np.hstack([y, y_any])
                    ys.append(y)

        X_outs = np.vstack(X_outs) 
        ys     = np.vstack(ys)
        metric = calculate_score(X_outs, ys)                
        print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))   
        
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model, f'{BASE_PATH}/weights/best.pt')    
        val_metrics[epoch] = metric
        
        del X, X_outs, y, ys, X_any
        gc.collect()        

                                       

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x37440 and 299520x2)