<a href="https://colab.research.google.com/github/antoniog18/BME-4790-Projects/blob/main/BME_4790_Final_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("awsaf49/brats2020-training-data")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'brats2020-training-data' dataset.
Path to dataset files: /kaggle/input/brats2020-training-data


# Task
Train a UNET model for brain tumor segmentation using the BraTS 2020 dataset located at "/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3", and then summarize the model's performance based on the evaluation metrics.

## Install and Import Libraries

### Subtask:
Install necessary libraries such as `nibabel` for NIfTI file handling and `monai` for medical image AI, along with standard libraries like `numpy` and `torch`. Then, import all required modules.


**Reasoning**:
The first step is to install the necessary libraries, `nibabel` and `monai`, using pip as specified in the instructions. This needs to be in a separate code block.



In [None]:
!pip install nibabel monai
print("nibabel and monai installed.")

nibabel and monai installed.


**Reasoning**:
Now that the necessary libraries are installed, I will import all required modules as specified in the instructions to prepare the environment for data loading, preprocessing, model definition, and training.



In [None]:
import numpy as np
import torch
import nibabel as nib
from monai.transforms import Compose, LoadImaged, Orientationd, ScaleIntensityRanged, CropForegroundd, AsDiscrete, EnsureChannelFirstd
from monai.data import decollate_batch, Dataset, DataLoader
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

print("Required libraries imported successfully.")

Required libraries imported successfully.


**Reasoning**:
The previous code failed because `AddChanneld` could not be imported from `monai.transforms`. This error usually indicates that the transform has been moved or renamed in the installed MONAI version (1.5.1). I will replace it with `EnsureChannelFirstd`, which serves a similar purpose for dictionary-based transforms and is a more common and robust way to handle channel dimensions in recent MONAI versions.



## Prepare Data Paths

### Subtask:
Utilize the `path` variable from the Kaggle dataset download to identify and organize the paths to the individual NIfTI image files (T1, T1Gd, T2, T2-FLAIR) and their corresponding segmentation masks for each patient in the BraTS dataset.


**Reasoning**:
I need to define the base path, import the `os` module, iterate through the dataset directory to find patient folders, and then construct the file paths for the different MRI modalities (T1, T1Gd, T2, T2-FLAIR) and their corresponding segmentation masks for each patient.



In [None]:
import os

base_path = '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/'

data_dicts = []

# List all patient directories
patient_dirs = sorted([d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))])

for patient_id in patient_dirs:
    patient_path = os.path.join(base_path, patient_id)

    # Construct paths for image modalities and segmentation mask
    # Assuming the naming convention: <patient_id>_<modality>.nii.gz
    image_t1 = os.path.join(patient_path, f"{patient_id}_t1.nii.gz")
    image_t1gd = os.path.join(patient_path, f"{patient_id}_t1ce.nii.gz") # Often t1ce for contrast-enhanced T1
    image_t2 = os.path.join(patient_path, f"{patient_id}_t2.nii.gz")
    image_flair = os.path.join(patient_path, f"{patient_id}_flair.nii.gz")
    label_seg = os.path.join(patient_path, f"{patient_id}_seg.nii.gz")

    data_dicts.append({
        "image": [image_t1, image_t1gd, image_t2, image_flair],
        "label": label_seg
    })

print(f"Found {len(data_dicts)} patient data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
split_idx = int(len(data_dicts) * 0.8)
train_files = data_dicts[:split_idx]
val_files = data_dicts[split_idx:]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# Safely print first entries if lists are not empty
if train_files:
    print("First entry of train_files:", train_files[0])
else:
    print("train_files is empty.")
if val_files:
    print("First entry of val_files:", val_files[0])
else:
    print("val_files is empty.")

Found 1 patient data entries.
Training files: 0
Validation files: 1
train_files is empty.
First entry of val_files: {'image': ['/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/content_t1.nii.gz', '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/content_t1ce.nii.gz', '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/content_t2.nii.gz', '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/content_flair.nii.gz'], 'label': '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/content_seg.nii.gz'}


**Reasoning**:
The previous code output indicated that only one patient directory ('content') was found, leading to an empty `train_files` list. This suggests that the `base_path` was still one level too high in the directory structure. The actual patient folders are likely located inside the `content` directory within `BraTS2020_training_data`. I need to correct the `base_path` to include `content/` so that `os.listdir` can correctly find all individual patient directories.



In [None]:
import os
import glob

base_path = '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/'

data_dicts = []

# Get all NIfTI files in the base_path
all_nifti_files = glob.glob(os.path.join(base_path, "*.nii.gz"))

# Extract unique patient IDs from the filenames
patient_ids = set()
for filepath in all_nifti_files:
    filename = os.path.basename(filepath)
    # Assuming filename format is BraTS20_XXXXX_modality.nii.gz or BraTS20_XXXXX_seg.nii.gz
    # The patient ID is the part before the last underscore
    parts = filename.split('_')
    if len(parts) >= 2:
        # patient_id will be something like 'BraTS20_00001'
        patient_id = '_'.join(parts[:-1])
        if patient_id.startswith('BraTS20'): # Filter to ensure it's a valid patient ID
            patient_ids.add(patient_id)

# Sort patient IDs for consistent order
sorted_patient_ids = sorted(list(patient_ids))

# For each unique patient ID, construct the data dictionary
for patient_id in sorted_patient_ids:
    image_t1 = os.path.join(base_path, f"{patient_id}_t1.nii.gz")
    image_t1gd = os.path.join(base_path, f"{patient_id}_t1ce.nii.gz") # Often t1ce for contrast-enhanced T1
    image_t2 = os.path.join(base_path, f"{patient_id}_t2.nii.gz")
    image_flair = os.path.join(base_path, f"{patient_id}_flair.nii.gz")
    label_seg = os.path.join(base_path, f"{patient_id}_seg.nii.gz")

    # Ensure all expected files exist for this patient
    if (os.path.exists(image_t1) and os.path.exists(image_t1gd) and
        os.path.exists(image_t2) and os.path.exists(image_flair) and
        os.path.exists(label_seg)):

        data_dicts.append({
            "image": [image_t1, image_t1gd, image_t2, image_flair],
            "label": label_seg
        })

print(f"Found {len(data_dicts)} patient data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
split_idx = int(len(data_dicts) * 0.8)
train_files = data_dicts[:split_idx]
val_files = data_dicts[split_idx:]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# Safely print first entries if lists are not empty
if train_files:
    print("First entry of train_files:", train_files[0])
else:
    print("train_files is empty.")
if val_files:
    print("First entry of val_files:", val_files[0])
else:
    print("val_files is empty.")


Found 0 patient data entries.
Training files: 0
Validation files: 0
train_files is empty.
val_files is empty.


**Reasoning**:
The previous attempts failed to correctly identify all patient files, likely due to unexpected nesting or a non-standard directory structure for the BraTS dataset from this Kaggle source. To robustly find all NIfTI files and their corresponding patient IDs, I will perform a recursive search for all `.nii.gz` files starting from the top-level dataset path. Then, I will extract unique patient IDs from these filenames and group the modality images and segmentation masks for each patient.



In [None]:
import os
import glob

# Set the base path to the top-level directory where the dataset was extracted
base_dataset_path = '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/'

data_dicts = []

# Find all NIfTI files recursively within the base_dataset_path
all_nifti_files = glob.glob(os.path.join(base_dataset_path, '**', '*.nii.gz'), recursive=True)

# Dictionary to store files for each patient ID
patient_files = {}

for filepath in all_nifti_files:
    filename = os.path.basename(filepath)

    # Extract patient_id from filename (e.g., BraTS20_00001_t1.nii.gz -> BraTS20_00001)
    # Or for the observed 'data_t1.nii.gz' -> 'data'
    parts = filename.split('_')
    if len(parts) >= 2: # Ensure there's a modality/seg part
        patient_id = '_'.join(parts[:-1])

        # Get the directory where this patient's files are located
        patient_dir = os.path.dirname(filepath)

        if patient_id not in patient_files:
            patient_files[patient_id] = {'dir': patient_dir, 'files': []}
        patient_files[patient_id]['files'].append(filepath)

# Now, for each patient_id, construct the full data dictionary
for patient_id, info in patient_files.items():
    patient_dir = info['dir']

    image_t1 = os.path.join(patient_dir, f"{patient_id}_t1.nii.gz")
    image_t1gd = os.path.join(patient_dir, f"{patient_id}_t1ce.nii.gz")
    image_t2 = os.path.join(patient_dir, f"{patient_id}_t2.nii.gz")
    image_flair = os.path.join(patient_dir, f"{patient_id}_flair.nii.gz")
    label_seg = os.path.join(patient_dir, f"{patient_id}_seg.nii.gz")

    # Ensure all expected files exist for this patient
    # Note: This might filter out patients with incomplete data
    if (os.path.exists(image_t1) and os.path.exists(image_t1gd) and
        os.path.exists(image_t2) and os.path.exists(image_flair) and
        os.path.exists(label_seg)):

        data_dicts.append({
            "image": [image_t1, image_t1gd, image_t2, image_flair],
            "label": label_seg
        })

print(f"Found {len(data_dicts)} patient data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
split_idx = int(len(data_dicts) * 0.8)
train_files = data_dicts[:split_idx]
val_files = data_dicts[split_idx:]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# Safely print first entries if lists are not empty
if train_files:
    print("First entry of train_files:", train_files[0])
else:
    print("train_files is empty.")
if val_files:
    print("First entry of val_files:", val_files[0])
else:
    print("val_files is empty.")

Found 0 patient data entries.
Training files: 0
Validation files: 0
train_files is empty.
val_files is empty.


**Reasoning**:
The previous code block resulted in an empty `all_nifti_files` list, meaning no `.nii.gz` files were found recursively within the `base_dataset_path`. This suggests an issue with the path or the actual presence of the files. To diagnose this, I will first list the immediate contents of the `base_dataset_path` and then retry the recursive `glob.glob` with a print statement to confirm how many files are found. If `glob` still yields no files, I will use `os.walk` as a more robust way to enumerate all `.nii.gz` files and explicitly print their paths to understand the actual directory structure.



In [None]:
import os
import glob

# Set the base path to the top-level directory where the dataset was extracted
base_dataset_path = '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/'

print(f"Listing contents of: {base_dataset_path}")
# List direct contents of the base path to inspect structure
try:
    print(os.listdir(base_dataset_path))
except FileNotFoundError:
    print(f"Error: {base_dataset_path} not found.")

data_dicts = []

# Find all NIfTI files recursively within the base_dataset_path
# This is a re-attempt of the previous glob with a diagnostic print
all_nifti_files = glob.glob(os.path.join(base_dataset_path, '**', '*.nii.gz'), recursive=True)
print(f"Found {len(all_nifti_files)} NIfTI files using recursive glob.")

# If glob still finds nothing, use os.walk for explicit enumeration and inspection
if not all_nifti_files:
    print("Recursive glob found no files. Attempting os.walk to find .nii.gz files.")
    for root, _, files in os.walk(base_dataset_path):
        for file in files:
            if file.endswith('.nii.gz'):
                all_nifti_files.append(os.path.join(root, file))
    print(f"Found {len(all_nifti_files)} NIfTI files using os.walk.")
    if all_nifti_files:
        print("First 5 NIfTI files found (if any):\n", '\n'.join(all_nifti_files[:5]))
    else:
        print("No NIfTI files found even with os.walk. Dataset might be missing or path is wrong.")

# Dictionary to store files for each patient ID
patient_files = {}

for filepath in all_nifti_files:
    filename = os.path.basename(filepath)

    # Attempt to extract patient_id. BraTS filenames are typically BraTS20_XXXXX_modality.nii.gz
    # Or, as observed previously, 'data_modality.nii.gz' if a simplified dataset was provided.
    parts = filename.split('_')
    patient_id_found = False
    if len(parts) >= 2: # At least 'ID_modality.nii.gz'
        if parts[0].startswith('BraTS20'): # Typical BraTS ID
            patient_id = '_'.join(parts[:-1])
            patient_id_found = True
        elif parts[0] == 'data' and len(parts) == 2: # e.g., 'data_t1.nii.gz'
            patient_id = parts[0]
            patient_id_found = True

    if patient_id_found:
        patient_dir = os.path.dirname(filepath)

        if patient_id not in patient_files:
            patient_files[patient_id] = {'dir': patient_dir, 'files': []}
        patient_files[patient_id]['files'].append(filepath)
    else:
        print(f"Warning: Could not extract patient ID from {filename}")

# Now, for each unique patient_id, construct the full data dictionary
for patient_id, info in patient_files.items():
    patient_dir = info['dir']

    image_t1 = os.path.join(patient_dir, f"{patient_id}_t1.nii.gz")
    image_t1gd = os.path.join(patient_dir, f"{patient_id}_t1ce.nii.gz")
    image_t2 = os.path.join(patient_dir, f"{patient_id}_t2.nii.gz")
    image_flair = os.path.join(patient_dir, f"{patient_id}_flair.nii.gz")
    label_seg = os.path.join(patient_dir, f"{patient_id}_seg.nii.gz")

    # Ensure all expected files exist for this patient
    # This filters out patients with incomplete data or incorrectly inferred IDs
    if (os.path.exists(image_t1) and os.path.exists(image_t1gd) and
        os.path.exists(image_t2) and os.path.exists(image_flair) and
        os.path.exists(label_seg)): # All modalities and segmentation mask must exist

        data_dicts.append({
            "image": [image_t1, image_t1gd, image_t2, image_flair],
            "label": label_seg
        })

print(f"Final found {len(data_dicts)} patient data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
# Handle case where data_dicts might be empty to avoid division by zero or IndexError
if len(data_dicts) > 0:
    split_idx = int(len(data_dicts) * 0.8)
    train_files = data_dicts[:split_idx]
    val_files = data_dicts[split_idx:]
else:
    train_files = []
    val_files = []

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# Safely print first entries if lists are not empty
if train_files:
    print("First entry of train_files:", train_files[0])
else:
    print("train_files is empty.")
if val_files:
    print("First entry of val_files:", val_files[0])
else:
    print("val_files is empty.")

Listing contents of: /root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/
['BraTS20 Training Metadata.csv', 'BraTS2020_training_data']
Found 0 NIfTI files using recursive glob.
Recursive glob found no files. Attempting os.walk to find .nii.gz files.
Found 0 NIfTI files using os.walk.
No NIfTI files found even with os.walk. Dataset might be missing or path is wrong.
Final found 0 patient data entries.
Training files: 0
Validation files: 0
train_files is empty.
val_files is empty.


**Reasoning**:
The previous code failed to find any NIfTI files because the `base_dataset_path` was set one level too high. Based on the `os.listdir` output, the NIfTI files are likely located within the `BraTS2020_training_data` subdirectory. I need to correct the `base_dataset_path` to include this directory and then re-execute the file search and organization logic.



In [None]:
import os
import glob

# Corrected base path to the directory containing patient data
base_dataset_path = '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/'

data_dicts = []

print(f"Listing contents of: {base_dataset_path}")
try:
    # Use os.walk to find all .nii.gz files, assuming they could be nested
    all_nifti_files = []
    for root, _, files in os.walk(base_dataset_path):
        for file in files:
            if file.endswith('.nii.gz'):
                all_nifti_files.append(os.path.join(root, file))
    print(f"Found {len(all_nifti_files)} NIfTI files using os.walk.")
    if all_nifti_files:
        print("First 5 NIfTI files found (if any):\n", '\n'.join(all_nifti_files[:5]))
    else:
        print("No NIfTI files found even with os.walk. Dataset might be missing or path is wrong.")

except FileNotFoundError:
    print(f"Error: {base_dataset_path} not found. Please check the dataset path.")
    all_nifti_files = []

# Dictionary to store files for each patient ID
patient_files = {}

for filepath in all_nifti_files:
    filename = os.path.basename(filepath)

    # Extract patient_id from filename (e.g., BraTS20_00001_t1.nii.gz -> BraTS20_00001)
    # The format can be 'BraTS20_XXXXX_modality.nii.gz' or 'data_modality.nii.gz'
    parts = filename.split('_')
    patient_id_found = False
    patient_id = None

    if len(parts) >= 2:
        # Check for typical BraTS ID format
        if parts[0].startswith('BraTS20'):
            patient_id = '_'.join(parts[:-1])
            patient_id_found = True
        # Check for the simplified 'data_modality.nii.gz' format if it's the only patient
        elif parts[0] == 'data' and len(parts) >= 2 and any(mod in filename for mod in ['t1', 't1ce', 't2', 'flair', 'seg']): # Ensure it's a modality file
            patient_id = parts[0]
            patient_id_found = True

    if patient_id_found and patient_id:
        patient_dir = os.path.dirname(filepath)

        if patient_id not in patient_files:
            patient_files[patient_id] = {'dir': patient_dir, 'files': []}
        patient_files[patient_id]['files'].append(filepath)
    else:
        # This warning helps diagnose if patient IDs are not extracted correctly
        print(f"Warning: Could not reliably extract patient ID from filename: {filename}")

# Now, for each unique patient_id, construct the full data dictionary
for patient_id, info in patient_files.items():
    patient_dir = info['dir']

    # Assuming consistent naming for modalities and segmentation mask
    image_t1 = os.path.join(patient_dir, f"{patient_id}_t1.nii.gz")
    image_t1gd = os.path.join(patient_dir, f"{patient_id}_t1ce.nii.gz")
    image_t2 = os.path.join(patient_dir, f"{patient_id}_t2.nii.gz")
    image_flair = os.path.join(patient_dir, f"{patient_id}_flair.nii.gz")
    label_seg = os.path.join(patient_dir, f"{patient_id}_seg.nii.gz")

    # Ensure all expected files exist for this patient before adding to data_dicts
    if (os.path.exists(image_t1) and os.path.exists(image_t1gd) and
        os.path.exists(image_t2) and os.path.exists(image_flair) and
        os.path.exists(label_seg)):

        data_dicts.append({
            "image": [image_t1, image_t1gd, image_t2, image_flair],
            "label": label_seg
        })

print(f"Final found {len(data_dicts)} patient data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
if len(data_dicts) > 0:
    split_idx = int(len(data_dicts) * 0.8)
    train_files = data_dicts[:split_idx]
    val_files = data_dicts[split_idx:]
else:
    train_files = []
    val_files = []

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# Safely print first entries if lists are not empty
if train_files:
    print("First entry of train_files:", train_files[0])
else:
    print("train_files is empty.")
if val_files:
    print("First entry of val_files:", val_files[0])
else:
    print("val_files is empty.")

Listing contents of: /root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/
Found 0 NIfTI files using os.walk.
No NIfTI files found even with os.walk. Dataset might be missing or path is wrong.
Final found 0 patient data entries.
Training files: 0
Validation files: 0
train_files is empty.
val_files is empty.


### Data Format Mismatch

The current dataset at `/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/` does not appear to contain any NIfTI (`.nii.gz`) files, which are typically used for medical image segmentation tasks like BraTS. The `os.walk` command, which performs a thorough recursive search, found no such files.

Upon reviewing the kernel state, it seems there are `.h5` (HDF5) files present (e.g., `volume_159_slice_141.h5`). If the dataset provided is indeed in `.h5` format, the subsequent steps of loading and processing NIfTI files will fail.

To proceed, please clarify the expected file format of the dataset or confirm if a different dataset containing NIfTI files should be used. As the subtask explicitly requires processing NIfTI files, I am unable to continue without this clarification or a corrected dataset.

## Prepare Data Paths (Retry: Verify Dataset Format)

### Subtask:
Re-verify the contents of the dataset path to confirm the absence of NIfTI files, and check for the presence of alternative formats like HDF5 (.h5) files, as the previous attempts to locate NIfTI files failed. If no NIfTI files are found, indicate that the task cannot proceed with the current dataset.


**Reasoning**:
I need to explicitly list the contents of the base dataset path, then recursively search for `.nii.gz` and `.h5` files using `os.walk` to verify the dataset format, as previous attempts to locate NIfTI files failed. This will help determine if the data is present in an unexpected format or if the path is still incorrect.



## Finish task - Data Format Mismatch

### Subtask:
Summarize the critical issue encountered and state why the task cannot proceed as planned.

**Summary of Issue:**

The core data preparation step has failed because the dataset located at `/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/` does not contain the expected NIfTI (`.nii.gz`) files, which are essential for this BraTS segmentation task. Instead, the recursive search revealed that the dataset primarily consists of HDF5 (`.h5`) files.

Due to this fundamental data format mismatch, the current task plan, which assumes the availability and processing of NIfTI files for preprocessing, model training, and evaluation using MONAI's NIfTI-specific functionalities, cannot be executed.

Therefore, the task cannot proceed further as designed until either:
1. A dataset in the correct NIfTI format is provided.
2. The task's requirements are adjusted to process the discovered HDF5 files, which would necessitate significant changes to the data loading and preprocessing pipeline.

## Inspect HDF5 Data Structure

### Subtask:
Investigate the internal structure of the discovered HDF5 (.h5) files to understand how the multi-modal images (T1, T1Gd, T2, T2-FLAIR) and their corresponding segmentation masks are stored within each file.


**Reasoning**:
Since I have confirmed that the dataset contains HDF5 files, I will now inspect the internal structure of a sample HDF5 file using `h5py` to understand how the image modalities and segmentation masks are stored, as this is crucial for data loading and preprocessing.



In [None]:
import h5py

# Select a sample .h5 file from the h5_files list
if h5_files:
    sample_h5_file = h5_files[0]
    print(f"Inspecting sample HDF5 file: {sample_h5_file}")

    try:
        with h5py.File(sample_h5_file, 'r') as f:
            print("\nKeys within the HDF5 file:")
            for key in f.keys():
                print(f" - {key}")
                if isinstance(f[key], h5py.Dataset):
                    print(f"   Shape: {f[key].shape}, Dtype: {f[key].dtype}")
                elif isinstance(f[key], h5py.Group):
                    # If it's a group, print its immediate members
                    print("   Group members:")
                    for sub_key in f[key].keys():
                        print(f"    - {sub_key}")
                        if isinstance(f[key][sub_key], h5py.Dataset):
                            print(f"      Shape: {f[key][sub_key].shape}, Dtype: {f[key][sub_key].dtype}")

            # Attempt to access typical data keys based on common HDF5 medical imaging structures
            # and previous errors indicating 'image' and 'mask' were expected
            print("\nDetailed inspection of potential 'image' and 'mask' keys:")
            if 'image' in f:
                print(f"'image' dataset found. Shape: {f['image'].shape}, Dtype: {f['image'].dtype}")
            else:
                print("No 'image' key found directly.")

            if 'mask' in f:
                print(f"'mask' dataset found. Shape: {f['mask'].shape}, Dtype: {f['mask'].dtype}")
            else:
                print("No 'mask' key found directly.")

    except Exception as e:
        print(f"Error opening or inspecting HDF5 file {sample_h5_file}: {e}")
else:
    print("No HDF5 files found in the 'h5_files' list. Cannot perform inspection.")

Inspecting sample HDF5 file: /root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/volume_237_slice_51.h5

Keys within the HDF5 file:
 - image
   Shape: (240, 240, 4), Dtype: float64
 - mask
   Shape: (240, 240, 3), Dtype: uint8

Detailed inspection of potential 'image' and 'mask' keys:
'image' dataset found. Shape: (240, 240, 4), Dtype: float64
'mask' dataset found. Shape: (240, 240, 3), Dtype: uint8


### Summary of HDF5 Data Structure

The inspection of a sample HDF5 file (`volume_237_slice_51.h5`) revealed that each `.h5` file contains two primary datasets:

1.  **`image`**: This dataset has a shape of `(240, 240, 4)` and a `dtype` of `float64`. This structure indicates that it stores a 2D image slice where the last dimension (size 4) represents the four different MRI modalities (T1, T1Gd, T2, T2-FLAIR) stacked together as channels for a single slice.
2.  **`mask`**: This dataset has a shape of `(240, 240, 3)` and a `dtype` of `uint8`. This structure suggests that it stores the corresponding 2D segmentation mask for the image slice, with the last dimension (size 3) likely representing different tumor subregions (e.g., enhancing tumor, edema, necrotic core) as separate channels.

Each HDF5 file therefore represents a single 2D slice from a 3D MRI volume, containing all four modalities and their associated 3-channel segmentation mask for that slice. This means the dataset is organized on a per-slice basis rather than per-3D-volume basis, and modalities are concatenated along the channel dimension within the 'image' dataset.

## Prepare Data Paths (for HDF5 files)

### Subtask:
Organize the discovered HDF5 (.h5) file paths and split them into training and validation sets, similar to how NIfTI file paths were intended to be prepared.


**Reasoning**:
I need to organize the HDF5 file paths into dictionaries for training and validation, following the instructions to map both 'image' and 'label' to the HDF5 file path, and then split them into train and validation sets.



In [None]:
data_dicts_h5 = []

# Iterate through the previously identified h5_files list
for filepath in h5_files:
    # Create a dictionary for each HDF5 file
    # 'image' and 'label' both point to the same HDF5 file path
    # as this file contains both image and mask data
    data_dicts_h5.append({
        "image": filepath,
        "label": filepath
    })

print(f"Found {len(data_dicts_h5)} HDF5 data entries.")

# Split data into training and validation sets (e.g., 80% train, 20% validation)
# Handle case where data_dicts_h5 might be empty to avoid errors
if len(data_dicts_h5) > 0:
    split_idx = int(len(data_dicts_h5) * 0.8)
    train_files_h5 = data_dicts_h5[:split_idx]
    val_files_h5 = data_dicts_h5[split_idx:]
else:
    train_files_h5 = []
    val_files_h5 = []

print(f"Training files (HDF5): {len(train_files_h5)}")
print(f"Validation files (HDF5): {len(val_files_h5)}")

# Safely print first entries if lists are not empty
if train_files_h5:
    print("First entry of train_files_h5:", train_files_h5[0])
else:
    print("train_files_h5 is empty.")
if val_files_h5:
    print("First entry of val_files_h5:", val_files_h5[0])
else:
    print("val_files_h5 is empty.")

Found 57195 HDF5 data entries.
Training files (HDF5): 45756
Validation files (HDF5): 11439
First entry of train_files_h5: {'image': '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/volume_237_slice_51.h5', 'label': '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/volume_237_slice_51.h5'}
First entry of val_files_h5: {'image': '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/volume_315_slice_65.h5', 'label': '/root/.cache/kagglehub/datasets/awsaf49/brats2020-training-data/versions/3/BraTS2020_training_data/content/data/volume_315_slice_65.h5'}


## Preprocess and Create Data Loaders

### Subtask:
Implement a custom PyTorch `Dataset` and `DataLoader` to handle the BraTS HDF5 files. This will involve reading the multi-modal MRI slices, performing preprocessing steps like intensity normalization, and ensuring channel-first format, and then splitting the data into training and validation sets.


**Reasoning**:
I need to define a custom `LoadHDF5d` transform to correctly load the image and mask data from the HDF5 files as identified in the previous step. This transform will open the HDF5 file, extract the 'image' and 'mask' datasets, and ensure they are converted to the correct data types (float32).



In [None]:
import h5py
from monai.transforms import MapTransform

class LoadHDF5d(MapTransform):
    """Custom transform to load 'image' and 'mask' from an HDF5 file."""

    def __call__(self, data):
        d = dict(data)
        filepath = d["image"] # Both 'image' and 'label' point to the same HDF5 file

        with h5py.File(filepath, 'r') as f:
            # Load image data and convert to float32
            image_data = f['image'][()] # [()] reads the entire dataset into memory
            d["image"] = image_data.astype(np.float32)

            # Load mask data and convert to float32
            mask_data = f['mask'][()]
            d["label"] = mask_data.astype(np.float32) # Use float32 for masks as well for consistency with DiceLoss

        return d

print("Custom LoadHDF5d transform defined.")

Custom LoadHDF5d transform defined.


**Reasoning**:
Now that the custom HDF5 loader is defined, I will define the MONAI transforms for training and validation, and then create the Dataset and DataLoader instances for both, as specified in the instructions.



In [None]:
from monai.transforms import Compose, EnsureChannelFirstd, ScaleIntensityRanged, Orientationd, CropForegroundd
from monai.data import Dataset, DataLoader

# 2. Define train_transforms
train_transforms = Compose(
    [
        LoadHDF5d(keys=["image", "label"]), # Custom transform to load HDF5 data
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=2), # Change (H, W, C) to (C, H, W)
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-500.0,
            a_max=3000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Add other transforms as needed for training (e.g., augmentation, spatial transforms)
        # For BraTS, typical transforms might include SpatialPadd, RandCropByPosNegLabeld, RandFlipd, RandRotated, etc.
        # For now, keeping it simple as per the general instruction for preprocessing.
    ]
)

# 3. Define val_transforms
# For simplicity, using the same transforms as train_transforms, typically excluding augmentations
val_transforms = Compose(
    [
        LoadHDF5d(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-500.0,
            a_max=3000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
    ]
)

# 4. Create train_ds and val_ds
train_ds = Dataset(data=train_files_h5, transform=train_transforms)
val_ds = Dataset(data=val_files_h5, transform=val_transforms)

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

# 5. Create train_loader and val_loader
batch_size = 8 # Can be adjusted based on GPU memory
num_workers = 4 # Can be adjusted based on CPU cores

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Training DataLoader created with batch size {batch_size} and {num_workers} workers.")
print(f"Validation DataLoader created with batch size {batch_size} and {num_workers} workers.")

# Optionally, test fetching a batch
for batch_data in train_loader:
    print(f"Sample train batch image shape: {batch_data['image'].shape}")
    print(f"Sample train batch label shape: {batch_data['label'].shape}")
    break

Training dataset size: 45756
Validation dataset size: 11439
Training DataLoader created with batch size 8 and 4 workers.
Validation DataLoader created with batch size 8 and 4 workers.
Sample train batch image shape: torch.Size([8, 4, 240, 240])
Sample train batch label shape: torch.Size([8, 3, 240, 240])


**Reasoning**:
The previous code failed because `EnsureChannelFirstd` could not determine the channel dimension of the HDF5 data, which is loaded as `(H, W, C)`. I need to explicitly specify `channel_dim=-1` in `EnsureChannelFirstd` to correctly transform the data to `(C, H, W)` format.



In [None]:
from monai.transforms import Compose, EnsureChannelFirstd, ScaleIntensityRanged, Orientationd, CropForegroundd
from monai.data import Dataset, DataLoader

# 2. Define train_transforms
train_transforms = Compose(
    [
        LoadHDF5d(keys=["image", "label"]), # Custom transform to load HDF5 data
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=2), # Change (H, W, C) to (C, H, W), explicitly stating channel_dim
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-500.0,
            a_max=3000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Add other transforms as needed for training (e.g., augmentation, spatial transforms)
        # For BraTS, typical transforms might include SpatialPadd, RandCropByPosNegLabeld, RandFlipd, RandRotated, etc.
        # For now, keeping it simple as per the general instruction for preprocessing.
    ]
)

# 3. Define val_transforms
# For simplicity, using the same transforms as train_transforms, typically excluding augmentations
val_transforms = Compose(
    [
        LoadHDF5d(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-500.0,
            a_max=3000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
    ]
)

# 4. Create train_ds and val_ds
train_ds = Dataset(data=train_files_h5, transform=train_transforms)
val_ds = Dataset(data=val_files_h5, transform=val_transforms)

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

# 5. Create train_loader and val_loader
batch_size = 8 # Can be adjusted based on GPU memory
num_workers = 4 # Can be adjusted based on CPU cores

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Training DataLoader created with batch size {batch_size} and {num_workers} workers.")
print(f"Validation DataLoader created with batch size {batch_size} and {num_workers} workers.")

# Optionally, test fetching a batch
for batch_data in train_loader:
    print(f"Sample train batch image shape: {batch_data['image'].shape}")
    print(f"Sample train batch label shape: {batch_data['label'].shape}")
    break

Training dataset size: 45756
Validation dataset size: 11439
Training DataLoader created with batch size 8 and 4 workers.
Validation DataLoader created with batch size 8 and 4 workers.
Sample train batch image shape: torch.Size([8, 4, 240, 240])
Sample train batch label shape: torch.Size([8, 3, 240, 240])


**Reasoning**:
The previous code block successfully executed, creating the data loaders and verifying the shapes of the loaded batches. The warnings about `num_workers` are advisory and do not require a code fix. Now, I will proceed to the next logical step in the task, which is to define the UNET model as required for brain tumor segmentation.



In [None]:
## Define UNET Model

### Subtask:
Instantiate a UNET model from MONAI's 'monai.networks.nets'
#module with appropriate input channels (4 for multi-modal MRI), output channels (3 for the segmentation mask), and a suitable number of feature maps and layers.

SyntaxError: unterminated string literal (detected at line 4) (ipython-input-3053072020.py, line 4)

**Reasoning**:
The previous attempt to create a markdown cell failed due to incorrect cell type usage. The current task is to instantiate the UNET model in a Python code block. I will define the UNet model using the specified input and output channels based on the HDF5 data structure.



In [None]:
from monai.networks.nets import UNet

# Define the UNET model
# Input channels: 4 (T1, T1Gd, T2, FLAIR)
# Output channels: 3 (different tumor subregions)
# spatial_dims: 2 for 2D slices
# channels: list of feature map numbers for each layer
# strides: list of strides for each layer

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = UNet(
    spatial_dims=2,
    in_channels=4,  # T1, T1Gd, T2, FLAIR
    out_channels=3, # Necrotic/Non-enhancing tumor core, Edema, Enhancing tumor
    channels=(16, 32, 64, 128, 256), # Example feature map sizes
    strides=(2, 2, 2, 2), # Example strides for downsampling
).to(device)

print(f"UNET model created and moved to {device}.")
print(f"Model input channels: {model.in_channels}")
print(f"Model output channels: {model.out_channels}")

## Set Up Training Components

### Subtask:
Configure the essential components for model training: select an appropriate loss function (e.g., DiceLoss or a combination from MONAI), choose an optimizer (e.g., Adam), and set up the device (CPU or GPU) for computation.


**Reasoning**:
The device has already been set up in the previous step. Now I need to define the loss function using `DiceLoss` from `monai.losses` and the optimizer using `torch.optim.Adam` with the model's parameters and a learning rate.



In [None]:
from monai.losses import DiceLoss
import torch.optim as optim

# Device setup was already done in the previous step, ensuring model is on the correct device.
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2. Define the loss function
# Using DiceLoss for segmentation. For multi-label where each channel is a binary prediction,
# set sigmoid=True. to_onehot_y=False because our mask already has 3 channels for different tumor parts.
loss_function = DiceLoss(sigmoid=True, to_onehot_y=False)

# 3. Define the optimizer
# Using Adam optimizer with a learning rate of 1e-4
optimizer = optim.Adam(model.parameters(), 1e-4)

print(f"Loss function defined: {loss_function.__class__.__name__}")
print(f"Optimizer defined: {optimizer.__class__.__name__}")

## Train the UNET Model

### Subtask:
Implement the training loop for the UNET model, including forward and backward passes, loss calculation, weight updates, and periodic validation and checkpointing.


**Reasoning**:
I need to set up the training loop as described, including defining epochs, initializing the DiceMetric, and then creating the main training loop with forward/backward passes, loss calculation, optimization, and a validation phase with metric calculation and checkpointing.



In [None]:
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
import torch
import time
import copy

# 1. Set the number of training epochs
max_epochs = 50 # Reduced for demonstration, can be increased for full training

# 2. Initialize DiceMetric for evaluation
# include_background=False for ignoring background class in Dice calculation
# For multi-class segmentation, typically each class is evaluated separately or averaged.
# to_onehot_y is set to True if the ground truth mask is not one-hot encoded.
# Our HDF5 mask is (H, W, 3) where the last dim might represent different classes, not one-hot
# so we'll treat it as multi-label and apply sigmoid to prediction, DiceLoss assumes this.
# However, DiceMetric expects one-hot for class-wise calculation unless specified otherwise.
# Let's adjust for multi-label binary segmentation per channel.

dice_metric = DiceMetric(include_background=False, reduction="mean_batch")

# Keep track of best metric and best model weights
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

# Create a copy of the model's state_dict to save the best model weights
best_model_wts = copy.deepcopy(model.state_dict())

print("Starting training...")

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train() # Set model to training mode
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        # print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    # Validation phase
    if (epoch + 1) % 5 == 0 or epoch == max_epochs -1: # Validate every 5 epochs or on the last epoch
        model.eval() # Set model to evaluation mode
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                # Perform sliding window inference for larger images if necessary,
                # but for 2D slices, direct inference is fine.
                val_outputs = model(val_inputs)

                # Apply sigmoid to outputs for DiceMetric as per DiceLoss setup
                val_outputs = torch.sigmoid(val_outputs)

                # DiceMetric expects one-hot encoding for ground truth if include_background is False and reduction is per_batch
                # Given our labels are (B, C, H, W) where C=3, and each channel is a binary mask for a class,
                # we should treat val_labels as already one-hot-like or multi-label binary.
                # DiceMetric's default behavior for reduction="mean_batch" with C>1 is to calculate for each class and average
                # it expects predictions to be a probability map (after sigmoid) and labels to be one-hot. Here labels are 3-channel masks.
                # For MONAI DiceMetric with multi-channel binary labels (0 or 1 per channel),
                # if to_onehot_y is False, it will assume labels are already in the correct format (e.g. one-hot-like or directly comparable to logits).
                # Since our mask is float32 (0 or 1), it matches. However, our mask is (B, 3, H, W), representing 3 independent binary segments
                # which means we don't need to_onehot_y=True here. It will calculate Dice for each channel and average them.
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            dice_metric.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, "best_metric_model.pth")
                print(f"saved new best metric model at epoch {best_metric_epoch} with dice: {best_metric:.4f}")
            print(f"Current epoch {epoch + 1} validation dice: {metric:.4f}")
            print(f"Best validation dice: {best_metric:.4f} at epoch {best_metric_epoch}")

print(f"Training completed. Best_metric: {best_metric:.4f} at epoch {best_metric_epoch}")

# Load the best model weights
model.load_state_dict(best_model_wts)