# Participation in this competition helps contribute to a meaningful cause 
As mentioned in the competition overview, "If successful, you'll help brain cancer patients receive less invasive diagnoses and treatments. The introduction of new and customized treatment strategies before surgery has the potential to improve the management, survival, and prospects of patients with brain cancer."

**Dataset**

The dataset is divided as follows: Each independent case has a dedicated folder identified by a five-digit number. Within each of these “case” folders, there are four sub-folders, each of them corresponding to each of the structural multi-parametric MRI (mpMRI) scans, in DICOM format. The exact mpMRI scans included are:

* Fluid Attenuated Inversion Recovery (FLAIR)
* T1-weighted pre-contrast (T1w)
* T1-weighted post-contrast (T1Gd)
* T2-weighted (T2)

**Files**

* train/ - folder containing the training files, with each top-level folder representing a subject. 
* train_labels.csv - file containing the target MGMT_value for each subject in the training data (e.g. the presence of MGMT promoter methylation)
* test/ - the test files, which use the same structure as train/; your task is to predict the MGMT_value for each subject in the test data. NOTE: the total size of the rerun test set (Public and Private) is ~5x the size of the Public test set
* sample_submission.csv - a sample submission file in the correct format

Each independent case is labeled with MGMT_value. 1 corresponds to the presence of MGMT (tumor) and 0 corresponds to absense.

## Understanding Dataset

For any Data Science project, it is crucial to understand the data, as best as possible, using visualization and Exploratory Data Analysis (EDA) techniques. 

Let's start with understanding the training labels file. But before that, let's make a cell to keep all the imported libraries in one place.

In [None]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

from collections import OrderedDict
import copy
import csv
import cv2
import glob
import json
import matplotlib.pyplot as plt
import numpy as np # linear algebra
from operator import itemgetter
import os
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import random
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import models, transforms
#import torch_xla
#import torch_xla.debug.metrics as met
#import torch_xla.distributed.parallel_loader as pl
#import torch_xla.utils.utils as xu
#import torch_xla.core.xla_model as xm
#import torch_xla.distributed.xla_multiprocessing as xmp
#import torch_xla.test.test_utils as test_utils

random.seed(37)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

EDA code taken from awesome work at https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling

In [None]:
#exploring distribution of training labels file
train_df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
train_df.head()

In [None]:
plt.figure(figsize=(5, 5))
sns.countplot(data=train_df, x="MGMT_value");

It's a near even split of data between tumourous and non-tumorous patients. This is nice as we don't need to perform stratification of data.  This gives us the simplest possible classifier: one which always assigns MGMT = 1 to every patient, because it's the slightly more common one.  Such a classifer would achieve an accuracy of 307/585 or 52.5% accuracy, nothing to write home about.

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data


def visualize_sample(
    brats21id, 
    slice_i,
    mgmt_value,
    types=("FLAIR", "T1w", "T1wCE", "T2w")
):
    plt.figure(figsize=(16, 5))
    patient_path = os.path.join(
        "../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/", 
        str(brats21id).zfill(5),
    )
    for i, t in enumerate(types, 1):
        t_paths = sorted(
            glob.glob(os.path.join(patient_path, t, "*")), 
            key=lambda x: int(x[:-4].split("-")[-1]),
        )
        data = load_dicom(t_paths[int(len(t_paths) * slice_i)])
        plt.subplot(1, 4, i)
        plt.imshow(data, cmap="gray")
        plt.title(f"{t}", fontsize=16)
        plt.axis("off")

    plt.suptitle(f"MGMT_value: {mgmt_value}", fontsize=16)
    plt.show()

In [None]:
for i in random.sample(range(train_df.shape[0]), 10):
    _brats21id = train_df.iloc[i]["BraTS21ID"]
    _mgmt_value = train_df.iloc[i]["MGMT_value"]
    visualize_sample(brats21id=_brats21id, mgmt_value=_mgmt_value, slice_i=0.5)

Some images have apparent tumorous growth, whereas many others are hard to detect by a human eye. This is where ML based techniques can help us find patterns not fully clear to a human eye.  Before proceeding to the full-scale ML analysis of the images, let's see how effective a classifier we can make using only the number of images of each modality for the patient in question.  Note, this is very important: the number of images of each type should tell us *absolutely nothing* about the MGMT status of the patient, because the number of images is determined by the settings on the MRI machine prior to anyone having any knowledge about the type of tumor the patient is afflicted with.  Nevertheless, there may be some spurious correlations that an algorithm could pick up, which could lead to errors when used to analyze images from new patients.

In [None]:
train_path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
train_df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
train_df.head()

In [None]:
#Write out a csv file with 6 columns: Patient ID, # FLAIR images, # T1w images, # T1wCE images, # T2w images, MGMT type
full_list = []
IDs = train_df['BraTS21ID'].values
MGMTs = train_df['MGMT_value'].values
labels = {str(ID).zfill(5): MGMT for ID, MGMT in zip(IDs, MGMTs)}
patients = glob.glob(f'{train_path}/train/*')
patient_IDs = sorted([p.split('/')[-1] for p in patients])
modes = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
for p in patient_IDs:
    subdict = {}
    subdict['Patient'] = p
    subdict['MGMT'] = labels[p]
    for mode in modes:
        imgnames = glob.glob(f'{train_path}/train/{p}/{mode}/*.dcm')
        subdict[mode] = len(imgnames)
    full_list.append(subdict)
    #print('subdict: ', subdict)

with open('patient_image_count.csv', mode='w') as image_count_file:
    fieldnames = ['Patient', 'FLAIR', 'T1w', 'T1wCE', 'T2w', 'MGMT']
    image_writer = csv.DictWriter(image_count_file, fieldnames=fieldnames)
    
    image_writer.writeheader()
    for f in full_list:
        image_writer.writerow(f)

In [None]:
simple_df = pd.read_csv('patient_image_count.csv')
X = simple_df[['FLAIR', 'T1w', 'T1wCE', 'T2w']]
y = simple_df['MGMT']
print(simple_df.head(10))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state=37)
#print('X_train shape: ', X_train.shape)
#print('y_test shape: ', y_test.shape)

In [None]:
print('------------SVM--------------')
svc = SVC()
svc.fit(X_train, y_train)
svc_prediction = svc.predict(X_test)
#print(confusion_matrix(y_test, svc_prediction))
print(classification_report(y_test, svc_prediction))
print('-------------Decision Tree--------------')
DT = DecisionTreeClassifier()
DT.fit(X_train, y_train)
DT_prediction = DT.predict(X_test)
print(classification_report(y_test, DT_prediction))
print('-----------Random Forest----------------')
RF = RandomForestClassifier()
RF.fit(X_train, y_train)
RF_prediction = RF.predict(X_test)
print(classification_report(y_test, RF_prediction))

Notice that all three simple sklearn classifiers with default settings achieve higher accuracy than the default "guess the most common one for everything", with the SVM getting nearly to 60%.  This establishes a hard floor of performance, any complex model that fails to achieve at least 60% validation set accuracy is not worth using, as it can't even beat models that are essentially fact-free.

Alright, a bit more EDA before we get to the real models.  From the images seen above, we know that not all images are in the same orientation, even for the same patient.  Are there different orientations for different images within the same patient *and* same modality?  If so, we need to know this in order to properly do the alignment necessary for considering 3D structure.  Looking at the available information from the DICOM, the terms most likely to be relevant for determining positioning are:
(0018, 1310) Acquisition Matrix \\
(0018, 1314) Flip Angle  \\
(0018, 5100) Patient Position \\
(0020, 0037) Image Orientation (Patient) \\

In [None]:
#Access the JSON file with all of the sorted good images
f = open('../input/brain-tumor-valid-images-in-order/all_valid_image_names.json')
good_image_names = json.load(f)
print('Number of Patients with good images: ', len(good_image_names))

In [None]:
patient = '00018'
codes = {'Acquisition_Matrix': ('0018', '1310'), 'Flip_Angle': ('0018', '1314'), 'Patient Position': ('0018', '5100'), 
         'Image Orientation (Patient)': ('0020', '0037')}
modes = ['FLAIR', 'T1wCE', 'T1w', 'T2w']
def count_orientations(patient, mode, good_images=good_image_names):
    image_list = []
    for im in good_images[patient][mode]:
        image_list.append(glob.glob(f'{train_path}/train/{patient}/{mode}/{im}')[0])
        
    print('Number of good images for patient %s in mode %s: %i' % (patient, mode, len(image_list)))
    
    code_lists = {k:[] for k in codes.keys()}
    
    for im in image_list:
        dicom = pydicom.dcmread(im)
        for c in codes.keys():
            data = dicom[codes[c]].value
            if type(data) == pydicom.multival.MultiValue or type(data) == list:
                data = np.array(data)
                data = tuple(data.round(decimals=4))
            code_lists[c].append(data)
    
    return code_lists

for m in modes:
    print('-' * 25)
    counts = count_orientations(patient, m)
    for k,v in counts.items():
        print('Number of distinct %s: %i' % (k, len(set(v))))
        print('%s settings: ' % k)
        print(set(v))


In [None]:
for i in range(1):
    _brats21id = train_df.iloc[i]["BraTS21ID"]
    _mgmt_value = train_df.iloc[i]["MGMT_value"]
    visualize_sample(brats21id=_brats21id, mgmt_value=_mgmt_value, slice_i=0.5)

In [None]:
#This is from the furcifer notebook
def _dicom2array(path, voi_lut=True, fix_monochrome=True, resize=False):
    dicom = pydicom.read_file(path)
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    #Normalize the data: subtract off the minimum, divide by the maximum, convert to 256 uint8
    data = data - np.min(data)
    data = data/np.max(data)
    data = (data * 255).astype(np.uint8)
    
    #Resize images to target value
    if resize:
        data = cv2.resize(data, (256, 256))
    return data

In [None]:
patient = '00000'
modes = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
chosen_images = []
for m in modes:
    image_names = good_image_names['00000'][m]
    n_images = len(image_names)
    floor = int(n_images * 0.25)
    ceiling = int(n_images * 0.75)
    i_slice = random.choice(range(floor, ceiling))

    chosen_image = image_names[i_slice]
    path = f'{train_path}/train/00000/{m}/{chosen_image}'
    chosen_images.append(_dicom2array(path)) 

plt.figure(figsize=(16,5))

for i,m in enumerate(modes):
    plt.subplot(1, 4, i+1)
    plt.imshow(chosen_images[i], cmap="gray")
    plt.title(f"{m}", fontsize=16)
    plt.axis("off")

plt.suptitle(f' Patient: 00000', fontsize=16)
plt.show()

In [None]:
mismatch_counts = {m:0 for m in modes}
for p in good_image_names.keys():
    print('Patient: ', p)
    for m in good_image_names[p].keys():
        counts = count_orientations(p, m)
        for k,v in counts.items():
            if len(set(v)) > 1:
                print('Mismatch:  mode %s setting %s has %i distinct values' % (m, k, len(set(v))))
                mismatch_counts[m] += 1
        

In [None]:
def orientation_imager(patient, mode, good_images=good_image_names, key='Image Orientation (Patient)'):
    #Grab all the viable images for this patient & mode
    image_list = []
    for im in good_images[patient][mode]:
        image_list.append(glob.glob(f'{train_path}/train/{patient}/{mode}/{im}')[0])
        
    #Group the image names according to their orientations
    orientation_groups = {}
    for im in image_list:
        dicom = pydicom.dcmread(im)
        data = dicom[codes[key]].value
        if type(data) == pydicom.multival.MultiValue or type(data) == list:
            data = np.array(data)
            data = tuple(data.round(decimals=4))
        if data not in orientation_groups.keys():
            orientation_groups[data] = [im]
        else:
            orientation_groups[data].append(im)
            
    #Plot the middle one for each orientation
    n = len(orientation_groups.keys())
    fig = plt.figure(figsize=(16,5))
    fig.suptitle(f'Patient: {patient}', fontsize=16)
    rows = n
    columns = 1
    
    for i, k in enumerate(orientation_groups.keys()):
        images = orientation_groups[k]
        mid_slice = int(len(images) * 0.5)
        chosen_im = images[mid_slice]
        fig.add_subplot(rows, columns, i+1)
        chosen_image = _dicom2array(f'{chosen_im}')
        plt.imshow(chosen_image)
        plt.axis('off')
        plt.title(f'{k}')
        plt.show()
    
    return

In [None]:
orientation_imager('00018', 'T1w')

In [None]:
#These are the patients whose DICOM files, for at least one modality, are missing the "slice location" information
#and therefore could not be sorted in the same way.
core_weirds = ['00109', '00157', '00170', '00186', '00353', '00367', '00414', '00561', '00563', '00564', '00565',
         '00756', '00834', '00839']

Because we have a fixed data set & each patient can be expensive to analyze in detail, we have produced a JSON file which has, as keys, the patient IDs, and values are dictionaries, where each dictionary has the modalities (FLAIR, etc) as keys and lists of image names as the values.  The image names within each list are in order according to the metadata in the corresponding DICOM files, where we have used 'slice location', accessed via this command: pydicom.dcmread(im)[('0020', '1041')].value.  However, there is a handful of patients with DICOM images that do not contain this information.  We will probably return to them later, but for the moment they have been classified as "weird" and are excluded from our initial training & analysis.

Several of these helper functions came from or were inspired by this notebook: https://www.kaggle.com/furcifer/no-baseline-pytorch-cnn-for-mri?scriptVersionId=68186710


In [None]:
#To get an understanding of the what dcm files look like by manually changing the path
import matplotlib.pyplot as plt
from pydicom import dcmread
from pydicom.data import get_testdata_file

ds = dcmread('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00000/FLAIR/Image-119.dcm')
# plot the image using matplotlib
plt.imshow(ds.pixel_array, cmap=plt.cm.gray)
plt.show()

If one looks at the raw images, one sees lots of blank space around the actually brain, which is not useful for classification.  The following function finds the edges of the brain and eliminates the surrounding blank space.  Note: this is circumscribing it, not eliminating *all* the blank space, there is still blank space in the corners because PyTorch needs a cuboid input (like putting something in a box, not shrinkwrapping it).

In [None]:
HP_control = {'OPTIMIZER': 'Adam', 'ILR': 1e-3, 'SCHEDULER': 'StepLR', 'STEP_SIZE': 10, 'GAMMA': 0.1,
             'HOLE_SIZE': 10, 'HOLE_LIMIT': 5, 'NUM_SLICES': 17, 'VAL_SPLIT': 0.25, 'MIN_HEIGHT_WIDTH': 32}
NUM_SLICES = HP_control['NUM_SLICES']
VAL_SPLIT = HP_control['VAL_SPLIT']
MIN_HEIGHT_WIDTH = HP_control['MIN_HEIGHT_WIDTH']

In [None]:
#Some patients, for some modalities, only have a handful of images, which make them functionally unusable. 
#This cell loops through the good_image_names file, and returns a dictionary with the image modalities as 
#the keys and lists of "underpopulated" patients as the values
threshold = NUM_SLICES + 2
underpopulated = {}
modes = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
for mode in modes:
    sublist = [key for key, value in good_image_names.items() if len(value[mode]) < threshold]
    underpopulated[mode] = sublist
    
print('FLAIR underpopulated list: ', underpopulated['FLAIR'])
print('FLAIR underpopulated length: ', len(underpopulated['FLAIR']))
print('T1w underpopulated list: ', underpopulated['T1w'])
print('T1w underpopulated length: ', len(underpopulated['T1w']))
print('T1wCE underpopulated list: ', underpopulated['T1wCE'])
print('T1wCE underpopulated length: ', len(underpopulated['T1wCE']))
print('T2w underpopulated list: ', underpopulated['T2w'])
print('T2w underpopulated length: ', len(underpopulated['T2w']))

final_exclusions = {'FLAIR': core_weirds + underpopulated['FLAIR'], 'T1w': core_weirds + underpopulated['T1w'], 
                   'T1wCE': core_weirds + underpopulated['T1wCE'], 'T2w': core_weirds + underpopulated['T2w']}
print('final exclusions: ', final_exclusions)

In [None]:
#Eliminate as much blank space around a 2D slice as possible
def _circumscriber2D(img: np.array) -> np.array:
    vmin = 0
    hmin = 0
    vlimit, hlimit = img.shape
    
    for i in range(vlimit):
        if np.max(img[i, :]) == 0:
            vmin += 1
        else:
            break
    vmax = vmin + 1
    for i in range(vmin+1, vlimit):
        if np.max(img[i, :]) > 0:
            vmax += 1
        else:
            break

    for j in range(hlimit):
        if np.max(img[:, j]) == 0:
            hmin += 1
        else:
            break
    hmax = hmin + 1
    for j in range(hmin+1, hlimit):
        if np.max(img[:, j]) > 0:
            hmax += 1
        else:
            break
    return img[vmin: vmax, hmin:hmax]

In [None]:
#Eliminate as much blank space around a 3D cube as possible.  This one assumes the purely blank slices were already filtered out
#Ensure each dimension has a minimum size...
def _circumscriber3D(img: np.array, min_height=MIN_HEIGHT_WIDTH, min_width=MIN_HEIGHT_WIDTH) -> np.array:
    #First is vertical, second is horizontal, third is slices
    vmin = 0
    vlimit = img.shape[0]
    hmin = 0
    hlimit = img.shape[1]
    
    for i in range(vlimit):
        if np.max(img[i, :, :]) == 0:
            vmin += 1
        else:
            break
    vmax = vmin + 1
    for i in range(vmin+1, vlimit):
        if np.max(img[i, :, :]) > 0:
            vmax += 1
        else:
            break
    
    for j in range(hlimit):
        if np.max(img[:, j, :]) == 0:
            hmin += 1
        else:
            break
    hmax = hmin + 1
    for j in range(hmin+1, hlimit):
        if np.max(img[:, j, :]) > 0:
            hmax += 1
        else:
            break
            
    delta_v = max(vmax - vmin, min_height)
    delta_h = max(hmax - hmin, min_width)
    return img[vmin:vmin+delta_v, hmin:hmin+delta_h, :]

This is the Cutout data augmenter: when active, it will randomly zero out cubes of the brain image.  My guess is that, with the proper block size ('side' parameter), this will be a more effective regularizer than SliceSkip, but I could be mistaken.
Besides the size of the blocks, the number of dropped blocks will likely also matter a fair bit, but I'll add that functionality later.  I have changed the name from DropBlock to Cutout in keeping with the convention described in the DropBlock paper: https://arxiv.org/pdf/1810.12890.pdf
The full DropBlock regularizer applies this to all the feature maps produced by the convolutional layers, whereas what I'm doing here is making random changes to the input.

In [None]:
def _Cutout(img: np.array, side: int, hole_count: int) -> np.array:
    ycap = img.shape[0] - side
    xcap = img.shape[1] - side
    
    #Safety check
    if ycap <= 5 or xcap <= 5:
        return img
    
    nholes = random.randint(1, hole_count)
    ycorners = random.sample(range(0, ycap), nholes)
    xcorners = random.sample(range(0, xcap), nholes)
    
    if len(img.shape) == 3:
        dcap = img.shape[2] - min(side, img.shape[2])
        if dcap == 0:
            dcorners = [0] * nholes
        else:
            dcorners = random.sample(range(0, dcap), nholes)
    
    for i in range(nholes):
        if len(img.shape) == 3:
            img[ycorners[i]:ycorners[i]+side, xcorners[i]:xcorners[i]+side, dcorners[i]:dcorners[i]+side] = 0
        else:
            img[ycorners[i]:ycorners[i]+side, xcorners[i]:xcorners[i]+side] = 0
    
    return img

In [None]:
#This was copy-pasted from the furcifer notebook
def plot_imgs(imgs, cols=4, size=7, is_rgb=True, title="", cmap='gray', img_size=(512,512)):
    rows = len(imgs)//cols + 1
    fig = plt.figure(figsize=(cols*size, rows*size))
    for i in range(4):
        img = imgs[:,:,i]
        fig.add_subplot(rows, cols, i+1)
        plt.imshow(img, cmap=cmap)
    plt.suptitle(title)
    plt.show()

In [None]:
#Apply transformations for data augmentation
#Not sure if RandomRotation & RandomHorizontalFlip are appropriate for 3D tensors.
#The images are already normalized, don't really need that one either.
chosen_transforms = {'train': transforms.Compose([
    transforms.ToTensor(),
    #transforms.RandomRotation(degrees=90),
    #transforms.RandomHorizontalFlip(),
    #transforms.Normalize(means, stds)
]),
   'val': transforms.Compose([
       transforms.ToTensor(),
       #transforms.Normalize(means, stds)
   ])}

In [None]:
#This was inspired by load_rand_dicom_images in the furcifer notebook
def load_FULL_brain(scan_id, split = 'train', modality='FLAIR', image_names=None, sliceSkip=0.0, cutOut=False):
    """
    send all of the images in the chosen modality, in order, as a single 3D np array
    """
    if split != "train" and split != "test":
        print('Please request a valid split: train or test.  Defaulting to train.')
        split = "train"
        
    if modality != 'FLAIR' and modality != 'T1w' and modality != 'T1wCE' and modality != 'T2w':
        print('Please select an appropriate modality: FLAIR, T1w, T1wCE, or T2w')
        print('Defaulting to FLAIR')
        modality = 'FLAIR'
        
    if sliceSkip >= 1:
        sliceSkip = 0
        print('Please choose a valid sliceSkip number, from 0 to 1 inclusive/exclusive')
        
    if image_names:
        image_block = []
        #print('Patient ID: ', scan_id)
        #print('Number of available images: ', len(image_names))
        for im in image_names:
            skipper = random.random()
            if skipper < sliceSkip:
                continue
            image_block.append(glob.glob(f'{train_path}/{split}/{scan_id}/{modality}/{im}')[0])
    
        #print('Number of images in use: ', len(image_block))
        real_images = [_dicom2array(f) for f in image_block]
        good_images = np.array([im for im in real_images if np.max(im) > 0]).T #This filtering should be superfluous now
        final_image = _circumscriber3D(good_images)
        #Only use Cutout data augmentation if explicitly desired
        if cutOut:
            final_image = _Cutout(final_image, HP_control['HOLE_SIZE'], HP_control['HOLE_LIMIT'])
            
        return final_image
        
    print('Please input the valid image names in the proper order')

In [None]:
class MultiSliceLoader(Dataset):
    def __init__(self, label_file, path, modality, num_slices=1, split='train', good_images=None, exclude=None, 
                 val_split=0.25, transform=None, cutout=False):
        train_data = pd.read_csv(os.path.join(path, label_file))
        self.labels = {}
        self.path = path
        brats = list(train_data['BraTS21ID'])
        mgmt = list(train_data['MGMT_value'])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        self.split = split
        self.modality = modality
        self.num_slices = num_slices
        self.good_images = good_images
        self.exclude = exclude
        self.transform = transform
        self.cutout = cutout
        self.splitdir = 'train'
        if self.split == 'test':
            self.splitdir = 'test'
        self.ids = [a.split('/')[-1] for a in sorted(glob.glob(path + f'/{self.splitdir}/*'))]
        self.ids = [a for a in self.ids if a not in self.exclude]
        stop = int(len(self.ids) * (1 - val_split))
        if split == 'train':
            self.ids = self.ids[:stop]
        elif split == 'val':
            self.ids = self.ids[stop:]
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        p_id = self.ids[idx]
        mode = self.modality
        patient_images = self.good_images[p_id][mode]
        
        n_available = len(patient_images)
        if n_available == 0:
            print('There doesnt seem to be anything here! Lets go back and tell Master Luke...')
            return None
        
        #ceil = n_available - self.num_slices
        #first = random.choice(range(ceil))
        #if n_available % 2 == 0:
        #    first -= 1 
        #last = first + self.num_slices
        first = random.choice([0,1])
        #last = n_available
        #targets = patient_images[first:last]
        targets = patient_images[first::2]
        target_paths = [f'{train_path}/{self.splitdir}/{p_id}/{mode}/{t}' for t in targets]
        image = np.stack([_dicom2array(p) for p in target_paths]).T
        
        image = _circumscriber3D(image)
        
        if self.cutout:
            image = _Cutout(image, HP_control['HOLE_SIZE'], HP_control['HOLE_LIMIT'])
        
        transform = self.transform
        image = transform(image)
        
        image = image.unsqueeze(0)
        image = F.interpolate(image, 128)
        
        if self.split != 'test':
            label = torch.tensor(self.labels[p_id], dtype=torch.long)
            return torch.tensor(image, dtype=torch.float32), label
        return torch.tensor(image, dtype=torch.float32)

In [None]:
class MultiModeLoader(Dataset):
    def __init__(self, label_file, path, num_slices=1, split='train', good_images=None, exclude=None, 
                 val_split=0.25, transform=None, cutout=False):
        train_data = pd.read_csv(os.path.join(path, label_file))
        self.labels = {}
        self.path = path
        brats = list(train_data['BraTS21ID'])
        mgmt = list(train_data['MGMT_value'])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        self.split = split
        self.num_slices = num_slices
        self.good_images = good_images
        self.exclude = exclude
        self.transform = transform
        self.cutout = cutout
        self.splitdir = 'train'
        if self.split == 'test':
            self.splitdir = 'test'
        self.ids = [a.split('/')[-1] for a in sorted(glob.glob(path + f'/{self.splitdir}/*'))]
        self.ids = [a for a in self.ids if a not in self.exclude]
        stop = int(len(self.ids) * (1 - val_split))
        if split == 'train':
            self.ids = self.ids[:stop]
        elif split == 'val':
            self.ids = self.ids[stop:]
            
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        p_id = self.ids[idx]
        all_patient_images = self.good_images[p_id]
        
        output = {}
        for k,v in all_patient_images.items():
            n_available = len(v)
            if n_available == 0:
                print('There doesnt seem to be anything here! Lets go back and tell Master Luke...')
                return None
            
            ceil = n_available - self.num_slices
            #first = random.choice(range(ceil))
            first = random.choice([0,1])
            last = first + self.num_slices
            
            #targets = v[first:last]
            targets = v[first::2]
            target_paths = [f'{train_path}/{self.splitdir}/{p_id}/{k}/{t}' for t in targets]
            image = np.stack([_dicom2array(p) for p in target_paths]).T
        
            image = _circumscriber3D(image)
        
            if self.cutout:
                image = _Cutout(image, HP_control['HOLE_SIZE'], HP_control['HOLE_LIMIT'])
        
            transform = self.transform
            image = transform(image)
            image = image.unsqueeze(0)
            image = F.interpolate(image, 64)
            
            output[k] = torch.tensor(image, dtype=torch.float32)
        
        if self.split != 'test':
            label = torch.tensor(self.labels[p_id], dtype=torch.long)
            return output, label
        return output
        

In [None]:
#Now let's test MultiModeLoader...
train_bs = 1
val_bs = 1
exclusion = []
for k,v in final_exclusions.items():
    exclusion.extend(v)
exclusion = list(set(exclusion))
train_dataset = MultiModeLoader('train_labels.csv', train_path, num_slices=NUM_SLICES, split='train',
                               good_images=good_image_names, exclude=exclusion, val_split=VAL_SPLIT,
                               transform=chosen_transforms['train'])
val_dataset = MultiModeLoader('train_labels.csv', train_path, num_slices=NUM_SLICES, split='val',
                             good_images=good_image_names, exclude=exclusion, val_split=VAL_SPLIT,
                             transform=chosen_transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_bs, shuffle=False)

dataloaders = {'train': train_loader, 'val': val_loader}

for imgs, label in train_loader:
    print('------------Iteration--------------')
    #print('Keys: ', imgs.keys())
    for k,v in imgs.items():
        print('Key: ', k)
        print('Shape: ', v.shape)
    break
for imgs, label in val_loader:
    print('-----------Iteration---------------')
    for k,v in imgs.items():
        print('Key: ', k)
        print('Shape: ', v.shape)
    break

In [None]:
#Let's see if MultiSliceLoader works...
mode = 'T2w'
train_bs = 1
val_bs = 1
train_dataset = MultiSliceLoader('train_labels.csv', train_path, modality=mode, num_slices=NUM_SLICES, split='train',
                                 good_images=good_image_names, exclude=final_exclusions[mode], val_split=VAL_SPLIT,
                                 transform=chosen_transforms['train'])
val_dataset = MultiSliceLoader('train_labels.csv', train_path, modality=mode, num_slices=NUM_SLICES, split='val',
                              good_images=good_image_names, exclude=final_exclusions[mode],  val_split=VAL_SPLIT,
                              transform=chosen_transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_bs, shuffle=False)

dataloaders = {'train': train_loader, 'val': val_loader}

for img, label in train_loader:
    print('---------------Iteration---------------------')
    print(img.shape)
    print(img.min())
    print(img.mean())
    print(img.max())
    print(label.shape)
    break
    
for img, label in val_loader:
    print('---------------Iteration---------------------')
    print(img.shape)
    print(img.min())
    print(img.mean())
    print(img.max())
    print(label.shape)
    break

In [None]:
print('len(good_image_names): ', len(good_image_names))
print('len(underpopulated[mode]): ', len(underpopulated[mode]))

In [None]:
class FeatureFinder(nn.Module):
    def __init__(self, init_features=64):
        super(FeatureFinder, self).__init__()
        
        features = init_features
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=features, kernel_size=7, padding=3)
        self.bnorm = nn.BatchNorm3d(features)
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride=(2,2,2), padding=1)
        self.avgpool = nn.AdaptiveAvgPool3d(8)
        
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = self.bnorm(x)
        x = self.maxpool(x)
        x = self.avgpool(x)
        
        return x
        

In [None]:
class MultiModeModel(nn.Module):
    def __init__(self, FLAIR, T1w, T1wCE, T2w):
        super(MultiModeModel, self).__init__()
        
        #self.submodels = {'FLAIR': FLAIR, 'T1w': T1w, 'T1wCE': T1wCE, 'T2w':T2w}
        self.FLAIR = FLAIR
        self.T1w = T1w
        self.T1wCE = T1wCE
        self.T2w = T2w
        
        self.flat = nn.Flatten()
        self.linear1 = nn.Linear(2**17, 256)
        self.linear2 = nn.Linear(256, 1)
        
    def forward(self, x):
        temp = []
        for k,v in x.items():
            #t = self.submodels[k](v)
            if k == 'FLAIR':
                t = self.FLAIR(v)
            elif k == 'T1w':
                t = self.T1w(v)
            elif k == 'T1wCE':
                t = self.T1wCE(v)
            else:
                t = self.T2w(v)
            temp.append(t)
        T = torch.cat(tuple(temp), dim=1)
        T = self.flat(T)
        T = T.squeeze()
        T = F.relu(self.linear1(T))
        T = self.linear2(T)
        
        
        return T

In [None]:
class SingleModeModel(nn.Module):
    def __init__(self, init_features=32):
        super(SingleModeModel, self).__init__()
        
        features = init_features
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=features, kernel_size=7, padding=2)
        self.conv2 = nn.Conv3d(in_channels=features, out_channels=features * 2, kernel_size=5, padding=1)
        self.conv3 = nn.Conv3d(in_channels=features * 2, out_channels=features * 4, kernel_size=3, padding=1)
        self.conv4 = nn.Conv3d(in_channels=features * 4, out_channels=features * 8, kernel_size=3, padding=1)
        
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride=(1,2,2))
        
        AVGPOOL = (4,4,4)
        self.avgpool = nn.AdaptiveAvgPool3d(AVGPOOL)
        self.flat = nn.Flatten()
        self.linear1 = nn.Linear(features * 8 * AVGPOOL[0] * AVGPOOL[1] * AVGPOOL[2], 1)
        
    def forward(self, x):
        #print('initial x size: ', x.size())
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        #print('conv2 x size: ', x.size())
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        #print('conv3 x size: ', x.size())
        x = self.maxpool(x)
        x = F.relu(self.conv4(x))
        #print('conv4 x size: ', x.size())
        x = self.avgpool(x)
        x = self.flat(x)
        x = x.squeeze()
        x = self.linear1(x)
        #print('final x size: ', x.size())
        
        return x
        

In [None]:
SMM = SingleModeModel()
test = torch.randn(1, 1, 100, 150, 200)
SMM(test)

In [None]:
#Let's build a proper 3D CNN.  This is the encoder half of the U-Net used in a previous competition for identifying anomalous regions
#in MRI brain scans, expanded into 3D.  References: https://arxiv.org/abs/1505.04597,
#https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/master/unet.py

class Encoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(Encoder, self).__init__()

        features = init_features
        self.encoder1 = Encoder._block(in_channels, features, name='enc1')
        self.encoder2 = Encoder._block(features, features * 2, name='enc2')
        self.encoder3 = Encoder._block(features * 2, features * 4, name='enc3')
        self.encoder4 = Encoder._block(features * 4, features * 8, name='enc4') 
        
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride=(1,2,2))

        #Now that features have been encoded, let's use an average pooling layer
        AVG_SIZE = (4,4,4)
        self.avg = nn.AdaptiveAvgPool3d(AVG_SIZE)
        self.flat = nn.Flatten()
        self.linear1 = nn.Linear(features * 8 * AVG_SIZE[0] * AVG_SIZE[1] * AVG_SIZE[2], 1)
        #self.linear2 = nn.Linear(256, 1)

    def forward(self, x):
        #print('x size: ', x.size())
        x = x.to(device)
        x = self.encoder1(x)
        #print('enc1 size: ', x.size())
        x = self.maxpool(x)
        #print('pool1 size: ', x.size())
        x = self.encoder2(x)
        #print('enc2 size: ', x.size())
        x = self.maxpool(x)
        #print('pool3 size: ', x.size())
        x = self.encoder3(x)
        #print('enc3 size: ', x.size())
        x = self.maxpool(x)
        #print('pool3 size: ', x.size())
        x = self.encoder4(x)
        #print('enc4 size: ', x.size())
        #x = self.maxpool(x)
        #print('pool4 size: ', x.size())
        x = self.avg(x)
        #print('adaptive pool size: ', x.size())
        x = self.flat(x)
        #print('flat size: ', x.size())
        x = x.squeeze()
        #print('Squeezed size: ', x.size())
        x = self.linear1(x)
        #x = F.relu(self.linear1(x))
        #print('linear1 size: ', x.size())
        #x = self.linear2(x)
        #print('linear2 size: ', x.size())
        
        return x

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict([
                         (
                             name + 'conv1',
                             nn.Conv3d(in_channels=in_channels,
                                       out_channels=features,
                                       kernel_size=3, 
                                       padding=1,
                                       bias=False)
                         ),
                         (name + 'norm1', nn.BatchNorm3d(num_features=features)),
                         (name + 'relu1', nn.ReLU(inplace=True)),
                         (
                             name + 'conv2',
                             nn.Conv3d(in_channels=features,
                                       out_channels=features,
                                       kernel_size=3,
                                       padding=1,
                                       bias=False)
                         ),
                         (name + 'norm2', nn.BatchNorm3d(num_features=features)),
                         (name + 'relu2', nn.ReLU(inplace=True))
            ])
        )


In [None]:
#Use this cell to test the shapes produced by Encoder().  This does require uncommenting the print statements in forward()
enc_model = Encoder()
test_input = torch.randn(1, 1, 100, 150, 200)
#enc_model(test_input)

In [None]:
class MultiEncoder(nn.Module):
    def __init__(self, FLAIR_model, T1w_model, T1wCE_model, T2w_model):
        super(MultiEncoder, self).__init__()
        
        #self.submodels = {'FLAIR': Encoder, 'T1w': Encoder, 'T1wCE': Encoder, 'T2w': Encoder}
        self.FLAIR_model = FLAIR_model
        self.T1w_model = T1w_model
        self.T1wCE_model = T1wCE_model
        self.T2w_model = T2w_model
        self.classifier = nn.Linear(4, 1)
        
    def forward(self, x):
        temp = []
        for k,v in x.items():
            #print('Shape of v: ', v.shape)
            #print('Type of v: ', type(v))
            #v = v.to(device)
            if k == 'FLAIR':
                t = self.FLAIR_model(v)
            elif k == 'T1w':
                t = self.T1w_model(v)
            elif k == 'T1wCE':
                t = self.T1wCE_model(v)
            else:
                t = self.T2w_model(v)
            temp.append(t)
        T = torch.cat(tuple(temp), dim=1)
        #print('T shape: ', T.shape)
        X = self.classifier(F.relu(T))
        return X

In [None]:
FLAIR_model = Encoder()
T1w_model = Encoder()
T1wCE_model = Encoder()
T2w_model = Encoder()
multi_model = MultiEncoder(FLAIR_model, T1w_model, T1wCE_model, T2w_model)
for imgs, label in train_loader:
    print('------------Iteration--------------')
    for k,v in imgs.items():
        print('Key: ', k)
        print('Shape: ', v.shape)
    #print(multi_model(imgs))
    break

In [None]:
model_type = 'multi'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = xm.xla_device()
#torch.set_default_tensor_type('torch.FloatTensor')
if model_type == 'single':
    train_size = int((len(good_image_names) - len(underpopulated[mode])) * (1 - VAL_SPLIT))
    val_size = len(good_image_names) - len(underpopulated[mode]) - train_size
    #model = Encoder()
    model = SingleModeModel()
elif model_type == 'multi':
    train_size = int((len(good_image_names) - len(exclusion)) * (1 - VAL_SPLIT))
    val_size = len(good_image_names) - len(exclusion) - train_size
    #FLAIR = Encoder()
    #T1w = Encoder()
    #T1wCE = Encoder()
    #T2w = Encoder()
    #model = MultiEncoder(FLAIR, T1w, T1wCE, T2w)
    FLAIR = FeatureFinder()
    T1w = FeatureFinder()
    T1wCE = FeatureFinder()
    T2w = FeatureFinder()
    model = MultiModeModel(FLAIR, T1w, T1wCE, T2w)

dataset_sizes = {'train': train_size, 'val': val_size}
print('dataset_sizes: ', dataset_sizes)
criterion = nn.BCEWithLogitsLoss()
#According to this page, the model should be sent to the device before setting the optimizer
#https://pytorch.org/docs/stable/optim.html 

#Send model to device and set a default optimizer
model = model.to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)

#If one is selected, overwrite it
if HP_control['OPTIMIZER'] == 'Adam':
    optimizer_ft = optim.Adam(model.parameters(), lr=HP_control['ILR'])
    
lr_scheduler = None
if HP_control['SCHEDULER'] == 'StepLR':
    # Decay LR by a factor of 0.1 every 5 epochs
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=HP_control['STEP_SIZE'], gamma=HP_control['GAMMA'])
elif HP_control['SCHEDULER'] == 'ExponentialLR':
    # Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer_ft, gamma=HP_control['GAMMA'])

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, device='cpu'):
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    print('Initiating training loop...')
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            for img, label in dataloaders[phase]:
                #Send everything to the relevant device
                #print('Loading in new img and label')
                #img = img.to(device)
                label = label.float()
                label = label.to(device)
                
                #Zero the parameter gradients
                optimizer.zero_grad()
                
                #Forward pass, track history only in train mode
                with torch.set_grad_enabled(phase == 'train'):
                    output = model(img)
                    #print('output: ', output)
                    pred = torch.round(output)
                    loss = criterion(output, label)
                 
                
                    #Backward pass, optimize only in training mode
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        #xm.optimizer_step(optimizer, barrier=True)
                        
                # statistics
                #running_loss += loss.item() * img.size(0)
                running_loss += loss.item()
                running_corrects += torch.sum(pred == label.data)
            
            if phase == 'train' and scheduler is not None:
                scheduler.step()
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))

            #Eject if accuracy super-low --> flip to get super-high accuracy
            if epoch_acc < 0.25:
                print('Accuracy absymal, ejecting and saving model for flip')
                best_model_wts = copy.deepcopy(model.state_dict())
                return model
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
    
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
    

In [None]:
model = train_model(model, criterion, optimizer_ft, lr_scheduler, num_epochs=25, device=device)

In [None]:
class ReverseModel(nn.Module):
    def __init__(self, BaseModel):
        super(ReverseModel, self).__init__()
        
        self.base = BaseModel
        
    def forward(self, x):
        x = self.base(x) 
        return 1 - x

In [None]:
# let's write our simplest cnn, then we can add variations or different models for improvement in results

class SimpleCNN(nn.Module): 
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=9) # 4 incoming channels
        self.conv2 = nn.Conv2d(64, 32, kernel_size=7)
        self.conv2_drop = nn.Dropout2d()
        self.conv3 = nn.Conv2d(32, 16, kernel_size=5)
        self.conv3_drop = nn.Dropout2d()
        self.conv4 = nn.Conv2d(16, 8, kernel_size=3)
        self.conv4_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(200, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 4))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 4))
        x = F.relu(F.max_pool2d(self.conv3_drop(self.conv3(x)), 2))
        x = F.relu(F.max_pool2d(self.conv4_drop(self.conv4(x)), 2))
        x = x.view(x.shape[0],-1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [None]:
model=SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)
n_epochs = 1

In [None]:
print(model)
model(torch.randn(1, 4, 512, 512))

In [None]:
simpleCNN = train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25, device=device)