## Imports

In [1]:
import os
import sys 
import json
import glob
import random
import re
import collections
import time

import numpy as np
import pandas as pd
import pydicom
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import sys
sys.path.append("../input/torchio/")
import torchio as tio

import nibabel as nib
import numpy as np
import imageio
import os
import re
import shutil
import pandas as pd
from PIL import Image
import glob

In [2]:
data_directory = './rsna-test-jpg/'
input_monaipath = "/kaggle/input/monai-v060-deep-learning-in-healthcare-imaging/"
monaipath = "/kaggle/tmp/monai/"

root_dir = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'

In [3]:
!mkdir -p {monaipath}
!cp -r {input_monaipath}/* {monaipath}

## Configs

In [4]:
mri_types = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
SIZE = 256
NUM_IMAGES = 64
BATCH_SIZE = 4
N_EPOCHS = 16
SEED = 42
LEARNING_RATE = 0.0005
LR_DECAY = 0.9

sys.path.append(monaipath)

from monai.networks.nets.densenet import DenseNet121

## Preprocessing for Test Data 

In [5]:
preprocessing_transforms = (
    tio.ToCanonical(),
    tio.Resample(1, image_interpolation='bspline'),
    tio.Resample('T1w', image_interpolation='nearest'),
)
preprocess = tio.Compose(preprocessing_transforms)
test_set = tio.datasets.RSNAMICCAI(root_dir, train=False, transform=preprocess)

In [6]:
import shutil
import multiprocessing as mp
from pathlib import Path
from tqdm.notebook import tqdm

def preprocess_dataset(dataset, out_dir, parallel=True, demo=False):
    out_dir = Path(out_dir)
    labels_name = 'train_labels.csv'
    if demo:
        dataset._subjects = dataset._subjects[:5]
    out_dir.mkdir(exist_ok=True, parents=True)
    shutil.copy(dataset.root_dir / labels_name, out_dir / labels_name)
    subjects_dir = out_dir / ('train' if dataset.train else 'test')
    if parallel:
        loader = torch.utils.data.DataLoader(
            dataset,
            num_workers=mp.cpu_count(),
            collate_fn=lambda x: x[0],
        )
        iterable = loader
    else:
        iterable = dataset
    for subject in tqdm(iterable):
        subject_dir = subjects_dir / subject.BraTS21ID
        for name, image in tqdm(subject.get_images_dict().items(), leave=False):
                image_dir = subject_dir / name
                image_dir.mkdir(exist_ok=True, parents=True)
                image_path = image_dir / f'{name}.nii.gz'
                image.save(image_path)

In [7]:
!rm -rf ./rsna-test

In [8]:
out_dir = 'rsna-test'
if not Path(out_dir).is_dir():
    preprocess_dataset(test_set, out_dir)

  0%|          | 0/87 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000981787

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000863878



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0010352

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000631195

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000876314



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000856223



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00084837

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0004

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000997044



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0001

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000882981



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000885482



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000914145

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000100501



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000298343

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000298343



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0005

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000199052



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0001

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000564005

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000981472

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000902178

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:9.96785e-05



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000563574



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000984698

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000843646

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000298089

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000298089

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000125685



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000132827

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0005379



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000132827

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000895005

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000911107

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00019943

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00019943

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0004

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000198551



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00019943

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00019943

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000526114

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000879

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00080743

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000692008



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000748198



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000729492



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000497436



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000498721

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000519397

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0002

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0002



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000172238



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000541773

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000541773

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000800582

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000820838



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00087282

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00087282

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000763781

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000739309



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000998291

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000998291

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000395169

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000299034



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000397268

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000397268

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000675442

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000760811



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000684474

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000684474

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000840292

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00039763



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000164251

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000586097

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000783375



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000210919

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000800219

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000674505



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000586156

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000805848

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000586156

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000151524

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00106723

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00029886



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000842656

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000842656

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000398575

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000556253

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000935313

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000890021

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000591329



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.0001

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000199535



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000114347

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000131044

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000101327

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000397207



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000570696

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000570696

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000803261

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000738335



  0%|          | 0/4 [00:00<?, ?it/s]

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000318897

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000318897

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000750812

ImageSeriesReader (0x56375f8589d0): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000286866



  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
def read_niifile(niifile):  # Read niifile file
    img = nib.load(niifile)  # Download niifile file (actually extract the file)
    img_fdata = img.get_fdata()  # Get niifile data
    return img_fdata


def save_fig(file, savepicdir):  # Save as picture
    fdata = read_niifile(file)  # Call the above function to get data
    (x, y, z) = fdata.shape  # Get data shape information: (length, width, dimension-Number of slices)
    for k in range(z):
        silce = fdata[:, :, k]  # Three positions represent three slices at different angles
        imageio.imwrite(os.path.join(savepicdir, '{}.jpg'.format(k)), silce)

In [10]:
!rm -rf rsna-test-jpg/

In [11]:
os.mkdir('./rsna-test-jpg/')
os.mkdir('./rsna-test-jpg/test')
# test_list = os.listdir('./rsna-test/test')
# for i in sorted(test_list):
#     if not os.path.isdir(f'./rsna-test-jpg/test/{i}'):
#         os.mkdir(f'./rsna-test-jpg/test/{i}')
#     if not os.path.isdir(f'./rsna-test-jpg/test/{i}/FLAIR'):
#         os.mkdir(f'./rsna-test-jpg/test/{i}/FLAIR')

test_list = os.listdir('../input/rsna-miccai-brain-tumor-radiogenomic-classification/test')

for i in sorted(test_list):
    if not os.path.isdir(f'./rsna-test-jpg/test/{i}'):
        os.mkdir(f'./rsna-test-jpg/test/{i}')
    if not os.path.isdir(f'./rsna-test-jpg/test/{i}/FLAIR'):
        os.mkdir(f'./rsna-test-jpg/test/{i}/FLAIR')
    if not os.path.isdir(f'./rsna-test-jpg/test/{i}/T1w'):
        os.mkdir(f'./rsna-test-jpg/test/{i}/T1w')
    if not os.path.isdir(f'./rsna-test-jpg/test/{i}/T1wCE'):
        os.mkdir(f'./rsna-test-jpg/test/{i}/T1wCE')
    if not os.path.isdir(f'./rsna-test-jpg/test/{i}/T2w'):
        os.mkdir(f'./rsna-test-jpg/test/{i}/T2w')

In [12]:
paths_flair = glob.glob('./rsna-test/test/*/FLAIR/*')
paths_t1w = glob.glob('./rsna-test/test/*/T1w/*')
paths_t1wce = glob.glob('./rsna-test/test/*/T1wCE/*')
paths_t2w = glob.glob('./rsna-test/test/*/T2w/*')

for i in range(len(paths_flair)):
    savedir = f"rsna-test-jpg/test/{paths_flair[i].split('/')[-3]}/FLAIR/"
    img = save_fig(paths_flair[i], savedir)
    del img
for i in range(len(paths_t1w)):
    savedir = f"rsna-test-jpg/test/{paths_t1w[i].split('/')[-3]}/T1w/"
    img = save_fig(paths_t1w[i], savedir)
    del img
for i in range(len(paths_t1wce)):
    savedir = f"rsna-test-jpg/test/{paths_t1wce[i].split('/')[-3]}/T1wCE/"
    img = save_fig(paths_t1wce[i], savedir)
    del img
for i in range(len(paths_t2w)):
    savedir = f"rsna-test-jpg/test/{paths_t2w[i].split('/')[-3]}/T2w/"
    img = save_fig(paths_t2w[i], savedir)
    del img

## Utils 

In [13]:
def load_dicom_image(path, img_size=SIZE):

    data = cv2.imread(path, -1)
    if np.min(data)==np.max(data):
        data = np.zeros((img_size,img_size))
        return data
    
    data = cv2.resize(data, (img_size, img_size))
    return data


def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = natural_sort(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.jpg"))
    every_nth = len(files) / num_imgs
    indexes = [min(int(round(i*every_nth)), len(files)-1) for i in range(0,num_imgs)]
    
    files_to_load = [files[i] for i in indexes]
    
    img3d = np.stack([load_dicom_image(f) for f in files_to_load]).T 
    
    img3d = img3d - np.min(img3d)
    if np.max(img3d) != 0:
        img3d = img3d / np.max(img3d)
    
    return np.expand_dims(img3d,0)


load_dicom_images_3d("00001", mri_type="FLAIR", split="test").shape

(1, 256, 256, 64)

In [15]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

In [16]:
def build_model():
    model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
    return model    

## Dataset 

In [17]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, split="train"):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")
            
        if self.targets is None:
            return {"X": data, "id": scan_id}
        else:
            return {"X": data, "y": torch.tensor(self.targets[index], dtype=torch.float)}


### Model Files

In [18]:
modelfiles = ['../input/rsna-densenet-deep-auc/FLAIR-e0-auc0.670.pth', '../input/rsna-densenet-deep-auc/T1w-e1-auc0.633.pth', '../input/rsna-densenet-deep-auc/T1wCE-e0-auc0.656.pth', '../input/rsna-densenet-deep-auc/T2w-e1-auc0.671.pth']

## Predict 

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def predict(modelfile, df, mri_type, split):
    print("Predict:", modelfile, mri_type, df.shape)
    df.loc[:,"MRI_Type"] = mri_type
    data_retriever = Dataset(
        df.index.values, 
        mri_type=df["MRI_Type"].values,
        split=split
    )

    data_loader = torch_data.DataLoader(
        data_retriever,
        batch_size=4,
        shuffle=False,
        num_workers=8,
    )
   
    model = build_model()
    model.to(device)
    
    checkpoint = torch.load(modelfile)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(data_loader,1):
        print(f"{e}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            tmp_pred = torch.sigmoid(model(torch.tensor(batch["X"]).float().to(device)).squeeze(1)).cpu().numpy().squeeze()
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch["id"].numpy().tolist())
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
    preddf = preddf.set_index("BraTS21ID")
    return preddf

In [21]:
submission = pd.read_csv(f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv", index_col="BraTS21ID")

submission["MGMT_value"] = 0
for m, mtype in zip(modelfiles, mri_types):
    pred = predict(m, submission, mtype, split="test")
    submission["MGMT_value"] += pred["MGMT_value"]

submission["MGMT_value"] /= len(modelfiles)
submission["MGMT_value"].to_csv("submission.csv")
shutil.rmtree('./rsna-test')
shutil.rmtree('./rsna-test-jpg')

Predict: ../input/rsna-densenet-deep-auc/FLAIR-e0-auc0.670.pth FLAIR (87, 1)
1/22



Predict: ../input/rsna-densenet-deep-auc/T1w-e1-auc0.633.pth T1w (87, 2)
Predict: ../input/rsna-densenet-deep-auc/T1wCE-e0-auc0.656.pth T1wCE (87, 2)
Predict: ../input/rsna-densenet-deep-auc/T2w-e1-auc0.671.pth T2w (87, 2)
22/22

In [22]:
submission

Unnamed: 0_level_0,MGMT_value,MRI_Type
BraTS21ID,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.454308,T2w
13,0.578598,T2w
15,0.550014,T2w
27,0.557468,T2w
37,0.529028,T2w
...,...,...
826,0.408969,T2w
829,0.579819,T2w
833,0.408246,T2w
997,0.423868,T2w
