 # Environment Setup

 **Kaggle Environment Requirements**:
 - Python 3.11.11
 - Key Packages:
   - tensorflow==2.18.0
   - antspynet==0.3.0 (installed from GitHub)
   - nibabel==5.3.2
   - plotly==5.24.1
   - scipy==1.15.2
   - scikit-learn
   - statsmodels

 **Installation Command**:
 ```
 %pip install git+https://github.com/ValV/ANTsPyNet.git
 %pip install tensorflow==2.18.0 nibabel plotly scipy scikit-learn statsmodels
 ```

 # Dependencies

 Install required packages including ANTsPyNet from GitHub

In [None]:
%%capture
%pip install git+https://github.com/ValV/ANTsPyNet.git

 # Data Processing

 ## Utility Functions

 Functions for processing MRI data and extracting surface coordinates

In [None]:
from glob import glob
from os import path as osp

import nibabel as nib
import numpy as np

from scipy.ndimage import binary_erosion


def get_surface_coords(mask: np.ndarray) -> np.ndarray:
    """
    Returns coordinates of surface voxels of a binary 3D mask

    Args:
        mask: 3D binary numpy array representing a brain mask

    Returns:
        coords: Array of (x,y,z) coordinates for surface voxels
    """
    eroded = binary_erosion(mask)
    surface = mask ^ eroded  # XOR operation to get surface voxels
    coords = np.argwhere(surface)
    return coords


def downsample(coords: np.ndarray, factor: int = 16) -> np.ndarray:
    """
    Downsample coordinate array by specified factor

    Args:
        coords: Array of 3D coordinates
        factor: Downsampling factor (keep every nth point)

    Returns:
        Downsampled coordinate array
    """
    return coords[::factor]

 ## Data Loading

 Load sample MRI scans from CONTROLS and PATIENTS directories

In [None]:
PATH_PREFIX = osp.join('/', 'kaggle', 'input', 'brainsearch-classification')
index = 4

# Load MRI scans
path_controls = glob(osp.join(PATH_PREFIX, 'CONTROLS', '*.nii'))
path_patients = glob(osp.join(PATH_PREFIX, 'PATIENTS', '*.nii'))

source_controls = nib.load(path_controls[index])
source_patients = nib.load(path_patients[index])

# Get image data as numpy arrays
image_control = source_controls.get_fdata()
image_patient = source_patients.get_fdata()

# Extract surface coordinates
coords_control = get_surface_coords(image_control > 0)
coords_patient = get_surface_coords(image_patient > 0)

 ## Interactive 3D Visualization

 Visualize brain surface using Plotly (requires Plotly installation)

In [None]:
import plotly.graph_objects as go


def scatter3d(coords: np.ndarray, color: str, name: str) -> go.Scatter3d:
    """
    Create 3D scatter plot for brain surface visualization

    Args:
        coords: Array of 3D coordinates
        color: Color for the points
        name: Name for the plot legend

    Returns:
        Plotly Scatter3d object
    """
    return go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers',
        marker=dict(size=2, color=color),
        name=name,
        hoverinfo='skip',
    )


# Create visualization figure
figure = go.Figure(
    data=[scatter3d(downsample(coords_control, 24), 'gray', 'Control')]
)

figure.update_layout(
    title='3D Scatter Plot of a Brain (Surface Voxels)',
    scene=dict(
        xaxis=dict(visible=False, showbackground=False),
        yaxis=dict(visible=False, showbackground=False),
        zaxis=dict(visible=False, showbackground=False),
    ),
    margin=dict(l=0, r=0, b=0, t=40),
)

figure.show()

 ## Multi-view Visualization

 Show sagittal, axial, and coronal views using Matplotlib

In [None]:
import matplotlib.pyplot as plt


def plot_angle(ax, coords: np.ndarray, color: str, label: str):
    """
    Plot brain surface coordinates in 3D from specific angle

    Args:
        ax: Matplotlib axis object
        coords: Array of 3D coordinates
        color: Color for the points
        label: Label for the legend
    """
    ax.scatter(*(coords).T[::-1], s=1, c=color, label=label)


# Create multi-view visualization
fig = plt.figure(figsize=(18, 6))
angles = ((90, -180), (0, 0), (0, 90))  # (elevation, azimuth)
views = ('Sagittal', 'Axial', 'Coronal')

for i, (elevation, azimuth) in enumerate(angles):
    ax = fig.add_subplot(1, 3, i + 1, projection='3d')
    plot_angle(ax, downsample(coords_patient, 12), 'yellow', 'Patient')
    plot_angle(ax, downsample(coords_control, 24), 'gray', 'Control')
    ax.view_init(elev=elevation, azim=azimuth)
    ax.set_title(views[i])
    ax.set_xlim(0, 200)
    ax.set_ylim(0, 200)
    ax.set_zlim(0, 150)
    ax.axis('off')

plt.tight_layout()
plt.show()

 # Dataset Preparation

 Custom dataset class for loading and preprocessing MRI data

In [None]:
import random
import tensorflow as tf
import ants
from urllib.request import urlretrieve


AUTOTUNE = tf.data.AUTOTUNE
TARGET_ORIENTATION = 'RAS'
TEMPLATE_PATH = './MNI152_T1_1mm.nii.gz'
TEMPLATE_URL = 'https://github.com/Jfortin1/MNITemplate/raw/master/inst/extdata/MNI152_T1_1mm.nii.gz'


class BrainMRIDataset:
    """
    Custom dataset for loading and preprocessing brain MRI scans

    Key Features:
    - Handles both control and patient data
    - Performs spatial normalization to MNI152 template
    - Supports template-based registration
    - Generates TensorFlow datasets

    Args:
        path_prefix: Root directory containing 'CONTROLS' and 'PATIENTS' folders
        use_template: Whether to use template-based normalization (bool or path to custom template)
        shuffle: Whether to shuffle samples
    """

    def __init__(
        self,
        path_prefix: str,
        use_template: bool | str = False,
        shuffle: bool = False,
    ):
        self.shapes = set()
        self.samples = []
        self.label_map = {}

        # Load MNI152 template if needed
        if not osp.isfile(TEMPLATE_PATH):
            urlretrieve(TEMPLATE_URL, TEMPLATE_PATH)
        self.template = ants.image_read(TEMPLATE_PATH)

        # Configure normalization
        self.normalize = False
        if isinstance(use_template, bool):
            self.normalize = use_template
        elif osp.isfile(use_template):
            self.template = ants.image_read(use_template)
            self.normalize = True

        # Load and preprocess data
        data = []
        for idx, label in enumerate(sorted(ls(path_prefix))):
            if not osp.isdir(osp.join(path_prefix, label)):
                continue
            path_class = osp.join(path_prefix, label)
            self.label_map[idx] = label
            data.append((path_class, idx))

        # Process all MRI files
        c = 0
        for path_class, idx in data:
            for i, nii in enumerate(glob(osp.join(path_class, '*.nii'))):
                array, label = self._preprocess(nii, idx)
                self.shapes.add(array.shape)
                self.samples.append((array, label))
                print(f"{i + 1:03d}:{c + 1:03d} Loaded '{nii}'")
                c += 1

        # Handle dataset shuffling
        if shuffle:
            random.shuffle(self.samples)

        self.num_classes = len(self.label_map)

    def __len__(self):
        return len(self.samples)

    def _preprocess(self, path: str, label: int) -> tuple:
        """
        Preprocess MRI scan:
        1. Reorient to RAS coordinate system
        2. Perform template-based normalization (if enabled)
        3. Intensity normalization (z-score)

        Args:
            path: Path to .nii file
            label: Class label (0=control, 1=patient)

        Returns:
            tuple: (preprocessed_array, label)
        """
        # Load NIfTI file
        img = nib.load(path)
        data = img.get_fdata()

        # Reorient to RAS coordinate system
        orient = nib.orientations.io_orientation(img.affine)
        transform = nib.orientations.ornt_transform(
            orient, nib.orientations.axcodes2ornt(TARGET_ORIENTATION)
        )
        data = nib.orientations.apply_orientation(data, transform)
        affine = img.affine.dot(
            nib.orientations.inv_ornt_aff(transform, img.shape)
        )

        # ANTs-based processing
        itk = ants.from_numpy(data, spacing=img.header.get_zooms())
        itk.set_origin(list(affine[:3, 3]))

        # Template-based normalization
        if self.normalize and self.template is not None:
            res = ants.registration(
                fixed=self.template, moving=itk, type_of_transform='Affine'
            )
            warped = ants.resample_image(
                res['warpedmovout'],
                self.template.shape,
                use_voxels=True,
                interp_type=0,
            )
            arr = warped.numpy()
        else:
            arr = itk.numpy()

        # Intensity normalization (z-score)
        arr = (arr - arr.mean()) / (arr.std() + 1e-6)
        arr = np.expand_dims(
            arr, axis=-1
        )  # Add channel dimension: (D, H, W, 1)

        return arr, label

    def __getitem__(self, index: int) -> tuple:
        return self.samples[index]

    def generator(self):
        """Generator function for tf.data API"""
        for i in range(len(self)):
            yield self[i]


def build_dataset(
    path_root: str,
    batch_size: int = 2,
    use_template: bool = False,
    shuffle: bool = True,
) -> tuple:
    """
    Build TensorFlow dataset from brain MRI data

    Args:
        path_root: Root directory containing MRI data
        batch_size: Batch size for training
        use_template: Whether to use template normalization
        shuffle: Whether to shuffle data

    Returns:
        tuple: (tf.data.Dataset, BrainMRIDataset instance)
    """
    brains = BrainMRIDataset(
        path_root, use_template=use_template, shuffle=shuffle
    )
    dataset = tf.data.Dataset.from_generator(
        brains.generator,
        output_signature=(
            tf.TensorSpec(shape=brains.shape, dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
        ),
    )
    if shuffle:
        dataset = dataset.shuffle(buffer_size=64)
    dataset = dataset.batch(batch_size).prefetch(AUTOTUNE)
    return dataset, brains

 # Model Definition

 Create 3D ResNet model using ANTsPyNet

In [None]:
import antspynet


def build_model(
    num_classes: int,
    shape: tuple = (128, 128, 128, 1),
    learning_rate: float = 1e-4,
) -> tf.keras.Model:
    """
    Build 3D ResNet model for classification

    Architecture Features:
    - 3D convolutional neural network
    - Residual connections
    - Squeeze-and-excitation blocks
    - Output layer with softmax activation

    Args:
        num_classes: Number of output classes (2 for control/patient)
        shape: Input volume shape (depth, height, width, channels)
        learning_rate: Learning rate for Adam optimizer

    Returns:
        Compiled Keras model
    """
    model = antspynet.create_resnet_model_3d(
        input_image_size=shape,
        number_of_outputs=num_classes,
        layers=(2, 2, 2, 2),  # Number of blocks in each stage
        lowest_resolution=32,  # Number of filters in first convolution
        mode='classification',
        squeeze_and_excite=True,  # Use channel attention mechanism
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )
    return model

 # Training Setup

 Configure training parameters and data pipeline

In [None]:
from datetime import datetime


# Configuration
root = osp.join('/', 'kaggle', 'input', 'brainsearch-classification')
batch_size = 2
epochs = 20

# Create log directory
path_logs = osp.join('runs', datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
callback_tensorboard = tf.keras.callbacks.TensorBoard(
    log_dir=path_logs,
    update_freq='epoch',
)
writer_log = tf.summary.create_file_writer(path_logs)

# Build dataset
tfds, ds = build_dataset(
    root, batch_size=batch_size, use_template=False, shuffle=True
)

# Train/validation split (80/20)
index_split = int(0.8 * (len(ds) // batch_size))
dataset_train = tfds.take(index_split).prefetch(AUTOTUNE)
dataset_valid = tfds.skip(index_split).prefetch(AUTOTUNE)

# Initialize model
model = build_model(ds.num_classes, shape=ds.shape, learning_rate=5e-5)
model.summary()

 # Training Execution

 Train the 3D ResNet model

In [None]:
model.fit(
    dataset_train.repeat(),
    validation_data=dataset_valid.repeat(),
    epochs=epochs,
    steps_per_epoch=18,  # Number of batches per epoch
    validation_steps=18,  # Validation batches per epoch
    callbacks=[callback_tensorboard],
)

# Save trained model
model.save(osp.join(path_logs, 'antspynet_3d_classifier.h5'))

 # Model Interpretation

 Visualize model decisions using Grad-CAM

In [None]:
import io
from tensorflow.keras.models import Model


def make_gradcam_heatmap(
    volume: tf.Tensor,
    model: tf.keras.Model,
    last_conv_layer_name: str,
    pred_index: int = None,
) -> np.ndarray:
    """
    Generate Grad-CAM heatmap for 3D volume

    Args:
        volume: Input volume (batch_size, D, H, W, C)
        model: Trained Keras model
        last_conv_layer_name: Name of last convolutional layer
        pred_index: Class index to generate heatmap for (default: predicted class)

    Returns:
        heatmap: 3D activation heatmap
    """
    # Create gradient model
    grad_model = Model(
        [model.inputs],
        [model.get_layer(last_conv_layer_name).output, model.output],
    )

    # Compute gradients
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(volume)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]

    # Compute guided gradients
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2, 3))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # Normalize heatmap
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap + 1e-8)
    heatmap = tf.image.resize(
        heatmap[..., tf.newaxis], (volume.shape[1], volume.shape[2])
    )
    return heatmap.numpy()


def display_heatmap(
    volume: np.ndarray,
    heatmap: np.ndarray,
    alpha: float = 0.5,
    slice_axis: int = -1,
    slice_index: int = None,
):
    """
    Display heatmap overlay on MRI slice

    Args:
        volume: Original MRI volume
        heatmap: Activation heatmap
        alpha: Heatmap opacity
        slice_axis: Axis to slice along (0=axial, 1=coronal, 2=sagittal)
        slice_index: Slice index to display
    """
    if volume.ndim == 4:
        volume = volume[..., 0]
    if heatmap.ndim == 4:
        heatmap = heatmap[..., 0]

    # Default to middle slice
    if slice_index is None:
        slice_index = volume.shape[slice_axis] // 2

    # Extract slice
    if slice_axis == 0:
        v_slice = volume[slice_index, :, :]
        h_slice = heatmap[slice_index, :, :]
    elif slice_axis == 1:
        v_slice = volume[:, slice_index, :]
        h_slice = heatmap[:, slice_index, :]
    else:
        v_slice = volume[:, :, slice_index]
        h_slice = heatmap[:, :, slice_index]

    # Create visualization
    plt.figure(figsize=(6, 6))
    plt.imshow(v_slice, cmap='gray', interpolation='none')
    plt.imshow(h_slice, cmap='jet', alpha=alpha, interpolation='none')
    plt.title(f"Grad-CAM Overlay (axis={slice_axis}, index={slice_index})")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

 ## Generate and Visualize Heatmaps

In [None]:
# Identify last convolutional layer
names_conv3d = [
    l.name for l in model.layers if isinstance(l, tf.keras.layers.Conv3D)
]
print("Convolutional layers:", names_conv3d)

# Generate heatmap for validation sample
for sample in dataset_valid.take(1):
    volume, label = sample
    volume = tf.tile(
        volume[:1], [2, 1, 1, 1, 1]
    )  # Duplicate for batch processing
    heatmap = make_gradcam_heatmap(
        volume, model, last_conv_layer_name=names_conv3d[-1]
    )

    # Display 2D slice with heatmap
    idx_slice = volume.shape[1] // 2
    mid_slice = volume[0, idx_slice, :, :, :].numpy()
    display_heatmap(mid_slice, heatmap[..., 0], slice_axis=1)

 ## Interactive 3D Heatmap Visualization

 Requires Plotly

In [None]:
def render_heatmap_3d(
    volume: np.ndarray,
    heatmap: np.ndarray,
    threshold: float = 0.6,
    downsample_factor: int = 8,
):
    """
    Render 3D heatmap using Plotly

    Args:
        volume: Original MRI volume
        heatmap: 3D activation heatmap
        threshold: Minimum activation value to display
        downsample_factor: Downsampling factor for visualization
    """
    # Preprocess volume
    if volume.ndim == 5:
        volume = volume[0, :, :, :, 0]
    volume = np.rot90(volume, 1, axes=(1, 2))
    coords = downsample(get_surface_coords(volume > 0), 60)

    # Create brain surface plot
    surface = go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers',
        marker=dict(size=2, color='gray'),
        name='Brain',
        hoverinfo='skip',
    )

    # Prepare heatmap data
    if heatmap.ndim == 4:
        heatmap = heatmap[..., 0]
    pad = np.zeros_like(volume)
    indices = np.linspace(0, volume.shape[0] - 1, heatmap.shape[0], dtype=int)
    pad[indices, :, :] = heatmap
    coords = np.argwhere(pad >= threshold)
    if downsample_factor > 1:
        coords = downsample(coords, downsample_factor)
    values = pad[tuple(coords.T)]

    # Create heatmap plot
    heatmaps = go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers',
        marker=dict(size=2, color=values, colorscale='jet', opacity=0.5),
        name='Activation',
        hoverinfo='skip',
    )

    # Configure layout
    layout = go.Layout(
        title='3D Heatmap (Grad-CAM)',
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
        ),
        margin=dict(l=0, r=0, b=0, t=40),
    )

    fig = go.Figure(data=[heatmaps, surface], layout=layout)
    fig.show()


# Visualize 3D heatmap
render_heatmap_3d(volume.numpy(), heatmap)

 # Training Metrics Visualization

 Load and display training logs

In [None]:
from tensorboard.backend.event_processing.event_accumulator import (
    EventAccumulator,
)


def display_logs(logdir: str = 'runs'):
    """
    Display training metrics from TensorBoard logs

    Args:
        logdir: Root directory containing log files
    """

    def extract_scalar_from_tensor(tensor_proto):
        t = tf.make_ndarray(tensor_proto)
        return float(t) if t.size == 1 else None

    def extract_events(path):
        files = glob(osp.join(path, '*tfevents*'))
        if not files:
            return None
        event = EventAccumulator(files[-1])
        event.Reload()
        events = {}

        # Process scalar events
        for tag in event.Tags().get('scalars', []):
            events[tag] = [(x.step, x.value) for x in event.Scalars(tag)]

        # Process tensor events
        for tag in event.Tags().get('tensors', []):
            if not tag.startswith('epoch'):
                continue
            entries = event.Tensors(tag)
            values = []
            for e in entries:
                val = extract_scalar_from_tensor(e.tensor_proto)
                if val is not None:
                    values.append((e.step, val))
            if values:
                events[tag] = values
        return events

    # Process all subdirectories
    for subdir in ['train', 'validation', '.']:
        path = osp.join(logdir, subdir)
        if not osp.isdir(path):
            continue

        logs = extract_events(path)
        if not logs:
            continue

        print(f"\n{subdir.upper()} LOGS")
        plt.figure(figsize=(10, 6))

        # Plot each metric
        for tag, points in logs.items():
            steps, values = zip(*points)
            plt.plot(steps, values, label=tag)

        plt.title(f"{subdir.title()} Metrics")
        plt.xlabel('Epoch')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


# Display training logs
display_logs(path_logs)

# Miscellaneous

Print out all loaded Python modules and their dependencies (for environment pinning)

In [None]:
import sys
import pkg_resources
import subprocess
from collections import defaultdict


# Get a list of all imported modules
imported_modules = {
    name: module
    for name, module in sys.modules.items()
    if module and getattr(module, '__file__', None)
}

# Get the version and dependencies of each imported module
package_info = defaultdict(dict)

for name in imported_modules:
    try:
        # Skip non-package modules
        if name.startswith('_') or name in sys.builtin_module_names:
            continue

        dist = pkg_resources.get_distribution(name)
        version = dist.version
        package_info[name]['version'] = version

        # Get dependencies using pip
        result = subprocess.run(
            ['pip', 'show', name], capture_output=True, text=True
        )
        dependencies = []
        for line in result.stdout.split('\n'):
            if line.startswith('Requires: '):
                dependencies = [
                    d.strip()
                    for d in line.split('Requires: ')[1].split(',')
                    if d.strip()
                ]
                break

        package_info[name]['dependencies'] = dependencies
    except (pkg_resources.DistributionNotFound, subprocess.CalledProcessError):
        pass

# Print the package information
for name, info in package_info.items():
    print(f"Package: {name}")
    print(f"Version: {info.get('version', 'N/A')}")
    print(f"Dependencies: {', '.join(info.get('dependencies', []))}")
    print()