# Train-Test Split with Stratification
## Mammography Dataset - Stratified Data Splitting for Imbalanced Classification


In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent / 'src'))

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

from data_processing.data_loader import load_mammography_data
from data_processing.class_imbalance import print_class_distribution


## 1. Load Dataset


In [None]:
# Load the mammography dataset
X, y = load_mammography_data()

print(f"Dataset loaded successfully!")
print(f"Features shape (X): {X.shape}")
print(f"Target shape (y): {y.shape}")
print(f"\nOriginal dataset class distribution:")
print_class_distribution(y)


## 2. Train-Test Split with Stratification


In [None]:
# Split the dataset into training and test sets
# Using stratification to maintain class distribution in both sets
# test_size=0.2 means 20% for testing, 80% for training
# random_state=42 for reproducibility
# stratify=y ensures proportional class distribution in both sets

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=y  # Critical for imbalanced datasets!
)

print("=" * 70)
print("TRAIN-TEST SPLIT COMPLETED")
print("=" * 70)
print(f"\nTraining set size: {X_train.shape[0]:,} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
print(f"Test set size:     {X_test.shape[0]:,} samples ({X_test.shape[0]/len(X)*100:.1f}%)")
print(f"\nFeature dimensions: {X_train.shape[1]} features")
print("=" * 70)


## 3. Verify Class Distribution in Training Set


In [None]:
print("TRAINING SET CLASS DISTRIBUTION")
print("=" * 70)
print_class_distribution(y_train)


## 4. Verify Class Distribution in Test Set


In [None]:
print("TEST SET CLASS DISTRIBUTION")
print("=" * 70)
print_class_distribution(y_test)


## 5. Comparison: Original vs Train vs Test Distributions


In [None]:
# Calculate class distributions
def get_class_stats(y):
    """Calculate class distribution statistics"""
    unique, counts = np.unique(y, return_counts=True)
    total = len(y)
    stats = {}
    for cls, count in zip(unique, counts):
        stats[cls] = {'count': count, 'percentage': (count/total)*100}
    return stats

# Get statistics for all sets
original_stats = get_class_stats(y)
train_stats = get_class_stats(y_train)
test_stats = get_class_stats(y_test)

# Create comparison DataFrame
comparison_data = {
    'Set': ['Original', 'Training', 'Test'],
    'Total Samples': [len(y), len(y_train), len(y_test)],
    'Benign (0) Count': [
        original_stats[0]['count'],
        train_stats[0]['count'],
        test_stats[0]['count']
    ],
    'Benign (0) %': [
        f"{original_stats[0]['percentage']:.2f}%",
        f"{train_stats[0]['percentage']:.2f}%",
        f"{test_stats[0]['percentage']:.2f}%"
    ],
    'Malignant (1) Count': [
        original_stats[1]['count'],
        train_stats[1]['count'],
        test_stats[1]['count']
    ],
    'Malignant (1) %': [
        f"{original_stats[1]['percentage']:.2f}%",
        f"{train_stats[1]['percentage']:.2f}%",
        f"{test_stats[1]['percentage']:.2f}%"
    ],
    'Imbalance Ratio': [
        f"{original_stats[0]['count']/original_stats[1]['count']:.1f}:1",
        f"{train_stats[0]['count']/train_stats[1]['count']:.1f}:1",
        f"{test_stats[0]['count']/test_stats[1]['count']:.1f}:1"
    ]
}

comparison_df = pd.DataFrame(comparison_data)

print("=" * 100)
print("CLASS DISTRIBUTION COMPARISON")
print("=" * 100)
print(comparison_df.to_string(index=False))
print("=" * 100)

# Verify stratification worked correctly
print("\n" + "=" * 100)
print("STRATIFICATION VERIFICATION")
print("=" * 100)
original_ratio = original_stats[0]['percentage']
train_ratio = train_stats[0]['percentage']
test_ratio = test_stats[0]['percentage']

print(f"Original Benign percentage: {original_ratio:.2f}%")
print(f"Training Benign percentage: {train_ratio:.2f}%")
print(f"Test Benign percentage:     {test_ratio:.2f}%")
print(f"\nDifference (Train - Original): {abs(train_ratio - original_ratio):.2f}%")
print(f"Difference (Test - Original):  {abs(test_ratio - original_ratio):.2f}%")

if abs(train_ratio - original_ratio) < 0.5 and abs(test_ratio - original_ratio) < 0.5:
    print("\n✅ Stratification successful! Class distributions are well-preserved.")
else:
    print("\n⚠️  Warning: Significant difference in class distributions detected.")
print("=" * 100)


## 6. Summary

The dataset has been successfully split into training and test sets with:
- **Stratification**: Class distribution is preserved in both sets
- **80/20 Split**: 80% for training, 20% for testing
- **Reproducibility**: `random_state=42` ensures consistent splits

**Key Points:**
- Stratification is crucial for imbalanced datasets to ensure both sets have representative samples from each class
- The class imbalance ratio should be similar across original, training, and test sets
- This split is ready for model training and evaluation
