In [1]:
import os
import sys
import re
import random
import pandas as pd
import numpy as np
import cv2

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data
import time

import sklearn
from sklearn import model_selection as sk_model_selection
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from skimage.transform import resize

import nibabel as nib
import matplotlib.pyplot as plt

sys.path.append('../')
from unet_down import UNet

In [2]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        
def get_train_valid_split(label_path):
    train_df = pd.read_csv(label_path,dtype = {'BraTS21ID':'str','MGMT_value':'int'})
    index_name = train_df[(train_df['BraTS21ID'] == '00109') | (train_df['BraTS21ID'] == '00123') | (train_df['BraTS21ID'] == '00709')].index
    train_df = train_df.drop(index_name).reset_index(drop=True)

    X = train_df['BraTS21ID'].values
    y = train_df['MGMT_value'].values
    
    kfold =  StratifiedKFold(n_splits=5,shuffle = True,random_state = SEED)
    return X,y,list(kfold.split(X,y))


def load_raw_voxel(patient_id,mri_type):
    # Normalize voxel volume to 0~255
    voxels = nib.load(f'{RAW_DATA_PATH}/BraTS2021_{patient_id}/BraTS2021_{patient_id}_{mri_type}.nii.gz').get_fdata().astype('float')
    _min = voxels.min()
    _max = voxels.max()
    new_voxels = (voxels - _min) / (_max-_min) * 255.0
    return new_voxels

def load_mask(patient_id):
    return nib.load(f'{MASKS_PATH}/BraTS2021_{patient_id}.nii.gz').get_fdata().astype('float')

def non_0_voxel_mask(voxel,mask):
    length = mask.shape[2]
    start_id = 0
    end_id = length-1

    # From begining to find start index
    for i in range(length):
        if np.max(mask[:,:,i]) != 0:
            start_id = i
            break

    # From final to find end index
    for i in range(length-1,-1,-1):
        if np.max(mask[:,:,i]) != 0:
            end_id = i
            break
    non_0_indexs = slice(start_id,end_id+1)
    
    return voxel[:,:,non_0_indexs],mask[:,:,non_0_indexs]

def find_largest_countours(contours):
    max_cnt = max(contours, key=lambda cnt: cv2.contourArea(cnt))
    return max_cnt


def get_area_over_image_ratio(image, mask):
    _, image_thresh = cv2.threshold(image,1,255,cv2.THRESH_BINARY)
    
    # image_contours, _ = cv2.findContours(image_thresh,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    image_contours, _ = cv2.findContours(image=image_thresh, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
    if not image_contours:
        return 0
    max_image_cnt = find_largest_countours(image_contours)
    
    _, mask_thresh = cv2.threshold(mask,0.5,1,cv2.THRESH_BINARY)
    mask_contours, _ = cv2.findContours(image=mask_thresh, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
    count_n_mask_contours = len(mask_contours)
    if(count_n_mask_contours == 0):
        return 0
    max_mask_cnt = find_largest_countours(mask_contours)
    
    area_mask_over_image_ratio = cv2.contourArea(max_mask_cnt) / cv2.contourArea(max_image_cnt)
    return area_mask_over_image_ratio 

def get_percent_volume(raw_mask):
    mask_per = raw_mask.copy()
    
    
    unique, counts = np.unique(raw_mask, return_counts=True)
    unique = unique[unique!=0].astype('int')
    
    mask_count = dict(zip(unique, counts))
    total_mask = 0
    
    for mask in unique:
        total_mask += mask_count[mask]
    
    for mask in unique:
        mask_per[mask_per == mask] =  (mask_count[mask]/total_mask)*100
    
    return mask_per

In [11]:
# ---------------------- TC 90 90 90 ---------------------------
def construct_target_volume(scan_id,mri_type,scale_size=260):
    voxel_TC = nib.load(f'{TC_PATH}/BraTS2021_{scan_id}/BraTS2021_{scan_id}_{mri_type}.nii.gz').get_fdata().astype('float')
    return voxel_TC


class Dataset(torch_data.Dataset):
    def __init__(self, ids, targets, mri_type, if_pred = False):
        self.ids = ids
        self.targets = targets
        self.mri_type = mri_type
        self.if_pred = if_pred
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        scan_id = self.ids[index]
        data = construct_target_volume(scan_id,self.mri_type,scale_size=SCALE_SIZE)

        if self.if_pred:
            return {"X": torch.tensor(data).float().unsqueeze(0), "id":scan_id}
        else:
            y = torch.tensor(self.targets[index], dtype = torch.long)
            return {"X": torch.tensor(data).float().unsqueeze(0), "y": y}

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def non_0_voxel(voxel):
    length = voxel.shape[2]
    start_id = 0
    end_id = length-1

    # From begining to find start index
    for i in range(length):
        if np.max(voxel[:,:,i]) != 0:
            start_id = i
            break

    # From final to find end index
    for i in range(length-1,-1,-1):
        if np.max(voxel[:,:,i]) != 0:
            end_id = i
            break
    non_0_indexs = slice(start_id,end_id+1)
    
    return non_0_indexs 

def predict(model_path,X_valid, y_valid,fold,mri_type):
    start = time.time()
    print(f"Predict: {mri_type} Fold:{fold}")

    # Load predict sample data
    data_retriever = Dataset(
        X_valid, 
        y_valid,
        mri_type,
        if_pred = True
    )
    
    data_loader = torch_data.DataLoader(
        data_retriever, 
        batch_size=2,
        shuffle=False,
        num_workers=4
    )
    model = UNet(in_channels=1,
                 out_channels=2,
                 n_blocks=4,
                 input_shape = INPUT_SIZE,
                 start_filters=32,
                 activation='relu',
                 normalization='batch',
                 conv_mode='same',
                 dim=3,
                 hidden_channels=2048)
    
    
    checkpoint = torch.load(f'{model_path}/{mri_type}/{mri_type}-fold{fold}-best.pth')
    model.load_state_dict(checkpoint["model_state_dict"])
    
    model.to(device)
    model.eval()
    
    ids_all = []
    preds_all = []
    
    # Predict traninig samples
    for step, batch in enumerate(data_loader, 1):
        print(f"{step}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            X = batch["X"].to(device)
            # scan_id = batch["id"].to(device)
            outputs = model(X)
            
            preds = outputs.softmax(dim=1)[:,1]
            ids_all.extend(batch["id"])
            preds_all.extend(preds.tolist())
    
    del model
    print(f'Inference completed in {int(time.time()-start)}s')
    
    return preds_all,ids_all

In [21]:
mri_types_capital = ['flair','t1','t1ce','t2']

WB_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_WB_90_90_90'
WT_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_WT_90_90_90'
TC_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_TC_90_90_90'
ET_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_ET_90_90_90'

model_path = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/jupyter/Code/U-Net/Model/ET_90_90_90'
RAW_DATA_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_Prep_Segmentation'
MASKS_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_Masks_Train'
LABEL_PATH = './train_labels.csv'
SCALE_SIZE = 90
SEED = 42

In [22]:
INPUT_SIZE = (90,90,90)
set_seed(SEED)

X,y,SPLIT = get_train_valid_split(LABEL_PATH)

In [23]:
def get_results(folds,mri_types):
    for mri_type in mri_types:
        for fold in folds:
            train_idx,valid_idx = SPLIT[fold]
            X_valid,y_valid =X[valid_idx],y[valid_idx]
            preds,ids = predict(model_path,X_valid, y_valid,fold,mri_type)
    
            result_metrics = {'BraTS21ID':ids,'MGMT_value':preds}
            result_df = pd.DataFrame(result_metrics)

            result_df.to_csv(f'./predictions/ET_90_90_90/{mri_type}/{mri_type}_fold{fold}.csv',index=False)

In [24]:
folds = range(5)
mri_types_capital = ['flair','t1','t1ce','t2']
df = get_results(folds,mri_types_capital)

Predict: flair Fold:0
Inference completed in 7s
Predict: flair Fold:1
Inference completed in 7s
Predict: flair Fold:2
Inference completed in 7s
Predict: flair Fold:3
Inference completed in 7s
Predict: flair Fold:4
Inference completed in 7s
Predict: t1 Fold:0
Inference completed in 7s
Predict: t1 Fold:1
Inference completed in 7s
Predict: t1 Fold:2
Inference completed in 7s
Predict: t1 Fold:3
Inference completed in 7s
Predict: t1 Fold:4
Inference completed in 7s
Predict: t1ce Fold:0
Inference completed in 7s
Predict: t1ce Fold:1
Inference completed in 7s
Predict: t1ce Fold:2
Inference completed in 7s
Predict: t1ce Fold:3
Inference completed in 7s
Predict: t1ce Fold:4
Inference completed in 7s
Predict: t2 Fold:0
Inference completed in 7s
Predict: t2 Fold:1
Inference completed in 7s
Predict: t2 Fold:2
Inference completed in 7s
Predict: t2 Fold:3
Inference completed in 7s
Predict: t2 Fold:4
Inference completed in 5s
