<a href="https://www.kaggle.com/code/yannicksteph/cnn-cv-brain-prediction?scriptVersionId=143948558" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# | CNN | CV | Brain | Prediction |
## Convolutional Neural Networks (CNN) with Computer Vision (CV) for Brain Prediction
# <b>1 <span style='color:#78D118'>|</span> Introduction</b>

Glioblastoma, the most common and aggressive form of brain cancer in adults, poses significant challenges in diagnosis and treatment. The presence of MGMT promoter methylation in the tumor has been identified as a crucial prognostic factor and an indicator of chemotherapy responsiveness. However, the current genetic analysis of brain cancer requires invasive procedures and time-consuming processes.

The objective of this project is to improve the diagnosis and treatment strategies for glioblastoma patients, minimizing the need for invasive procedures and streamlining the genetic analysis process. By leveraging radiogenomics, the aim is to develop a non-invasive method to predict the genetic profile of the tumor solely through imaging.

To address these issues, the Radiological Society of North America (RSNA) and the Medical Image Computing and Computer Assisted Intervention Society (MICCAI Society) have collaborated on a competition focusing on glioblastoma diagnosis and treatment planning. The competition involves using MRI scans to develop a model that can accurately predict the genetic subtype of glioblastoma by detecting the presence of MGMT promoter methylation.

Successful outcomes from this competition will contribute to less invasive diagnostic procedures and more tailored treatment approaches for brain cancer patients. This abstract provides an overview of the project's objectives, the competition's context, and the potential impact on the management and survival rates of individuals affected by glioblastoma.

## Dataset Overview

The Radiological Society of North America (RSNA®) is a non-profit organization representing 31 radiologic subspecialties from 145 countries worldwide. RSNA promotes excellence in patient care and healthcare delivery through education, research, and technological innovation.

RSNA provides high-quality educational resources, publishes five top peer-reviewed journals, hosts the world's largest radiology conference, and is dedicated to shaping the future of the profession through the RSNA Research & Education (R&E) Foundation, which has funded $66 million in grants since its establishment. Additionally, RSNA actively supports and facilitates research in medical imaging artificial intelligence (AI) by sponsoring ongoing AI challenge competitions.

The Medical Image Computing and Computer Assisted Intervention Society (MICCAI Society) is committed to advancing research, education, and practice in the field of medical image computing, computer-assisted interventions, biomedical imaging, and medical robotics. The society achieves this objective by organizing high-quality international conferences, workshops, tutorials, and publications that promote the exchange and dissemination of advanced knowledge, expertise, and experiences produced by leading institutions, scientists, physicians, and educators worldwide.

A complete list of acknowledgments can be found on this page.

[RSNA-MICCAI Brain Tumor Radiogenomic Classification](https://www.kaggle.com/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/data?select=train_labels.csv)

## Research Efforts
During the course of this project, we conducted extensive research to explore various methodologies for predicting the genetic subtype of glioblastoma based on MGMT promoter methylation. One of the methods we investigated was the use of the Unit-net architecture, a convolutional neural network designed specifically for medical image analysis.

[| UNIT-NET | CV | BRAIN | Classification |](https://www.kaggle.com/code/yannicksteph/rsna-miccai-brain-tumor-classification)

However, despite our efforts, we did not achieve conclusive results with the Unit-net model. The complexity and variability of glioblastoma tumors, as well as the limited availability of labeled data, presented significant challenges in training an effective Unit-net model for this task. As a result, we decided to pursue alternative approaches that showed more promise in accurately predicting the genetic subtype of glioblastoma.

## Objectives
- Predict the genetic subtype of glioblastoma by detecting the presence of MGMT promoter methylation value between 0 to 1.

## References and Research Sources
- MGMT
    - [Is it Possible to Predict MGMT Promoter Methylation from Brain Tumor MRI Scans using Deep Learning Models](https://arxiv.org/abs/2201.06086)
    - [Automatic Prediction of MGMT Status in Glioblastoma via Deep Learning-Based MR Image Analysis](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7530505/)
    - [MRI-Based Deep-Learning Method for Determining Glioma MGMT Promoter Methylation Status](https://www.ajnr.org/content/42/5/845.abstract)
    - [Improving MGMT methylation status prediction of glioblastoma through optimizing radiomics features using genetic algorithm-based machine learning approach](https://www.nature.com/articles/s41598-022-17707-w)
- Data augmentation
    - [Data augmentation for deep learning based accelerated MRI reconstruction with limited data](https://arxiv.org/abs/2106.14947)
    - [Data augmentation: how to overcome small radiology datasets](https://www.quantib.com/blog/image-augmentation-how-to-overcome-small-radiology-datasets?hs_amp=true)



## Implementation
To achieve the aforementioned objectives, we will follow these steps:

- **Training Setup Explained**
- **Setup**
- **Data Retrieval**
- **Data Preparation** 
- **Model Creation** 
- **Model Training** 
- **Model Evaluation** 

# <b>2 <span style='color:#78D118'>|</span> Training Setup Explained</b>

## Cross-Validation

- **Splitting Method**
    - **Type:** Stratified K-Fold Cross-Validation
        - **Explanation:** Employed to ensure a balanced class distribution within each fold, enhancing model assessment and generalization.
        - **Configuration:**
            - **Number of Folds:** Typically set to 5.
            - **Validation Fold:** Specifically, the first fold is designated as the validation fold in this case.

## Processing

- **Width, Height, Channels:**
    - Images are initially resized to a uniform dimension of 128x128 pixels. Furthermore, IRM converted to image give a grayscale format, resulting in a single channel for each image. 
    - It's important to note that due to Kaggle limitations, we cannot further increase the image size to 224x224 pixels.
- **Sequence:**
    - For each input we give a sequence of 32 sequential images, which are processed in batches. 
    - This sequence captures the temporal dimension of the MRI scans and allows the model to analyze a series of images to make predictions.
- **Scale:**
    - During preprocessing, images are scaled down to 85% of their original size. 
    - This scaling operation is applied at to both the test and validation datasets. 
    - For the raining dataset we apply augmentation which we will see later.
    - The purpose of this scaling is to remove any empty space around the brain border, ensuring that the model focuses primarily on the brain region itself.
- **Central Focus:**
    - In our dataset processing pipeline, we place a strong emphasis on central regions where is the ROI.
    - Specifically, we prioritize the center image and include 16 images before and 16 images after the central image in each sequence. 
    - This approach ensures that the most informative parts of the MRI scans, corresponding to the central brain area and the ROI, receive the most attention during model training.

## Augmentation

- **Data Augmentation**
    - We expand our original dataset by a significant 400%, effectively quadrupling the available training data. This is essential as the original dataset consists of only approximately 500 samples, with the validation set containing just 100 samples.
- **Crop Augmentation**
    - Images are randomly cropped while preserving between 85% and 95% of their original size. This enhances the model's ability to recognize different brain regions.
- **Rotation Augmentation**
    - Random rotations ranging from 4 to 12 degrees aid the model in becoming orientation-invariant, allowing it to handle variations in the orientation of brain scans effectively.
- **Translation Augmentation**
    - Random translations are applied both horizontally (2 to 6 pixels) and vertically (0 to 2 pixels), simulating minor positional variations commonly encountered in medical imaging.
- **Blur Augmentation**
    - A random blur effect is introduced with a 10% to 15% probability, mimicking real-world imaging imperfections and improving the model's generalization.
- **Contrast and Brightness Augmentation**
    - Image contrast is dynamically scaled between 0.8 and 1.2, while brightness is adjusted between -2 and 2. This adaptation accommodates varying lighting conditions, making the model more robust to different lighting scenarios.

## Train

- **Batch Size**
    - Training is performed in batches of 8 images at a time, a limitation imposed by Kaggle.
- **Epochs**
    - The model undergoes 32 training epochs.
- **Optimizer**
    - **Type:** Stochastic Gradient Descent (SGD)
        - **Explanation:**  Stochastic Gradient Descent (SGD) iteratively adjusts model weights using gradients computed from training data to update to minimize the loss.
        - **Configuration:**
            - **Learning Rate:** A learning rate of 0.001 strikes a balance between convergence speed and stability.
- **Loss Function**
    - **Type:** Binary Cross-Entropy
        - **Explanation:**  Binary Cross-Entropy is used for training, specifically suited for binary classification tasks such as this one.
- **Compilation Metrics**
    - **Type:** Area Under the ROC Curve (AUC)
        - **Explanation:** AUC is employed as a metric, measuring the model's ability to discriminate between positive and negative classes.

### Model Architecture

- **DeepScanModel**
    - The Model is a 3D Convolutional Neural Network (CNN) designed specifically for processing sequences of medical images. 
    - It takes four channels corresponding to four image sequences, concatenates them, and processes them to perform binary classification.


# <b>2 <span style='color:#78D118'>|</span> Setup</b>

## <b>2.1 <span style='color:#78D118'>|</span> Imports</b>

In [None]:
!pip install -q pydicom
!pip install -q git+https://github.com/YanSteph/SKit.git

In [None]:
import numpy as np
import pandas as pd
import os
import glob
import numpy as np
import random

# Dicom
import pydicom

# Enum
from enum import Enum

# CV
import cv2

# Tensorflow
import tensorflow as tf
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import AUC
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import (
    Model,
    load_model
)
from tensorflow.keras.callbacks import (
    Callback, 
    ModelCheckpoint, 
    EarlyStopping
)
from tensorflow.keras.layers import (
    Input,
    Conv3D,
    BatchNormalization,
    MaxPooling3D,
    MaxPool3D,
    Flatten,
    Dense,
    Dropout,
    Resizing,
    Rescaling,
    RandomFlip,
    RandomRotation,
    concatenate,
    GlobalAveragePooling3D,
    Reshape,
    LeakyReLU,
    ReLU
)

# Keras
import keras
from keras.utils.vis_utils import plot_model


# Skit
from skit.Debug import Debug
from skit.InternalDebug import InternalDebug
from skit.image import average_image_size
from skit.dataset import stratifiedTrainValidSplit
from skit.Summarizable import Summarizable
from skit.tensorflow import configure_gpu_memory
from skit.dicom import (
    DICOMLoader, 
    ImageFormat
)
from skit.utils import (
    ls, 
    mkdir, 
    count_files
)
from skit.show import (
    show_text, 
    show_images, 
    show_donut, 
    show_history, 
    show_best_history,
    show_confusion_matrix, 
    show_donut,
    show_histogram
)

## <b>2.2 <span style='color:#78D118'>|</span> Constants</b>

In [None]:
class MRIType(Enum):
    FLAIR = "FLAIR"
    T1w = "T1w"
    T1wCE = "T1wCE"
    T2w = "T2w"
    
class DatasetType(Enum):
    TRAIN = "train"
    VALIDATION = "validation"
    TEST = "test"

In [None]:
# Global
# ----
VERSION         = "V1"
VERBOSITY       = 2
SEED            = 123
SCAN_CATEGORIES = [mri_type.value for mri_type in MRIType]
EXCLUDED_IDS    = [109, 123, 709]

# Paths
# ----
RUN_DIR = './run'
INPUT_PATH = "../input/rsna-miccai-brain-tumor-radiogenomic-classification"

# Train
TRAIN_DATASET_PATH = INPUT_PATH + "/train"
TRAIN_DATASET_DF_DIR = INPUT_PATH + "/train_labels.csv"

# Test
TEST_DATASET_PATH = INPUT_PATH + "/test"
TEST_DATASET_DF_DIR = INPUT_PATH + "/sample_submission.csv"

# Submission
SUBMISSION_DATASET_DF_DIR = '/kaggle/working/submission.csv'

# TF Callback Paths
# ----
LOGS_PATH = f'{RUN_DIR}/logs'
BEST_MODEL_PATH = f'{RUN_DIR}/models'
BEST_MODEL_H5_DIR = f'{BEST_MODEL_PATH}/model_{VERSION}.h5'

# Fold
# ----
NUM_SPLIT_FOLDS = 5
SELECTED_VALIDATION_FOLD = 1

# Dicom Loader
# ----
MAX_THREADS_DICOM_LOADER = 8

# Image
# ----
IMG_WIDTH_SIZE, IMG_HEIGHT_SIZE, IMG_CHAN = (128, 128, 1)
IMG_SIZE = (IMG_WIDTH_SIZE, IMG_HEIGHT_SIZE)

IMG_SEQ                  = 32
IMG_SCALE                = .95
IMG_ROTATE               = 0 
IMG_ENABLE_CENTRAL_FOCUS = True

SHUFFLE    = True

# Augmentation
# ----
AUGMENTATION_FRACTION               = 4
AUGMENTATION_CROP_LIMITS            = (0.85, 0.95)
AUGMENTATION_ROTATION_LIMITS        = (4, 12)
AUGMENTATION_TRANSLATION_X_Y_LIMITS = ((2, 6), (0, 2))
AUGMENTATION_BLUR                   = (0, 0.15)
AUGMENTATION_CONSTRAST_BRIGHT       = ((0.8, 1.2),(-2, 2))

# Model
# ----
INPUT_SHAPE = (IMG_WIDTH_SIZE, IMG_HEIGHT_SIZE, IMG_SEQ, IMG_CHAN) # Format sample: (128, 128, 64, 1)

MODEL_NAME = "Mult3DCNN4Input"
BATCH_SIZE = 8
EPOCHS     = 26

COMPILE_OPTIMIZER = SGD(learning_rate =0.001)
COMPILE_LOSS = 'binary_crossentropy'
COMPILE_METRICS = [AUC(name='auc')]

# TF Callback
# ----
TF_CALL_BACK_BEST_MODEL_MONITOR  = "val_auc"
TF_CALL_BACK_EARLY_STOP_MONITOR  = "auc"
TF_CALL_BACK_EARLY_STOP_PATIENTE = 6

In [None]:
if SEED is not None:
    os.environ["PYTHONHASHSEED"] = str(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)

### <b><span style='color:#78D118'>|</span> Debug</b>

In [None]:
Debug.set_debug_mode(False)

# Train
DEBUG_DICOM_TRAIN_AUGMENTATION       = False
DEBUG_SCANDATASET_TRAIN_AUGMENTATION = False

# Validation
DEBUG_DICOM_VALIDATION       = False
DEBUG_SCANDATASET_VALIDATION = False

# Test
DEBUG_DICOM_TEST       = False
DEBUG_SCANDATASET_TEST = False

## <b>2.3 <span style='color:#78D118'>|</span> Methods</b>

In [None]:
class DICOMLoaderAugmentation(DICOMLoader):
    def __init__(self,
        df,
        input_path,
        scan_categories,
        fraction_augmented                = 0,
        crop_limits                       = None,
        rotation_limits                   = None, 
        translation_x_y_limits            = None, 
        blur_limits                       = None,
        contrast_bright_alpha_beta_limits = None,
        num_imgs                          = None,
        size                              = (224, 224),
        rotate_angle                      = 0,
        enable_center_focus               = False,
        id_column_name                    = "ID",
        label_column_name                 = "Label",
        image_format                      = ImageFormat.WHDC,
        max_threads                       = 8,
        image_file_sorter                 = lambda x: int(x[:-4].split("-")[-1]),
        shuffle                           = False,
        seed                              = None,
        debug_mode                        = False
    ):        
        self.__crop_limits                       = crop_limits
        self.__rotation_limits                   = rotation_limits
        self.__translation_x_y_limits            = translation_x_y_limits
        self.__blur_limits                       = blur_limits
        self.__contrast_bright_alpha_beta_limits = contrast_bright_alpha_beta_limits
        self.__seed                              = seed
        self.__debug                             = InternalDebug(debug_mode=debug_mode)
        self.__fraction_augmented                = fraction_augmented
        df = df.copy()
        df = DICOMLoaderAugmentation.__data_augmentation(df, fraction_augmented, shuffle, self.__debug, seed)
          
        super().__init__(
            df, 
            input_path,
            scan_categories, 
            num_imgs, 
            size, 
            0, # NOTE: We define her the scale for augmented data
            rotate_angle,
            enable_center_focus,
            id_column_name, 
            label_column_name, 
            image_format,
            max_threads,
            image_file_sorter,
            False
        )
        
    # ---------------- #
    # Enum
    # ---------------- #
    
    class OrigineType(Enum):
        ORIGINAL = "ORIGINAL"
        AUGMENTED = "AUGMENTED"

    # ---------------- #
    # Public
    # ---------------- #

    def load_all_scans(
        self,
        row,
        show_progress=True
    ):
        self.__debug.log("== load_all_scans ==")
        
        # Load all scans for the current index
        # ----
        scans_images = super().load_all_scans(row, show_progress)
        
        if self.__fraction_augmented > 0 and self.__is_augmented(row):
            scans_images = self.__augmented(row, scans_images)
       
        return scans_images
    
    # ---------------- #
    # Private
    # ---------------- #
    
    # ---- ---- ---- #
    # Augmentation
    # ---- ---- ---- #
    
    def __augmented(self, row, scans_images):
        self.__debug.log("== augmented ==")
        
        # Radom Crop
        # ----
        if self.__crop_limits is not None:
            min_crop, max_crop = self.__crop_limits
        
            crop_random = random.uniform(min_crop, max_crop)
        
        # Radom Rotation
        # ----
        if self.__rotation_limits is not None:
            min_rot, max_rot = self.__rotation_limits
        
            rotation_random = random.uniform(min_rot, max_rot) * random.choice([-1, 1])
        
        # Radom Translation 
        # ----
        if self.__translation_x_y_limits is not None:
            tx, ty = self.__translation_x_y_limits
        
            min_tx, max_tx = tx
            min_ty, max_ty = ty
        
            tx_random = random.uniform(min_tx, max_tx) * random.choice([-1, 1])
            ty_random = random.uniform(min_ty, max_ty) * random.choice([-1, 1])
        
        # Radom Blur 
        # ----
        if self.__blur_limits is not None:
            min_blur, max_blur = self.__blur_limits
        
            blur_random = random.uniform(min_blur, max_blur)

        # Radom Contrast and Brightness
        # ----
        if self.__contrast_bright_alpha_beta_limits is not None:
            alpha, meta = self.__contrast_bright_alpha_beta_limits
        
            min_alpha, max_alpha = alpha
            min_beta, max_beta   = meta
        
            alpha_random = random.uniform(min_alpha, max_alpha)
            beta_radom = random.uniform(min_beta, max_beta)
        
        # Log
        # ----
        self.__debug.log(
                f"Generation\n"
                f"- Crop: {crop_random}\n"
                f"- Rotation: {rotation_random}\n"
                f"- translation x: {tx_random}\n"
                f"- translation y: {ty_random}\n"
                f"- Blur: {blur_random}\n"
                f"- Contrast Brightness alpha: {alpha_random}\n"
                f"- Contrast Brightness Beta: {beta_radom}\n"               
        )

        for scan_type, images in scans_images.items():
            
            augmented_images = []
            
            # Format Normalize
            # ----
            images = self.format(images, "normalize")
      
            for image in images:
                # Remove chan
                # ----
                image = np.squeeze(image, axis=-1)
                
                # Crop
                # ----
                if self.__crop_limits is not None:
                    image = self.__crop_img(image, crop_random)
                
                # Resize
                # ----
                image = self._resize_img(image)
                
                # Rotation
                # ----
                if self.__rotation_limits is not None:
                    image = self.__rotation_img(image, rotation_random)
                
                # Translation
                # ----
                if self.__translation_x_y_limits is not None:
                    image = self.__translation_img(image, tx_random, ty_random)
                
                # Blur
                # ----
                if self.__blur_limits is not None:
                    image = self.__blur_img(image, 3, blur_random)
                
                # Constrast and Brightness
                # ----
                if self.__contrast_bright_alpha_beta_limits is not None:
                    image = self.__contrast_and_brightness_img(image, alpha_random, beta_radom)
                  
                # Save
                # ----
                augmented_images.append(image)
        
            # Numpy array
            # ----
            images = np.array(augmented_images)
            
            # Add chan
            # ----
            images = np.expand_dims(images, axis=-1)
            
            # Format Train
            # ----
            images = self.format(images, "default")

            scans_images[scan_type] = images
        
        return scans_images

    def __blur_img(self, image, kernel_size, blur_random):
        return cv2.GaussianBlur(image, (kernel_size, kernel_size), blur_random)

    def __translation_img(self, image, tx, ty):
        height, width = image.shape[:2]
        translation_matrix = np.float32([[1, 0, tx], [0, 1, ty]])
        return cv2.warpAffine(image, translation_matrix, (width, height))
    
    def __rotation_img(self, image, rotation):
        height, width = image.shape[:2]
        center = (width / 2, height / 2)
        rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0)
        return cv2.warpAffine(image, rotation_matrix, (width, height))
    
    def __contrast_and_brightness_img(self,image, alpha, beta):
        return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
    
    def __crop_img(self, image, crop):
        # Skip if no crop
        # ----
        if crop <= 0:
            return image

        # Calculate the center of the image
        # ----
        center_x, center_y = image.shape[1] / 2, image.shape[0] / 2

        # Calculate the dimensions of the scaled image
        # ----
        width_scaled, height_scaled = image.shape[1] * crop, image.shape[0] * crop

        # Calculate the coordinates for cropping the image
        # ----
        left_x, right_x = center_x - width_scaled / 2, center_x + width_scaled / 2
        top_y, bottom_y = center_y - height_scaled / 2, center_y + height_scaled / 2

        # Crop the image using the calculated coordinates
        # ----
        return image[int(top_y):int(bottom_y), int(left_x):int(right_x)]
    
    def __is_augmented(self, row):
        return self.df.loc[row, "origine_type"] == DICOMLoaderAugmentation.OrigineType.AUGMENTED.value
    
    # ---------------- #
    # Private Static
    # ---------------- #
   
    @staticmethod
    def __data_augmentation(df, percentage, shuffle, debug, seed=None):
        if percentage <= 0:
            return df
        debug.log("== __data_augmentation ==")
        
        df["origine_type"] = DICOMLoaderAugmentation.OrigineType.ORIGINAL.value
        
        # Calculate the number of times to copy
        # ----
        num_copies = int(percentage)
        debug.log("num copies:", num_copies)
        rest_num_copies = DICOMLoaderAugmentation.__extract_first_decimal(percentage)
        debug.log("rest copies:", rest_num_copies, "%")
        
        # Calculate the number of rows to copy
        # ----
        num_rows = len(df)
        debug.log("Initial size:", num_rows)
        
        # Initialize a new DataFrame with copy the structure
        # ----
        augmented_df = df.copy().iloc[0:0]
        
        # Copy the DataFrame 'num_copies' times (525% = 500%)
        # ----
        for _ in range(num_copies):
            random_sample = df.sample(n=num_rows, random_state=seed)
            augmented_df = pd.concat([augmented_df, random_sample], ignore_index=True)
        
        # Copy the remaining percentage (525% = 25%)
        # ----
        if rest_num_copies > 0:
            nombre_d_echantillons = int(num_rows * (rest_num_copies / 100))
            echantillons_aleatoires = df.sample(n=nombre_d_echantillons, random_state=seed)
            augmented_df = pd.concat([augmented_df, echantillons_aleatoires], ignore_index=True)
               
        # Identified as Augmented
        # ----
        augmented_df["origine_type"] = DICOMLoaderAugmentation.OrigineType.AUGMENTED.value
        debug.log("Augmented size add:", len(augmented_df))
        
        # Merge with Original
        # ----
        df = pd.concat([df, augmented_df], ignore_index=True)
        
        # Suffle
        # ----
        if shuffle:
            df = df.sample(frac=1, random_state=seed)
          
        # Reset Ids
        # ----
        df = df.reset_index(drop=True)
        debug.log("Augmented size Total:", len(df))
        
        return df
    
    @staticmethod
    def __extract_first_decimal(number):
        # Convert the number to a string for easier manipulation
        number_str = str(number)

        # Find the index of the decimal point
        decimal_point_index = number_str.find(".")

        # Extract the first decimal digit from the string
        if decimal_point_index != -1 and decimal_point_index + 1 < len(number_str):
            first_decimal_digit = number_str[decimal_point_index + 1]
            
            if len(first_decimal_digit) == 1:
                first_decimal_digit += "0"
        else:
            # If there's no decimal part, return 0
            first_decimal_digit = "0"

        return int(first_decimal_digit)

In [None]:
class ScanDataset(Sequence, Summarizable):
    def __init__(
        self,
        dicom_loader,
        batch_size,
        subset     = "train",
        shuffle    = True,
        debug_mode = False
    ):
        """
        Initializes the ScanDataset object.

        Parameters
        ----------
        dicom_loader : object
            The DICOMLoader object to load DICOM images.
        batch_size : int
            The size of each batch.
        subset: Subset of the data to return.
          One of "training", "validation" or "other".
          training and validation give the y_bath
        shuffle : bool, optional
            Whether to shuffle the dataset.
        debug_mode : bool, optional
            Whether to print debug information.
        """
        self.__dicom_loader = dicom_loader
        self.__batch_size   = batch_size
        self.__is_trainable = subset.lower() in ["validation", "train"]
        self.__shuffle      = shuffle
        self.__debug        = InternalDebug(debug_mode=debug_mode)
        self.__indices      = np.arange(self.__dicom_loader.len)
        
        if self.__shuffle:
            np.random.shuffle(self.__indices)
        
    # ---------------- #
    # Public methods
    # ---------------- #

    def show_batch(
        self,
        row,
        columns=None,
        figure_size=(5, 5),
        color_map='hot'
    ):
        # Get the batch of images and labels (if in training mode)
        # ----
        if self.__is_trainable:
            x_batch, y_batch = self[row]
        else:
            x_batch = self[row]
            y_batch = None

        # Determine the number of columns to show
        # ----
        if columns == None:
            columns = self.__dicom_loader.num_imgs
            
        # Determine the number of input tensors
        # ----
        num_input_tensors = len(self.__dicom_loader.scan_categories)

        # Loop through each sub-batch in the main batch
        # ----
        for i in range(len(x_batch[0])):
            # Display the batch number and label (if available)
            # ----
            label_info = f" Label: {y_batch[i]}" if y_batch is not None else ""

            # Loop through each input tensor
            # ----
            for j in range(num_input_tensors):
                images = x_batch[j][i]
                scan_type = self.__dicom_loader.scan_categories[j]

                if self.__dicom_loader.image_format == ImageFormat.WHDC:
                    images = ImageFormat.swap_dimensions(images, ImageFormat.DWHC)

                # Generate labels for each image in the set
                # ----
                labels = [
                            f"Batch: {i + 1} \nImg: {k + 1} \nType: {scan_type} \nLabel:{label_info}"
                            for k in range(len(images))
                         ]

                # Show
                # ----
                show_images(
                    images,
                    y=labels,
                    columns=columns,
                    figure_size=figure_size,
                    color_map=color_map
                )

    # Overriding the summary method
    def summary(self, train_dataset=None):
        super().summary()

        print("Additional summary details specific:")

        # 0 is the index of the first batch
        # ----
        if self.__is_trainable:
            batch_x, batch_y = self[0]
        else:
            batch_x = self[0]

        # Checking the batch format
        # ----
        print("Batch_x format:")
        for i, x in enumerate(batch_x):
          print(f"- Scan type {i+1}: {x.shape}")

        if self.__is_trainable:
            print(f"Batch_y format: {batch_y.shape}")
            
        print("=" * 40)

    def on_epoch_end(self):
        """
        Shuffles the dataset at the end of each epoch if shuffle is True.
        """
        if self.__shuffle:
            np.random.shuffle(self.__indices)

    # ---------------- #
    # Private methods
    # ---------------- #

    def __getitem__(self, ids):
        self.__debug.log("== __getitem__ ==")
        """
        Retrieves a batch of data by batch index.

        Parameters
        ----------
        ids : int
            The batch index.

        Returns
        -------
        tuple
            A tuple containing the batch of images and labels.
        """
        # Calculate the start and end indices for the batch
        # ----
        from_id = ids * self.__batch_size
        to_id = (ids + 1) * self.__batch_size

        self.__debug.log(f"Batch ID: {ids}")
         
        # Get the indices and labels for the current batch
        # ----
        batch_indices = self.__indices[from_id: to_id]
   
        self.__debug.log("batch_indices:", batch_indices.tolist())

        batches_y = []

        # Initialize a list to hold batches for each input tensor
        # ----
        batches_x = [[] for _ in range(len(self.__dicom_loader.scan_categories))]

        # Loop through each index in the batch
        # ----
        for i in batch_indices:
            self.__debug.log("Processing batch index:", i)

            # Store label
            # ----
            label = self.__dicom_loader.gel_label(i)
            batches_y.append(label)
            self.__debug.log("Label:", label)

            # Load all scans for the current index
            # ----
            batch_x_image_paths = self.__dicom_loader.load_all_scans(i, show_progress=False)

            # Loop through each scan type and its corresponding images
            # ----
            for j, (scan_type, images) in enumerate(batch_x_image_paths.items()):
                self.__debug.log("Processing Scan Type:", scan_type, "Number of Images Loaded:", len(images))
                batches_x[j].append(images)

        # Convert to batch x y
        # ----
        batch_x = [np.array(b) for b in batches_x]
        batch_y = np.array(batches_y)

        self.__debug.log(f"Final batch shapes - batch_x: {[x.shape for x in batch_x]}, batch_y: {batch_y}")

        # Return the image batches and labels if in training mode, otherwise just the image batches
        # ----
        if self.__is_trainable:
            return batch_x, batch_y
        else:
            return batch_x

    def __len__(self):
        """
        Calculates the number of batches in the dataset.

        Returns
        -------
        int
            The number of batches.
        """
        return int(np.ceil(self.__dicom_loader.len / self.__batch_size))

## <b>3. <span style='color:#78D118'>|</span> Data Retrieval</b>



In [None]:
# Train
train_df = pd.read_csv(TRAIN_DATASET_DF_DIR)
train_df.rename(columns = { "BraTS21ID": "ID", "MGMT_value": "Label"}, inplace=True)

index_to_remove = train_df[train_df['ID'].isin(EXCLUDED_IDS)].index
train_df.drop(index_to_remove, inplace=True)

train_df.reset_index(drop=True, inplace=True)
show_text("b", "Train", False)
display(train_df.head())

# Test
test_df = pd.read_csv(TEST_DATASET_DF_DIR)
test_df.rename(columns = { "BraTS21ID": "ID", "MGMT_value": "Label"}, inplace=True)
show_text("b", "Test", False)
display(test_df.head())

## <b>4 <span style='color:#78D118'>|</span> Data Preparation</b>

### <b>4.1 <span style='color:#78D118'>|</span> Split Data</b>

In [None]:
# Split Train, valid using Stratified K-Folds
# ----
train_df, valid_df = stratifiedTrainValidSplit(
    train_df,
    x_feature_columns = ['ID'],
    y_target_columns = ['Label'],
    num_splits = NUM_SPLIT_FOLDS,
    selected_fold = SELECTED_VALIDATION_FOLD,
    seed = SEED,
    shuffle = SHUFFLE
)

# Show
# ----
show_donut(
    [len(train_df), len(valid_df)],
    ["Train", "Validation"], # TODO: Add test
    colors = ["lightsteelblue","coral"],
    figsize = (8,8),
    title = "Dataset Distribution"
)

### <b>4.2. <span style='color:#78D118'>|</span> Train and Validation</b>

In [None]:
# ---- ---- ---- #
#      Train
# ---- ---- ---- #

# Dicom loader
# ----
train_dicom_loader = DICOMLoaderAugmentation(
    train_df,
    input_path          = TRAIN_DATASET_PATH,
    scan_categories     = SCAN_CATEGORIES,
    
    # Image
    num_imgs            = IMG_SEQ,
    size                = IMG_SIZE,
    rotate_angle        = IMG_ROTATE,
    enable_center_focus = IMG_ENABLE_CENTRAL_FOCUS,
    shuffle             = SHUFFLE,
    
    # Augmentation
    fraction_augmented                = AUGMENTATION_FRACTION,
    crop_limits                       = AUGMENTATION_CROP_LIMITS,
    rotation_limits                   = AUGMENTATION_ROTATION_LIMITS,
    translation_x_y_limits            = AUGMENTATION_TRANSLATION_X_Y_LIMITS,
    blur_limits                       = AUGMENTATION_BLUR,
    contrast_bright_alpha_beta_limits = AUGMENTATION_CONSTRAST_BRIGHT,
    
    # Settings
    max_threads = MAX_THREADS_DICOM_LOADER,
    seed        = SEED,
    debug_mode =  DEBUG_DICOM_TRAIN_AUGMENTATION,
)

# Dataset
# ----
train_dataset = ScanDataset(
    dicom_loader = train_dicom_loader,
    batch_size   = BATCH_SIZE,
    subset       = DatasetType.TRAIN.value,
    shuffle      = SHUFFLE,
    debug_mode   = DEBUG_SCANDATASET_TRAIN_AUGMENTATION
)

# ---- ---- ---- #
#   Validation
# ---- ---- ---- #

# Dicom loader
# ----
val_dicom_loader = DICOMLoader(
    valid_df,
    input_path          = TRAIN_DATASET_PATH,
    scan_categories     = SCAN_CATEGORIES,
    num_imgs            = IMG_SEQ,
    size                = IMG_SIZE,
    scale               = IMG_SCALE,
    rotate_angle        = IMG_ROTATE,
    max_threads         = MAX_THREADS_DICOM_LOADER,
    enable_center_focus = IMG_ENABLE_CENTRAL_FOCUS,
    debug_mode          = DEBUG_DICOM_VALIDATION
)

# Dataset
# ----
val_dataset = ScanDataset(
    dicom_loader = val_dicom_loader,
    batch_size   = BATCH_SIZE,
    subset       = DatasetType.VALIDATION.value,
    shuffle      = False,
    debug_mode   = DEBUG_SCANDATASET_VALIDATION
)

# ---- ---- ---- #
#      Test
# ---- ---- ---- #

# Dicom loader
# ----
test_dicom_loader = DICOMLoader(
    test_df,
    input_path          = TEST_DATASET_PATH,
    scan_categories     = SCAN_CATEGORIES,
    num_imgs            = IMG_SEQ,
    size                = IMG_SIZE,
    scale               = IMG_SCALE,
    rotate_angle        = IMG_ROTATE,
    max_threads         = MAX_THREADS_DICOM_LOADER,
    enable_center_focus = IMG_ENABLE_CENTRAL_FOCUS,
    debug_mode          = DEBUG_DICOM_TEST
)

# Dataset
# ----
test_dataset = ScanDataset(
    dicom_loader = test_dicom_loader,
    batch_size   = BATCH_SIZE,
    subset       = DatasetType.TEST.value,
    shuffle      = False,
    debug_mode   = DEBUG_SCANDATASET_TEST
)

### <b>4.3. <span style='color:#78D118'>|</span> Preview</b>

In [None]:
train_dicom_loader.show_all(0)

In [None]:
train_dataset.show_batch(0, columns=10, figure_size=(10, 10))

### <b>4.4 <span style='color:#78D118'>|</span> Summary</b>

In [None]:
def summary():
    separator = "\n" + "---- " * 20 + "\n"
    
    def print_section(section_name, section_info):   
        print(separator)
        print(section_name)
        print("\n")
        for key, value in section_info:
            print(f"{key}: {value}")
        
    version_section = [
        ("VERSION", VERSION),
    ]
    print_section("Version", version_section)
    
    # Image
    image_section = [
        ("IMG_SEQ", IMG_SEQ),
        ("IMG_SCALE", IMG_SCALE),
        ("IMG_ROTATE", IMG_ROTATE),
        ("IMG_ENABLE_CENTRAL_FOCUS", IMG_ENABLE_CENTRAL_FOCUS),
        ("IMG_SHAPE", INPUT_SHAPE),
    ]
    print_section("Image", image_section)

    # Model
    model_section = [
        ("MODEL_NAME", MODEL_NAME),
        ("BATCH_SIZE", BATCH_SIZE),
        ("EPOCHS", EPOCHS),
        ("OPTIMIZER", type(COMPILE_OPTIMIZER)),
        ("COMPILE_LOSS", COMPILE_LOSS),
        ("COMPILE_METRICS", COMPILE_METRICS),
        ("SHUFFLE", SHUFFLE),
    ]
    print_section("Model", model_section)

    # Train
    # ----
    print(separator)
    print("Train")
    print(separator)
    train_dicom_loader.summary()
    print(separator)
    train_dataset.summary()

    # Val
    # ----
    print(separator)
    print("Validation")
    print(separator)
    val_dicom_loader.summary()
    print(separator)
    val_dataset.summary()
    print(separator)
    
    
    # Test
    # ----
    print(separator)
    print("Test")
    print(separator)
    test_dicom_loader.summary()
    print(separator)
    test_dataset.summary()
    print(separator)



summary()

# <b>5 <span style='color:#78D118'>|</span> Model creation</b>

In [None]:
class DeepScanModel(Model):
    def __init__(self, input_shape, model_name="My3DCNNModel"):
        # Define input layers
        # ----
        self.input_layers = [Input(shape=input_shape) for _ in range(4)]

        # Build CNN models for each input
        # ----
        self.cnn_models = [self.build_cnn_branch(input_layer) for input_layer in self.input_layers]

        # Concatenate outputs of CNN models
        # ----
        concatenated = concatenate(self.cnn_models)

        # Add Global Average Pooling and Dense layers
        # ----
        x = self.build_head(concatenated)

        # Define the final model
        # ----
        super(DeepScanModel, self).__init__(inputs=self.input_layers, outputs=x, name=model_name)

    def build_cnn_branch(self, input_layer):        
        x = Conv3D(64, 3)(input_layer)
        x = ReLU()(x)
        x = MaxPool3D(2)(x)
        x = BatchNormalization()(x)
        
        x = Conv3D(128, 3)(x)
        x = ReLU()(x)
        x = MaxPool3D(2)(x)
        x = BatchNormalization()(x)
        x = Dropout(0.1)(x)

        x = Conv3D(256, 3)(x)
        x = ReLU()(x)
        x = MaxPool3D(2)(x)
        x = BatchNormalization()(x)
        x = Dropout(0.2)(x)
        
        return x

    def build_head(self, x):
        x = GlobalAveragePooling3D()(x)

        x = Dense(1024)(x)
        x = ReLU()(x)
        x = Dropout(0.3)(x)

        x = Dense(1, activation="sigmoid")(x)
        
        return x

    def show_graph(self):
        display(plot_model(self, show_shapes=True, show_layer_names=True))


In [None]:
model = DeepScanModel(input_shape=INPUT_SHAPE, model_name=MODEL_NAME)
model.show_graph()

# <b>6 <span style='color:#78D118'>|</span> Model Training</b>

In [None]:
# Logs and models dir
# ----
mkdir(LOGS_PATH)
mkdir(BEST_MODEL_PATH)

# Each model
# ----
show_text("sep")
show_text("h3","Run model")

# Callbacks bestmodel
# ----
bestmodel_callback = ModelCheckpoint(
    filepath       = BEST_MODEL_H5_DIR,
    verbose        = VERBOSITY,
    monitor        = TF_CALL_BACK_BEST_MODEL_MONITOR,
    mode           = 'max',
    save_best_only = True
)

# Callbacks EarlyStopping
# ----
earlystopping_callback = EarlyStopping(
    monitor   = TF_CALL_BACK_EARLY_STOP_MONITOR,
    min_delta = 0,
    patience  = TF_CALL_BACK_EARLY_STOP_PATIENTE,
    verbose   = VERBOSITY,
    mode      = 'auto',
    baseline  = None,
    restore_best_weights = True,
)

# Compile
# ----
model.compile(
    optimizer = COMPILE_OPTIMIZER,
    loss      = COMPILE_LOSS,
    metrics   = COMPILE_METRICS 
)

# Train
# ----
history = model.fit(
    x               = train_dataset,
    validation_data = val_dataset,
    epochs          = EPOCHS,
    shuffle         = SHUFFLE,
    verbose         = VERBOSITY,
    callbacks       = [bestmodel_callback, earlystopping_callback]
)

# <b>7 <span style='color:#78D118'>|</span> Model Evaluation</b>

In [None]:
# Show result
# ----
show_best_history(
    history, 
    metric ="auc",
    add_metric=["val_auc"]
)
# Show history
# ----
show_history(
    history,
    title = "AUC History",
    y_label = "AUC",
    metrics = ["auc", "val_auc"],
    metric_labels = ["Train AUC", "Validation AUC"]
)

show_history(
    history,
    title = "Loss History",
    y_label = "Loss",
    metrics = ["loss", "val_loss"],
    metric_labels = ["Train Loss", "Validation AUC"]
)

# <b>8 <span style='color:#78D118'>|</span> Submission</b>

In [None]:
model = load_model(BEST_MODEL_H5_DIR, custom_objects={'DeepScanModel': DeepScanModel})

In [None]:
def generate_predictions(model, test_dataset, test_df):
    predictions = []

    for batch_idx in range(len(test_dataset)):
        scan_type_1, scan_type_2, scan_type_3, scan_type_4 = test_dataset[batch_idx]
        
        batch_predictions = model.predict([scan_type_1, scan_type_2, scan_type_3, scan_type_4])

        predictions.append(batch_predictions)

    # Flatten the predictions list
    # ----
    submission = test_df.copy()
    submission["Label"] = [item[0] for sublist in predictions for item in sublist]
    submission.rename(columns={"ID": "BraTS21ID", "Label": "MGMT_value"}, inplace=True)

    return submission

submission = generate_predictions(model, test_dataset, test_df)

In [None]:
show_histogram(
    submission["MGMT_value"], 
    xlabel="MGMT value", 
    ylabel="Qt", 
    title="Result"
)

display(submission.head(100))

In [None]:
submission.to_csv("submission.csv", index=False)