# Install and Import Required Libraries


In [None]:
%pip install tensorflow nibabel dicom2nifti nilearn matplotlib numpy antspyx

In [None]:
import os
import glob
import time
import random
import shutil
import subprocess
import dicom2nifti
import ants
import nibabel as nib
import numpy as np
import nilearn.plotting as plotting
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split
from scipy.ndimage import zoom

# Utility functions


In [None]:
def get_nii_files(base_dir, prefix=None):
    """Retrieve all NIfTI file paths in the directory, optionally filtering by prefix."""
    found = []

    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".nii.gz"):
                file_path = os.path.join(root, file)
                if prefix:
                    if file.startswith(prefix):
                        found.append(file_path)
                else:
                    found.append(file_path)

    return found


def display_nii_stats(base_dir):
    """Display statistics about NIfTI files in the directory"""
    nii_files = get_nii_files(base_dir)
    prefix_counts = {}

    for file in nii_files:
        filename = os.path.basename(file)
        prefix = filename.split("_")[0]
        if prefix in prefix_counts:
            prefix_counts[prefix] += 1
        else:
            prefix_counts[prefix] = 1

    total_files = len(nii_files)

    # Display prefix counts in a pandas grid
    prefix_data = {
        "Prefix": list(prefix_counts.keys()),
        "Count": list(prefix_counts.values()),
    }
    prefix_df = pd.DataFrame(prefix_data)
    prefix_df = prefix_df.sort_values(by="Count", ascending=False)
    prefix_df.loc["Total"] = prefix_df.sum(numeric_only=True)
    print("\nGrid Display for Prefix Counts:")
    display(prefix_df)


def display_comprehensive_stats(base_dir, prefix=""):
    """Display comprehensive statistics about all NIfTI files starting with the given prefix."""
    display_nii_stats(base_dir)

    print(f"Analysing files with prefix '{prefix}'")
    nii_files = get_nii_files(base_dir, prefix=prefix)

    if not nii_files:
        if prefix:
            print(f"No files found with prefix '{prefix}'.")
        else:
            print("No files found.")
        return

    # Plot the first NIfTI file
    first_img = nib.load(nii_files[0])
    plotting.plot_anat(first_img, title=f"Displaying: {nii_files[0]}")
    plotting.show()

    total_files = len(nii_files)
    dimensions = []
    voxel_sizes = []
    orientations = []

    for file in nii_files:
        try:
            img = nib.load(file)
            dimensions.append(img.shape)
            voxel_sizes.append(img.header.get_zooms())
            orientations.append(nib.aff2axcodes(img.affine))
        except Exception as e:
            print(f"Error reading {file}: {e}")

    unique_dimensions = {dim: dimensions.count(dim) for dim in set(dimensions)}
    unique_voxel_sizes = {size: voxel_sizes.count(size) for size in set(voxel_sizes)}
    unique_orientations = {
        orient: orientations.count(orient) for orient in set(orientations)
    }

    # Display in grids
    dim_data = {
        "Dimension": list(unique_dimensions.keys()),
        "Frequency": list(unique_dimensions.values()),
    }
    dim_df = pd.DataFrame(dim_data)
    dim_df["Percentage"] = (dim_df["Frequency"] / total_files * 100).round(1)
    dim_df = dim_df.sort_values(by="Frequency", ascending=False)
    print("\nGrid Display for Dimensions:")
    display(dim_df)

    voxel_data = {
        "Voxel Size": list(unique_voxel_sizes.keys()),
        "Frequency": list(unique_voxel_sizes.values()),
    }
    voxel_df = pd.DataFrame(voxel_data)
    voxel_df["Percentage"] = (voxel_df["Frequency"] / total_files * 100).round(1)
    voxel_df = voxel_df.sort_values(by="Frequency", ascending=False)
    print("\nGrid Display for Voxel Sizes:")
    display(voxel_df)

    orient_data = {
        "Orientation": list(unique_orientations.keys()),
        "Frequency": list(unique_orientations.values()),
    }
    orient_df = pd.DataFrame(orient_data)
    orient_df["Percentage"] = (orient_df["Frequency"] / total_files * 100).round(1)
    orient_df = orient_df.sort_values(by="Frequency", ascending=False)
    print("\nGrid Display for Orientations:")
    display(orient_df)

# Convert dicom to NII


In [None]:
def convert_dicom_to_nifti(base_dir):
    """Converts DICOM files to NIfTI while preserving folder structure."""
    # Pre-scan directories to convert
    dirs_to_convert = []
    for root, _, files in os.walk(base_dir):
        if any(f.endswith(".dcm") for f in files):
            dirs_to_convert.append((root, files))

    # Process directories with a progress bar
    for root, files in tqdm(
        dirs_to_convert, desc="Converting DICOM to NIfTI", unit="folder"
    ):
        nii_output_dir = root  # Save in the same directory as DICOMs
        nii_output_path = os.path.join(nii_output_dir, "scan.nii.gz")

        if not os.path.exists(nii_output_path):  # Avoid redundant conversion
            try:
                dicom2nifti.convert_directory(
                    root, nii_output_dir, compression=True, reorient=True
                )
                print(f"\nConverted: {root} -> {nii_output_path}")

                # Remove DICOM files after conversion
                for file in files:
                    os.remove(os.path.join(root, file))
            except Exception as e:
                print(f"\nFailed to convert {root}: {e}")


convert_dicom_to_nifti("./data/adni-2-4")

# Split and Name


## rename to image id


In [None]:
# Change these paths as needed
base_dir = "./data/adni-2-4"

for root, dirs, files in os.walk(base_dir):
    for file in files:
        if file.endswith(".nii.gz"):
            # The immediate directory of the file
            parent_dir = os.path.basename(root)
            source_path = os.path.join(root, file)

            # Rename the file to parent's name (.nii.gz)
            target_filename = f"{parent_dir}.nii.gz"
            target_path = os.path.join(base_dir, target_filename)

            print(f"Moving {source_path} to {target_path}")
            shutil.move(source_path, target_path)

## prepend subject id


In [None]:
def prepend_subject_id_single_dir(data_dir, xml_dir):
    # Create mapping from scan ID to subject ID
    scan_to_subject = {}

    # Parse XML filenames to extract subject ID and scan ID
    print("Creating scan-to-subject mapping...")
    for xml_file in os.listdir(xml_dir):
        if xml_file.endswith(".xml"):
            # From format like "ADNI_013_S_0575_MPRAGE_S28210_I44926.xml"
            # Extract the 013_S_0575 (subject ID) and I44926 (scan ID)
            parts = xml_file.split("_")
            if len(parts) >= 6 and parts[-1].startswith("I"):
                scan_id = parts[-1].split(".")[0]  # e.g., "I44926"
                subject_id = "_".join(parts[1:4])  # e.g., "013_S_0575"
                scan_to_subject[scan_id] = subject_id

    print(f"Found {len(scan_to_subject)} scan-to-subject mappings")

    # Process all nii.gz files in the provided data_dir
    for file_path in glob.glob(os.path.join(data_dir, "*.nii.gz")):
        file_name = os.path.basename(file_path)
        scan_id = file_name.split(".")[0]  # e.g., "I44926"

        if scan_id in scan_to_subject:
            subject_id = scan_to_subject[scan_id]
            new_name = f"{subject_id}_{file_name}"
            new_path = os.path.join(data_dir, new_name)
            print(f"Renaming {file_name} to {new_name}")
            os.rename(file_path, new_path)
        else:
            print(f"Could not find subject ID for scan {scan_id}")

    print("File renaming complete!")


# Example usage:
prepend_subject_id_single_dir("./data/adni-2-4", "./data/metadata")

## Split by research group


In [None]:
def get_research_group(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for elem in root.iter("researchGroup"):
            return elem.text
    except Exception as e:
        print(f"Error parsing {xml_path}: {e}")
    return None


def restructure_dataset(data_dir, metadata_dir, output_dir):
    # Create output subdirectories for AD and CN
    for group in ["AD", "CN"]:
        os.makedirs(os.path.join(output_dir, group), exist_ok=True)

    # Process each nii.gz file in the data_dir
    print("Processing nii.gz files...")
    for file_path in glob.glob(os.path.join(data_dir, "*.nii.gz")):
        file_name = os.path.basename(file_path)

        # Assume filename format: subjectID_imageID.nii.gz; extract imageID
        parts = file_name.split("_")
        # If the image ID is at the end (removing extension)
        image_id = parts[-1].split(".")[0] if parts else file_name.split(".")[0]

        # Locate corresponding XML file in the metadata_dir
        xml_file = None
        for candidate in os.listdir(metadata_dir):
            if candidate.endswith(f"{image_id}.xml"):
                xml_file = candidate
                break

        if xml_file:
            xml_path = os.path.join(metadata_dir, xml_file)
            group = get_research_group(xml_path)
            if group in {"AD", "CN"}:
                dest_path = os.path.join(output_dir, group, file_name)
                print(f"Copying {file_name} to {group} folder")
                shutil.copy(file_path, dest_path)
            else:
                print(f"Research group for {file_name} is invalid: {group}")
        else:
            print(f"No XML found for {file_name}")

    print("Restructuring complete!")


# Example usage:
restructure_dataset("./data/adni-2-4", "./data/metadata", "./data/adni-2-4-cond")

## Filter files by subject spread

I used this to get a certain number of files from the CN folder because I wanted to make the split 50/50 which meant only picking 107 files, but I wanted them to come from as wide an arrray of subjects as possible


In [None]:
def filter_files_wide_spread(data_dir, output_dir, count=100):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Get list of nii.gz files in data_dir
    files = glob.glob(os.path.join(data_dir, "*.nii.gz"))

    # Group files by subject id (subject id is everything up to the last underscore)
    subject_files = {}
    for file_path in files:
        filename = os.path.basename(file_path)
        parts = filename.split("_")
        if len(parts) < 2:
            continue  # Skip files that do not follow the expected naming format.
        # Subject id is defined as everything until the last underscore.
        subject_id = "_".join(parts[:-1])
        subject_files.setdefault(subject_id, []).append(file_path)

    selected_files = []
    subjects = list(subject_files.keys())

    # Case 1: We have at least "count" distinct subjects.
    if len(subjects) >= count:
        # Randomly choose count subjects and from each choose one random file.
        chosen_subjects = random.sample(subjects, count)
        for subj in chosen_subjects:
            chosen_file = random.choice(subject_files[subj])
            selected_files.append(chosen_file)
    else:
        # Case 2: fewer than count subjects.
        # First, select one file from each subject.
        for subj in subjects:
            chosen_file = random.choice(subject_files[subj])
            selected_files.append(chosen_file)

        # Now fill up remainder from subjects that have extra files.
        remaining = count - len(selected_files)
        # Prepare a list of iterators for subjects with more than one file.
        extra_files = []
        for subj in subjects:
            # Add extra files (exclude the one already used)
            files_for_subj = subject_files[subj][:]
            if len(files_for_subj) > 1:
                # Remove the file already selected (if present)
                file_already = selected_files.pop(
                    0
                )  # this line ensures we don't accidentally select the same file by accident.
                files_for_subj = [f for f in subject_files[subj] if f != file_already]
                selected_files.insert(0, file_already)
                extra_files.append(files_for_subj)
            else:
                extra_files.append([])

        # Flatten extra_files while preserving the order of subjects
        # Use round-robin style selection.
        added = 0
        while added < remaining:
            any_added = False
            for subj, files_list in zip(subjects, extra_files):
                if files_list:
                    selected_files.append(files_list.pop(0))
                    added += 1
                    any_added = True
                    if added >= remaining:
                        break
            if not any_added:
                # No more extra files available.
                break

    # Copy the selected files to the output directory
    for file_path in selected_files:
        file_name = os.path.basename(file_path)
        dest_path = os.path.join(output_dir, file_name)
        print(f"Copying {file_name} to {output_dir}")
        shutil.copy(file_path, dest_path)

    print("File filtering complete!")


# Example usage:
filter_files_wide_spread(
    "./data/adni-2-4-cond/CN", "./data/adni-2-4-cond/CN-filtered", count=107
)

# Skull Stripping


In [None]:
def run_synthstrip(freesurfer_home, input_path, ss_output_path):
    """Runs SynthStrip on a single NIfTI file."""
    if os.path.exists(ss_output_path):  # Avoid redundant processing
        print(f"Skipping {input_path}, output already exists.")
        return

    try:
        env = os.environ.copy()
        env["FREESURFER_HOME"] = freesurfer_home
        env["SUBJECTS_DIR"] = os.path.join(freesurfer_home, "subjects")

        command = [
            "/bin/bash",
            "-c",  # Use bash explicitly
            f"source {freesurfer_home}/SetUpFreeSurfer.sh && "
            f"mri_synthstrip -i {input_path} -o {ss_output_path}",
        ]

        start_time = time.time()
        subprocess.run(command, check=True, env=env)
        elapsed_time = time.time() - start_time

        print(
            f"✔ Processed: {input_path} -> {ss_output_path} (Time: {elapsed_time:.2f}s)"
        )

        # Delete the original file after successful processing
        os.remove(input_path)
        print(f"🗑️ Deleted original file: {input_path}")

    except subprocess.CalledProcessError as e:
        print(f"❌ Failed to process {input_path}: {e}")


def skull_strip_nifti(base_dir, freesurfer_home="/Applications/freesurfer/7.4.1"):
    """Runs SynthStrip on NIfTI files sequentially while preserving folder structure."""
    tasks = []

    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".nii.gz") and not file.startswith("ss_"):
                input_path = os.path.join(root, file)
                ss_output_path = os.path.join(root, "ss_" + file)
                tasks.append((freesurfer_home, input_path, ss_output_path))

    total_tasks = len(tasks)
    if total_tasks == 0:
        print("✅ No new NIfTI files to process.")
        return

    print(f"🔍 Found {total_tasks} files to process.")

    start_time = time.time()

    # Process tasks sequentially
    for i, task in enumerate(tasks):
        print(f"[{i+1}/{total_tasks}] Processing: {task[1]}")
        try:
            run_synthstrip(*task)
        except Exception as e:
            print(f"⚠️ Error processing {task[1]}: {e}")

    elapsed_time = time.time() - start_time
    print(f"✅ Finished processing all files in {elapsed_time:.2f}s.")


skull_strip_nifti("./data/adni-2-4-cond")

## Visualise the Stripped Scans


In [None]:
ss_files = get_nii_files(DATA, "ss_")

# Display the first few NIfTI file paths
print("First few NIfTI files:")
for file in ss_files[:5]:
    print(file)

# Visualize all NIfTI scans
for scan in ss_files[::50]:
    img = nib.load(scan)
    plotting.plot_anat(
        img,
        title=f"Anatomical View: {scan}",
        annotate=False,
        draw_cross=False,
        cut_coords=(0, 0, 0),
    )
    plt.show()

# Orientation Standardisation


In [None]:
def check_las_orientation(file_list):
    """
    Checks if each NIfTI file in the provided list is in LAS+ orientation.

    Args:
        file_list (list): List of file paths to .nii.gz files.

    Returns:
        None: Prints the number of LAS+ files and lists any files that are not in LAS+.
    """
    las_count = 0
    non_las_files = []

    for file in file_list:
        try:
            img = nib.load(file)
            original_orientation = nib.aff2axcodes(img.affine)

            if original_orientation == ("L", "A", "S"):
                las_count += 1
            else:
                non_las_files.append((file, original_orientation))

        except Exception as e:
            print(f"Error reading {file}: {e}")

    print(f"\nTotal LAS+ files: {las_count}/{len(file_list)}")

    if non_las_files:
        print("\nFiles not in LAS+ orientation:")
        for file, orientation in non_las_files:
            print(f"{file}: {orientation}")
    else:
        print("All files are already in LAS+ orientation.")


nii_files = get_nii_files(DATA)

check_las_orientation(nii_files)

don't need it cuz it's applied at augementations step


# Crop and Reshape


In [None]:
def safe_load_nifti(file_path):
    """
    Safely load NIfTI file with minimal fallback methods.
    """
    try:
        img = nib.load(file_path)
        return img.get_fdata(), img.affine, img.header
    except Exception as e:
        print(f"Failed to load {file_path}: {e}")
        return None, None, None


def crop_brain_from_mri(img_data, padding=3):
    """
    Crop out empty space around the brain in 3D MRI scans.
    """
    # Use a low threshold to capture brain tissue while excluding noise
    mask = img_data > np.mean(img_data) * 0.1
    coords = np.argwhere(mask)

    # If no significant tissue is found, return original image
    if len(coords) == 0:
        return img_data, (
            (0, img_data.shape[0]),
            (0, img_data.shape[1]),
            (0, img_data.shape[2]),
        )

    mins = coords.min(axis=0)
    maxs = coords.max(axis=0)
    cropped_mins = [max(0, m - padding) for m in mins]
    cropped_maxs = [min(img_data.shape[i], m + padding) for i, m in enumerate(maxs)]

    cropped_img = img_data[
        cropped_mins[0] : cropped_maxs[0],
        cropped_mins[1] : cropped_maxs[1],
        cropped_mins[2] : cropped_maxs[2],
    ]

    crop_coords = (
        (cropped_mins[0], cropped_maxs[0]),
        (cropped_mins[1], cropped_maxs[1]),
        (cropped_mins[2], cropped_maxs[2]),
    )
    return cropped_img, crop_coords


def preprocess_crop_and_reshape_mri(file_path, target_shape, padding=3):
    """
    Load, crop, and reshape (by interpolation) an MRI scan to a target shape.

    Parameters:
    -----------
    file_path : str
        Path to the MRI scan file (.nii.gz).
    target_shape : tuple of ints
        Desired output shape after cropping and reshaping.
    padding : int, optional
        Additional padding around the brain region (default: 3 voxels).

    Returns:
    --------
    final_img : numpy.ndarray
        Processed image data after cropping and reshaping.
    crop_coords : tuple
        Coordinates of the crop used.
    affine : numpy.ndarray
        The affine transform of the original scan.
    """
    # Load the image
    img_data, affine, header = safe_load_nifti(file_path)
    if img_data is None:
        return None, None, None

    # Crop the brain
    cropped_img, crop_coords = crop_brain_from_mri(img_data, padding=padding)

    # Reshape using cubic interpolation if the shape differs
    if cropped_img.shape != target_shape:
        zoom_factors = [t / s for t, s in zip(target_shape, cropped_img.shape)]
        final_img = zoom(cropped_img, zoom_factors, order=3)
    else:
        final_img = cropped_img

    return final_img, crop_coords, affine


# Example usage within a batch processing function
def batch_preprocess_mri_dataset(input_dir, output_dir, target_shape, padding=3):
    """
    Batch process MRI scans by cropping and reshaping, then save the results.
    """
    os.makedirs(output_dir, exist_ok=True)
    processed_files = []
    failed_files = []

    for root, _, files in os.walk(input_dir):
        for filename in files:
            if filename.endswith(".nii.gz"):
                input_path = os.path.join(root, filename)
                relative_path = os.path.relpath(root, input_dir)
                output_subdir = os.path.join(output_dir, relative_path)
                os.makedirs(output_subdir, exist_ok=True)

                output_path = os.path.join(output_subdir, filename)

                try:
                    processed_img, crop_coords, affine = (
                        preprocess_crop_and_reshape_mri(
                            input_path, target_shape, padding
                        )
                    )
                    if processed_img is None:
                        print(f"Skipping {filename} due to loading error.")
                        failed_files.append(filename)
                        continue

                    # Save the processed image
                    processed_nii = nib.Nifti1Image(processed_img, affine)
                    nib.save(processed_nii, output_path)

                    processed_files.append(filename)
                    print(f"Processed: {filename}")
                except Exception as e:
                    print(f"Failed to process {filename}: {e}")
                    failed_files.append(filename)

    print("\nProcessing Summary:")
    print(f"Total files processed: {len(processed_files)}")
    print(f"Total files failed: {len(failed_files)}")
    if failed_files:
        print("Failed files:")
        for file in failed_files:
            print(file)
    return processed_files, failed_files

# Spatial Normalisation


## Reasons I do NOT Need Spatial Normalization

1. ADNI-1 T1W data is already preprocessed
   - ADNI follows a standardized acquisition protocol, ensuring consistent voxel sizes and orientations across patients.
   - If you’re only using ADNI-1 (no ADNI-2 or ADNI-3), there’s less variation in scanner settings, meaning alignment might already be sufficient.
2. CNNs Learn Spatial Features
   - If you're using a deep learning model (e.g., CNN), it can learn spatial variations on its own.
   - Adding spatial normalization could remove subtle differences in brain shape that might be relevant for classification.
3. Preserving Native Brain Shape
   - Some models benefit from analyzing brain atrophy without forced alignment to MNI space.
   - If you want to measure structural differences in their original form (e.g., hippocampal shrinkage), keeping scans in native space may be better.


# Bias Field Correction


In [None]:
# Get all skull-stripped files
nii_files = get_nii_files(DATA, "ss_")

# Apply N4 Bias Field Correction to all files
for file in nii_files:
    bias_corrected_path = file.replace("ss_", "bc_ss_")

    if not os.path.exists(bias_corrected_path):
        # Load the skull-stripped image
        input_image = ants.image_read(file)

        # Apply N4 Bias Field Correction
        bias_corrected = ants.n4_bias_field_correction(input_image)

        # Save the bias-corrected image
        ants.image_write(bias_corrected, bias_corrected_path)
        print(f"Bias-corrected image saved to: {bias_corrected_path}")
    else:
        print(f"Skipping {file}, bias-corrected file already exists.")

# Load one bias-corrected image for comparison
if len(nii_files) >= 1:
    original_image = nib.load(nii_files[0])
    bias_corrected_image = nib.load(nii_files[0].replace("ss_", "bc_ss_"))

    # Plot the middle slice of the original scan
    plotting.plot_anat(original_image, title="Original Image", cut_coords=(0, 0, 0))
    plt.show()

    # Plot the middle slice of the bias-corrected scan
    plotting.plot_anat(
        bias_corrected_image, title="Bias-Corrected Image", cut_coords=(0, 0, 0)
    )
    plt.show()

    # Convert to numpy arrays
    original_data = original_image.get_fdata()
    bias_corrected_data = bias_corrected_image.get_fdata()

    # Choose a slice index
    slice_idx = original_data.shape[2] // 2  # Middle slice

    # Compute absolute difference
    difference = np.abs(
        original_data[:, :, slice_idx] - bias_corrected_data[:, :, slice_idx]
    )

    # Plot difference heatmap
    plt.figure(figsize=(5, 5))
    plt.imshow(difference, cmap="hot")
    plt.colorbar(label="Intensity Difference")
    plt.title("Difference Map")
    plt.show()
else:
    print("Not enough files for comparison.")

# Voxel Standardisation


## Check for uniform voxality


In [None]:
def check_uniform_voxel_size(file_list):
    """
    Checks if all NIfTI files in the provided list have uniform voxel size.

    Args:
        file_list (list): List of file paths to .nii.gz files.

    Returns:
        None: Prints the result of the check.
    """
    voxel_size_counts = {}

    for file in file_list:
        try:
            img = nib.load(file)
            voxel_size = img.header.get_zooms()
            if voxel_size in voxel_size_counts:
                voxel_size_counts[voxel_size] += 1
            else:
                voxel_size_counts[voxel_size] = 1
        except Exception as e:
            print(f"Error reading {file}: {e}")

    if len(voxel_size_counts) == 1:
        print("All files have uniform voxel size:", list(voxel_size_counts.keys())[0])
    else:
        print("Files have different voxel sizes:")
        for voxel_size, count in voxel_size_counts.items():
            print(f"Voxel size: {voxel_size}, Count: {count}")


# Check voxel size for all NIfTI files
nii_files = get_nii_files(DATA, prefix="bc_")
check_uniform_voxel_size(nii_files)

## Resample Voxel Size


In [None]:
def resample(file_path, output_path, target=(1, 1, 1)):
    """Resample a NIfTI file to 1x1x1 mm voxel size using ANTs."""
    try:
        # Load the image
        img = ants.image_read(file_path)

        # Resample the image to 1x1x1 mm voxel size
        resampled_img = ants.resample_image(img, target, use_voxels=False)

        # Save the resampled image
        ants.image_write(resampled_img, output_path)
        print(f"Resampled and saved: {output_path}")

    except Exception as e:
        print(f"Error resampling {file_path}: {e}")


def test_resample_single_image(file_path):
    """Test resampling on a single image and display the outputs."""
    resampled_file_path = file_path.replace("bc_", "resampled_bc_")
    resample(file_path, resampled_file_path)

    # Load the original and resampled images
    original_img = nib.load(file_path)
    resampled_img = nib.load(resampled_file_path)

    # Display the resolutions
    original_resolution = original_img.header.get_zooms()
    resampled_resolution = resampled_img.header.get_zooms()
    print(f"Original resolution: {original_resolution}")
    print(f"Resampled resolution: {resampled_resolution}")

    # Plot the middle slice of the original scan
    plotting.plot_anat(original_img, title="Original Image", cut_coords=(0, 0, 0))
    plt.show()

    # Plot the middle slice of the resampled scan
    plotting.plot_anat(resampled_img, title="Resampled Image", cut_coords=(0, 0, 0))
    plt.show()


# Test the resampling on a single image
test_resample_single_image(nii_files[0])

In [None]:
def resample_all_bc_files(base_dir):
    """Resample all bias-corrected NIfTI files in the directory to 1x1x1 mm voxel size."""
    bc_files = get_nii_files(base_dir, prefix="bc_")
    for file_path in bc_files:
        resampled_file_path = file_path.replace("bc_", "resampled_bc_")
        if not os.path.exists(resampled_file_path):
            resample(file_path, resampled_file_path)
        else:
            print(f"Skipping {file_path}, resampled file already exists.")


resample_all_bc_files(DATA)

# Final Preprocessing Clean


In [None]:
display_comprehensive_stats(DATA, "resampled")

In [None]:
def remove_non_resampled_files(base_dir):
    """Remove all files that don't begin with 'resampled_'."""
    for root, _, files in os.walk(base_dir):
        for file in files:
            if not file.startswith("resampled_"):
                file_path = os.path.join(root, file)
                os.remove(file_path)
                print(f"Removed: {file_path}")


remove_non_resampled_files(DATA)

In [None]:
display_comprehensive_stats(DATA, "resampled")

In [None]:
def rename_files_as_directory(base_dir):
    """Rename all files in the directory to their immediate directory name."""
    for root, _, files in os.walk(base_dir):
        for file in files:
            file_path = os.path.join(root, file)
            immediate_dir = os.path.basename(root)
            new_file_name = f"{immediate_dir}.nii.gz"
            new_file_path = os.path.join(root, new_file_name)
            os.rename(file_path, new_file_path)
            print(f"Renamed: {file_path} -> {new_file_path}")


rename_files_as_directory(DATA)


# Check for duplicate file names
def check_duplicate_file_names(base_dir):
    """Check if any files in the directory have the same name."""
    file_names = {}
    duplicates = []

    for root, _, files in os.walk(base_dir):
        for file in files:
            if file in file_names:
                duplicates.append(file)
            else:
                file_names[file] = root

    if duplicates:
        print("Duplicate file names found:")
        for file in duplicates:
            print(file)
    else:
        print("No duplicate file names found.")


check_duplicate_file_names(DATA)

# Create the Split Dataset


In [None]:
# Create output directories
for split in ["train", "val", "test"]:
    for group in ["AD", "CN"]:
        os.makedirs(os.path.join(OUTPUT, split, group), exist_ok=True)


# Function to get research group from XML
def get_research_group(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for elem in root.iter("researchGroup"):
            return elem.text
    except Exception as e:
        print(f"Error parsing {xml_path}: {e}")
    return None


# Collect all NIfTI files and their corresponding research groups
file_groups = {"AD": [], "CN": []}
for root, _, files in os.walk(DATA):
    for file in files:
        if file.endswith(".nii.gz"):
            image_id = file.split(".")[0]
            # Find the corresponding XML file
            xml_file = None
            for xml in os.listdir(METADATA):
                if xml.endswith(f"{image_id}.xml"):
                    xml_file = xml
                    break
            if xml_file:
                xml_path = os.path.join(METADATA, xml_file)
                group = get_research_group(xml_path)
                if group in file_groups:
                    file_groups[group].append(os.path.join(root, file))

# Split the data into train, validation, and test sets
train_files = {"AD": [], "CN": []}
val_files = {"AD": [], "CN": []}
test_files = {"AD": [], "CN": []}

for group in ["AD", "CN"]:
    train, temp = train_test_split(file_groups[group], test_size=0.2, random_state=42)
    val, test = train_test_split(temp, test_size=0.5, random_state=42)
    train_files[group].extend(train)
    val_files[group].extend(val)
    test_files[group].extend(test)


# Function to copy files to the output directory
def copy_files(file_list, split, group):
    for nii_path in file_list:
        nii_filename = os.path.basename(nii_path)
        shutil.copy(nii_path, os.path.join(OUTPUT, split, group, nii_filename))


# Copy files to the respective directories
for group in ["AD", "CN"]:
    copy_files(train_files[group], "train", group)
    copy_files(val_files[group], "val", group)
    copy_files(test_files[group], "test", group)

print("Dataset split and creation completed.")

In [None]:
# Function to print dataset statistics
def print_dataset_statistics():
    total_ad = 0
    total_cn = 0
    for split in ["train", "val", "test"]:
        ad_count = len(os.listdir(os.path.join(OUTPUT, split, "AD")))
        cn_count = len(os.listdir(os.path.join(OUTPUT, split, "CN")))
        total_ad += ad_count
        total_cn += cn_count
        print(f"{split.capitalize()} set: AD={ad_count}, CN={cn_count}")
    print(f"Total: AD={total_ad}, CN={total_cn}")


# Print dataset statistics
print_dataset_statistics()

## Check Current state before reshaping


In [None]:
display_comprehensive_stats(OUTPUT_CROPPED)