In [None]:
import os
import glob
from collections import defaultdict


# Define directory paths
base_dir = "./DATA"
data_dir = os.path.join(
    base_dir, "ADNI_CROPPED_128"
)  # Directory with train/val/test folders
xml_dir = os.path.join(base_dir, "ADNI_METADATA")  # Directory with XML files

I've a directory with 3 subdirs, train, val, and test. Each subdir has 2 subdirs, AD and CN. Each AD and CN dir has many nii.gz files in it, their name is an id like "I12345". I then have another dir with lots of xml files in it, each of them is named like such "ADNI_013_S_0575_MPRAGE_S28210_I44926" where the final part is the scan id "I44926" and the second third and fourth part is the subject id "013_S_0575". For all the scans in the first dir append the subject id corresponding to that scan as the new name of the scan. DO this in python.


In [None]:
def prepend_subject_id(data_dir, xml_dir):
    # Create mapping from scan ID to subject ID
    scan_to_subject = {}

    # Parse XML filenames to extract subject ID and scan ID
    print("Creating scan-to-subject mapping...")
    for xml_file in os.listdir(xml_dir):
        if xml_file.endswith(".xml"):
            # From format like "ADNI_013_S_0575_MPRAGE_S28210_I44926.xml"
            # Extract the 013_S_0575 (subject ID) and I44926 (scan ID)
            parts = xml_file.split("_")
            if len(parts) >= 6 and parts[-1].startswith("I"):
                scan_id = parts[-1].split(".")[0]  # Extract scan ID (e.g., "I12345")
                subject_id = "_".join(
                    parts[1:4]
                )  # Parts 2-4 form subject ID (013_S_0575)
                scan_to_subject[scan_id] = subject_id

    print(f"Found {len(scan_to_subject)} scan-to-subject mappings")

    # Process each directory and rename files
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                print(f"Directory {dir_path} does not exist, skipping...")
                continue

            print(f"Processing {dir_path}...")
            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)
                scan_id = file_name.split(".")[0]  # Extract scan ID (e.g., "I12345")

                if scan_id in scan_to_subject:
                    subject_id = scan_to_subject[scan_id]
                    new_name = f"{subject_id}_{file_name}"
                    new_path = os.path.join(dir_path, new_name)
                    print(f"Renaming {file_name} to {new_name}")
                    os.rename(file_path, new_path)
                else:
                    print(f"Could not find subject ID for scan {scan_id}")

    print("File renaming complete!")


# prepend_subject_id(data_dir, xml_dir)

now create another function that checks if any subjects that are in either the train or validations splits are in the test split


In [None]:
def check_subject_overlap_and_scans(data_dir):
    """
    Check if any subjects in train or validation splits also appear in test split
    and identify which scans belong to the same subject.
    """
    # Sets to store subject IDs
    total_subjects = set()
    train_val_subjects = set()
    test_subjects = set()

    # Track all scans for each subject
    subject_to_scans = defaultdict(list)

    # Collect subject IDs and scans from each split
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)
                # Files should now be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])
                    scan_id = parts[-1].split(".")[0]

                    total_subjects.add(subject_id)

                    if split in ["train", "val"]:
                        train_val_subjects.add(subject_id)
                    elif split == "test":
                        test_subjects.add(subject_id)

                    # Store scan info with split and condition
                    subject_to_scans[subject_id].append(
                        {
                            "scan_id": scan_id,
                            "split": split,
                            "condition": condition,
                            "filename": file_name,
                        }
                    )

    # Find subjects that appear in both train/val and test
    overlapping_subjects = train_val_subjects.intersection(test_subjects)

    # Report basic results
    print(f"Total subjects in train/val: {len(train_val_subjects)}")
    print(f"Total subjects in test: {len(test_subjects)}")
    print(f"Total subjects in dataset: {len(total_subjects)}")

    # Report subjects with multiple scans
    subjects_with_multiple_scans = {
        subj: scans for subj, scans in subject_to_scans.items() if len(scans) > 1
    }
    print(f"\nSubjects with multiple scans: {len(subjects_with_multiple_scans)}")

    # Report overlap results
    if overlapping_subjects:
        print(
            f"\nWARNING: Found {len(overlapping_subjects)} subjects in both train/val and test!"
        )
        print("\nOverlapping subjects and their scans:")
        for subject in sorted(overlapping_subjects):
            print(f"\nSubject {subject} appears in multiple splits:")
            for scan in subject_to_scans[subject]:
                print(f"  - {scan['filename']} ({scan['split']}/{scan['condition']})")
        return False
    else:
        print("\n✓ No subject overlap found between train/val and test splits")
        return True


# Run the enhanced overlap check
check_subject_overlap_and_scans(data_dir)

In [14]:
from collections import defaultdict
import os
import glob
import shutil
import random


def reorganize_data_split(
    data_dir, output_dir, train_pct=0.8, val_pct=0.1, test_pct=0.1, random_seed=42
):
    """
    Reorganize data splits to ensure no subject overlap between train/val and test sets
    while maintaining scan percentages close to the specified train/val/test ratios.

    Args:
        data_dir: Directory containing original data
        output_dir: Directory to save reorganized data
        train_pct: Percentage for training set (default 0.8)
        val_pct: Percentage for validation set (default 0.1)
        test_pct: Percentage for test set (default 0.1)
        random_seed: Random seed for reproducibility
    """
    random.seed(random_seed)

    # Check percentages
    if abs(train_pct + val_pct + test_pct - 1.0) > 0.001:
        raise ValueError("Split percentages must sum to 1.0")

    # Create output directory structure
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            os.makedirs(os.path.join(output_dir, split, condition), exist_ok=True)

    # Collect all subjects and their scans
    subjects_by_condition = defaultdict(list)
    subject_to_scans = defaultdict(list)
    total_scans_by_condition = defaultdict(int)

    # First pass: collect all subjects and their scans
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)

                # Files should be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])
                    scan_id = parts[-1].split(".")[0]

                    # Track unique subjects by condition
                    if subject_id not in [s for s in subjects_by_condition[condition]]:
                        subjects_by_condition[condition].append(subject_id)

                    # Store scan information
                    subject_to_scans[subject_id].append(
                        {
                            "scan_id": scan_id,
                            "condition": condition,
                            "filename": file_name,
                            "original_path": file_path,
                        }
                    )

                    # Count total scans by condition
                    total_scans_by_condition[condition] += 1

    # Get subject scan counts
    subject_scan_counts = {
        subject: len(scans) for subject, scans in subject_to_scans.items()
    }

    # Print summary statistics
    for condition in subjects_by_condition:
        print(f"Condition {condition}:")
        print(f"  - Total subjects: {len(subjects_by_condition[condition])}")
        print(f"  - Total scans: {total_scans_by_condition[condition]}")

    # Create new splits aiming for scan percentages
    new_splits = {}

    for condition in subjects_by_condition:
        subjects = subjects_by_condition[condition].copy()
        random.shuffle(subjects)

        # Calculate target scan counts
        total_scans = total_scans_by_condition[condition]
        target_train_scans = int(total_scans * train_pct)
        target_val_scans = int(total_scans * val_pct)
        target_test_scans = total_scans - target_train_scans - target_val_scans

        print(f"\n{condition} target scan counts:")
        print(f"  - Train: {target_train_scans} scans ({train_pct*100:.1f}%)")
        print(f"  - Val: {target_val_scans} scans ({val_pct*100:.1f}%)")
        print(f"  - Test: {target_test_scans} scans ({test_pct*100:.1f}%)")

        # Initialize splits
        new_splits[condition] = {"train": [], "val": [], "test": []}
        current_counts = {"train": 0, "val": 0, "test": 0}

        # Sort subjects by number of scans (descending)
        subjects_by_scans = sorted(
            subjects, key=lambda subject: subject_scan_counts[subject], reverse=True
        )

        # First, assign subjects to test set (to ensure it gets close to target)
        remaining_subjects = []
        for subject in subjects_by_scans:
            scan_count = subject_scan_counts[subject]

            # If adding this subject would get test closer to target, add it
            if (
                current_counts["test"] < target_test_scans
                and current_counts["test"] + scan_count <= target_test_scans * 1.1
            ):
                new_splits[condition]["test"].append(subject)
                current_counts["test"] += scan_count
            else:
                remaining_subjects.append(subject)

        # Then assign to validation set
        subjects_for_train = []
        for subject in remaining_subjects:
            scan_count = subject_scan_counts[subject]

            # If adding this subject would get val closer to target, add it
            if (
                current_counts["val"] < target_val_scans
                and current_counts["val"] + scan_count <= target_val_scans * 1.1
            ):
                new_splits[condition]["val"].append(subject)
                current_counts["val"] += scan_count
            else:
                subjects_for_train.append(subject)

        # Remaining subjects go to training set
        for subject in subjects_for_train:
            new_splits[condition]["train"].append(subject)
            current_counts["train"] += subject_scan_counts[subject]

        # Print actual scan counts
        print(f"\n{condition} actual split:")
        for split in ["train", "val", "test"]:
            subjects_count = len(new_splits[condition][split])
            scans_count = current_counts[split]
            percentage = scans_count / total_scans * 100 if total_scans > 0 else 0
            print(
                f"  - {split}: {subjects_count} subjects, {scans_count} scans ({percentage:.1f}%)"
            )

    # Copy files to new locations
    copied_files = 0
    for subject_id, scans in subject_to_scans.items():
        for scan in scans:
            condition = scan["condition"]

            # Determine which split this subject belongs to
            target_split = None
            for split in ["train", "val", "test"]:
                if subject_id in new_splits[condition][split]:
                    target_split = split
                    break

            if target_split:
                # Copy file to new location
                src_path = scan["original_path"]
                dst_path = os.path.join(
                    output_dir, target_split, condition, scan["filename"]
                )
                shutil.copy2(src_path, dst_path)
                copied_files += 1

    print(f"\nCopied {copied_files} files to new split structure")

    # Verify no subject overlap
    verify_no_overlap(output_dir)

    # Verify actual scan percentages
    verify_scan_percentages(
        output_dir,
        target_train_pct=train_pct,
        target_val_pct=val_pct,
        target_test_pct=test_pct,
    )


def verify_no_overlap(data_dir):
    """
    Verify that there is no subject overlap between train/val and test splits.
    """
    # Sets to store subject IDs
    train_val_subjects = set()
    test_subjects = set()

    # Collect subject IDs from each split
    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            for file_path in glob.glob(os.path.join(dir_path, "*.nii.gz")):
                file_name = os.path.basename(file_path)

                # Files should be named like "013_S_0575_I44926.nii.gz"
                if "_" in file_name:
                    parts = file_name.split("_")
                    subject_id = "_".join(parts[:3])

                    if split in ["train", "val"]:
                        train_val_subjects.add(subject_id)
                    elif split == "test":
                        test_subjects.add(subject_id)

    # Find subjects that appear in both train/val and test
    overlapping_subjects = train_val_subjects.intersection(test_subjects)

    # Report basic results
    print(f"\nVerification results:")
    print(f"Total subjects in train/val: {len(train_val_subjects)}")
    print(f"Total subjects in test: {len(test_subjects)}")
    print(f"Total subjects in dataset: {len(train_val_subjects.union(test_subjects))}")

    # Report overlap results
    if overlapping_subjects:
        print(
            f"\nWARNING: Found {len(overlapping_subjects)} subjects in both train/val and test!"
        )
        return False
    else:
        print("\n✓ No subject overlap found between train/val and test splits")
        return True


def verify_scan_percentages(
    data_dir, target_train_pct=0.8, target_val_pct=0.1, target_test_pct=0.1
):
    """
    Verify the actual percentages of scans in each split.
    """
    print("\nVerifying scan distribution:")

    # Count scans in each split
    scan_counts_by_condition = defaultdict(lambda: {"train": 0, "val": 0, "test": 0})

    for split in ["train", "val", "test"]:
        for condition in ["AD", "CN"]:
            dir_path = os.path.join(data_dir, split, condition)
            if not os.path.exists(dir_path):
                continue

            scan_count = len(glob.glob(os.path.join(dir_path, "*.nii.gz")))
            scan_counts_by_condition[condition][split] = scan_count

    # Calculate percentages
    for condition, counts in scan_counts_by_condition.items():
        total_scans = sum(counts.values())

        if total_scans > 0:
            train_pct = counts["train"] / total_scans
            val_pct = counts["val"] / total_scans
            test_pct = counts["test"] / total_scans

            print(f"\n{condition} scan distribution:")
            print(
                f"  - Train: {counts['train']} scans ({train_pct*100:.1f}%), target: {target_train_pct*100:.1f}%"
            )
            print(
                f"  - Val: {counts['val']} scans ({val_pct*100:.1f}%), target: {target_val_pct*100:.1f}%"
            )
            print(
                f"  - Test: {counts['test']} scans ({test_pct*100:.1f}%), target: {target_test_pct*100:.1f}%"
            )

            # Check if actual percentages are close to targets
            train_diff = abs(train_pct - target_train_pct)
            val_diff = abs(val_pct - target_val_pct)
            test_diff = abs(test_pct - target_test_pct)

            max_diff = max(train_diff, val_diff, test_diff)

            if max_diff < 0.05:  # Within 5% of target
                print(
                    f"  ✓ {condition} scan distribution is close to target percentages"
                )
            else:
                print(
                    f"  ⚠️ {condition} scan distribution deviates from targets by up to {max_diff*100:.1f}%"
                )


# Example usage
if __name__ == "__main__":
    output_dir = "./DATA/OUTPUT"
    reorganize_data_split(data_dir, output_dir)

Condition AD:
  - Total subjects: 94
  - Total scans: 306
Condition CN:
  - Total subjects: 117
  - Total scans: 543

AD target scan counts:
  - Train: 244 scans (80.0%)
  - Val: 30 scans (10.0%)
  - Test: 32 scans (10.0%)

AD actual split:
  - train: 84 subjects, 241 scans (78.8%)
  - val: 6 subjects, 32 scans (10.5%)
  - test: 4 subjects, 33 scans (10.8%)

CN target scan counts:
  - Train: 434 scans (80.0%)
  - Val: 54 scans (10.0%)
  - Test: 55 scans (10.0%)

CN actual split:
  - train: 105 subjects, 429 scans (79.0%)
  - val: 6 subjects, 54 scans (9.9%)
  - test: 6 subjects, 60 scans (11.0%)

Copied 849 files to new split structure

Verification results:
Total subjects in train/val: 201
Total subjects in test: 10
Total subjects in dataset: 211

✓ No subject overlap found between train/val and test splits

Verifying scan distribution:

AD scan distribution:
  - Train: 241 scans (78.8%), target: 80.0%
  - Val: 32 scans (10.5%), target: 10.0%
  - Test: 33 scans (10.8%), target: 10.0%
