# RSNA-MICCAI Brain Tumor Radiogenomic Classificationn - **An approach with PyTorch EfficientNet 3D**

## **Problem Description**:

There are structural multi-parametric MRI (mpMRI) scans for different subjects, in DICOM format. The exact mpMRI scans included are:

* Fluid Attenuated Inversion Recovery (FLAIR)
* T1-weighted pre-contrast (T1w)
* T1-weighted post-contrast (T1Gd)
* T2-weighted (T2)

`train_labels.csv` - file contains the target **MGMT_value** for each subject in the training data **(e.g. the presence of MGMT promoter methylation)**.

So, it's a binary classification problem.

## **An EfficientNet3D solution**:

* For each patient, we consider 4 sequences (FLAIR, T1w, T1Gd, T2), and for each of those sequences we take 64 slices from the middle. We resize the slices in shape (256, 256).

* Construct an efficientnet-3d in pytorch with input shape (256, 256, 256) or (4, 256, 256, 64).

* Perform binary classification.


### ⚡ **Inference kernel:** https://www.kaggle.com/furcifer/torch-effnet3d-for-mri-no-inference/


### **Importing libraries**

In [None]:
import os
import glob
from tqdm import tqdm_notebook as tqdm
import math
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
from sklearn.metrics import roc_auc_score


import warnings
warnings.filterwarnings("ignore")

### **Importing EfficientNet-3D**

In [None]:
import sys
sys.path.append('../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D')
from efficientnet_pytorch_3d import EfficientNet3D

### **Inspecting Labels**

In [None]:
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
print('Num of train samples:', len(train_data))
train_data.head()

### **MRI Slice Loading/Processing**

In [None]:
def dicom2array(path, voi_lut=True, fix_monochrome=True, remove_black_boundary=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    if remove_black_boundary:
        (x, y) = np.where(data > 0)
        if len(x) > 0 and len(y) > 0:
            x_mn = np.min(x)
            x_mx = np.max(x)
            y_mn = np.min(y)
            y_mx = np.max(y)
            if (x_mx - x_mn) > 10 and (y_mx - y_mn) > 10:
                data = data[:,np.min(y):np.max(y)]
    data = cv2.resize(data, (256, 256))
    return data

def load_3d_dicom_images(scan_id, split = "train", channel_expand = True, remove_black_boundary=True):
    """
    we will use some heuristics to choose the slices to avoid any numpy zero matrix (if possible)
    """
    flair = sorted(glob.glob(f"{path}/{split}/{scan_id}/FLAIR/*.dcm"))
    t1w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1w/*.dcm"))
    t1wce = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1wCE/*.dcm"))
    t2w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T2w/*.dcm"))
    
    
    flair_img = np.array([dicom2array(a, remove_black_boundary = remove_black_boundary) for a in flair[len(flair)//2 - 32:len(flair)//2 + 32]]).T
    
    if len(flair_img) == 0:
        flair_img = np.zeros((256, 256, 64))
    elif flair_img.shape[-1] < 64:
        n_zero = 64 - flair_img.shape[-1]
        flair_img = np.concatenate((flair_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(flair_img.shape)
        
    
    
    t1w_img = np.array([dicom2array(a, remove_black_boundary = remove_black_boundary) for a in t1w[len(t1w)//2 - 32:len(t1w)//2 + 32]]).T
    
    if len(t1w_img) == 0:
        t1w_img = np.zeros((256, 256, 64))
    elif t1w_img.shape[-1] < 64:
        n_zero = 64 - t1w_img.shape[-1]
        t1w_img = np.concatenate((t1w_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t1w_img.shape)
    
    
    t1wce_img = np.array([dicom2array(a, remove_black_boundary = remove_black_boundary) for a in t1wce[len(t1wce)//2 - 32:len(t1wce)//2 + 32]]).T
    
    if len(t1wce_img) == 0:
        t1wce_img = np.zeros((256, 256, 64))
    elif t1wce_img.shape[-1] < 64:
        n_zero = 64 - t1wce_img.shape[-1]
        t1wce_img = np.concatenate((t1wce_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t1wce_img.shape)
    
    
    t2w_img = np.array([dicom2array(a, remove_black_boundary = remove_black_boundary) for a in t2w[len(t2w)//2 - 32:len(t2w)//2 + 32]]).T
    
    if len(t2w_img) == 0:
        t2w_img = np.zeros((256, 256, 64))
    elif t2w_img.shape[-1] < 64:
        n_zero = 64 - t2w_img.shape[-1]
        t2w_img = np.concatenate((t2w_img, np.zeros((256, 256, n_zero))), axis = -1)
    #print(t2w_img.shape)
    
    return np.concatenate((flair_img, t1w_img, t1wce_img, t2w_img), axis = -1) if not channel_expand else np.moveaxis(np.array((flair_img, t1w_img, t1wce_img, t2w_img)), 0, -1)

In [None]:
load_3d_dicom_images("00000", channel_expand = False).shape

In [None]:
slices = load_3d_dicom_images("00000", remove_black_boundary=False)
print(slices.shape)

In [None]:
# doing a little more cleaning up
s_slice = slices[:,:,0,0]
plt.imshow(s_slice)
plt.title("Lots of black pixels")
plt.show()
(x, y) = np.where(s_slice > 0)
ns_slice = s_slice[np.min(x):np.max(x),np.min(y):np.max(y)]
plt.title("Less black pixels")
plt.imshow(ns_slice)
plt.show()

### **Visualization**

In [None]:
# https://www.kaggle.com/josepc/rsna-effnet/

views = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
def load_imgs(idx):
    slices = load_3d_dicom_images(idx)
    imgs = {}
    for i, view in enumerate(views):
        imgs[view] = slices[:,:,:,i].swapaxes(0, -1)
    return imgs

for i in range(10,32):
    idx = str(i).zfill(5)
    imgs = load_imgs(idx)

In [None]:
# the video play doesn't work, you can download it to view

from IPython.display import HTML
from base64 import b64encode
import matplotlib.animation as animation

def play(filename):
    html = ''
    video = open(filename,'rb').read()
    src = 'data:video/mp4;base64,' + b64encode(video).decode()
    html += '<video width=500 controls autoplay loop><source src="%s" type="video/mp4"></video>' % src 
    return HTML(html)

def create_video(imgs, output='/kaggle/working/vis_video.mp4', duration=30, subplot=True, 
                frame_delay=200):
    fig, ax = plt.subplots(figsize=(15, 10))
    ims = []
    if not subplot:
        shape = imgs.shape[0]
        for i in range(duration):
            im = ax.imshow(imgs[i % shape], animated=True)
            ims.append([im])
        plt.close(fig)
    else:
        shapes = [imgs[views[0]].shape[0], imgs[views[1]].shape[0], 
                  imgs[views[2]].shape[0], imgs[views[3]].shape[0]]
        fig, ax = plt.subplots(2,2, figsize=(10,10))
        for k in range(duration):
            im_ = []
            for i in range(2):
                for j in range(2):
                    im = ax[i,j].imshow(imgs[views[2*i+j]][k % shapes[2*i+j]], animated=True)
                    im_.append(im)
                    ax[i,j].set_title(views[2*i+j])
                    plt.close()
            ims.append(im_)

    ani = animation.ArtistAnimation(fig, ims, interval=frame_delay, blit=True, repeat_delay=1000)

    ani.save(output)

In [None]:
create_video(imgs, duration=60, subplot=True, frame_delay=300)
play('/kaggle/working/vis_video.mp4')

### **Data Loader**

In [None]:
# let's write a simple pytorch dataloader


class BrainTumor(Dataset):
    def __init__(self, path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification', split = "train", validation_split = 0.2):
        # labels
        train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
        self.labels = {}
        brats = list(train_data["BraTS21ID"])
        mgmt = list(train_data["MGMT_value"])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        if split == "valid":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/train/" + "/*"))]
            self.ids = self.ids[:int(len(self.ids)* validation_split)] # first 20% as validation
        elif split == "train":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            self.ids = self.ids[int(len(self.ids)* validation_split):] # last 80% as train
        else:
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{split}/" + "/*"))]
            
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        imgs = load_3d_dicom_images(self.ids[idx], self.split)
        imgs = imgs - imgs.min()
        imgs = (imgs + 1e-5) / (imgs.max() - imgs.min() + 1e-5)
        
        # imgs = imgs - imgs.mean()
        # imgs = (imgs + 1e-5)/(imgs.std() + 1e-5)
        
        if self.split != "test":
            label = self.labels[self.ids[idx]]
            return torch.tensor(imgs, dtype = torch.float32).permute(-1, 0, 1, 2), torch.tensor(label, dtype = torch.long)
        else:
            return torch.tensor(imgs, dtype = torch.float32).permute(-1, 0, 1, 2)

In [None]:
# testing the dataloader
train_dataset = BrainTumor()
train_bs = 4
train_loader = DataLoader(train_dataset, batch_size = train_bs, shuffle=True)

val_dataset = BrainTumor(split = "valid")
val_bs = 2
val_loader = DataLoader(val_dataset, batch_size = val_bs, shuffle=True)

In [None]:
for img, label in train_loader:
    print(img.shape)
    print(img.max())
    print(img.mean())
    print(img.min())
    print(label.shape)
    break

for img, label in val_loader:
    print(img.shape)
    print(label.shape)
    break

### **Model: EfficientNet-3D B0**

In [None]:
PATH = "../input/rsna-efficientnet3db0/best_roc_0.29_loss_1826.83.pt" # using a pretrained weight

model = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=4)
model.load_state_dict(torch.load(PATH))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),lr = 0.0007, weight_decay=0.08)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.25)
n_epochs = 10

In [None]:
print(model)
model(torch.randn(1, 4, 256, 256, 64))

### **Training**

In [None]:
# helper
def one_hot(arr):
    return [[1, 0] if a_i == 0 else [0, 1] for a_i in arr]

In [None]:
# let's train
gpu = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
model.to(gpu)

train_loss = []
val_loss = []
train_roc = []
val_roc = []
best_roc = 0.0

for epoch in range(n_epochs):  # loop over the dataset multiple times
    y_all = []
    outputs_all = []
    running_loss = 0.0
    roc = 0.0
    
    model.train()
    for i, data in tqdm(enumerate(train_loader, 0)):
        x, y = data
        
        # x = torch.unsqueeze(x, dim = 1)
        x = x.to(gpu)
        y = y.to(gpu)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        y_all.extend(y.tolist())
        outputs_all.extend(outputs.tolist())
    
    roc += roc_auc_score(one_hot(y_all), outputs_all) / train_bs
    print(f"epoch {epoch+1} train loss: {running_loss} train roc: {roc}")
    
    train_loss.append(running_loss)
    train_roc.append(roc)

    y_all = []
    outputs_all = []
    running_loss = 0.0
    roc = 0.0 
    
    model.eval()
    for i, data in tqdm(enumerate(val_loader, 0)):

        x, y = data
        
        # x = torch.unsqueeze(x, dim = 1)
        x = x.to(gpu)
        y = y.to(gpu)

        # forward
        outputs = model(x)
        loss = criterion(outputs, y)

        # print statistics
        running_loss += loss.item()
        y_all.extend(y.tolist())
        outputs_all.extend(outputs.tolist())
    
    roc += roc_auc_score(one_hot(y_all), outputs_all) / val_bs
    scheduler.step(running_loss)
        
    print(f"epoch {epoch+1} val loss: {running_loss} val roc: {roc}")
    
    val_loss.append(running_loss)
    val_roc.append(roc)
    
    if roc > best_roc:
        best_roc = roc
        torch.save(model.state_dict(), f'best_roc_{round(roc, 2)}_loss_{round(running_loss, 2)}.pt')

In [None]:
plt.plot(train_loss, label = 'train loss')
plt.plot(val_loss, label = 'val loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(['train loss', 'val loss'])
plt.show()

plt.plot(train_roc, label = 'train roc')
plt.plot(val_roc, label = 'val roc')
plt.xlabel('epochs')
plt.ylabel('roc auc')
plt.legend(['train roc', 'val roc'])
plt.show()

In [None]:
submission = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv")
submission.to_csv("submission.csv", index=False)
submission