In [None]:
import pandas as pd
import numpy as np

def create_stratified_healthcare_samples(
    df, 
    sample_sizes=[25_000, 50_000, 75_000], 
    stratify_cols=['race', 'gender', 'age'],
    random_state=42
):
    """
    Create stratified samples for privacy technique comparison.
    
    Parameters:
    - df: Full cleaned dataset
    - sample_sizes: List of target sample sizes
    - stratify_cols: Columns to stratify on (quasi-identifiers)
    - random_state: For reproducibility
    
    Returns:
    - Dictionary of sample DataFrames
    """
    
    print(f"Creating stratified samples from {len(df):,} records...")
    print(f"Stratifying on: {stratify_cols}")
    
    samples = {}
    
    # Store full dataset distributions for comparison
    full_distributions = {}
    for col in stratify_cols:
        full_distributions[col] = df[col].value_counts(normalize=True).sort_index()
    
    for n in sample_sizes:
        print(f"\n{'='*50}")
        print(f"Creating {n:,}-record sample...")
        
        # Calculate sampling fraction
        frac = min(1.0, n / len(df))  # Prevent fraction > 1
        
        # Stratified sampling
        try:
            sample_df = (
                df
                .groupby(stratify_cols, group_keys=False)
                .apply(lambda grp: grp.sample(
                    n=max(1, int(len(grp) * frac)), 
                    random_state=random_state,
                    replace=False
                ))
                .reset_index(drop=True)
            )
            
            # Validate sample size
            actual_size = len(sample_df)
            print(f"✓ Created sample: {actual_size:,} rows (target: {n:,})")
            
            # Distribution validation
            print(f"\nDistribution Validation:")
            for col in stratify_cols:
                sample_dist = sample_df[col].value_counts(normalize=True).sort_index()
                
                # Calculate distribution similarity (KL divergence or simple difference)
                full_dist_aligned = full_distributions[col].reindex(sample_dist.index).fillna(0)
                max_diff = abs(sample_dist - full_dist_aligned).max()
                
                print(f"  {col}: Max distribution difference = {max_diff:.4f}")
            
            # Save sample
            filename = f'diabetic_data_{n//1000}k.csv'
            sample_df.to_csv(filename, index=False)
            print(f"✓ Saved to '{filename}'")
            
            # Store in results
            samples[f'{n//1000}k'] = sample_df
            
        except Exception as e:
            print(f"✗ Error creating {n:,} sample: {e}")
    
    return samples

# Usage
df = pd.read_csv('diabetic_data_final.csv', na_filter=False)
samples = create_stratified_healthcare_samples(df)