# **Set up Dependencies and Datasets**

In [None]:
!sudo apt update
!sudo apt install python3-pip
!pip3 install jupyter tensorflow==2.19.0 segmentation-models-3D h5py numpy matplotlib scikit-learn patchify kaggle

In [None]:
# Ensure kaggle.json is in ~/.kaggle/ with chmod 600
!kaggle datasets download -d awsaf49/brats20-dataset-training-validation -p /root/BraTS2020
!unzip /root/BraTS2020/brats20-dataset-training-validation.zip -d /root/BraTS2020

# **Imports and GPU Setup**

In [None]:
# Imports
import tensorflow as tf
import tensorflow.keras as keras
import segmentation_models_3D as sm
import numpy as np
import h5py
from patchify import patchify
import os
from matplotlib import pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import logging
import sys
from IPython.utils import io

# Create output directory
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

# Set up logging
log_file = os.path.join(output_dir, "training_log.txt")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stdout)]
)

class LoggerWriter:
    def __init__(self, logger, level):
        self.logger = logger
        self.level = level
    def write(self, message):
        if message.strip():
            self.logger.log(self.level, message)
    def flush(self):
        pass

sys.stdout = LoggerWriter(logging.getLogger(), logging.INFO)

# GPU setup with mixed precision
gpus = tf.config.list_physical_devices('GPU')
if len(gpus) > 0:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        logging.info(f"Found GPU: {gpus[0]}")
        policy = tf.keras.mixed_precision.Policy('mixed_float16')
        tf.keras.mixed_precision.set_global_policy(policy)
        logging.info(f"TensorFlow version: {tf.__version__}")
    except RuntimeError as e:
        logging.error(f"GPU setup error: {e}")
else:
    logging.error(f"No GPU detected. Available devices: {tf.config.list_physical_devices()}")
    raise SystemError('GPU device not found')

logging.info("Cell 1 completed: Setup and GPU configuration done.")

# **Utility Functions**

In [None]:
def load_h5(file_path, modality='t1ce'):
    """Loads a modality from an HDF5 file."""
    with h5py.File(file_path, 'r') as f:
        img = f[modality][:]
        return img.astype(np.float32)

def visualize_mri_and_segmentation(mri_data, seg_data, slice_idx=None):
    """Visualize MRI scan and segmentation mask."""
    if slice_idx is None:
        slice_idx = mri_data.shape[2] // 2
    logging.info(f"MRI Shape: {mri_data.shape}, Segmentation Shape: {seg_data.shape}")
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(mri_data[:, :, slice_idx], cmap="gray")
    axes[0].set_title(f"MRI Scan (T1CE) - Slice {slice_idx}")
    axes[1].imshow(seg_data[:, :, slice_idx], cmap="gray")
    axes[1].set_title(f"Segmentation Mask - Slice {slice_idx}")
    plt.savefig(os.path.join(output_dir, f"mri_seg_slice_{slice_idx}.png"))
    plt.close()
    logging.info(f"Saved visualization to {output_dir}/mri_seg_slice_{slice_idx}.png")

def load_data_from_folder(folder_path, max_subfolders=369, chunk_size=20):
    """Load MRI and segmentation data from .h5 files in chunks."""
    mri_data_dict = {}
    seg_data_dict = {}
    subfolders = sorted([f for f in os.listdir(folder_path) if 'BraTS20_Training_' in f])[:max_subfolders]
    
    for chunk_start in range(0, len(subfolders), chunk_size):
        chunk_end = min(chunk_start + chunk_size, len(subfolders))
        logging.info(f"Loading chunk {chunk_start}-{chunk_end}")
        for subfolder in subfolders[chunk_start:chunk_end]:
            subfolder_path = os.path.join(folder_path, subfolder)
            h5_file = os.path.join(subfolder_path, f"{subfolder}.h5")
            if os.path.isfile(h5_file):
                logging.info(f"Loading data from {subfolder}")
                mri_data_dict[subfolder] = {
                    't1': load_h5(h5_file, 't1'),
                    't1ce': load_h5(h5_file, 't1ce'),
                    't2': load_h5(h5_file, 't2'),
                    'flair': load_h5(h5_file, 'flair')
                }
                seg_data_dict[subfolder] = load_h5(h5_file, 'seg')
    
    first_key = list(mri_data_dict.keys())[0]
    visualize_mri_and_segmentation(mri_data_dict[first_key]['t1ce'], seg_data_dict[first_key])
    logging.info("Cell 2 completed: Data loaded.")
    return mri_data_dict, seg_data_dict

# **Patch Extraction**

In [None]:
def extract_patches(mri_data_dict, seg_data_dict, patch_size=(64, 64, 64, 4), step=32):
    """Extract patches from stacked MRI modalities and segmentation data."""
    mri_patches_dict = {}
    seg_patches_dict = {}
    
    for subfolder in mri_data_dict.keys():
        logging.info(f"Processing subfolder: {subfolder}")
        modalities = mri_data_dict[subfolder]
        
        mri_data_stacked = np.stack([modalities['t1'], modalities['t1ce'], 
                                    modalities['t2'], modalities['flair']], axis=-1)
        target_shape = tuple(((dim + step - 1) // step) * step for dim in mri_data_stacked.shape[:-1]) + (4,)
        pad_width = [(0, target_shape[i] - mri_data_stacked.shape[i]) for i in range(3)] + [(0, 0)]
        mri_data_padded = np.pad(mri_data_stacked, pad_width, mode='constant', constant_values=0)
        
        mri_patches = patchify(mri_data_padded, patch_size=patch_size, step=step)
        logging.info(f"MRI patches shape: {mri_patches.shape}")
        
        seg_data = seg_data_dict[subfolder]
        seg_target_shape = tuple(((dim + step - 1) // step) * step for dim in seg_data.shape)
        seg_pad_width = [(0, seg_target_shape[i] - seg_data.shape[i]) for i in range(3)]
        seg_data_padded = np.pad(seg_data, seg_pad_width, mode='constant', constant_values=0)
        seg_patches = patchify(seg_data_padded, patch_size=patch_size[:-1], step=step)
        logging.info(f"Segmentation patches shape: {seg_patches.shape}")
        
        mri_patches_dict[subfolder] = mri_patches
        seg_patches_dict[subfolder] = seg_patches
    
    logging.info("Cell 3 completed: Patches extracted.")
    return mri_patches_dict, seg_patches_dict

# **Preprocessing**

In [None]:
def preprocess_data(mri_patches_dict, seg_patches_dict, chunk_size=20):
    """Normalize and preprocess patches in chunks."""
    mri_patches = []
    seg_patches = []
    subfolders = list(mri_patches_dict.keys())
    
    for chunk_start in range(0, len(subfolders), chunk_size):
        chunk_end = min(chunk_start + chunk_size, len(subfolders))
        logging.info(f"Preprocessing chunk {chunk_start}-{chunk_end}")
        chunk_mri = []
        chunk_seg = []
        
        for subfolder in subfolders[chunk_start:chunk_end]:
            mri_patch = mri_patches_dict[subfolder]
            seg_patch = seg_patches_dict[subfolder]
            
            mri_patch = (mri_patch - np.mean(mri_patch)) / (np.std(mri_patch) + 1e-8)
            num_patches = np.prod(mri_patch.shape[:-4])
            mri_patch = mri_patch.reshape(num_patches, *mri_patch.shape[-4:])
            seg_patch = seg_patch.reshape(num_patches, *seg_patch.shape[-3:])
            seg_patch = to_categorical(seg_patch, num_classes=4)
            
            chunk_mri.append(mri_patch)
            chunk_seg.append(seg_patch)
        
        mri_patches.append(np.concatenate(chunk_mri, axis=0))
        seg_patches.append(np.concatenate(chunk_seg, axis=0))
        del chunk_mri, chunk_seg
        K.clear_session()  # Clear GPU memory
    
    X = np.concatenate(mri_patches, axis=0)
    y = np.concatenate(seg_patches, axis=0)
    logging.info(f"Preprocessed data shapes - X: {X.shape}, y: {y.shape}")
    logging.info("Cell 4 completed: Data preprocessed.")
    return X, y

# Load and process data
with io.capture_output() as captured:
    folder_path = "/root/BraTS2020/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
    mri_data_dict, seg_data_dict = load_data_from_folder(folder_path)
    mri_patches_dict, seg_patches_dict = extract_patches(mri_data_dict, seg_patches_dict)
    X, y = preprocess_data(mri_patches_dict, seg_patches_dict)
logging.info(f"Cell 4 output:\n{captured.stdout}")

# **Data Split and Model Definition**

In [None]:
# Split data
with io.capture_output() as captured:
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
logging.info(f"Training data shape: {X_train.shape}, Validation data shape: {X_val.shape}")
logging.info(f"Cell 5 output (split):\n{captured.stdout}")

# Define model
model = sm.Unet(
    backbone_name='resnet50',
    input_shape=(64, 64, 64, 4),
    classes=4,
    activation='softmax',
    encoder_weights=None
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=sm.losses.DiceLoss(),
    metrics=[sm.metrics.IOUScore()]
)

with io.capture_output() as captured:
    model.summary()
logging.info(f"Model summary:\n{captured.stdout}")
logging.info("Cell 5 completed: Data split and model defined.")

# **Training and Saving**

In [None]:
# Training
batch_size = 4
epochs = 30
with io.capture_output() as captured:
    history = model.fit(
        X_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(X_val, y_val),
        verbose=1
    )
logging.info(f"Training output:\n{captured.stdout}")

# Plot and save training history
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.savefig(os.path.join(output_dir, "loss_plot.png"))
plt.close()
logging.info(f"Saved loss plot to {output_dir}/loss_plot.png")

# Save model
model_file = os.path.join(output_dir, "3d_unet_brats2020.h5")
model.save(model_file)
logging.info(f"Model saved to {model_file}")
logging.info("Cell 6 completed: Training completed and outputs saved.")