In [1]:
!pip install ../input/libauc-116/libauc-1.1.6-py3-none-any.whl

Processing /kaggle/input/libauc-116/libauc-1.1.6-py3-none-any.whl
Installing collected packages: libauc
Successfully installed libauc-1.1.6


## Imports

In [2]:
import os
import sys 
import json
import glob
import random
import re
import collections
import time

import numpy as np
import pandas as pd
import pydicom
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam

import warnings 
warnings.filterwarnings('ignore')

In [3]:
data_directory = '../input/rsnatraintestcombined/RSNA-train'
input_monaipath = "/kaggle/input/monai-v060-deep-learning-in-healthcare-imaging/"
monaipath = "/kaggle/tmp/monai/"

In [4]:
!mkdir -p {monaipath}
!cp -r {input_monaipath}/* {monaipath}

## Configs

In [5]:
mri_types = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
SIZE = 256
NUM_IMAGES = 64
BATCH_SIZE = 4
N_EPOCHS = 20
SEED = 2001
LEARNING_RATE = 0.0005
LR_DECAY = 0.9

sys.path.append(monaipath)

from monai.networks.nets.densenet import DenseNet121

## Functions to load images

In [6]:
def load_dicom_image(path, img_size=SIZE):
#     dicom = pydicom.read_file(path)
#     data = dicom.pixel_array
    data = cv2.imread(path, -1)
    if np.min(data)==np.max(data):
        data = np.zeros((img_size,img_size))
        return data
    
    data = cv2.resize(data, (img_size, img_size))
    return data


def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = natural_sort(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.jpg"))
    
    every_nth = len(files) / num_imgs
    indexes = [min(int(round(i*every_nth)), len(files)-1) for i in range(0,num_imgs)]
    
    files_to_load = [files[i] for i in indexes]
    
    img3d = np.stack([load_dicom_image(f) for f in files_to_load]).T 
    
    img3d = img3d - np.min(img3d)
    if np.max(img3d) != 0:
        img3d = img3d / np.max(img3d)
    
    return np.expand_dims(img3d,0)


load_dicom_images_3d("00000", mri_type=mri_types[0]).shape

(1, 256, 256, 64)

In [7]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

## train / test splits

In [8]:
samples_to_exclude = [109, 123, 709]
target_files = [1,15,47,80,119,129,145,153,174,181,182,190,200,213,252,256,264,287,323,333,335,393,422,428,434,458,460,463,467,489,492,553,573,592,603,647,662,721,762,821,825,826,833,997]
train_df = pd.read_csv(f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
print("original shape", train_df.shape)
train_df = train_df[~train_df.BraTS21ID.isin(samples_to_exclude)]
add = pd.DataFrame({'BraTS21ID' : target_files, 
                    'MGMT_value' : np.zeros(len(target_files))})
train_df = train_df.append(add)
train_df.MGMT_value = train_df.MGMT_value.astype('int64')
print("new shape", train_df.shape)
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=SEED, 
    stratify=train_df["MGMT_value"],
)


original shape (585, 2)
new shape (626, 2)


Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1
...,...,...
39,821,0
40,825,0
41,826,0
42,833,0


In [9]:
df_train.tail()

Unnamed: 0,BraTS21ID,MGMT_value
174,259,0
385,564,1
227,329,1
426,615,1
300,436,1


## Model and training classes

In [10]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, split="train"):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")
            
        if self.targets is None:
            return {"X": data, "id": scan_id}
        else:
            return {"X": data, "y": torch.tensor(self.targets[index], dtype=torch.float)}


In [11]:
def build_model():
    model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
    return model    

In [12]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=LR_DECAY)
        self.criterion = criterion

        self.best_valid_score = .0
        self.n_patience = 0
        self.lastmodel = None
        
        self.val_losses = []
        self.train_losses = []
        self.val_auc = []
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):      
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_time = self.train_epoch(train_loader)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.train_losses.append(train_loss)
            self.val_losses.append(valid_loss)
            self.val_auc.append(valid_auc)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            if self.best_valid_score < valid_auc: 
                self.save_model(n_epoch, save_path, valid_loss, valid_auc)
                self.info_message(
                     "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_auc, self.lastmodel
                )
                self.best_valid_score = valid_auc
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = torch.tensor(batch["X"]).float().to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            loss = self.criterion(outputs, targets)
                
            loss.backward()

            sum_loss += loss.detach().item()
            
            self.optimizer.step()
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
            
        self.lr_scheduler.step()
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                targets = batch["y"].to(self.device)

                output = torch.sigmoid(self.model(torch.tensor(batch["X"]).float().to(self.device)).squeeze(1))
                loss = self.criterion(output, targets)
                sum_loss += loss.detach().item()

                y_all.extend(batch["y"].tolist())
                outputs_all.extend(output.tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = f"{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
        
    def display_plots(self, mri_type):
        plt.figure(figsize=(10,5))
        plt.title("{}: Training and Validation Loss")
        plt.plot(self.val_losses,label="val")
        plt.plot(self.train_losses,label="train")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        plt.close()
        
        plt.figure(figsize=(10,5))
        plt.title("{}: Validation AUC-ROC")
        plt.plot(self.val_auc,label="val")
        plt.xlabel("iterations")
        plt.ylabel("AUC")
        plt.legend()
        plt.show()
        plt.close()
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

## Deep-auc optimization 

In [14]:

def train_mri_type(df_train, df_valid, mri_type, model_path):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
            df_train.loc[:,"MRI_Type"] = mri_type
            train_list.append(df_train.copy())
            df_valid.loc[:,"MRI_Type"] = mri_type
            valid_list.append(df_valid.copy())

        df_train = pd.concat(train_list)
        df_valid = pd.concat(valid_list)
    else:
        df_train.loc[:,"MRI_Type"] = mri_type
        df_valid.loc[:,"MRI_Type"] = mri_type


    
    train_data_retriever = Dataset(
        df_train["BraTS21ID"].values, 
        df_train["MGMT_value"].values, 
        df_train["MRI_Type"].values
    )

    valid_data_retriever = Dataset(
        df_valid["BraTS21ID"].values, 
        df_valid["MGMT_value"].values,
        df_valid["MRI_Type"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
    )
    
    def save_model(n_epoch, save_path, auc, model, optimizer, best_valid_score):
        lastmodel = f"{save_path}-e{n_epoch}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "best_valid_score": best_valid_score,
                "n_epoch": n_epoch,
            },
            lastmodel,
        )
    
    model = build_model()
    model.to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint["model_state_dict"])

    Loss = AUCMLoss()
    optimizer = PESG(model, 
                 a=Loss.a, 
                 b=Loss.b, 
                 alpha=Loss.alpha, 
                 lr=0.0005, 
                 gamma=500, 
                 margin=1.0, 
                 weight_decay=1e-5)

    best_val_auc = 0
    for epoch in range(2):
        if epoch > 0:
                optimizer.update_regularizer(decay_factor=10)
        for idx, data in enumerate(train_loader):
            train_data = torch.tensor(data['X']).float().cuda()
            train_labels = data['y'].cuda()
            y_pred = model(train_data)
            loss = Loss(y_pred, train_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # validation
            if idx % 400 == 0:
                model.eval()
            with torch.no_grad():    
                test_pred = []
                test_true = [] 
                for jdx, data in enumerate(valid_loader):
                    test_data = torch.tensor(data['X']).float().cuda()
                    test_label = data['y']
                    y_pred = model(test_data)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_label.numpy())

                test_true = np.concatenate(test_true)
                test_pred = np.concatenate(test_pred)
                val_auc =  roc_auc_score(test_true, test_pred) 
                model.train()
                
                if best_val_auc < val_auc:
                    save_model(epoch, mri_type, val_auc, model, optimizer, best_val_auc)

                    best_val_auc = val_auc

            print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, lr=%.4f'%(epoch, idx, val_auc,  optimizer.lr))

    print ('Best Val_AUC is %.4f'%best_val_auc)
    
modelfiles = ['../input/2-monoi-models-random-seeds/FLAIR-e5-loss0.712-auc0.676.pth', '../input/2-monoi-models-random-seeds/T1w-e18-loss0.727-auc0.623.pth', '../input/2-monoi-models-random-seeds/T1wCE-e18-loss0.727-auc0.650.pth', '../input/2-monoi-models-random-seeds/T2w-e7-loss0.712-auc0.645.pth']

for model, m in zip(modelfiles, mri_types):
    train_mri_type(df_train, df_valid, m, model) 


Epoch=0, BatchID=0, Val_AUC=0.6696, lr=0.0005
Epoch=0, BatchID=1, Val_AUC=0.6308, lr=0.0005
Epoch=0, BatchID=2, Val_AUC=0.6316, lr=0.0005
Epoch=0, BatchID=3, Val_AUC=0.6326, lr=0.0005
Epoch=0, BatchID=4, Val_AUC=0.6346, lr=0.0005
Epoch=0, BatchID=5, Val_AUC=0.6368, lr=0.0005
Epoch=0, BatchID=6, Val_AUC=0.6373, lr=0.0005
Epoch=0, BatchID=7, Val_AUC=0.6371, lr=0.0005
Epoch=0, BatchID=8, Val_AUC=0.6373, lr=0.0005
Epoch=0, BatchID=9, Val_AUC=0.6376, lr=0.0005
Epoch=0, BatchID=10, Val_AUC=0.6381, lr=0.0005
Epoch=0, BatchID=11, Val_AUC=0.6394, lr=0.0005
Epoch=0, BatchID=12, Val_AUC=0.6444, lr=0.0005
Epoch=0, BatchID=13, Val_AUC=0.6457, lr=0.0005
Epoch=0, BatchID=14, Val_AUC=0.6459, lr=0.0005
Epoch=0, BatchID=15, Val_AUC=0.6454, lr=0.0005
Epoch=0, BatchID=16, Val_AUC=0.6487, lr=0.0005
Epoch=0, BatchID=17, Val_AUC=0.6502, lr=0.0005
Epoch=0, BatchID=18, Val_AUC=0.6525, lr=0.0005
Epoch=0, BatchID=19, Val_AUC=0.6522, lr=0.0005
Epoch=0, BatchID=20, Val_AUC=0.6527, lr=0.0005
Epoch=0, BatchID=21, Va