In [None]:
import os
import glob
from collections import defaultdict


# Define directory paths
base_dir = "./DATA"
data_dir = os.path.join(
    base_dir, "ADNI_CROPPED_128"
)  # Directory with train/val/test folders
xml_dir = os.path.join(base_dir, "ADNI_METADATA")  # Directory with XML files

I've a directory with 3 subdirs, train, val, and test. Each subdir has 2 subdirs, AD and CN. Each AD and CN dir has many nii.gz files in it, their name is an id like "I12345". I then have another dir with lots of xml files in it, each of them is named like such "ADNI_013_S_0575_MPRAGE_S28210_I44926" where the final part is the scan id "I44926" and the second third and fourth part is the subject id "013_S_0575". For all the scans in the first dir append the subject id corresponding to that scan as the new name of the scan. DO this in python.


In [None]:
def prepend_subject_id(data_dir, xml_dir):
    # Create mapping from scan ID to subject ID
    scan_to_subject = {}

    # Parse XML filenames to extract subject ID and scan ID
    print("Creating scan-to-subject mapping...")
    for xml_file in os.listdir(xml_dir):
        if xml_file.endswith(".xml"):
            # From format like "ADNI_013_S_0575_MPRAGE_S28210_I44926.xml"
            # Extract the 013_S_0575 (subject ID) and I44926 (scan ID)
            parts = xml_file.split("_")
            if len(parts) >= 6 and parts[-1].startswith("I"):
                scan_id = parts[-1].split(".")[0]  # Extract scan ID (e.g., "I12345")
                subject_id = "_".join(
                    parts[1:4]
                )  # Parts 2-4 form subject ID (013_S_0575)
                scan_to_subject[scan_id] = subject_id

    print(f"Found {len(scan_to_subject)} scan-to-subject mappings")

    # Process each directory and rename files
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                print(f"Directory {dir_path} does not exist, skipping...")
                continue

            print(f"Processing {dir_path}...")
            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)
                scan_id = file_name.split(".")[0]  # Extract scan ID (e.g., "I12345")

                if scan_id in scan_to_subject:
                    subject_id = scan_to_subject[scan_id]
                    new_name = f"{subject_id}_{file_name}"
                    new_path = os.path.join(dir_path, new_name)
                    print(f"Renaming {file_name} to {new_name}")
                    os.rename(file_path, new_path)
                else:
                    print(f"Could not find subject ID for scan {scan_id}")

    print("File renaming complete!")


# prepend_subject_id(data_dir, xml_dir)

now create another function that checks if any subjects that are in either the train or validations splits are in the test split


In [None]:
def check_subject_overlap_and_scans(data_dir):
    """
    Check if any subjects in train or validation splits also appear in test split
    and identify which scans belong to the same subject.
    """
    # Sets to store subject IDs
    total_subjects = set()
    train_val_subjects = set()
    test_subjects = set()

    # Track all scans for each subject
    subject_to_scans = defaultdict(list)

    # Collect subject IDs and scans from each split
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)
                # Files should now be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])
                    scan_id = parts[-1].split(".")[0]

                    total_subjects.add(subject_id)

                    if split in ["train", "val"]:
                        train_val_subjects.add(subject_id)
                    elif split == "test":
                        test_subjects.add(subject_id)

                    # Store scan info with split and condition
                    subject_to_scans[subject_id].append(
                        {
                            "scan_id": scan_id,
                            "split": split,
                            "condition": condition,
                            "filename": file_name,
                        }
                    )

    # Find subjects that appear in both train/val and test
    overlapping_subjects = train_val_subjects.intersection(test_subjects)

    # Report basic results
    print(f"Total subjects in train/val: {len(train_val_subjects)}")
    print(f"Total subjects in test: {len(test_subjects)}")
    print(f"Total subjects in dataset: {len(total_subjects)}")

    # Report subjects with multiple scans
    subjects_with_multiple_scans = {
        subj: scans for subj, scans in subject_to_scans.items() if len(scans) > 1
    }
    print(f"\nSubjects with multiple scans: {len(subjects_with_multiple_scans)}")

    # Report overlap results
    if overlapping_subjects:
        print(
            f"\nWARNING: Found {len(overlapping_subjects)} subjects in both train/val and test!"
        )
        print("\nOverlapping subjects and their scans:")
        for subject in sorted(overlapping_subjects):
            print(f"\nSubject {subject} appears in multiple splits:")
            for scan in subject_to_scans[subject]:
                print(f"  - {scan['filename']} ({scan['split']}/{scan['condition']})")
        return False
    else:
        print("\n✓ No subject overlap found between train/val and test splits")
        return True


# Run the enhanced overlap check
check_subject_overlap_and_scans(data_dir)

Total subjects in train/val: 204
Total subjects in test: 68
Total subjects in dataset: 211

Subjects with multiple scans: 176


Overlapping subjects and their scans:

Subject 003_S_0981 appears in multiple splits:
  - 003_S_0981_I80850.nii.gz (train/CN)
  - 003_S_0981_I53046.nii.gz (train/CN)
  - 003_S_0981_I27140.nii.gz (train/CN)
  - 003_S_0981_I126046.nii.gz (test/CN)

Subject 011_S_0008 appears in multiple splits:
  - 011_S_0008_I12209.nii.gz (train/CN)
  - 011_S_0008_I7211.nii.gz (test/CN)

Subject 011_S_0010 appears in multiple splits:
  - 011_S_0010_I91038.nii.gz (train/AD)
  - 011_S_0010_I91071.nii.gz (train/AD)
  - 011_S_0010_I8460.nii.gz (train/AD)
  - 011_S_0010_I14868.nii.gz (test/AD)

Subject 011_S_0021 appears in multiple splits:
  - 011_S_0021_I14130.nii.gz (train/CN)
  - 011_S_0021_I7679.nii.gz (train/CN)
  - 011_S_0021_I28808.nii.gz (train/CN)
  - 011_S_0021_I125318.nii.gz (train/CN)
  - 011_S_0021_I193526.nii.gz (test/CN)

Subject 011_S_0183 appears in multiple splits

False

In [13]:
from collections import defaultdict
import os
import glob
import shutil
import random


def reorganize_data_split(
    data_dir, output_dir, train_pct=0.8, val_pct=0.1, test_pct=0.1, random_seed=42
):
    """
    Reorganize data splits to ensure no subject overlap between train/val and test sets
    while maintaining specified train/val/test percentages.

    Args:
        data_dir: Directory containing original data
        output_dir: Directory to save reorganized data
        train_pct: Percentage for training set (default 0.8)
        val_pct: Percentage for validation set (default 0.1)
        test_pct: Percentage for test set (default 0.1)
        random_seed: Random seed for reproducibility
    """
    random.seed(random_seed)

    # Check percentages
    if abs(train_pct + val_pct + test_pct - 1.0) > 0.001:
        raise ValueError("Split percentages must sum to 1.0")

    # Create output directory structure
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            os.makedirs(os.path.join(output_dir, split, condition), exist_ok=True)

    # Collect all subjects and their scans
    subjects_by_condition = defaultdict(list)
    subject_to_scans = defaultdict(list)

    # First pass: collect all subjects and their scans
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)

                # Files should be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])
                    scan_id = parts[-1].split(".")[0]

                    subjects_by_condition[condition].append(subject_id)
                    subject_to_scans[subject_id].append(
                        {
                            "scan_id": scan_id,
                            "condition": condition,
                            "filename": file_name,
                            "original_path": file_path,
                        }
                    )

    # Get unique subjects by condition
    for condition in subjects_by_condition:
        subjects_by_condition[condition] = list(set(subjects_by_condition[condition]))
        random.shuffle(subjects_by_condition[condition])

    # Create new splits for each condition
    new_splits = {}
    for condition in subjects_by_condition:
        subjects = subjects_by_condition[condition]
        total_subjects = len(subjects)

        # Calculate subject counts for each split
        train_count = int(total_subjects * train_pct)
        val_count = int(total_subjects * val_pct)
        test_count = total_subjects - train_count - val_count

        # Assign subjects to splits
        new_splits[condition] = {
            "train": subjects[:train_count],
            "val": subjects[train_count : train_count + val_count],
            "test": subjects[train_count + val_count :],
        }

        # Print split information
        print(
            f"{condition} split: {train_count}/{val_count}/{test_count} subjects for train/val/test"
        )

    # Copy files to new locations
    copied_files = 0
    for subject_id, scans in subject_to_scans.items():
        for scan in scans:
            condition = scan["condition"]

            # Determine which split this subject belongs to
            target_split = None
            for split in ["train", "val", "test"]:
                if subject_id in new_splits[condition][split]:
                    target_split = split
                    break

            if target_split:
                # Copy file to new location
                src_path = scan["original_path"]
                dst_path = os.path.join(
                    output_dir, target_split, condition, scan["filename"]
                )
                shutil.copy2(src_path, dst_path)
                copied_files += 1

    print(f"Copied {copied_files} files to new split structure")

    # Verify no subject overlap
    verify_no_overlap(output_dir)


def verify_no_overlap(data_dir):
    """
    Verify that there is no subject overlap between train/val and test splits.
    """
    # Sets to store subject IDs
    train_val_subjects = set()
    test_subjects = set()

    # Collect subject IDs from each split
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)

                # Files should be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])

                    if split in ["train", "val"]:
                        train_val_subjects.add(subject_id)
                    elif split == "test":
                        test_subjects.add(subject_id)

    # Find subjects that appear in both train/val and test
    overlapping_subjects = train_val_subjects.intersection(test_subjects)

    # Report basic results
    print(f"\nVerification results:")
    print(f"Total subjects in train/val: {len(train_val_subjects)}")
    print(f"Total subjects in test: {len(test_subjects)}")
    print(f"Total subjects in dataset: {len(train_val_subjects.union(test_subjects))}")

    # Report overlap results
    if overlapping_subjects:
        print(
            f"\nWARNING: Found {len(overlapping_subjects)} subjects in both train/val and test!"
        )
        return False
    else:
        print("\n✓ No subject overlap found between train/val and test splits")
        return True


# Example usage
if __name__ == "__main__":
    output_dir = "./DATA/OUTPUT"
    reorganize_data_split(data_dir, output_dir)

AD split: 75/9/10 subjects for train/val/test
CN split: 93/11/13 subjects for train/val/test
Copied 849 files to new split structure

Verification results:
Total subjects in train/val: 188
Total subjects in test: 23
Total subjects in dataset: 211

✓ No subject overlap found between train/val and test splits
