In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os
import pandas as pd
base_path = "../rsna-miccai-brain-tumor-radiogenomic-classification"
dataframe = pd.read_csv(os.path.join(base_path, 'train_labels.csv'))

In [3]:
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from mgmt_conf.preprocessing.utils import normalize_intensity, clahe, random_noise, random_rotate, random_flip, load_complete_mri
from scipy import ndimage
from abc import abstractmethod
import nibabel as nib
from typing import Union, Tuple
from collections import OrderedDict

class _BaseMGMTDataset(Dataset):
    def __init__(self, base_path, dataframe, image_size, depth, augment) -> None:
        self.base_path = base_path
        self.dataframe = dataframe
        self.image_size = image_size
        self.depth = depth
        self.augment = augment

    def _get_patient_id(self, idx) -> str:
        return str(self.dataframe["BraTS21ID"].iloc[idx]).zfill(5)
    
    def _get_target(self, idx) -> int:
        return self.dataframe["MGMT_value"].iloc[idx]

    @staticmethod
    def _crop_on_nonzero_voxels(x):
        argwhere = np.argwhere(x)
        min_z, min_x, min_y = (
            np.min(argwhere[:, 0]),
            np.min(argwhere[:, 1]),
            np.min(argwhere[:, 2]),
        )
        max_z, max_x, max_y = (
            np.max(argwhere[:, 0]),
            np.max(argwhere[:, 1]),
            np.max(argwhere[:, 2]),
        )

        x = x[min_z:max_z, min_x:max_x, min_y:max_y]
        
        return x
        
    def _preprocess(self, x: np.ndarray) -> np.ndarray:
        [height, width, depth] = x.shape
        scale = [
            self.image_size[0] * 1.0 / height,
            self.image_size[1] * 1.0 / width,
            self.depth * 1 / depth,
        ]
        x = ndimage.zoom(x, scale, order=3)

        x = clahe(x)
        x = normalize_intensity(x)
        if self.augment:
            x = random_noise(x)

        x = np.expand_dims(x, 0)

        return x

    @abstractmethod
    def __getitem__(self, idx):
        raise NotImplementedError


    def __len__(self):
        return len(self.dataframe)



class _BaseMGMTNiftiDataset(_BaseMGMTDataset):
    conversion_dict = {"FLAIR": "flair", "T1wCE": "t1ce", "T1w": "t1", "T2w": "t2"}

    def __init__(self, base_path, dataframe, tumor_centered, image_size, depth, augment) -> None:
        super(_BaseMGMTNiftiDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth, augment=augment)
        self.tumor_centered = tumor_centered

    def _prepare_mask(self, patient_id):
        mask = nib.load(
            f"{self.base_path}/BraTS2021_{patient_id}/BraTS2021_{patient_id}_seg.nii.gz"
        ).get_fdata()
        argwhere = np.argwhere(mask)
        min_x, min_y, min_z = (
            np.min(argwhere[:, 0]),
            np.min(argwhere[:, 1]),
            np.min(argwhere[:, 2]),
        )
        max_x, max_y, max_z = (
            np.max(argwhere[:, 0]),
            np.max(argwhere[:, 1]),
            np.max(argwhere[:, 2]),
        )
        self.mask = (min_x, min_y, min_z, max_x, max_y, max_z)

    def _prepare_nifti_volume(self, patient_id, modality):
        x = nib.load(
            f"{self.base_path}/BraTS2021_{patient_id}/BraTS2021_{patient_id}_{self.conversion_dict[modality]}.nii.gz"
        ).get_fdata()

        
        if self.tumor_centered:
            self._prepare_mask(patient_id)
            x = x[
            self.mask[0] : self.mask[3],
            self.mask[1] : self.mask[4],
            self.mask[2] : self.mask[5],
            ]


        else:
            x = super()._crop_on_nonzero_voxels(x)
            
        if self.augment:
            x = random_rotate(x)
            p = np.random.rand(1)[0]
            if p > 0.5:
                x = random_flip(x)
        
        return x
    
class UnimodalMGMTNiftiDataset(_BaseMGMTNiftiDataset):
    def __init__(self, base_path, dataframe, modality, tumor_centered, image_size, depth, augment) -> None:
        super(UnimodalMGMTNiftiDataset, self).__init__(base_path=base_path, dataframe=dataframe, tumor_centered=tumor_centered, image_size=image_size, depth=depth, augment=augment)
        self.modality = modality

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        x = super()._prepare_nifti_volume(patient_id, self.modality)
        x = super()._preprocess(x)

        target = super()._get_target(idx)

        return torch.tensor(x).float(), torch.tensor(target)

class MultimodalMGMTNiftiDataset(_BaseMGMTNiftiDataset):
    def __init__(self, base_path, dataframe, modalities, fusion, tumor_centered, image_size, depth, augment) -> None:
        super(MultimodalMGMTNiftiDataset, self).__init__(base_path=base_path, dataframe=dataframe, tumor_centered=tumor_centered, image_size=image_size, depth=depth, augment=augment)
        self.modalities = modalities
        self.fusion = fusion

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        target = super()._get_target(idx)

        item = OrderedDict()
        if self.fusion == "early":
            seed = np.random.randint(0, 10000) + idx
        for modality in self.modalities:
            if self.fusion == "early":
                np.random.seed(seed)
            x = super()._prepare_nifti_volume(patient_id, modality)
            x = super()._preprocess(x)
            item[modality] = torch.tensor(x).float()

        if self.fusion == "early":
            return torch.cat(tuple(item.values()), 0), torch.tensor(target)

        else:
            return item, torch.tensor(target)

In [104]:
class _BaseMGMTDicomDataset(_BaseMGMTDataset):
    def __init__(self, base_path, dataframe, image_size, depth, split, augment) -> None:
        super(_BaseMGMTDicomDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth, augment=augment)
        self.split = split

    def _prepare_dicom_volume(self, idx, modality):
        patient_id, path_to_split = self._get_patient_id_and_path_to_split(idx)
        x = load_complete_mri(path_to_split, patient_id, modality, target_size=None, voi_lut=False, rotate=0, scale=False)
        x = super()._crop_on_nonzero_voxels(x)

        if self.augment:
            x = random_rotate(x)
            p = np.random.rand(1)[0]
            if p > 0.5:
                x = random_flip(x)
        return x
        

    def _get_patient_id_and_path_to_split(self, idx) -> Tuple[str, str]:
        return str(self.dataframe["BraTS21ID"].iloc[idx]).zfill(5), os.path.join(
            self.base_path, self.split
        )

class UnimodalMGMTDicomDataset(_BaseMGMTDicomDataset):
    def __init__(self, base_path, dataframe, modality, image_size, depth, split, augment) -> None:
        super(UnimodalMGMTDicomDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth, split=split, augment=augment)
        self.modality = modality

    def __getitem__(self, idx):
        x = super()._prepare_dicom_volume(idx, self.modality)
        x = super()._preprocess(x)

        target = super()._get_target(idx)

        return torch.tensor(x).float(), torch.tensor(target)

class MultimodalMGMTDicomDataset(_BaseMGMTDicomDataset):
    def __init__(self, base_path, dataframe, modalities, image_size, depth, split, augment) -> None:
        super(MultimodalMGMTDicomDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth, split=split, augment=augment)
        self.modalities = modalities
    
    def __getitem__(self, idx):
        item = OrderedDict()
        target = super()._get_target(idx)
        for modality in self.modalities:
            x = super()._prepare_dicom_volume(idx, modality)
            x = super()._preprocess(x)
            item[modality] = torch.tensor(x).float()

        else:
            return item, torch.tensor(target)


In [105]:
data = MultimodalMGMTDicomDataset(base_path=base_path, dataframe=dataframe, modalities=("FLAIR", "T1wCE"), image_size=(180, 180), depth=64, split='train', augment=False)
img, target = data[2]

In [106]:
img["FLAIR"].size()

torch.Size([1, 180, 180, 64])

In [65]:
fig = plt.figure(figsize=(12, 8))
ims = []
i = 1
data = np.concatenate([img["FLAIR"].squeeze().numpy(), img["T1wCE"].squeeze().numpy()], 1)
for snapshot in range(0, data.shape[-1]): 
    im = plt.imshow(data[:, :, snapshot], animated=True, cmap='gray')
    plt.axis("off")
    ims.append([im])


ani = animation.ArtistAnimation(fig, ims, interval=100, blit=False, repeat_delay=1000)
plt.close()

HTML(ani.to_jshtml())

In [66]:
torch.unique(img["FLAIR"])

tensor([-2.3093, -2.2444, -2.2100,  ...,  4.0669,  4.0688,  4.0898])

ONCOPOLE DATA

In [4]:
class _BaseMGMTPrivateDataset(_BaseMGMTDataset):
    conversion_dict = {
        "T1 gado": "T1wCE",
        "T1 Gado": "T1wCE",
        "T1 Gado post op": "T1wCE",
        "T1 Gado postop": "T1wCE",
        "T1 3D gado": "T1wCE",
        "3D T1 Gado": "T1wCE",
        "Ax 3D T1 Gado": "T1wCE",
        "3D T1 gado": "T1wCE",
        "Ax T1 gado": "T1wCE",
        "T1 3D Gado": "T1wCE",
        "3D t1 gado": "T1wCE",
        "T2 flair": "FLAIR",
        "Ax T2 flair": "FLAIR",
        "Ax T2 Flair": "FLAIR",
        "Ax Flair": "FLAIR",
        "Flair Ax": "FLAIR",
        "T2 Flair Ax": "FLAIR",
        "Flair Long": "FLAIR",
        "Flair long": "FLAIR",
        "3D flair": "FLAIR",
        "3D Flair": "FLAIR",
        "Flair": "FLAIR",
        "T2 Flair": "FLAIR",
        "Flair post op": "FLAIR",
        "Flair postop": "FLAIR"
    }

    def __init__(self, base_path, dataframe, image_size, depth) -> None:
        super(_BaseMGMTPrivateDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth, augment=False)

    def _get_patient_id(self, idx) -> str:
        return self.dataframe.iloc[idx].NNNPP
    
    def _get_target(self, idx) -> int:
        return self.dataframe.iloc[idx]["% Méthylé"]

    def _prepare_private_volume(self, patient_id, modality):

        if modality == "FLAIR":
            n = "FL"
        else:
            n = "T1CE"

        x = nib.load(
            f"{self.base_path}/{patient_id}/{n}_to_SRI_brain.nii.gz"
        ).get_fdata()   

        x = super()._crop_on_nonzero_voxels(x)
        return x
        
class UnimodalMGMTPrivateDataset(_BaseMGMTPrivateDataset):
    def __init__(self, base_path, dataframe, modality, image_size, depth) -> None:
        super(UnimodalMGMTPrivateDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth)
        self.modality = modality

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        print(patient_id)
        x = super()._prepare_private_volume(patient_id, self.modality)
        x = super()._preprocess(x)

        target = super()._get_target(idx)

        return torch.tensor(x).float(), torch.tensor(target)



In [19]:
class UnimodalTumorCenteredMGMTPrivateDataset(UnimodalMGMTPrivateDataset):
    def __init__(self, base_path, dataframe, modality):
        super(UnimodalTumorCenteredMGMTPrivateDataset, self).__init__(base_path=base_path, dataframe=dataframe, modality=modality, image_size=None, depth=None)

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        target = super()._get_target(idx)
        return torch.load(os.path.join(self.base_path, patient_id, f"{self.modality.lower()}.pt")), torch.tensor(target)

class MultimodalMGMTPrivateDataset(_BaseMGMTPrivateDataset):
    def __init__(self, base_path, dataframe, modalities, fusion, image_size, depth) -> None:
        super(MultimodalMGMTPrivateDataset, self).__init__(base_path=base_path, dataframe=dataframe, image_size=image_size, depth=depth)
        self.modalities = modalities
        self.fusion = fusion

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        target = super()._get_target(idx)

        item = OrderedDict()

        for modality in self.modalities:
            x = super()._prepare_private_volume(patient_id, modality)
            x = super()._preprocess(x)
            item[modality] = torch.tensor(x).float()

        if self.fusion == "early":
            return torch.cat(tuple(item.values()), 0), torch.tensor(target)

        else:
            return item, torch.tensor(target)

class MultimodalTumorCenteredMGMTPrivateDataset(MultimodalMGMTPrivateDataset):
    def __init__(self, base_path, dataframe, modalities):
        super(MultimodalTumorCenteredMGMTPrivateDataset, self).__init__(base_path=base_path, dataframe=dataframe, modality=modalities, image_size=None, depth=None)

    def __getitem__(self, idx):
        patient_id = super()._get_patient_id(idx)
        target = super()._get_target(idx)
        x = list()
        for modality in self.modalities:
            x.append(torch.load(os.path.join(self.base_path, patient_id, f"{modality.lower()}.pt")))

        return torch.stack(x), torch.tensor(target)

In [20]:
def create_df(folder, only_positive=False):

    df_path_methylated = os.path.join(folder, "listing_all_but_non_methylated.xlsx")
    df_m = pd.read_excel(df_path_methylated, sheet_name=2, skiprows=[1], header=1)
    df_m = df_m[["NNNPP", "% Méthylé"]]
    df_m = df_m.drop(0)
    df_m = df_m.drop(28)

    if only_positive:
        return df_m
    
    
    df_path_unmethylated = os.path.join(folder, "listing_non_methylated.xlsx")
    df_u = pd.read_excel(df_path_unmethylated, skiprows=0, header=2)
    df_u = df_u[["NNNPP"]]
    df_u = df_u.drop(0)
    df_u["% Méthylé"] = np.zeros(len(df_u))
    df = pd.concat([df_m, df_u], ignore_index=True)
    return df

In [21]:
path_oncopole = "../onco_data/"
df = create_df(path_oncopole)

In [39]:
#data = UnimodalMGMTPrivateDataset(os.path.join(path_oncopole, "IRMs"), df, ("T1wCE", "FLAIR"), 'early', (180, 180), 64)
dataset = UnimodalTumorCenteredMGMTPrivateDataset(os.path.join(path_oncopole, "IRMs"), df, "T1wCE")

In [45]:
img, t = dataset[56]

In [46]:
torch.unique(img)

tensor([-1.2199, -1.2153, -1.2069,  ...,  5.8121,  5.8628,  5.8889])

In [47]:
img.shape

torch.Size([77, 40, 36])

In [48]:
fig = plt.figure(figsize=(12, 8))
ims = []
i = 1
#data = np.concatenate([img[0].numpy(), img[1].numpy()], 1)
data = img.numpy()
for snapshot in range(0, data.shape[-1]): 
    im = plt.imshow(data[:, :, snapshot], animated=True, cmap='gray')
    plt.axis("off")
    ims.append([im])


ani = animation.ArtistAnimation(fig, ims, interval=100, blit=False, repeat_delay=1000)
plt.close()

HTML(ani.to_jshtml())

In [18]:
torch.unique(img)

tensor([-2.2947, -2.2130, -2.1730,  ...,  5.9426,  6.2805,  6.7016])

# SPLITTING

In [1]:
import pandas as pd

In [6]:
dataframe = pd.read_csv(os.path.join("../rsna-miccai-brain-tumor-radiogenomic-classification", 'train_labels.csv'))
dataframe_train, dataframe_val = clean_segmentation_dataframe(dataframe, random_state=42)

In [5]:
from typing import Optional, List
from sklearn.model_selection import train_test_split


def clean_segmentation_dataframe(
    dataframe: pd.DataFrame, random_state: int = 42, init: bool = False
):
    """
    Clean dataset with segmentation mask to be respectful to original implementation.

    Args:
        dataframe (pd.DataFrame) : the original dataframe from the Kaggle competition.
        init (bool) : Whether or not to remove all files that are not intertasks compatible. Must be done once after downloading the data.. Default to `False`.

    Returns:
        pd.DataFrame : the train dataframe
        pd.DataFrame : the val dataframe
    """
    dataframe = clean_dataset(dataframe)
    #cleans = [308, 197, 169, 794, 998, 564, 408, 245]
    #dataframe = clean_dataset(dataframe, cleans)
    dataframe_train, dataframe_val = train_test_split(
        dataframe,
        test_size=0.3,
        stratify=dataframe["MGMT_value"],
        random_state=random_state,
    )

    clean_train = [794, 998, 564, 408, 245]
    clean_val = [308, 197, 169]
    #cleans = [308, 197, 169, 794, 998, 564, 408, 245]
    dataframe_train = clean_dataset(dataframe_train, clean_train)
    dataframe_val = clean_dataset(dataframe_val, clean_val)

    if init:
        filenames = os.listdir(
            "../rsna-miccai-brain-tumor-radiogenomic-classification/archive"
        )
        filenames = [filename.split("_")[-1] for filename in filenames]
        filenames = sorted(filenames)
        current = list(dataframe["BraTS21ID"].values)
        current = [str(o).zfill(5) for o in current]

        list_to_delete = []
        for o in filenames:
            if o not in current:
                list_to_delete.append(o)

        for el in list_to_delete[:-11]:
            os.system(
                "rm -r"
                f" ../rsna-miccai-brain-tumor-radiogenomic-classification/archive/BraTS2021_{el}"
            )

    return dataframe_train, dataframe_val


def clean_dataset(
    dataframe: pd.DataFrame, ids: List[int] = [109, 123, 709]
) -> pd.DataFrame:
    """
    Returns the cleaned dataframe according to the given instruction https://www.kaggle.com/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/data .

    Args:
        dataframe (pd.DataFrame) : the initial dataframe.
        ids (List[int]) : index to be removed. Default to `[109, 123, 709]`.

    Returns:
        pd.DataFrame : the cleaned dataframe.
    """
    return dataframe[~dataframe["BraTS21ID"].isin(ids)]



In [8]:
dataframe_bis = pd.read_csv(os.path.join("../rsna-miccai-brain-tumor-radiogenomic-classification", 'train_labels.csv'))
dataframe_bis = clean_dataset(dataframe_bis)
dataframe_train_bis, dataframe_val_bis = train_test_split(dataframe, test_size=0.3, stratify=dataframe["MGMT_value"], random_state=42)

In [11]:
idx_val = dataframe_val.BraTS21ID.values
idx_val_bis = dataframe_val_bis.BraTS21ID.values

In [12]:
idx_val, idx_val_bis

(array([ 698,  478,  723,  619,  799,  221,  607,  403,  275,  517,  697,
         568,  128,  611,  386,  613,  524,   30,   24,  736,  338,  594,
         165,   11,  454,  122,  494,  836,  316,  160,  567,  659,  172,
         472,  688,  481,  751,  299,  146,  550,  663, 1010,  810,  139,
         667,   48,  383,  628,   60,  344,  309,  241,   78,  559,  387,
         588,   46,  513,  650,  758,  807, 1001,  157,  703,  649,  378,
         219,  405,  737,  413,  623,  510,  201,  377,  705,  718,  259,
          96,  645,  523,  571,   98,   53,  615,  565,  488,  456,  192,
         778,  262,  267,  652,  294,  332,  554, 1000,  149,  445,  446,
         414,  556,  704,  838,   77,  604,  764, 1002,   90, 1008,  621,
         511,  496,  808,    9,  558,  756,  676,  775,  759,  138,  470,
          72,    6,   66,  162,  537, 1009,   63,   94,  525,   58,  493,
         298,  753,  140,  247,  819,  382,  260,   18,  418,  801,  626,
         191,  730,  464,  767,   49, 

In [13]:
def prepare_dataframes(csv_path: str, random_state: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns both validation and training dataframe.

    Args:
        csv_path (str) : path leading to the csv file.
        random_state (int) : to ensure reproducibility.
    
    Returns:
        pd.DataFrame : the train dataframe.
        pd.DataFrame : the val dataframe.
    """
    dataframe = pd.read_csv(csv_path)
    dataframe = clean_dataset(dataframe, [109, 123, 709, 794, 998, 564, 408, 245, 308, 197, 169])
    dataframe_train, dataframe_val = train_test_split(
        dataframe,
        test_size=0.3,
        stratify=dataframe["MGMT_value"],
        random_state=random_state,
    )

    return dataframe_train, dataframe_val


NameError: name 'Tuple' is not defined

In [14]:
import os

In [16]:
if not os.path.isdir("regular"):
    os.makedirs("regular")