# multi-label stratification

## here we do train/dev splits

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

# --- 1. Load and Prepare Data ---

metadata_df = pd.read_csv('../metadata/metadata.csv')

print("--- Original Data Info ---")
print(f"Total number of patients: {len(metadata_df)}")
print("\nOriginal Diagnosis Distribution:")
print(metadata_df['clinical_diagnosis'].value_counts(normalize=True))
print("\nOriginal Site Distribution:")
print(metadata_df['site'].value_counts(normalize=True))
print("\n" + "="*50 + "\n")


# --- 2. Create the Combination Key for Stratification ---
metadata_df['stratify_key'] = metadata_df['clinical_diagnosis'] + '_' + metadata_df['site']

print("--- Combination Key Counts ---")
print(metadata_df['stratify_key'].value_counts())
print("\n" + "="*50 + "\n")


# --- 3. CRITICAL: Check for Rare Combinations ---
# Stratification requires at least 2 samples for each class. If any combination
# has only 1 sample, the split will fail.
key_counts = metadata_df['stratify_key'].value_counts()
if (key_counts < 2).any():
    print("!!! WARNING: Found combinations with only 1 sample. !!!")
    print("The following combinations are too rare for stratification:")
    print(key_counts[key_counts < 2])
    print("\nYou must handle this before proceeding. Options:")
    print("1. Get more data for these groups.")
    print("2. Group rare sites together (e.g., 'SiteC' and 'SiteD' become 'Other_Site').")
    print("3. Remove these single-sample patients (last resort).")
    # Exit or handle the issue here. For this script, we'll assume no error.
    # exit()


# --- 4. Perform the Stratified Splits using the Combination Key ---

# Define split proportions
TEST_SIZE = 0.20      # 20% of patients for the test set
VALIDATION_SIZE = 0.25 # 25% of the remaining 80% for validation (i.e., 20% of total)
RANDOM_STATE = 42

# First split: (Train + Validation) vs. Test
# We split the DataFrame directly since each row is a patient
train_val_df, test_df = train_test_split(
    metadata_df,
    test_size=TEST_SIZE,
    random_state=RANDOM_STATE,
    stratify=metadata_df['stratify_key'] # Stratify on our new key
)

# Second split: Train vs. Validation
train_df, val_df = train_test_split(
    train_val_df,
    test_size=VALIDATION_SIZE,
    random_state=RANDOM_STATE,
    stratify=train_val_df['stratify_key'] # Stratify again on the same key
)

# Clean up the temporary key from the final dataframes
train_df = train_df.drop(columns=['stratify_key'])
val_df = val_df.drop(columns=['stratify_key'])
test_df = test_df.drop(columns=['stratify_key'])


# --- 5. Verification ---
print("--- Verification of Splits ---")
print(f"Total patients in sets: {len(train_df) + len(val_df) + len(test_df)} (should match total)")
print(f"Training patients:   {len(train_df)}")
print(f"Validation patients: {len(val_df)}")
print(f"Test patients:       {len(test_df)}\n")

print("--- Diagnosis Distribution ---")
print("Train Set Diagnosis Distribution:")
print(train_df['clinical_diagnosis'].value_counts(normalize=True).sort_index())
print("\nValidation Set Diagnosis Distribution:")
print(val_df['clinical_diagnosis'].value_counts(normalize=True).sort_index())
print("\nTest Set Diagnosis Distribution:")
print(test_df['clinical_diagnosis'].value_counts(normalize=True).sort_index())
print("\n" + "="*50 + "\n")


print("--- Site Distribution ---")
print("Train Set Site Distribution:")
print(train_df['site'].value_counts(normalize=True).sort_index())
print("\nValidation Set Site Distribution:")
print(val_df['site'].value_counts(normalize=True).sort_index())
print("\nTest Set Site Distribution:")
print(test_df['site'].value_counts(normalize=True).sort_index())

--- Original Data Info ---
Total number of patients: 714

Original Diagnosis Distribution:
clinical_diagnosis
AD     0.333333
CN     0.333333
FTD    0.333333
Name: proportion, dtype: float64

Original Site Distribution:
site
Custodio      0.334734
Lopera        0.214286
Matallana     0.210084
Avila         0.096639
Slachevsky    0.085434
Bruno         0.037815
Behrens       0.021008
Name: proportion, dtype: float64


--- Combination Key Counts ---
stratify_key
FTD_Matallana     86
CN_Custodio       80
FTD_Custodio      80
AD_Custodio       79
AD_Lopera         76
CN_Lopera         62
CN_Matallana      39
AD_Matallana      25
AD_Slachevsky     24
FTD_Avila         23
AD_Avila          23
FTD_Slachevsky    23
CN_Avila          23
FTD_Lopera        15
CN_Slachevsky     14
CN_Behrens        11
AD_Bruno           9
CN_Bruno           9
FTD_Bruno          9
AD_Behrens         2
FTD_Behrens        2
Name: count, dtype: int64


--- Verification of Splits ---
Total patients in sets: 714 (should

## Analysis of Your Results
1. Successful Combination and Stratification:

Combination Key Counts: You can see the breakdown of every diagnosis-site pair. Crucially, the smallest counts are AD_Behrens and FTD_Behrens with 2 samples each. Because no group had a count of 1, the train_test_split function was able to perform the stratification without errors. This is a perfect outcome.

2. Perfect Split Proportions:

#### Patient Counts:
You started with 714 patients and ended with:
Training: 428 patients (~60%)
Validation: 143 patients (~20%)
Test: 143 patients (~20%)
Total: 428 + 143 + 143 = 714. The math is correct.

3. Outstanding Distribution Balance (The Key Success Metric):

#### Diagnosis Distribution: 
Look how closely the proportions match across the sets:

Original: AD (33.3%), CN (33.3%), FTD (33.3%)
Train: AD (33.1%), CN (33.1%), FTD (33.6%)
Validation: AD (33.5%), CN (34.2%), FTD (32.1%)
Test: AD (33.5%), CN (32.8%), FTD (33.5%)
This is an almost perfect stratification. The very minor differences are due to the rounding required when dealing with a finite number of patients.
Site Distribution: This is equally impressive. The proportions of patients from each site are very consistent across the three sets. For example:

Custodio: ~33.4% in all three sets.
Lopera: ~21.5% in all three sets.
Even the smaller sites like Behrens and Bruno are represented proportionally (any small variations are expected due to the low absolute numbers).

## Conclusion: 
You have successfully created high-quality, balanced, and non-overlapping training, validation, and test sets. You've accounted for the confounding variables of clinical_diagnosis and site, which gives your future models the best possible chance of learning true biological signals instead of dataset biases. This is a critical milestone passed with flying colors.

In [None]:
# --- Define filenames ---
train_output_path = '../metadata/train_metadata.csv'
val_output_path = '../metadata/validation_metadata.csv'
test_output_path = '../metadata/test_metadata.csv'

# --- Save DataFrames to CSV ---
# Using index=False is important to prevent pandas from writing the DataFrame index
# as an extra column in your CSV files.
train_df.to_csv(train_output_path, index=False)
val_df.to_csv(val_output_path, index=False)
test_df.to_csv(test_output_path, index=False)

DataFrames saved successfully:
- Training data: ../metadata/train_metadata.csv
- Validation data: ../metadata/validation_metadata.csv
- Test data: ../metadata/test_metadata.csv
