In [1]:
pip install '../input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl'

Processing /kaggle/input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl
Installing collected packages: monai
Successfully installed monai-0.6.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import glob

In [3]:
import albumentations as A
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.optim import lr_scheduler
from tqdm import tqdm
import re

from tensorflow import keras
import tensorflow as tf
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import nilearn as nl
import nibabel as nib
import cv2
from IPython.display import clear_output
from PIL import Image
from pydicom import read_file, dcmread
from scipy.ndimage import zoom

In [4]:
NUM_IMAGES_3D = 128
TRAINING_BATCH_SIZE = 1
TEST_BATCH_SIZE = 2
IMAGE_SIZE = 256
N_EPOCHS = 1
do_valid = True
n_workers = 4
type_ = "T1wCE"
MODEL_NAME = 'version1.enhanced'

In [5]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    if rotate > 0:
        rot_choices = [
            0,
            cv2.ROTATE_90_CLOCKWISE,
            cv2.ROTATE_90_COUNTERCLOCKWISE,
            cv2.ROTATE_180,
        ]
        data = cv2.rotate(data, rot_choices[rotate])

    data = cv2.resize(data, (img_size, img_size))
    return data

In [6]:
def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4
    for i in range(class_num):
        y_true_f = K.flatten(y_true[:,:,:,i])
        y_pred_f = K.flatten(y_pred[:,:,:,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
   #     K.print_tensor(loss, message='loss value for class {} : '.format(SEGMENT_CLASSES[i]))
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
    total_loss = total_loss / class_num
#    K.print_tensor(total_loss, message=' total dice coef: ')
    return total_loss


 
# define per class evaluation of dice coef
# inspired by https://github.com/keras-team/keras/issues/9395
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)



# Computing Precision 
def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    
# Computing Sensitivity      
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())


# Computing Specificity
def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())
############ load trained model ################
seg_model = keras.models.load_model('../input/modelperclasseval/model_per_class.h5', 
                                   custom_objects={ 'accuracy' : tf.keras.metrics.MeanIoU(num_classes=4),
                                                   "dice_coef": dice_coef,
                                                   "precision": precision,
                                                   "sensitivity":sensitivity,
                                                   "specificity":specificity,
                                                   "dice_coef_necrotic": dice_coef_necrotic,
                                                   "dice_coef_edema": dice_coef_edema,
                                                   "dice_coef_enhancing": dice_coef_enhancing
                                                  }, compile=False)

In [7]:
planes_df = pd.read_csv('../input/labels-with-planes/train_labels_with_planes.csv',
                        header=0,
                        names=['ID', 'MGMT_Value', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])

planes_df['ID'] = [(5-len(str(n)))*'0'+str(n) for n in planes_df['ID']]

# create key-value format
AXIS = {}
col_names = ['MGMT_Value', 'FLAIR', 'T1w', 'T1wCE', 'T2w']
for index, row in planes_df.iterrows():
#     for name in col_names:
    AXIS[row['ID']] = {name: row[name] for name in col_names}

In [8]:
DIM = 128
SLICES = 64
MR_DIM = 256
TRAIN_DIR = r'../input/rsna-miccai-brain-tumor-radiogenomic-classification/train'
TEST_DIR = r'../input/rsna-miccai-brain-tumor-radiogenomic-classification/test'
flair = tf.keras.models.load_model('../input/mrequalizer-weights/flair_part1.h5')
t1ce = tf.keras.models.load_model('../input/mrequalizer-weights/t1ce_part1.h5')
t2 = tf.keras.models.load_model('../input/mrequalizer-weights/t2_part1.h5')

enhancer_models = {'FLAIR': flair, 'T1wCE': t1ce, 'T2w': t2}



    
# def axial_2_coronal(mri, show=False):
# #     print('init shape:', mri.shape)
#     cor = np.moveaxis(mri, 0, 2)
#     cor = get_cropped_region(cor, reshape=(DIM, DIM))
#     cor = np.moveaxis(cor, 0, 1)
#     cor = np.rot90(cor, 2)
#     if show:
#         show_mri(cor)
#     return cor

# def sagittal_2_coronal(mri, show=False):
# #     print('init shape:', mri.shape)
#     cor = np.moveaxis(mri, 1, 2)
#     cor = get_cropped_region(cor, reshape=(DIM, DIM))
#     cor = np.moveaxis(cor, 0, 1)
#     cor = np.rot90(cor, 3)
#     if show:
#         show_mri(cor)
#     return cor
    
def get_num(name):
#     print(f'filename: {name}')
    return int(name[name.index('-')+1:name.index('.')])

def load_mri_arr(folder):
    for file in os.listdir(folder):
        slice_path = os.path.join(folder, file)
        dicom = read_file(slice_path)
        img = apply_voi_lut(dicom.pixel_array, dicom)
        shape = (img.shape[0], img.shape[1], len(os.listdir(folder)))
    out = np.empty(shape)
#     files = [os.path.join(folder, file) for file in os.listdir(folder)]
    data = sorted(os.listdir(folder), key=get_num)
    for i, file in enumerate(data):
        slice_path = os.path.join(folder, file)
        img = read_file(slice_path).pixel_array 
        out[:, :, i] = img
    out = out/np.max(out)
    return out

def show_mri(mri):
    for i in range(mri.shape[2]):
        clear_output(wait=True)
        plt.axis(False)
        plt.imshow(mri[:, :, i], cmap='gray')
        plt.show()


def lowleft_upright(img_arr):
    ret,thresh = cv2.threshold(img_arr,20,255,0)
    contours,hierarchy = cv2.findContours(thresh, 1, 2)
    cnt2cons = [cnt for cnt in contours if cv2.contourArea(cnt) >= 100]
    xmin, ymin = 500, 500
    xmax, ymax = -1, -1
    for cnt in cnt2cons:
        area = cv2.contourArea(cnt)
        x,y,w,h = cv2.boundingRect(cnt)
        img = np.zeros((240, 240, 3))
        cv2.rectangle(img,(x,y),(x+w,y+h),(255, 0, 0),2)
        xmin = min(xmin, x)
        xmax = max(xmax, x+w)
        ymin = min(ymin, y)
        ymax = max(ymax, y+h)
    return xmin, xmax, ymin, ymax


def get_cropped_region(img, reshape=None):
    img = img*255.
    img = img.astype(np.uint8)
    xmin_true, ymin_true = 500, 500
    xmax_true, ymax_true = -1, -1
#     print('marker 1')
    for i in range(img.shape[-1]):
#         print('marker 2')
        if np.sum(img[:, :, i]) == 0:
#             print('marker 3')
            continue
        xmin, xmax, ymin, ymax = lowleft_upright(img[:, :, i])
        xmin_true = min(xmin, xmin_true)
        ymin_true = min(ymin, ymin_true)
        xmax_true = max(xmax, xmax_true)
        ymax_true = max(ymax, ymax_true)
        
    count_black = 0
    for i in range(img.shape[-1]):
        count_black += int(np.sum(img[:, :, i]) == 0)
#     print(count_black)
    if ymax_true-ymin_true < 0 or xmax_true-xmin_true < 0 or img.shape[-1]-count_black < 0:
        print(ymax_true-ymin_true, xmax_true-xmin_true, img.shape[-1]-count_black)
    cropped = np.empty((ymax_true-ymin_true, xmax_true-xmin_true, img.shape[-1]-count_black))
    curr_slice = 0
    for i in range(img.shape[-1]):
        if np.sum(img[:, :, i]) == 0:
            continue
        
        crop = img[:, :, i][ymin_true:ymax_true, xmin_true:xmax_true]
        
        cropped[:, :, curr_slice] = crop
        curr_slice += 1
    cropped = cropped/np.max(cropped)
    if reshape is not None:
        reshaped = np.empty((reshape[0], reshape[1], cropped.shape[-1]))
        for i in range(cropped.shape[-1]):
            resized = cv2.resize(cropped[:, :, i], reshape, interpolation=cv2.INTER_AREA)
            reshaped[:, :, i] = resized
        return reshaped
    return cropped


def apply_equalizer(mri, img_type):
    original_dim = (mri.shape[1], mri.shape[0])
    new = np.empty((mri.shape[0], mri.shape[1], 2*mri.shape[2]-1))
    cur = 0
    for i in range(mri.shape[-1]-1):
#         print(mri[:, :, i].shape)
        in1 = cv2.resize(mri[:, :, i], (MR_DIM, MR_DIM))
        in2 = cv2.resize(mri[:, :, i+1], (MR_DIM, MR_DIM))
        output = get_middle_image(in1, in2, img_type)
        output = cv2.resize(output, original_dim)        
        new[:, :, cur] = mri[:, :, i]
        new[:, :, cur+1] = output
        cur += 2
    new[:, :, cur] = mri[:, :, -1]
    return new

def get_middle_image(in1, in2, img_type, show=False): # provide to 256x256 images
#     print(in1.shape)
    in1 = cv2.resize(in1, (MR_DIM, MR_DIM))
    in2 = cv2.resize(in2, (MR_DIM, MR_DIM))
    in1 = in1.reshape(1, MR_DIM, MR_DIM, 1)
    in2 = in2.reshape(1, MR_DIM, MR_DIM, 1)
    data = [in1, in2]
    output = np.array(enhancer_models[img_type](data))
    if show:
        f, axarr = plt.subplots(1, 3, squeeze=False)
        axarr[0][0].axis(False)
        axarr[0][1].axis(False)
        axarr[0][2].axis(False)
        axarr[0][0].imshow(in1[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        axarr[0][1].imshow(output[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        axarr[0][2].imshow(in2[0, :, :, 0], cmap='gray', vmin=0., vmax=1.)
        plt.show()
    return output[0, :, :, 0]

def get_image(folder, img_type, equalizer_iters=0):
    loaded = load_mri_arr(folder)
#     loaded = get_cropped_region(loaded)
#     print('Before:', loaded.shape[-1])
#     for i in range(equalizer_iters):
    while loaded.shape[-1] < DIM and equalizer_iters != 0:
        print('hihihihihasdf;ladf')
        loaded = apply_equalizer(loaded, img_type)
#     print('After:', loaded.shape[-1])
    return loaded

def resize(mri, shape):
    init_shape = mri.shape
    resized = zoom(mri, (shape[0]/init_shape[0], shape[1]/init_shape[1], shape[2]/init_shape[2]))
    return resized

def load_input_for_seg(folder, return_label=True, equalizer_iters=0):
    idnum = folder[-5:]
    data = AXIS[idnum]
    label = data['MGMT_Value']
    inp = np.empty((DIM, DIM, DIM, 2))
    
    path = os.path.join(folder, 'FLAIR')
    loaded = get_cropped_region(get_image(path, 'flair', equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
    loaded = resize(loaded, (DIM, DIM, DIM))
    if data['FLAIR'] == 'Sagittal':
        loaded = sagittal_2_coronal(loaded, False)
    if data['FLAIR'] == 'Axial':
        loaded = axial_2_coronal(loaded, False)
    inp[:, :, :, 0] = np.moveaxis(loaded, 2, 0)
    
    path = os.path.join(folder, 'T1wCE')
    loaded = get_cropped_region(get_image(path, 't1ce', equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
    loaded = resize(loaded, (DIM, DIM, DIM))
    if data['T1wCE'] == 'Sagittal':
        loaded = sagittal_2_coronal(loaded, False)
    if data['T1wCE'] == 'Axial':
        loaded = axial_2_coronal(loaded, False)
    inp[:, :, :, 1] = np.moveaxis(loaded, 2, 0)
    
    if return_label:
        return inp, label
    return inp


def get_mask(inp):
    out = seg_model.predict(inp)
    out = np.argmax(out, axis=-1)
#     print(out.shape)
    out[out != 0] = 1
    out[:12] = np.zeros((12, DIM, DIM), dtype=np.float32)
    out[-12:] = np.zeros((12, DIM, DIM), dtype=np.float32)
    return out

def find_largest_tumor_slice(mask): # return ind of largest tumor slice
    ind = -1
    largest_size = -1
    for i in range(mask.shape[-1]):
        curr = mask[:, :, i]
        tum_size = cv2.countNonZero(curr)
        if tum_size >= largest_size:
            largest_size = tum_size
            ind = i
    return ind

def load_input(
    case_id,
    num_imgs=NUM_IMAGES_3D,
    img_size=IMAGE_SIZE,
    rotate=0,
    equalizer_iters=0
):
    idnum = case_id[-5:]
#     print(idnum)
    data = AXIS[idnum]
    label = data['MGMT_Value']
    path2 = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}"
    path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/{type_}"
    flair_path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/FLAIR"
    t1wce_path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{case_id}/T1wCE"

    try:
        loaded = get_cropped_region(get_image(path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
        loaded = resize(loaded, (DIM, DIM, DIM))
        loaded_flair = get_cropped_region(get_image(flair_path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
        loaded_flair = resize(loaded_flair, (DIM, DIM, DIM))
        loaded_t1wce = get_cropped_region(get_image(t1wce_path, type_, equalizer_iters=equalizer_iters), reshape=(DIM, DIM))
        loaded_t1wce = resize(loaded_t1wce, (DIM, DIM, DIM))
        seg_inp = np.empty((DIM, DIM, DIM, 2))
        seg_inp[:, :, :, 0] = np.moveaxis(loaded_flair, 2, 0)
        seg_inp[:, :, :, 1] = np.moveaxis(loaded_t1wce, 2, 0)
#         if data['FLAIR'] == 'Sagittal':
#             loadexd = sagittal_2_coronal(loaded, False)
#         if data['FLAIR'] == 'Axial':
#             loaded = axial_2_coronal(loaded, False)

        mask = get_mask(seg_inp)
        middle = find_largest_tumor_slice(mask)
        sl_each_side = SLICES//2
        e1 = max(0, middle-sl_each_side)
        e2 = min(loaded.shape[-1]-1, middle+sl_each_side)

        if middle-e1 < e2-middle:
            e2 += 32-(middle-e1)
        elif e2-middle < middle-e1:
            e1 -= 32-(e2-middle)
        to_return = np.empty((DIM, DIM, e2-e1))
        count = 0
#         print(e1, e2, middle)
        for i in range(e1, e2):
            to_return[:, :, count] = loaded[i]
            count += 1
        to_return = np.expand_dims(to_return, axis=0)
        to_return = (to_return-np.min(to_return))/(np.max(to_return)-np.min(to_return))
    #         print('marker 1')
        return to_return
    except Exception as e:
        print(e)
        print(f'Failed on case {idnum}')
        return np.zeros((1, DIM, DIM, SLICES))    

In [9]:
os.mkdir('t1wce_train_npy_files_ori_same')

In [10]:
for i, case in enumerate(os.listdir('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train')):
    case_id = str(case).zfill(5)
    loaded = load_input(case_id, equalizer_iters=1)
    np.save(f't1wce_train_npy_files_ori_same/{case_id}.npy', loaded)
    print(i+1, end=' ')

1 hihihihihasdf;ladf
hihihihihasdf;ladf
2 3 hihihihihasdf;ladf
hihihihihasdf;ladf
4 5 6 hihihihihasdf;ladf
hihihihihasdf;ladf
7 8 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
9 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
10 hihihihihasdf;ladf
hihihihihasdf;ladf
11 hihihihihasdf;ladf
hihihihihasdf;ladf
12 13 hihihihihasdf;ladf
hihihihihasdf;ladf
14 15 hihihihihasdf;ladf
hihihihihasdf;ladf
16 17 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
18 hihihihihasdf;ladf
hihihihihasdf;ladf
19 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
20 21 hihihihihasdf;ladf
hihihihihasdf;ladf
22 hihihihihasdf;ladf
hihihihihasdf;ladf
23 24 hihihihihasdf;ladf
hihihihihasdf;ladf
25 hihihihihasdf;ladf
hihihihihasdf;ladf
26 27 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
28 29 hihihihiha



-501 -501 0
negative dimensions are not allowed
Failed on case 00709
44 hihihihihasdf;ladf
hihihihihasdf;ladf
45 46 47 48 49 hihihihihasdf;ladf
hihihihihasdf;ladf
50 hihihihihasdf;ladf
hihihihihasdf;ladf
51 hihihihihasdf;ladf
hihihihihasdf;ladf
52 53 hihihihihasdf;ladf
hihihihihasdf;ladf
54 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
55 hihihihihasdf;ladf
hihihihihasdf;ladf
56 hihihihihasdf;ladf
hihihihihasdf;ladf
57 hihihihihasdf;ladf
hihihihihasdf;ladf
58 59 60 hihihihihasdf;ladf
hihihihihasdf;ladf
61 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
62 63 hihihihihasdf;ladf
hihihihihasdf;ladf
64 65 66 hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
hihihihihasdf;ladf
67 hihihihihasdf;ladf
hihihihihasdf;ladf
68 69 70 71 hihihihihas

In [11]:
# case = np.load('./t1wce_train_npy_files_ori_same/00502.npy')

In [12]:
# for sl in range(case.shape[-1]):
#     clear_output(wait=True)
#     plt.axis('off')
#     plt.imshow(case[0, :, :, sl], cmap='gray')
#     plt.show()