In [1]:
pip install '../input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl'

Processing /kaggle/input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl
Installing collected packages: monai
Successfully installed monai-0.6.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_auc_score
import glob

In [3]:
import albumentations as A
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.optim import lr_scheduler
from tqdm import tqdm
import re

from tensorflow import keras
import tensorflow as tf
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import nilearn as nl
import nibabel as nib
import cv2
from IPython.display import clear_output
from PIL import Image
from pydicom import read_file, dcmread
from scipy.ndimage import zoom

In [4]:
NUM_IMAGES_3D = 128
TRAINING_BATCH_SIZE = 1
TEST_BATCH_SIZE = 2
IMAGE_SIZE = 256
N_EPOCHS = 40
do_valid = True
n_workers = 4
type_ = "FLAIR"
MODEL_NAME = 'version1.enhanced'

In [5]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    if rotate > 0:
        rot_choices = [
            0,
            cv2.ROTATE_90_CLOCKWISE,
            cv2.ROTATE_90_COUNTERCLOCKWISE,
            cv2.ROTATE_180,
        ]
        data = cv2.rotate(data, rot_choices[rotate])

    data = cv2.resize(data, (img_size, img_size))
    return data

In [6]:
def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4
    for i in range(class_num):
        y_true_f = K.flatten(y_true[:,:,:,i])
        y_pred_f = K.flatten(y_pred[:,:,:,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
   #     K.print_tensor(loss, message='loss value for class {} : '.format(SEGMENT_CLASSES[i]))
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
    total_loss = total_loss / class_num
#    K.print_tensor(total_loss, message=' total dice coef: ')
    return total_loss


 
# define per class evaluation of dice coef
# inspired by https://github.com/keras-team/keras/issues/9395
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)



# Computing Precision 
def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    
# Computing Sensitivity      
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())


# Computing Specificity
def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())
############ load trained model ################
seg_model = keras.models.load_model('../input/modelperclasseval/model_per_class.h5', 
                                   custom_objects={ 'accuracy' : tf.keras.metrics.MeanIoU(num_classes=4),
                                                   "dice_coef": dice_coef,
                                                   "precision": precision,
                                                   "sensitivity":sensitivity,
                                                   "specificity":specificity,
                                                   "dice_coef_necrotic": dice_coef_necrotic,
                                                   "dice_coef_edema": dice_coef_edema,
                                                   "dice_coef_enhancing": dice_coef_enhancing
                                                  }, compile=False)

In [7]:
DIM = 128
SLICES = 64
MR_DIM = 256
TRAIN_DIR = r'../input/rsna-miccai-brain-tumor-radiogenomic-classification/train'
TEST_DIR = r'../input/rsna-miccai-brain-tumor-radiogenomic-classification/test'
flair = tf.keras.models.load_model('../input/mrequalizer-weights/flair_part1.h5')
t1ce = tf.keras.models.load_model('../input/mrequalizer-weights/t1ce_part1.h5')
t2 = tf.keras.models.load_model('../input/mrequalizer-weights/t2_part1.h5')

enhancer_models = {'FLAIR': flair, 'T1wCE': t1ce, 'T2w': t2}

planes_df = pd.read_csv('../input/labels-with-planes/train_labels_with_planes.csv',
                        header=0,
                        names=['ID', 'MGMT_Value', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])

planes_df['ID'] = [(5-len(str(n)))*'0'+str(n) for n in planes_df['ID']]

# create key-value format
AXIS = {}
col_names = ['MGMT_Value', 'FLAIR', 'T1w', 'T1wCE', 'T2w']
for index, row in planes_df.iterrows():
#     for name in col_names:
    AXIS[row['ID']] = {name: row[name] for name in col_names}

    
def axial_2_coronal(mri, show=False):
#     print('init shape:', mri.shape)
    cor = np.moveaxis(mri, 0, 2)
    cor = get_cropped_region(cor, reshape=(DIM, DIM))
    cor = np.moveaxis(cor, 0, 1)
    cor = np.rot90(cor, 2)
    if show:
        show_mri(cor)
    return cor

def sagittal_2_coronal(mri, show=False):
#     print('init shape:', mri.shape)
    cor = np.moveaxis(mri, 1, 2)
    cor = get_cropped_region(cor, reshape=(DIM, DIM))
    cor = np.moveaxis(cor, 0, 1)
    cor = np.rot90(cor, 3)
    if show:
        show_mri(cor)
    return cor
    
def get_num(name):
#     print(f'filename: {name}')
    return int(name[name.index('-')+1:name.index('.')])

def load_mri_arr(folder):
    for file in os.listdir(folder):
        slice_path = os.path.join(folder, file)
        dicom = read_file(slice_path)
        img = apply_voi_lut(dicom.pixel_array, dicom)
        shape = (img.shape[0], img.shape[1], len(os.listdir(folder)))
    out = np.empty(shape)
#     files = [os.path.join(folder, file) for file in os.listdir(folder)]
    data = sorted(os.listdir(folder), key=get_num)
    for i, file in enumerate(data):
        slice_path = os.path.join(folder, file)
        img = read_file(slice_path).pixel_array 
        out[:, :, i] = img
    out = out/np.max(out)
    return out

def show_mri(mri):
    for i in range(mri.shape[2]):
        clear_output(wait=True)
        plt.axis(False)
        plt.imshow(mri[:, :, i], cmap='gray')
        plt.show()


def lowleft_upright(img_arr):
    ret,thresh = cv2.threshold(img_arr,20,255,0)
    contours,hierarchy = cv2.findContours(thresh, 1, 2)
    cnt2cons = [cnt for cnt in contours if cv2.contourArea(cnt) >= 100]
    xmin, ymin = 500, 500
    xmax, ymax = -1, -1
    for cnt in cnt2cons:
        area = cv2.contourArea(cnt)
        x,y,w,h = cv2.boundingRect(cnt)
        img = np.zeros((240, 240, 3))
        cv2.rectangle(img,(x,y),(x+w,y+h),(255, 0, 0),2)
        xmin = min(xmin, x)
        xmax = max(xmax, x+w)
        ymin = min(ymin, y)
        ymax = max(ymax, y+h)
    return xmin, xmax, ymin, ymax


def get_cropped_region(img, reshape=None):
    img = img*255.
    img = img.astype(np.uint8)
    xmin_true, ymin_true = 500, 500
    xmax_true, ymax_true = -1, -1
#     print('marker 1')
    for i in range(img.shape[-1]):
#         print('marker 2')
        if np.sum(img[:, :, i]) == 0:
#             print('marker 3')
            continue
        xmin, xmax, ymin, ymax = lowleft_upright(img[:, :, i])
        xmin_true = min(xmin, xmin_true)
        ymin_true = min(ymin, ymin_true)
        xmax_true = max(xmax, xmax_true)
        ymax_true = max(ymax, ymax_true)
        
    count_black = 0
    for i in range(img.shape[-1]):
        count_black += int(np.sum(img[:, :, i]) == 0)
#     print(count_black)
    if ymax_true-ymin_true < 0 or xmax_true-xmin_true < 0 or img.shape[-1]-count_black < 0:
        print(ymax_true-ymin_true, xmax_true-xmin_true, img.shape[-1]-count_black)
    cropped = np.empty((ymax_true-ymin_true, xmax_true-xmin_true, img.shape[-1]-count_black))
    curr_slice = 0
    for i in range(img.shape[-1]):
        if np.sum(img[:, :, i]) == 0:
            continue
        
        crop = img[:, :, i][ymin_true:ymax_true, xmin_true:xmax_true]
        
        cropped[:, :, curr_slice] = crop
        curr_slice += 1
    cropped = cropped/np.max(cropped)
    if reshape is not None:
        reshaped = np.empty((reshape[0], reshape[1], cropped.shape[-1]))
        for i in range(cropped.shape[-1]):
            resized = cv2.resize(cropped[:, :, i], reshape, interpolation=cv2.INTER_AREA)
            reshaped[:, :, i] = resized
        return reshaped
    return cropped


def apply_equalizer(mri, img_type):
    original_dim = (mri.shape[1], mri.shape[0])
    new = np.empty((mri.shape[0], mri.shape[1], 2*mri.shape[2]-1))
    cur = 0
    for i in range(mri.shape[-1]-1):
#         print(mri[:, :, i].shape)
        in1 = cv2.resize(mri[:, :, i], (MR_DIM, MR_DIM))
        in2 = cv2.resize(mri[:, :, i+1], (MR_DIM, MR_DIM))
        output = get_middle_image(in1, in2, img_type)
        output = cv2.resize(output, original_dim)        
        new[:, :, cur] = mri[:, :, i]
        new[:, :, cur+1] = output
        cur += 2
    new[:, :, cur] = mri[:, :, -1]
    return new

def get_middle_image(in1, in2, img_type, show=False): # provide to 256x256 images
#     print(in1.shape)
    in1 = cv2.resize(in1, (MR_DIM, MR_DIM))
    in2 = cv2.resize(in2, (MR_DIM, MR_DIM))
    in1 = in1.reshape(1, MR_DIM, MR_DIM, 1)
    in2 = in2.reshape(1, MR_DIM, MR_DIM, 1)
    data = [in1, in2]
    output = np.array(enhancer_models[img_type](data))
    if show:
        f, axarr = plt.subplots(1, 3, squeeze=False)
        axarr[0][0].axis(False)
        axarr[0][1].axis(False)
        axarr[0][2].axis(False)
        axarr[0][0].imshow(in1[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        axarr[0][1].imshow(output[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        axarr[0][2].imshow(in2[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        plt.show()
    return output[0, :, :, 0]

def get_image(folder, img_type, equalizer_iters=0):
    loaded = load_mri_arr(folder)
#     loaded = get_cropped_region(loaded)
#     print('Before:', loaded.shape[-1])
#     for i in range(equalizer_iters):
    while loaded.shape[-1] < DIM and equalizer_iters != 0:
        loaded = apply_equalizer(loaded, img_type)
#     print('After:', loaded.shape[-1])
    return loaded

def resize(mri, shape):
    init_shape = mri.shape
    resized = zoom(mri, (shape[0]/init_shape[0], shape[1]/init_shape[1], shape[2]/init_shape[2]))
    return resized

def load_input_for_seg(folder, return_label=True, equalizer_iters=0):
    idnum = folder[-5:]
    data = AXIS[idnum]
    label = data['MGMT_Value']
    inp = np.empty((DIM, DIM, DIM, 2))
    
    path = os.path.join(folder, 'FLAIR')
    loaded = get_cropped_region(get_image(path, 'flair', equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
    loaded = resize(loaded, (DIM, DIM, DIM))
    if data['FLAIR'] == 'Sagittal':
        loaded = sagittal_2_coronal(loaded, False)
    if data['FLAIR'] == 'Axial':
        loaded = axial_2_coronal(loaded, False)
    inp[:, :, :, 0] = np.moveaxis(loaded, 2, 0)
    
    path = os.path.join(folder, 'T1wCE')
    loaded = get_cropped_region(get_image(path, 't1ce', equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
    loaded = resize(loaded, (DIM, DIM, DIM))
    if data['T1wCE'] == 'Sagittal':
        loaded = sagittal_2_coronal(loaded, False)
    if data['T1wCE'] == 'Axial':
        loaded = axial_2_coronal(loaded, False)
    inp[:, :, :, 1] = np.moveaxis(loaded, 2, 0)
    
    if return_label:
        return inp, label
    return inp


def get_mask(inp):
    out = seg_model.predict(inp)
    out = np.argmax(out, axis=-1)
#     print(out.shape)
    out[out != 0] = 1
    out[:12] = np.zeros((12, DIM, DIM), dtype=np.float32)
    out[-12:] = np.zeros((12, DIM, DIM), dtype=np.float32)
    return out

def find_largest_tumor_slice(mask): # return ind of largest tumor slice
    ind = -1
    largest_size = -1
    for i in range(mask.shape[-1]):
        curr = mask[:, :, i]
        tum_size = cv2.countNonZero(curr)
        if tum_size >= largest_size:
            largest_size = tum_size
            ind = i
    return ind

def load_input(
    case_id,
    num_imgs=NUM_IMAGES_3D,
    img_size=IMAGE_SIZE,
    rotate=0,
    equalizer_iters=0
):
    files_dir = r'../input/flair-npy/npy_files'
    loaded = np.load(os.path.join(files_dir, case_id+'.npy'))
    return loaded
#     idnum = case_id[-5:]
# #     print(idnum)
#     data = AXIS[idnum]
#     label = data['MGMT_Value']
#     path2 = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}"
#     path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/{type_}"
#     flair_path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/FLAIR"
#     t1wce_path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/T1wCE"

#     try:
#         loaded = get_cropped_region(get_image(path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
#         loaded = resize(loaded, (DIM, DIM, DIM))
#         loaded_flair = get_cropped_region(get_image(flair_path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
#         loaded_flair = resize(loaded_flair, (DIM, DIM, DIM))
#         loaded_t1wce = get_cropped_region(get_image(t1wce_path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
#         loaded_t1wce = resize(loaded_t1wce, (DIM, DIM, DIM))
#         seg_inp = np.empty((DIM, DIM, DIM, 2))
#         seg_inp[:, :, :, 0] = np.moveaxis(loaded_flair, 2, 0)
#         seg_inp[:, :, :, 1] = np.moveaxis(loaded_t1wce, 2, 0)
#         if data['FLAIR'] == 'Sagittal':
#             loadexd = sagittal_2_coronal(loaded, False)
#         if data['FLAIR'] == 'Axial':
#             loaded = axial_2_coronal(loaded, False)

#         mask = get_mask(seg_inp)
#         middle = find_largest_tumor_slice(mask)
#         sl_each_side = SLICES//2
#         e1 = max(0, middle-sl_each_side)
#         e2 = min(loaded.shape[-1]-1, middle+sl_each_side)

#         if middle-e1 < e2-middle:
#             e2 += 32-(middle-e1)
#         elif e2-middle < middle-e1:
#             e1 -= 32-(e2-middle)
#         to_return = np.empty((DIM, DIM, e2-e1))
#         count = 0
# #         print(e1, e2, middle)
#         for i in range(e1, e2):
#             to_return[:, :, count] = loaded[i]
#             count += 1
#         to_return = np.expand_dims(to_return, axis=0)
#         to_return = (to_return-np.min(to_return))/(np.max(to_return)-np.min(to_return))
#     #         print('marker 1')
#         return to_return
#     except Exception as e:
#         print(e)
#         print(f'Failed on case {idnum}')
#         return np.zeros((1, DIM, DIM, SLICES))    

In [8]:
import random

import cv2
from torch.utils.data import Dataset


class BrainRSNADataset(Dataset):
    def __init__(
        self, data, transform=None, target="MGMT_value", mri_type="FLAIR", is_train=True,
        enhanced=False
    ):
        self.target = target
        self.data = data
        self.type = mri_type

        self.transform = transform
        self.is_train = is_train
        self.folder = "train" #if self.is_train else "test"
        self.enhanced = enhanced

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
#         print('INDEX:', index)
        row = self.data.loc[index]
        case_id = int(row.BraTS21ID)
        target = int(row[self.target])
        _3d_images = self.load_dicom_images_3d(case_id)
        _3d_images = torch.tensor(_3d_images).float()
        if self.is_train:
            return {"image": _3d_images, "target": target, "case_id": case_id}
        else:
            return {"image": _3d_images, "case_id": case_id}

    def load_dicom_images_3d(
        self,
        case_id,
        num_imgs=NUM_IMAGES_3D,
        img_size=IMAGE_SIZE,
        rotate=0,
    ):
#         print('marker 1')

        case_id = str(case_id).zfill(5)
#         case_id = '00000'
        loaded = load_input(case_id, equalizer_iters=1 if self.enhanced else 0)
#         print(loaded.shape)
        return loaded

In [9]:
ls ../input/

[0m[01;34mflair-npy[0m/
[01;34mlabels-with-planes[0m/
[01;34mmodel-x80-dcs65[0m/
[01;34mmodelperclasseval[0m/
[01;34mmonai-v060-deep-learning-in-healthcare-imaging[0m/
[01;34mmrequalizer-weights[0m/
[01;34mresnet10rsna[0m/
[01;34mrsna-miccai-brain-tumor-radiogenomic-classification[0m/
[01;34mrsna-monai-packages[0m/


In [10]:
import monai

# model 
model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
device = torch.device("cuda")
model.to(device);
all_weights = os.listdir("../input/resnet10rsna")
fold_files = [f for f in all_weights if type_ in f]
criterion = nn.BCEWithLogitsLoss()

In [11]:
import argparse

import pandas as pd
from sklearn.model_selection import StratifiedKFold

# parser = argparse.ArgumentParser()
# parser.add_argument("--n_folds", default=5, type=int)
# args = parser.parse_args()

train = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=518)
oof = []
targets = []
target = "MGMT_value"

for fold, (trn_idx, val_idx) in enumerate(
    skf.split(train, train[target])
):
    train.loc[val_idx, "fold"] = int(fold)


train.to_csv("train.csv", index=False)

In [12]:
train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,1.0
1,2,1,2.0
2,3,0,0.0
3,5,1,1.0
4,6,1,2.0


In [13]:
data = pd.read_csv("./train.csv")
curr_fold = 1
train_df = data[data.fold != curr_fold].reset_index(drop=False)
val_df = data[data.fold == curr_fold].reset_index(drop=False)

In [14]:
train_df.head()

Unnamed: 0,index,BraTS21ID,MGMT_value,fold
0,1,2,1,2.0
1,2,3,0,0.0
2,4,6,1,2.0
3,6,9,0,4.0
4,7,11,1,2.0


In [15]:
train_dataset = BrainRSNADataset(data=train_df, mri_type=type_, is_train=True, enhanced=False)

valid_dataset = BrainRSNADataset(data=val_df, mri_type=type_, is_train=True, enhanced=False)


train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAINING_BATCH_SIZE,
    shuffle=True,
    num_workers=n_workers,
    drop_last=True,
    pin_memory=True,
)

validation_dl = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=n_workers,
    pin_memory=True,
)

In [16]:
model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.5, last_epoch=-1, verbose=True)

model.zero_grad()
model.to(device)
best_loss = 9999
best_auc = 0
criterion = nn.BCEWithLogitsLoss()
best_model = None
final_thresh = 0.5
best_val_acc = 0
for counter in range(N_EPOCHS):

    epoch_iterator_train = tqdm(train_dl)
    tr_loss = 0.0
    for step, batch in enumerate(epoch_iterator_train):
        model.train()
        images, targets = batch["image"].to(device), batch["target"].to(device)

        outputs = model(images)
        targets = targets  # .view(-1, 1)
        loss = criterion(outputs.squeeze(1), targets.float())

        loss.backward()
        optimizer.step()
        model.zero_grad()
        optimizer.zero_grad()

        tr_loss += loss.item()
        epoch_iterator_train.set_postfix(
            batch_loss=(loss.item()), loss=(tr_loss / (step + 1))
        )
    scheduler.step()  # Update learning rate schedule

    if do_valid:
        with torch.no_grad():
            val_loss = 0.0
            preds = []
            true_labels = []
            case_ids = []
            epoch_iterator_val = tqdm(validation_dl)
            for step, batch in enumerate(epoch_iterator_val):
                model.eval()
                images, targets = batch["image"].to(device), batch["target"].to(device)

                outputs = model(images)
                targets = targets  # .view(-1, 1)
                loss = criterion(outputs.squeeze(1), targets.float())
                val_loss += loss.item()
                epoch_iterator_val.set_postfix(
                    batch_loss=(loss.item()), loss=(val_loss / (step + 1))
                )
                preds.append(outputs.sigmoid().detach().cpu().numpy())
                true_labels.append(targets.cpu().numpy())
                case_ids.append(batch["case_id"])
        preds = np.vstack(preds).T[0].tolist()
        true_labels = np.hstack(true_labels).tolist()
        case_ids = np.hstack(case_ids).tolist()
        auc_score = roc_auc_score(true_labels, preds)
        auc_score_adj_best = 0
        best_acc = 0
        best_recall = 0
        best_precision = 0
        best_specificity = 0
        best_sensitivity = 0
        best_f_score = 0
        for thresh in np.linspace(0, 1, 50):
            adj_preds = list(np.array(preds) > thresh)
            auc_score_adj = roc_auc_score(true_labels, adj_preds)
            if auc_score_adj > auc_score_adj_best:
                best_thresh = thresh
                auc_score_adj_best = auc_score_adj
        adj_preds = list(np.array(preds) > 0.5)
        if accuracy_score(true_labels, adj_preds) > best_acc:
            best_acc = max(best_acc, accuracy_score(true_labels, adj_preds))
            tn, fp, fn, tp = confusion_matrix(true_labels, adj_preds).ravel()
            specificity = tn/(tn+fp)
            sensitivity = tp/(tp+fn)
            precision = tp/(tp+fp)
            recall = tp/(tp+fn)
            f_score = 2*precision*recall/(precision+recall)
#         best_recall = max(best_recall, recall)
#         best_precision = max(best_precision, precision)
#         best_specificity = max(best_specificity, specificity)
#         best_sensitivity = max(best_sensitivity, sensitivity)
#         best_f_score = max(best_f_score, f_score)
                

        print(
            f"EPOCH {counter}/{N_EPOCHS}: Validation average loss: {val_loss/(step+1)} + AUC SCORE = {auc_score} + AUC SCORE THRESH {best_thresh} = {auc_score_adj_best}"
        )
        print(f'Best Accuracy: {best_acc}')
        print(f'Best Sensitivity: {sensitivity}')
        print(f'Best Specificity: {specificity}')
        print(f'Best Precision: {precision}')
        print(f'Best Recall: {recall}')
        print(f'Best AUROC: {auc_score_adj_best}')
        best_val_acc = max(best_val_acc, best_acc)
        if auc_score > best_auc:
            print("Saving the model...")
            final_thresh = best_thresh
            all_files = os.listdir("./")

            for f in all_files:
                if f"{MODEL_NAME}_{type_}_fold{curr_fold}" in f:
                    os.remove(f"./{f}")

            best_auc = auc_score
            best_model = f"./3d-{MODEL_NAME}_{type_}_fold{curr_fold}_{round(best_auc,3)}.pth"
            torch.save(
                model.state_dict(),
                f"./3d-{MODEL_NAME}_{type_}_fold{curr_fold}_{round(best_auc,3)}.pth",
            )

print(best_auc)

  0%|          | 0/468 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 468/468 [00:40<00:00, 11.60it/s, batch_loss=0.667, loss=0.718]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:02<00:00, 20.62it/s, batch_loss=0.752, loss=0.71] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 0/40: Validation average loss: 0.7097897721549212 + AUC SCORE = 0.5390029325513197 + AUC SCORE THRESH 0.44897959183673464 = 0.564076246334311
Best Accuracy: 0.49572649572649574
Best Sensitivity: 0.25806451612903225
Best Specificity: 0.7636363636363637
Best Precision: 0.5517241379310345
Best Recall: 0.25806451612903225
Best AUROC: 0.564076246334311
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.87it/s, batch_loss=0.423, loss=0.689]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.20it/s, batch_loss=0.836, loss=0.696]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 1/40: Validation average loss: 0.6955470139697447 + AUC SCORE = 0.5906158357771261 + AUC SCORE THRESH 0.5714285714285714 = 0.589149560117302
Best Accuracy: 0.5128205128205128
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.12727272727272726
Best Precision: 0.5247524752475248
Best Recall: 0.8548387096774194
Best AUROC: 0.589149560117302
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.85it/s, batch_loss=0.754, loss=0.695]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.65it/s, batch_loss=0.514, loss=0.701]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 2/40: Validation average loss: 0.7007652990898844 + AUC SCORE = 0.5466275659824047 + AUC SCORE THRESH 0.4897959183673469 = 0.5790322580645162
Best Accuracy: 0.5897435897435898
Best Sensitivity: 0.7096774193548387
Best Specificity: 0.45454545454545453
Best Precision: 0.5945945945945946
Best Recall: 0.7096774193548387
Best AUROC: 0.5790322580645162


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=0.986, loss=0.688]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:02<00:00, 24.91it/s, batch_loss=0.627, loss=0.682]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 3/40: Validation average loss: 0.6821490930298627 + AUC SCORE = 0.5612903225806452 + AUC SCORE THRESH 0.5306122448979591 = 0.5539589442815249
Best Accuracy: 0.5641025641025641
Best Sensitivity: 0.8387096774193549
Best Specificity: 0.2545454545454545
Best Precision: 0.5591397849462365
Best Recall: 0.8387096774193549
Best AUROC: 0.5539589442815249


100%|██████████| 468/468 [00:39<00:00, 11.90it/s, batch_loss=0.661, loss=0.672]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.63it/s, batch_loss=0.675, loss=0.687]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 4/40: Validation average loss: 0.6871688083066778 + AUC SCORE = 0.5519061583577712 + AUC SCORE THRESH 0.4693877551020408 = 0.5608504398826979
Best Accuracy: 0.5641025641025641
Best Sensitivity: 0.5645161290322581
Best Specificity: 0.5636363636363636
Best Precision: 0.5932203389830508
Best Recall: 0.5645161290322581
Best AUROC: 0.5608504398826979


100%|██████████| 468/468 [00:39<00:00, 11.92it/s, batch_loss=0.631, loss=0.661]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.66it/s, batch_loss=0.833, loss=0.699]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 5/40: Validation average loss: 0.6989505290985107 + AUC SCORE = 0.5454545454545454 + AUC SCORE THRESH 0.44897959183673464 = 0.5507331378299121
Best Accuracy: 0.5384615384615384
Best Sensitivity: 0.5806451612903226
Best Specificity: 0.4909090909090909
Best Precision: 0.5625
Best Recall: 0.5806451612903226
Best AUROC: 0.5507331378299121


100%|██████████| 468/468 [00:39<00:00, 11.75it/s, batch_loss=0.417, loss=0.65] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.81it/s, batch_loss=0.615, loss=0.691]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 6/40: Validation average loss: 0.690838425341299 + AUC SCORE = 0.5343108504398827 + AUC SCORE THRESH 0.5102040816326531 = 0.5718475073313782
Best Accuracy: 0.5982905982905983
Best Sensitivity: 0.9193548387096774
Best Specificity: 0.23636363636363636
Best Precision: 0.5757575757575758
Best Recall: 0.9193548387096774
Best AUROC: 0.5718475073313782


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=0.416, loss=0.633]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 29.92it/s, batch_loss=0.46, loss=0.707] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 7/40: Validation average loss: 0.7065169215202332 + AUC SCORE = 0.5296187683284458 + AUC SCORE THRESH 0.5510204081632653 = 0.5505865102639296
Best Accuracy: 0.5470085470085471
Best Sensitivity: 0.967741935483871
Best Specificity: 0.07272727272727272
Best Precision: 0.5405405405405406
Best Recall: 0.967741935483871
Best AUROC: 0.5505865102639296


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=1.05, loss=0.609] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 59/59 [00:01<00:00, 30.88it/s, batch_loss=0.713, loss=0.693]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 8/40: Validation average loss: 0.6934093312691834 + AUC SCORE = 0.5269794721407625 + AUC SCORE THRESH 0.4081632653061224 = 0.5596774193548386
Best Accuracy: 0.49572649572649574
Best Sensitivity: 0.46774193548387094
Best Specificity: 0.5272727272727272
Best Precision: 0.5272727272727272
Best Recall: 0.46774193548387094
Best AUROC: 0.5596774193548386


100%|██████████| 468/468 [00:39<00:00, 11.82it/s, batch_loss=0.817, loss=0.58] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 29.54it/s, batch_loss=0.708, loss=0.692]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 9/40: Validation average loss: 0.6917426081027015 + AUC SCORE = 0.5513196480938416 + AUC SCORE THRESH 0.44897959183673464 = 0.5828445747800586
Best Accuracy: 0.5470085470085471
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.2
Best Precision: 0.5463917525773195
Best Recall: 0.8548387096774194
Best AUROC: 0.5828445747800586


100%|██████████| 468/468 [00:39<00:00, 11.83it/s, batch_loss=0.459, loss=0.489]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.77it/s, batch_loss=0.701, loss=0.702]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 10/40: Validation average loss: 0.7017559469756434 + AUC SCORE = 0.5501466275659824 + AUC SCORE THRESH 0.5918367346938775 = 0.5831378299120235
Best Accuracy: 0.5555555555555556
Best Sensitivity: 0.8709677419354839
Best Specificity: 0.2
Best Precision: 0.5510204081632653
Best Recall: 0.8709677419354839
Best AUROC: 0.5831378299120235


100%|██████████| 468/468 [00:39<00:00, 11.86it/s, batch_loss=0.319, loss=0.418] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.34it/s, batch_loss=0.51, loss=0.696] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 11/40: Validation average loss: 0.6958325237540876 + AUC SCORE = 0.5668621700879766 + AUC SCORE THRESH 0.6326530612244897 = 0.5611436950146628
Best Accuracy: 0.5470085470085471
Best Sensitivity: 0.9032258064516129
Best Specificity: 0.14545454545454545
Best Precision: 0.5436893203883495
Best Recall: 0.9032258064516129
Best AUROC: 0.5611436950146628


100%|██████████| 468/468 [00:39<00:00, 11.92it/s, batch_loss=0.23, loss=0.32]   
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 28.23it/s, batch_loss=0.557, loss=0.739]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 12/40: Validation average loss: 0.739312885171276 + AUC SCORE = 0.5055718475073314 + AUC SCORE THRESH 0.5510204081632653 = 0.554692082111437
Best Accuracy: 0.5726495726495726
Best Sensitivity: 0.9354838709677419
Best Specificity: 0.16363636363636364
Best Precision: 0.5576923076923077
Best Recall: 0.9354838709677419
Best AUROC: 0.554692082111437


100%|██████████| 468/468 [00:39<00:00, 11.84it/s, batch_loss=0.0919, loss=0.203]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.64it/s, batch_loss=0.599, loss=0.687]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 13/40: Validation average loss: 0.6871568313089468 + AUC SCORE = 0.6181818181818182 + AUC SCORE THRESH 0.5306122448979591 = 0.656891495601173
Best Accuracy: 0.6581196581196581
Best Sensitivity: 0.7580645161290323
Best Specificity: 0.5454545454545454
Best Precision: 0.6527777777777778
Best Recall: 0.7580645161290323
Best AUROC: 0.656891495601173
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.84it/s, batch_loss=0.156, loss=0.117] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.77it/s, batch_loss=0.736, loss=0.683]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 14/40: Validation average loss: 0.6825521719657769 + AUC SCORE = 0.632258064516129 + AUC SCORE THRESH 0.42857142857142855 = 0.628592375366569
Best Accuracy: 0.5555555555555556
Best Sensitivity: 0.43548387096774194
Best Specificity: 0.6909090909090909
Best Precision: 0.6136363636363636
Best Recall: 0.43548387096774194
Best AUROC: 0.628592375366569
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=0.0272, loss=0.0589]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.20it/s, batch_loss=0.607, loss=0.732] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 15/40: Validation average loss: 0.7316976995286295 + AUC SCORE = 0.6357771260997068 + AUC SCORE THRESH 0.5510204081632653 = 0.6274193548387097
Best Accuracy: 0.6410256410256411
Best Sensitivity: 0.9193548387096774
Best Specificity: 0.32727272727272727
Best Precision: 0.6063829787234043
Best Recall: 0.9193548387096774
Best AUROC: 0.6274193548387097
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.91it/s, batch_loss=0.0227, loss=0.0314] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 24.29it/s, batch_loss=0.534, loss=0.857] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 16/40: Validation average loss: 0.8566139861941338 + AUC SCORE = 0.6263929618768329 + AUC SCORE THRESH 0.5510204081632653 = 0.6475073313782991
Best Accuracy: 0.6410256410256411
Best Sensitivity: 0.9838709677419355
Best Specificity: 0.2545454545454545
Best Precision: 0.5980392156862745
Best Recall: 0.9838709677419355
Best AUROC: 0.6475073313782991


100%|██████████| 468/468 [00:39<00:00, 11.87it/s, batch_loss=0.00962, loss=0.0213]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.59it/s, batch_loss=0.636, loss=0.72] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 17/40: Validation average loss: 0.7196952271259437 + AUC SCORE = 0.6067448680351906 + AUC SCORE THRESH 0.42857142857142855 = 0.616275659824047
Best Accuracy: 0.5982905982905983
Best Sensitivity: 0.7580645161290323
Best Specificity: 0.41818181818181815
Best Precision: 0.5949367088607594
Best Recall: 0.7580645161290323
Best AUROC: 0.616275659824047


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=0.00951, loss=0.0151]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.26it/s, batch_loss=1.03, loss=0.716] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 18/40: Validation average loss: 0.716332654326649 + AUC SCORE = 0.6392961876832846 + AUC SCORE THRESH 0.44897959183673464 = 0.6306451612903226
Best Accuracy: 0.6068376068376068
Best Sensitivity: 0.5967741935483871
Best Specificity: 0.6181818181818182
Best Precision: 0.6379310344827587
Best Recall: 0.5967741935483871
Best AUROC: 0.6306451612903226
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.88it/s, batch_loss=0.00587, loss=0.0128] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.34it/s, batch_loss=0.572, loss=0.711]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 19/40: Validation average loss: 0.7107880138239618 + AUC SCORE = 0.6442815249266862 + AUC SCORE THRESH 0.4693877551020408 = 0.6577712609970674
Best Accuracy: 0.6495726495726496
Best Sensitivity: 0.7580645161290323
Best Specificity: 0.5272727272727272
Best Precision: 0.6438356164383562
Best Recall: 0.7580645161290323
Best AUROC: 0.6577712609970674
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.85it/s, batch_loss=0.00398, loss=0.0097] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.63it/s, batch_loss=0.24, loss=0.857]  
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 20/40: Validation average loss: 0.8574162158420531 + AUC SCORE = 0.5997067448680352 + AUC SCORE THRESH 0.6326530612244897 = 0.6173020527859238
Best Accuracy: 0.5641025641025641
Best Sensitivity: 0.9193548387096774
Best Specificity: 0.16363636363636364
Best Precision: 0.5533980582524272
Best Recall: 0.9193548387096774
Best AUROC: 0.6173020527859238


100%|██████████| 468/468 [00:39<00:00, 11.88it/s, batch_loss=0.00257, loss=0.00928]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 29.18it/s, batch_loss=0.405, loss=0.816] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 21/40: Validation average loss: 0.8159323657973337 + AUC SCORE = 0.6258064516129033 + AUC SCORE THRESH 0.5714285714285714 = 0.6203812316715543
Best Accuracy: 0.5897435897435898
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.2909090909090909
Best Precision: 0.5760869565217391
Best Recall: 0.8548387096774194
Best AUROC: 0.6203812316715543


100%|██████████| 468/468 [00:39<00:00, 11.92it/s, batch_loss=0.00373, loss=0.00915]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.29it/s, batch_loss=0.525, loss=0.789]
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 22/40: Validation average loss: 0.7886460028209928 + AUC SCORE = 0.6595307917888564 + AUC SCORE THRESH 0.7346938775510203 = 0.6448680351906159
Best Accuracy: 0.6153846153846154
Best Sensitivity: 0.8387096774193549
Best Specificity: 0.36363636363636365
Best Precision: 0.5977011494252874
Best Recall: 0.8387096774193549
Best AUROC: 0.6448680351906159
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.89it/s, batch_loss=0.00353, loss=0.00748] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.83it/s, batch_loss=0.656, loss=0.71]  
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 23/40: Validation average loss: 0.7103194682779959 + AUC SCORE = 0.6736070381231672 + AUC SCORE THRESH 0.5102040816326531 = 0.6870967741935483
Best Accuracy: 0.6923076923076923
Best Sensitivity: 0.7903225806451613
Best Specificity: 0.5818181818181818
Best Precision: 0.6805555555555556
Best Recall: 0.7903225806451613
Best AUROC: 0.6870967741935483
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.88it/s, batch_loss=0.00186, loss=0.00728] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.36it/s, batch_loss=0.615, loss=0.745] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 24/40: Validation average loss: 0.7445755219560558 + AUC SCORE = 0.6648093841642229 + AUC SCORE THRESH 0.5714285714285714 = 0.6649560117302052
Best Accuracy: 0.6495726495726496
Best Sensitivity: 0.8064516129032258
Best Specificity: 0.4727272727272727
Best Precision: 0.6329113924050633
Best Recall: 0.8064516129032258
Best AUROC: 0.6649560117302052


100%|██████████| 468/468 [00:39<00:00, 11.93it/s, batch_loss=0.00255, loss=0.00714] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 27.44it/s, batch_loss=0.311, loss=0.834] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 25/40: Validation average loss: 0.8338698241543971 + AUC SCORE = 0.6454545454545455 + AUC SCORE THRESH 0.6326530612244897 = 0.6385630498533725
Best Accuracy: 0.6410256410256411
Best Sensitivity: 0.9193548387096774
Best Specificity: 0.32727272727272727
Best Precision: 0.6063829787234043
Best Recall: 0.9193548387096774
Best AUROC: 0.6385630498533725


100%|██████████| 468/468 [00:39<00:00, 11.91it/s, batch_loss=0.00141, loss=0.0066]  
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.94it/s, batch_loss=0.225, loss=0.893] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 26/40: Validation average loss: 0.8926065344426591 + AUC SCORE = 0.6149560117302052 + AUC SCORE THRESH 0.8163265306122448 = 0.624633431085044
Best Accuracy: 0.5811965811965812
Best Sensitivity: 0.9032258064516129
Best Specificity: 0.21818181818181817
Best Precision: 0.5656565656565656
Best Recall: 0.9032258064516129
Best AUROC: 0.624633431085044


100%|██████████| 468/468 [00:39<00:00, 11.91it/s, batch_loss=0.00147, loss=0.00647] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 26.12it/s, batch_loss=0.548, loss=0.846] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 27/40: Validation average loss: 0.8461090317785235 + AUC SCORE = 0.6366568914956011 + AUC SCORE THRESH 0.836734693877551 = 0.620674486803519
Best Accuracy: 0.5897435897435898
Best Sensitivity: 0.8709677419354839
Best Specificity: 0.2727272727272727
Best Precision: 0.574468085106383
Best Recall: 0.8709677419354839
Best AUROC: 0.620674486803519


100%|██████████| 468/468 [00:39<00:00, 11.93it/s, batch_loss=0.00096, loss=0.00595] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.39it/s, batch_loss=0.271, loss=0.945] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 28/40: Validation average loss: 0.9448151762707758 + AUC SCORE = 0.6002932551319649 + AUC SCORE THRESH 0.8571428571428571 = 0.6105571847507331
Best Accuracy: 0.5811965811965812
Best Sensitivity: 0.9032258064516129
Best Specificity: 0.21818181818181817
Best Precision: 0.5656565656565656
Best Recall: 0.9032258064516129
Best AUROC: 0.6105571847507331


100%|██████████| 468/468 [00:39<00:00, 11.97it/s, batch_loss=0.000565, loss=0.00585] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 27.91it/s, batch_loss=0.271, loss=0.966] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 29/40: Validation average loss: 0.9663168187805656 + AUC SCORE = 0.6252199413489736 + AUC SCORE THRESH 0.7755102040816326 = 0.6255131964809384
Best Accuracy: 0.5726495726495726
Best Sensitivity: 0.9193548387096774
Best Specificity: 0.18181818181818182
Best Precision: 0.5588235294117647
Best Recall: 0.9193548387096774
Best AUROC: 0.6255131964809384


100%|██████████| 468/468 [00:39<00:00, 11.95it/s, batch_loss=0.000864, loss=0.00557] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.34it/s, batch_loss=1.11, loss=0.712] 


EPOCH 30/40: Validation average loss: 0.7115785181522369 + AUC SCORE = 0.673900293255132 + AUC SCORE THRESH 0.4693877551020408 = 0.699266862170088
Best Accuracy: 0.6410256410256411
Best Sensitivity: 0.6129032258064516
Best Specificity: 0.6727272727272727
Best Precision: 0.6785714285714286
Best Recall: 0.6129032258064516
Best AUROC: 0.699266862170088
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.95it/s, batch_loss=0.000601, loss=0.00574] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 26.15it/s, batch_loss=0.551, loss=0.723] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 31/40: Validation average loss: 0.7232400210479558 + AUC SCORE = 0.6530791788856305 + AUC SCORE THRESH 0.6122448979591836 = 0.6417888563049853
Best Accuracy: 0.6068376068376068
Best Sensitivity: 0.6774193548387096
Best Specificity: 0.5272727272727272
Best Precision: 0.6176470588235294
Best Recall: 0.6774193548387096
Best AUROC: 0.6417888563049853


100%|██████████| 468/468 [00:39<00:00, 11.94it/s, batch_loss=0.000613, loss=0.00529] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.91it/s, batch_loss=0.654, loss=0.773] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 32/40: Validation average loss: 0.7733431657365823 + AUC SCORE = 0.6741935483870967 + AUC SCORE THRESH 0.5714285714285714 = 0.6658357771260996
Best Accuracy: 0.6068376068376068
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.32727272727272727
Best Precision: 0.5888888888888889
Best Recall: 0.8548387096774194
Best AUROC: 0.6658357771260996
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.92it/s, batch_loss=0.000647, loss=0.00605]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 28.83it/s, batch_loss=0.911, loss=0.745] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 33/40: Validation average loss: 0.7454415633643078 + AUC SCORE = 0.675366568914956 + AUC SCORE THRESH 0.5918367346938775 = 0.6649560117302052
Best Accuracy: 0.6581196581196581
Best Sensitivity: 0.7258064516129032
Best Specificity: 0.5818181818181818
Best Precision: 0.6617647058823529
Best Recall: 0.7258064516129032
Best AUROC: 0.6649560117302052
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.96it/s, batch_loss=0.000445, loss=0.00503]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.25it/s, batch_loss=0.405, loss=0.799] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 34/40: Validation average loss: 0.7993163709228827 + AUC SCORE = 0.6747800586510264 + AUC SCORE THRESH 0.6938775510204082 = 0.6618768328445748
Best Accuracy: 0.6581196581196581
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.43636363636363634
Best Precision: 0.6309523809523809
Best Recall: 0.8548387096774194
Best AUROC: 0.6618768328445748


100%|██████████| 468/468 [00:39<00:00, 11.91it/s, batch_loss=0.000422, loss=0.00485] 
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.96it/s, batch_loss=0.535, loss=0.8]   
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 35/40: Validation average loss: 0.7998888750828929 + AUC SCORE = 0.6806451612903226 + AUC SCORE THRESH 0.5714285714285714 = 0.7011730205278593
Best Accuracy: 0.6666666666666666
Best Sensitivity: 0.8387096774193549
Best Specificity: 0.4727272727272727
Best Precision: 0.6419753086419753
Best Recall: 0.8387096774193549
Best AUROC: 0.7011730205278593
Saving the model...


100%|██████████| 468/468 [00:39<00:00, 11.90it/s, batch_loss=0.000293, loss=0.00457]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 31.00it/s, batch_loss=0.464, loss=0.75]  
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 36/40: Validation average loss: 0.7502138054598186 + AUC SCORE = 0.6249266862170089 + AUC SCORE THRESH 0.6122448979591836 = 0.6335777126099706
Best Accuracy: 0.6239316239316239
Best Sensitivity: 0.8548387096774194
Best Specificity: 0.36363636363636365
Best Precision: 0.6022727272727273
Best Recall: 0.8548387096774194
Best AUROC: 0.6335777126099706


100%|██████████| 468/468 [00:39<00:00, 11.94it/s, batch_loss=0.000304, loss=0.00465]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:02<00:00, 29.12it/s, batch_loss=0.486, loss=0.918] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 37/40: Validation average loss: 0.9183316615931059 + AUC SCORE = 0.604692082111437 + AUC SCORE THRESH 0.3877551020408163 = 0.6121700879765396
Best Accuracy: 0.5982905982905983
Best Sensitivity: 0.8387096774193549
Best Specificity: 0.32727272727272727
Best Precision: 0.5842696629213483
Best Recall: 0.8387096774193549
Best AUROC: 0.6121700879765396


100%|██████████| 468/468 [00:39<00:00, 11.93it/s, batch_loss=0.000266, loss=0.00481]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 29.79it/s, batch_loss=0.777, loss=0.847] 
  0%|          | 0/468 [00:00<?, ?it/s]

EPOCH 38/40: Validation average loss: 0.8470530419657796 + AUC SCORE = 0.6404692082111437 + AUC SCORE THRESH 0.673469387755102 = 0.6296187683284457
Best Accuracy: 0.6068376068376068
Best Sensitivity: 0.7903225806451613
Best Specificity: 0.4
Best Precision: 0.5975609756097561
Best Recall: 0.7903225806451613
Best AUROC: 0.6296187683284457


100%|██████████| 468/468 [00:39<00:00, 11.84it/s, batch_loss=0.000304, loss=0.00445]
  0%|          | 0/59 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 59/59 [00:01<00:00, 30.69it/s, batch_loss=0.411, loss=0.795] 


EPOCH 39/40: Validation average loss: 0.794994903128531 + AUC SCORE = 0.6533724340175953 + AUC SCORE THRESH 0.5918367346938775 = 0.6416422287390029
Best Accuracy: 0.6153846153846154
Best Sensitivity: 0.7903225806451613
Best Specificity: 0.41818181818181815
Best Precision: 0.6049382716049383
Best Recall: 0.7903225806451613
Best AUROC: 0.6416422287390029
0.6806451612903226


In [17]:
tta_true_labels = []
tta_preds = []
test_dataset = BrainRSNADataset(data=val_df, mri_type=type_, is_train=True)
test_dl = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, num_workers=4
    )

# preds_f = np.zeros(len(sample))
# for fold in range(5):
image_ids = []
model.load_state_dict(torch.load(best_model))
preds = []
labels = []
epoch_iterator_test = tqdm(test_dl)
with torch.no_grad():
    for  step, batch in enumerate(epoch_iterator_test):
        model.eval()
        images = batch["image"].to(device)

        outputs = model(images)
        preds.append(outputs.sigmoid().detach().cpu().numpy())
        image_ids.append(batch["case_id"].detach().cpu().numpy())
        labels.append(batch['target'].detach().cpu().numpy())

#     preds_f += np.vstack(preds).T[0]/5

#     ids_f = np.hstack(image_ids)

100%|██████████| 117/117 [00:02<00:00, 52.71it/s]


In [18]:
all_preds = []
for batch in preds:
    for pred in batch:
        all_preds.append(pred[0])

In [19]:
all_labels = []
for batch in labels:
    for lab in batch:
        all_labels.append(lab)

In [20]:
assert len(all_preds) == len(all_labels)

In [21]:
from sklearn.metrics import accuracy_score, roc_auc_score

thresh = final_thresh
all_preds_thresh = [val >= thresh for val in all_preds]
best_acc = accuracy_score(all_labels, all_preds_thresh)
auc_score = roc_auc_score(all_labels, all_preds_thresh)
tn, fp, fn, tp = confusion_matrix(all_labels, all_preds_thresh).ravel()
specificity = tn/(tn+fp)
sensitivity = tp/(tp+fn)
precision = tp/(tp+fp)
recall = tp/(tp+fn)

print(f'Best Accuracy: {best_acc}')
print(f'Best Sensitivity: {sensitivity}')
print(f'Best Specificity: {specificity}')
print(f'Best Precision: {precision}')
print(f'Best Recall: {recall}')
print(f'Best AUROC: {auc_score}')

Best Accuracy: 0.7094017094017094
Best Sensitivity: 0.8387096774193549
Best Specificity: 0.5636363636363636
Best Precision: 0.6842105263157895
Best Recall: 0.8387096774193549
Best AUROC: 0.7011730205278593
