<a href="https://www.kaggle.com/code/abidur14004/illinosis-doc?scriptVersionId=269050689" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.preprocessing import LabelEncoder, StandardScaler, RobustScaler
from tqdm.auto import tqdm
import warnings
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json
from datetime import datetime
from scipy import stats
warnings.filterwarnings('ignore')


# ==================== VIT IMAGE PREPROCESSING CONFIGURATION ====================
class ViTPreprocessingConfig:
    """Configuration for ViT preprocessing with multi-modal features"""
    def __init__(self):
        # Image settings
        self.image_size = 224
        self.channels = 3
        
        # ImageNet normalization
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        
        # Data augmentation settings
        self.horizontal_flip_prob = 0.5
        self.color_jitter = {
            'brightness': 0.2,
            'contrast': 0.2,
            'saturation': 0.1,
            'hue': 0.05
        }
        self.rotation_degrees = 10
        
        # Cross-validation
        self.n_folds = 5
        self.random_state = 42
        
        # BMI filtering
        self.bmi_range = (16, 45)
        
        # Image quality thresholds
        self.min_image_size = 1
        
        # Feature columns to use (exclude height/weight derivatives to prevent leakage)
        self.feature_columns = ['race', 'eyes', 'sex', 'hair', 'age']
        
        # Feature engineering settings
        self.reference_date = datetime(2025, 1, 1)
        self.age_bins = [0, 18, 25, 35, 45, 55, 65, 100]
        self.age_labels = ['<18', '18-24', '25-34', '35-44', '45-54', '55-64', '65+']


# ==================== BASIC FEATURE ENGINEERING (NO LEAKAGE) ====================

def calculate_age(birth_date, reference_date=None):
    """Calculate age from date of birth"""
    if pd.isna(birth_date):
        return np.nan
    
    if reference_date is None:
        reference_date = datetime.now()
    
    try:
        if isinstance(birth_date, str):
            birth_date = pd.to_datetime(birth_date, errors='coerce')
        
        if pd.isna(birth_date):
            return np.nan
            
        age = reference_date.year - birth_date.year
        
        if reference_date.month < birth_date.month or \
           (reference_date.month == birth_date.month and reference_date.day < birth_date.day):
            age -= 1
            
        return age if age >= 0 else np.nan
    except:
        return np.nan


def calculate_bmi(height_inches, weight_lbs):
    """Calculate BMI from height (inches) and weight (pounds) - used only for target"""
    if pd.isna(height_inches) or pd.isna(weight_lbs):
        return np.nan
    if height_inches <= 0 or weight_lbs <= 0:
        return np.nan
    
    height_m = height_inches * 0.0254
    weight_kg = weight_lbs * 0.453592
    bmi = weight_kg / (height_m ** 2)
    return bmi


def engineer_basic_features(df, config):
    """
    Engineer ONLY basic features that don't depend on dataset statistics or height/weight.
    Height and weight are used ONLY to compute the target BMI.
    NO height/weight-derived features to prevent trivial prediction/leakage.
    """
    df = df.copy()
    
    print("[1/3] Engineering demographic features...")
    # Calculate age
    birth_col = next((col for col in ['birth', 'dob', 'date_of_birth'] if col in df.columns), None)
    if birth_col:
        df['age'] = df[birth_col].apply(lambda x: calculate_age(x, config.reference_date))
        
        # Age groups
        df['age_group'] = pd.cut(df['age'], bins=config.age_bins, labels=config.age_labels, right=False)
        
        # Age transformations (independent)
        df['age_squared'] = df['age'] ** 2
        df['age_decade'] = (df['age'] // 10) * 10
    
    print("[2/3] Cleaning categorical features...")
    # Clean categorical features
    categorical_features = ['race', 'eyes', 'sex', 'hair']
    for feat in categorical_features:
        if feat in df.columns:
            df[feat] = clean_categorical_feature(df[feat])
    
    # Categorical interactions (no height/weight)
    if 'race' in df.columns and 'sex' in df.columns:
        df['race_sex'] = df['race'].astype(str) + '_' + df['sex'].astype(str)
    
    if 'race' in df.columns and 'age_group' in df.columns:
        df['race_age'] = df['race'].astype(str) + '_' + df['age_group'].astype(str)
    
    if 'eyes' in df.columns and 'hair' in df.columns:
        df['eye_hair_combo'] = df['eyes'].astype(str) + '_' + df['hair'].astype(str)
    
    print("[3/3] Engineering image features...")
    # Image path features
    if 'image_path' in df.columns:
        df['image_type'] = df['image_path'].apply(
            lambda x: 'front' if 'front' in str(x).lower() 
            else ('side' if 'side' in str(x).lower() 
            else ('inmates' if 'inmates' in str(x).lower() else 'unknown'))
        )
        df['has_front_image'] = (df['image_type'] == 'front').astype(int)
        df['has_side_image'] = (df['image_type'] == 'side').astype(int)
    
    # Compute target BMI (but do not create any features from height/weight)
    if 'height' in df.columns and 'weight' in df.columns:
        df['bmi'] = df.apply(lambda row: calculate_bmi(row['height'], row['weight']), axis=1)
    
    # Remove height and weight columns to prevent accidental usage
    df = df.drop(columns=['height', 'weight'], errors='ignore')
    
    # Remove rows with invalid BMI
    initial_count = len(df)
    df = df.dropna(subset=['bmi'])
    print(f"\nRemoved {initial_count - len(df)} rows with invalid BMI")
    print(f"Final shape after basic feature engineering: {df.shape}")
    
    return df


def engineer_statistical_features(df, reference_stats=None, fit=True):
    """
    Calculate statistical features (z-scores, percentiles, etc.)
    MUST be called separately for train/val to avoid leakage.
    Only apply to non-leakage features like age.
    """
    df = df.copy()
    
    # ONLY use age (no height/weight to prevent leakage)
    numeric_cols = ['age']
    existing_cols = [col for col in numeric_cols if col in df.columns]
    
    if reference_stats is None:
        reference_stats = {}
    
    for col in existing_cols:
        if df[col].notna().sum() == 0:
            continue
        
        if fit:
            # Calculate statistics from THIS data (training data)
            col_data = df[col].dropna()
            stats_dict = {
                'mean': float(col_data.mean()),
                'std': float(col_data.std()),
                'median': float(col_data.median()),
                'q25': float(col_data.quantile(0.25)),
                'q75': float(col_data.quantile(0.75))
            }
            reference_stats[col] = stats_dict
        else:
            # Use pre-calculated statistics (from training data)
            if col not in reference_stats:
                print(f"Warning: No reference stats for {col}, skipping")
                continue
            stats_dict = reference_stats[col]
        
        # Apply transformations using appropriate statistics
        mean = stats_dict['mean']
        std = stats_dict['std']
        median = stats_dict['median']
        
        # Z-score normalization
        df[f'{col}_zscore'] = (df[col] - mean) / (std + 1e-8)
        
        # Deviation from median
        df[f'{col}_deviation_from_median'] = df[col] - median
        
        # Binary: above/below mean
        df[f'{col}_above_mean'] = (df[col] > mean).astype(int)
        
        # Percentile rank (calculate on current data, this is okay)
        df[f'{col}_percentile'] = df[col].rank(pct=True) * 100
    
    return df, reference_stats


def clean_categorical_feature(series, valid_values=None):
    """Clean and standardize categorical features"""
    series = series.astype(str).str.strip().str.upper()
    series = series.replace(['', 'NAN', 'NONE', 'NULL', 'UNKNOWN', 'NOT AVAILABLE', 'VOID'], 'UNKNOWN')
    
    if valid_values is not None:
        series = series.apply(lambda x: x if x in valid_values else 'UNKNOWN')
    
    return series


def process_person_data(csv_path, config):
    """
    Load person.csv and apply BASIC feature engineering only
    Statistical features will be added per-fold to prevent leakage
    """
    df = pd.read_csv(csv_path, delimiter=';')
    
    print("=== Starting Basic Feature Engineering ===")
    print(f"Initial shape: {df.shape}")
    
    # Apply basic feature engineering (no dataset statistics)
    df = engineer_basic_features(df, config)
    
    print("\n=== Basic Feature Engineering Summary ===")
    print(f"Total features created: {df.shape[1]}")
    
    return df


# ==================== DATA MATCHING AND VALIDATION ====================

def match_images_to_data(df, image_dirs):
    """Match images to person data"""
    if isinstance(image_dirs, str):
        image_dirs = [image_dirs]
    
    # Collect all available images
    available_images = {}
    for image_dir in image_dirs:
        if not os.path.exists(image_dir):
            continue
        for f in os.listdir(image_dir):
            if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                stem = os.path.splitext(f)[0].lower()
                available_images[stem] = os.path.join(image_dir, f)
    
    id_columns = ['id', 'ID', 'person_id', 'subject_id', 'doc_id', 'inmate_id']
    id_col = None
    for col in id_columns:
        if col in df.columns:
            id_col = col
            break
    
    if id_col is None:
        id_col = df.columns[0]
    
    df['image_stem'] = df[id_col].astype(str).str.lower()
    df['has_image'] = df['image_stem'].isin(available_images.keys())
    df['image_path'] = df['image_stem'].apply(lambda x: available_images.get(x, None))
    
    matched = df['has_image'].sum()
    print(f"\nMatched {matched}/{len(df)} records to images ({matched/len(df)*100:.1f}%)")
    
    df_matched = df[df['has_image']].copy()
    df_matched = df_matched.drop(columns=['name'], errors='ignore')
    df_matched = df_matched.rename(columns={'image_stem': 'name'})
    
    return df_matched


def validate_images_fast(df, min_size=1):
    """Fast image validation using basic checks only"""
    valid_indices = []
    failed_images = []
    
    print(f"Fast validating {len(df)} images...")
    
    for idx in tqdm(range(len(df)), desc="Quick validation"):
        img_name = df.loc[idx, 'name']
        image_path = df.loc[idx, 'image_path']
        
        if pd.isna(image_path) or not os.path.exists(image_path):
            failed_images.append((idx, img_name, "Not found"))
            continue
        
        # Quick file size check (much faster than loading)
        file_size = os.path.getsize(image_path)
        if file_size < 1000:  # Less than 1KB = likely corrupt
            failed_images.append((idx, img_name, "Too small"))
            continue
        
        valid_indices.append(idx)
    
    if failed_images and len(failed_images) < 100:
        failed_df = pd.DataFrame(failed_images, columns=['index', 'filename', 'reason'])
        failed_df.to_csv('failed_images_log.csv', index=False)
    
    print(f"Validated {len(valid_indices)}/{len(df)} images")
    return valid_indices


def filter_dataset(df, valid_indices, bmi_range=(16, 45)):
    """Filter dataset based on valid images and BMI range"""
    df_filtered = df.iloc[valid_indices].copy()
    
    if 'bmi' in df_filtered.columns:
        initial_len = len(df_filtered)
        df_filtered = df_filtered[
            (df_filtered['bmi'] >= bmi_range[0]) & 
            (df_filtered['bmi'] <= bmi_range[1])
        ].reset_index(drop=True)
        print(f"Filtered {initial_len - len(df_filtered)} samples outside BMI range {bmi_range}")
    
    return df_filtered


# ==================== K-FOLD CREATION (BEFORE STATISTICAL FEATURES) ====================

def create_kfold_splits(df, config):
    """
    Create K-fold splits BEFORE adding statistical features
    This is the key to preventing data leakage
    """
    use_stratified = False
    
    if 'bmi' in df.columns and len(df) >= config.n_folds * 10:
        df_temp = df.copy()
        try:
            df_temp['bmi_quartile'] = pd.qcut(df_temp['bmi'], q=config.n_folds, 
                                             labels=False, duplicates='drop')
            stratify_labels = df_temp['bmi_quartile'].values
            kf = StratifiedKFold(n_splits=config.n_folds, shuffle=True, 
                                random_state=config.random_state)
            split_iterator = kf.split(df, stratify_labels)
            use_stratified = True
            print(f"\nUsing Stratified K-Fold with {config.n_folds} folds")
        except:
            kf = KFold(n_splits=config.n_folds, shuffle=True, 
                      random_state=config.random_state)
            split_iterator = kf.split(df)
            print(f"\nUsing K-Fold with {config.n_folds} folds")
    else:
        kf = KFold(n_splits=config.n_folds, shuffle=True, 
                  random_state=config.random_state)
        split_iterator = kf.split(df)
        print(f"\nUsing K-Fold with {config.n_folds} folds")
    
    fold_splits = []
    for fold, (train_idx, val_idx) in enumerate(split_iterator):
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)
        
        # Verify no data leakage
        train_names = set(train_df['name'].astype(str))
        val_names = set(val_df['name'].astype(str))
        assert len(train_names & val_names) == 0, f"Data leakage detected in fold {fold+1}!"
        
        fold_info = {
            'fold': fold + 1,
            'train_size': len(train_df),
            'val_size': len(val_df),
            'train_df': train_df,
            'val_df': val_df
        }
        
        if 'bmi' in df.columns:
            fold_info['train_bmi_stats'] = {
                'mean': float(train_df['bmi'].mean()),
                'std': float(train_df['bmi'].std()),
                'min': float(train_df['bmi'].min()),
                'max': float(train_df['bmi'].max())
            }
            fold_info['val_bmi_stats'] = {
                'mean': float(val_df['bmi'].mean()),
                'std': float(val_df['bmi'].std()),
                'min': float(val_df['bmi'].min()),
                'max': float(val_df['bmi'].max())
            }
        
        fold_splits.append(fold_info)
        
        print(f"  Fold {fold+1}: Train={len(train_df)}, Val={len(val_df)}")
    
    return fold_splits


# ==================== NORMALIZATION (PER-FOLD) ====================

def normalize_features(df, feature_encoders=None, scalers=None, fit=True):
    """
    Normalize categorical and numerical features
    MUST be called separately for train/val with fit=True/False
    """
    df = df.copy()
    
    if feature_encoders is None:
        feature_encoders = {}
    if scalers is None:
        scalers = {}
    
    # Categorical encoding
    categorical_cols = [col for col in df.columns if df[col].dtype == 'object']
    
    for col in categorical_cols:
        if col in ['name', 'image_path', 'image_stem']:
            continue
            
        if fit:
            le = LabelEncoder()
            try:
                df[f'{col}_encoded'] = le.fit_transform(df[col].fillna('UNKNOWN'))
                feature_encoders[col] = le
            except:
                pass
        else:
            if col in feature_encoders:
                le = feature_encoders[col]
                df[f'{col}_encoded'] = df[col].fillna('UNKNOWN').apply(
                    lambda x: le.transform([x])[0] if x in le.classes_ else -1
                )
    
    # Numerical scaling (exclude any potential leakage columns)
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    scale_cols = [col for col in numerical_cols if not col.endswith('_encoded') and 
                  col not in ['name', 'image_path', 'bmi'] and 
                  not col.startswith('is_') and not col.startswith('has_') and 
                  'height' not in col and 'weight' not in col]
    
    if fit and len(scale_cols) > 0:
        scaler = RobustScaler()
        df[scale_cols] = scaler.fit_transform(df[scale_cols].fillna(df[scale_cols].median()))
        scalers['numerical'] = scaler
    elif 'numerical' in scalers and len(scale_cols) > 0:
        scaler = scalers['numerical']
        df[scale_cols] = scaler.transform(df[scale_cols].fillna(df[scale_cols].median()))
    
    return df, feature_encoders, scalers


# ==================== SAVE PREPROCESSED DATA ====================

def save_preprocessed_data(fold_splits, config, image_dirs, output_dir='illinois_doc_preprocessed'):
    """Save preprocessed fold splits with per-fold normalization"""
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\n=== Processing and Saving Folds to {output_dir} ===")
    
    # We'll store one set of encoders (they should be similar across folds)
    global_feature_encoders = {}
    
    for fold_info in fold_splits:
        fold_num = fold_info['fold']
        fold_dir = f'{output_dir}/fold_{fold_num}'
        os.makedirs(fold_dir, exist_ok=True)
        
        print(f"\nProcessing Fold {fold_num}...")
        
        train_df = fold_info['train_df'].copy()
        val_df = fold_info['val_df'].copy()
        
        # Add statistical features (FIT on train, TRANSFORM on val)
        print(f"  Adding statistical features (fit on train)...")
        train_df, train_stats = engineer_statistical_features(train_df, reference_stats=None, fit=True)
        val_df, _ = engineer_statistical_features(val_df, reference_stats=train_stats, fit=False)
        
        # Normalize features (FIT on train, TRANSFORM on val)
        print(f"  Normalizing features (fit on train)...")
        train_df, feature_encoders, scalers = normalize_features(train_df, fit=True)
        val_df, _, _ = normalize_features(val_df, feature_encoders, scalers, fit=False)
        
        # Store encoders from first fold
        if fold_num == 1:
            global_feature_encoders = feature_encoders
        
        # Save fold data
        train_df.to_csv(f'{fold_dir}/train.csv', index=False)
        val_df.to_csv(f'{fold_dir}/val.csv', index=False)
        
        # Save fold-specific statistics
        fold_metadata = {
            'fold': fold_num,
            'train_size': len(train_df),
            'val_size': len(val_df),
            'feature_count': len(train_df.columns),
            'train_bmi_stats': fold_info.get('train_bmi_stats', {}),
            'val_bmi_stats': fold_info.get('val_bmi_stats', {}),
            'reference_stats': train_stats
        }
        
        with open(f'{fold_dir}/metadata.json', 'w') as f:
            json.dump(fold_metadata, f, indent=4)
        
        print(f"  ✓ Fold {fold_num} saved: {len(train_df)} train, {len(val_df)} val samples")
    
    # Save global configuration
    config_dict = {
        'image_size': config.image_size,
        'channels': config.channels,
        'mean': config.mean,
        'std': config.std,
        'n_folds': config.n_folds,
        'bmi_range': list(config.bmi_range),
        'image_dirs': image_dirs,
        'random_state': config.random_state,
        'feature_columns': config.feature_columns,
        'augmentation': {
            'horizontal_flip_prob': config.horizontal_flip_prob,
            'color_jitter': config.color_jitter,
            'rotation_degrees': config.rotation_degrees
        },
        'age_bins': config.age_bins,
        'age_labels': config.age_labels
    }
    
    with open(f'{output_dir}/config.json', 'w') as f:
        json.dump(config_dict, f, indent=4)
    
    # Save feature encoders (from first fold as reference)
    encoder_dict = {}
    for feature, encoder in global_feature_encoders.items():
        encoder_dict[feature] = {
            'classes': encoder.classes_.tolist()
        }
    
    with open(f'{output_dir}/feature_encoders.json', 'w') as f:
        json.dump(encoder_dict, f, indent=4)
    
    # Create summary
    sample_train_df = pd.read_csv(f'{output_dir}/fold_1/train.csv')
    all_features = list(sample_train_df.columns)
    
    summary = {
        'total_folds': config.n_folds,
        'total_samples': fold_splits[0]['train_size'] + fold_splits[0]['val_size'],
        'total_features': len(all_features),
        'image_size': config.image_size,
        'normalization': {'mean': config.mean, 'std': config.std},
        'feature_columns': all_features,
        'categorical_features': [col for col in all_features if col.endswith('_encoded')],
        'numerical_features': [col for col in all_features if any(x in col for x in 
                              ['_zscore', '_percentile', '_squared', '_deviation', '_interaction'])],
        'fold_summary': [
            {
                'fold': f['fold'],
                'train_size': f['train_size'],
                'val_size': f['val_size'],
                'train_bmi_mean': f.get('train_bmi_stats', {}).get('mean', 0),
                'val_bmi_mean': f.get('val_bmi_stats', {}).get('mean', 0)
            }
            for f in fold_splits
        ]
    }
    
    with open(f'{output_dir}/summary.json', 'w') as f:
        json.dump(summary, f, indent=4)
    
    print(f"\n✓ All folds saved successfully")
    print(f"✓ Total features: {len(all_features)}")
    print(f"✓ Categorical features: {len(summary['categorical_features'])}")
    print(f"✓ Numerical features: {len(summary['numerical_features'])}")
    
    return output_dir


# ==================== DATA QUALITY & ANALYSIS ====================

def check_data_leakage(output_dir):
    """Verify no data leakage between train/val"""
    print("\n=== Checking for Data Leakage ===")
    
    issues_found = False
    
    # Load first fold
    train_df = pd.read_csv(f'{output_dir}/fold_1/train.csv')
    
    # Check for BMI-derived features
    bmi_features = [col for col in train_df.columns if 'bmi' in col.lower() and col != 'bmi']
    if bmi_features:
        print(f"⚠️  WARNING: BMI-derived features found: {bmi_features}")
        issues_found = True
    
    # Check for height/weight features
    hw_features = [col for col in train_df.columns if 'height' in col.lower() or 'weight' in col.lower() or 'bsa' in col.lower() or 'ponderal' in col.lower()]
    if hw_features:
        print(f"⚠️  WARNING: Height/weight-derived features found: {hw_features}")
        issues_found = True
    
    # Check correlations with BMI
    print("\nChecking feature correlations with BMI:")
    numeric_cols = train_df.select_dtypes(include=[np.number]).columns
    high_corr_features = []
    
    for col in numeric_cols:
        if col != 'bmi' and train_df[col].notna().sum() > 0:
            try:
                corr = train_df[[col, 'bmi']].corr().iloc[0, 1]
                if abs(corr) > 0.95:
                    print(f"  ⚠️  {col}: correlation = {corr:.4f}")
                    high_corr_features.append((col, corr))
                    issues_found = True
            except:
                pass
    
    if not issues_found:
        print("  ✓ No data leakage detected!")
        print("  ✓ No BMI-derived features found")
        print("  ✓ No height/weight-derived features found")
        print("  ✓ No suspiciously high correlations with BMI")
    
    return not issues_found


def analyze_dataset(df, output_dir):
    """Create analysis plots"""
    os.makedirs(output_dir, exist_ok=True)
    
    fig = plt.figure(figsize=(20, 10))
    
    # BMI Distribution
    plt.subplot(2, 4, 1)
    plt.hist(df['bmi'], bins=50, alpha=0.7, edgecolor='black', color='steelblue')
    plt.xlabel('BMI', fontsize=11)
    plt.ylabel('Frequency', fontsize=11)
    plt.title('BMI Distribution', fontsize=12, fontweight='bold')
    plt.grid(True, alpha=0.3)
    
    # Age Distribution
    plt.subplot(2, 4, 2)
    if 'age' in df.columns:
        plt.hist(df['age'].dropna(), bins=30, alpha=0.7, edgecolor='black', color='coral')
        plt.xlabel('Age', fontsize=11)
        plt.ylabel('Frequency', fontsize=11)
        plt.title('Age Distribution', fontsize=12, fontweight='bold')
        plt.grid(True, alpha=0.3)
    
    # BMI vs Age
    plt.subplot(2, 4, 3)
    
    if 'age' in df.columns:
        plt.scatter(df['age'], df['bmi'], alpha=0.3, s=10)
        plt.xlabel('Age', fontsize=11)
        plt.ylabel('BMI', fontsize=11)
        plt.title('BMI vs Age', fontsize=12, fontweight='bold')
        plt.grid(True, alpha=0.3)
    
    # Removed Height vs Weight plot to avoid leakage visualization
    
    # Sex Distribution
    plt.subplot(2, 4, 5)
    if 'sex' in df.columns:
        sex_counts = df['sex'].value_counts()
        sex_counts.plot(kind='bar', color=['lightblue', 'lightpink'], edgecolor='black')
        plt.xlabel('Sex', fontsize=11)
        plt.ylabel('Count', fontsize=11)
        plt.title('Sex Distribution', fontsize=12, fontweight='bold')
        plt.xticks(rotation=0)
        plt.grid(True, alpha=0.3, axis='y')
    
    # Race Distribution
    plt.subplot(2, 4, 6)
    if 'race' in df.columns:
        race_counts = df['race'].value_counts().head(10)
        race_counts.plot(kind='barh', color='lightgreen', edgecolor='black')
        plt.xlabel('Count', fontsize=11)
        plt.ylabel('Race', fontsize=11)
        plt.title('Top 10 Race Categories', fontsize=12, fontweight='bold')
        plt.grid(True, alpha=0.3, axis='x')
    
    # BMI Statistics by Sex
    plt.subplot(2, 4, 7)
    if 'sex' in df.columns and 'bmi' in df.columns:
        df.boxplot(column='bmi', by='sex', ax=plt.gca())
        plt.xlabel('Sex', fontsize=11)
        plt.ylabel('BMI', fontsize=11)
        plt.title('BMI Distribution by Sex', fontsize=12, fontweight='bold')
        plt.suptitle('')  # Remove automatic title
        plt.grid(True, alpha=0.3)
    
    # Statistics Summary
    plt.subplot(2, 4, 8)
    stats_text = f"""
    Dataset Statistics:
    
    Total Samples: {len(df):,}
    
    BMI:
      Mean: {df['bmi'].mean():.2f}
      Std: {df['bmi'].std():.2f}
      Min: {df['bmi'].min():.2f}
      Max: {df['bmi'].max():.2f}
    """
    
    if 'age' in df.columns:
        stats_text += f"""
    Age:
      Mean: {df['age'].mean():.1f}
      Range: {df['age'].min():.0f}-{df['age'].max():.0f}
    """
    
    plt.text(0.1, 0.5, stats_text, fontsize=10, family='monospace',
            verticalalignment='center', transform=plt.gca().transAxes)
    plt.axis('off')
    plt.title('Summary Statistics', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(f'{output_dir}/dataset_analysis.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"✓ Dataset analysis saved to {output_dir}/dataset_analysis.png")


def print_feature_summary(df):
    """Print summary of engineered features"""
    print("\n=== Feature Summary ===")
    
    all_features = df.columns.tolist()
    
    feature_categories = {
        'Age Features': [c for c in all_features if 'age' in c and 'bmi' not in c],
        'Statistical Features': [c for c in all_features if any(x in c for x in ['zscore', 'percentile', 'deviation', 'above_mean'])],
        'Categorical': [c for c in all_features if any(x in c for x in ['race', 'sex', 'eyes', 'hair', 'combo'])],
        'Image Features': [c for c in all_features if any(x in c for x in ['image_type', 'has_front', 'has_side'])],
    }
    
    for category, features in feature_categories.items():
        if features:
            print(f"\n{category}:")
            for feat in features[:10]:  # Show first 10
                print(f"  - {feat}")
            if len(features) > 10:
                print(f"  ... and {len(features) - 10} more")
    
    print(f"\nTotal features: {len(all_features)}")
    
    # Check for problematic features
    problematic = [f for f in all_features if 'bmi' in f.lower() and f != 'bmi']
    if problematic:
        print(f"\n⚠️  WARNING: Found {len(problematic)} BMI-derived features:")
        for f in problematic:
            print(f"  - {f}")
    
    hw_problematic = [f for f in all_features if 'height' in f.lower() or 'weight' in f.lower() or 'bsa' in f.lower() or 'ponderal' in f.lower()]
    if hw_problematic:
        print(f"\n⚠️  WARNING: Found {len(hw_problematic)} height/weight-derived features:")
        for f in hw_problematic:
            print(f"  - {f}")


# ==================== MAIN PREPROCESSING PIPELINE ====================

def main_preprocessing(person_csv_path, image_dirs, output_dir='illinois_doc_preprocessed', show_plots=True):
    """
    Complete preprocessing pipeline WITHOUT data leakage
    Statistical features are calculated per-fold
    """
    
    print("=" * 70)
    print("  ILLINOIS DOC DATASET - LEAKAGE-FREE PREPROCESSING PIPELINE")
    print("=" * 70)
    
    config = ViTPreprocessingConfig()
    
    # Step 1: Load and engineer BASIC features only
    print("\n[STEP 1] Loading data and engineering basic features...")
    df = process_person_data(person_csv_path, config)
    
    # Step 2: Match images
    print("\n[STEP 2] Matching images to data...")
    df_matched = match_images_to_data(df, image_dirs)
    
    if len(df_matched) == 0:
        raise ValueError("No matching images found!")
    
    # Step 3: Validate images
    print("\n[STEP 3] Validating images...")
    valid_indices = validate_images_fast(df_matched, min_size=config.min_image_size)
    
    if len(valid_indices) == 0:
        raise ValueError("No valid images found! Check failed_images_log.csv for details.")
    
    # Step 4: Filter by BMI range
    print("\n[STEP 4] Filtering dataset by BMI range...")
    df_filtered = filter_dataset(df_matched, valid_indices, config.bmi_range)
    
    if len(df_filtered) < config.n_folds * 2:
        raise ValueError(f"Not enough samples ({len(df_filtered)}) for {config.n_folds}-fold CV")
    
    print(f"✓ Retained {len(df_filtered)} samples within BMI range {config.bmi_range}")
    
    # Step 5: Analyze dataset
    print("\n[STEP 5] Analyzing dataset...")
    if show_plots:
        analyze_dataset(df_filtered, '/kaggle/working')
    
    print_feature_summary(df_filtered)
    
    # Step 6: Create K-fold splits (BEFORE statistical features)
    print("\n[STEP 6] Creating K-fold splits...")
    fold_splits = create_kfold_splits(df_filtered, config)
    
    # Step 7: Save preprocessed data (adds statistical features per-fold)
    print("\n[STEP 7] Adding per-fold statistical features and saving...")
    output_path = save_preprocessed_data(fold_splits, config, image_dirs, output_dir)
    
    # Step 8: Verify no data leakage
    print("\n[STEP 8] Verifying no data leakage...")
    is_clean = check_data_leakage(output_path)
    
    print("\n" + "=" * 70)
    print("  PREPROCESSING COMPLETE!")
    print("=" * 70)
    print(f"✓ Output directory: {output_path}")
    print(f"✓ Ready for training with {len(df_filtered)} samples")
    print(f"✓ BMI range: {df_filtered['bmi'].min():.2f} - {df_filtered['bmi'].max():.2f}")
    print(f"✓ No data leakage: {is_clean}")
    print("=" * 70 + "\n")
    
    return output_path, fold_splits


# ==================== UTILITY FUNCTIONS ====================

def load_config(preprocessed_dir='illinois_doc_preprocessed'):
    """Load preprocessing configuration"""
    config_path = f'{preprocessed_dir}/config.json'
    
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    with open(config_path, 'r') as f:
        config_dict = json.load(f)
    
    config = ViTPreprocessingConfig()
    config.image_size = config_dict['image_size']
    config.mean = config_dict['mean']
    config.std = config_dict['std']
    config.n_folds = config_dict['n_folds']
    config.bmi_range = tuple(config_dict['bmi_range'])
    config.feature_columns = config_dict.get('feature_columns', [])
    
    if 'augmentation' in config_dict:
        aug = config_dict['augmentation']
        config.horizontal_flip_prob = aug.get('horizontal_flip_prob', 0.5)
        config.color_jitter = aug.get('color_jitter', config.color_jitter)
        config.rotation_degrees = aug.get('rotation_degrees', 10)
    
    if 'age_bins' in config_dict:
        config.age_bins = config_dict['age_bins']
        config.age_labels = config_dict['age_labels']
    
    image_dirs = config_dict.get('image_dirs', [config_dict.get('image_dir', '')])
    
    return config, image_dirs


def load_feature_encoders(preprocessed_dir='illinois_doc_preprocessed'):
    """Load feature encoders"""
    encoder_path = f'{preprocessed_dir}/feature_encoders.json'
    
    if not os.path.exists(encoder_path):
        return {}
    
    with open(encoder_path, 'r') as f:
        encoder_dict = json.load(f)
    
    feature_encoders = {}
    for feature, info in encoder_dict.items():
        le = LabelEncoder()
        le.classes_ = np.array(info['classes'])
        feature_encoders[feature] = le
    
    return feature_encoders


def verify_preprocessing(preprocessed_dir='illinois_doc_preprocessed'):
    """Verify preprocessing was successful"""
    
    print(f"\n=== Verifying Preprocessing: {preprocessed_dir} ===")
    
    if not os.path.exists(preprocessed_dir):
        print("✗ Preprocessed directory not found")
        return False
    
    if not os.path.exists(f'{preprocessed_dir}/config.json'):
        print("✗ Configuration file not found")
        return False
    
    if not os.path.exists(f'{preprocessed_dir}/feature_encoders.json'):
        print("✗ Feature encoders file not found")
        return False
    
    print("✓ Core files exist")
    
    config, image_dirs = load_config(preprocessed_dir)
    feature_encoders = load_feature_encoders(preprocessed_dir)
    
    print(f"✓ Configuration loaded: {config.n_folds} folds")
    print(f"✓ Feature encoders loaded: {len(feature_encoders)} encoders")
    
    all_folds_ok = True
    for fold in range(1, config.n_folds + 1):
        fold_dir = f'{preprocessed_dir}/fold_{fold}'
        if not os.path.exists(fold_dir):
            print(f"✗ Fold {fold} directory not found")
            all_folds_ok = False
            continue
        
        required_files = ['train.csv', 'val.csv', 'metadata.json']
        for file in required_files:
            if not os.path.exists(f'{fold_dir}/{file}'):
                print(f"✗ Fold {fold}: {file} not found")
                all_folds_ok = False
    
    if all_folds_ok:
        print(f"✓ All {config.n_folds} folds verified")
    
    # Load summary
    if os.path.exists(f'{preprocessed_dir}/summary.json'):
        with open(f'{preprocessed_dir}/summary.json', 'r') as f:
            summary = json.load(f)
        print(f"\n=== Dataset Summary ===")
        print(f"Total samples: {summary['total_samples']}")
        print(f"Total features: {summary['total_features']}")
        print(f"Categorical features: {len(summary.get('categorical_features', []))}")
        print(f"Numerical features: {len(summary.get('numerical_features', []))}")
    
    return all_folds_ok


def visualize_batch(dataloader, num_images=8, save_path='batch_visualization.png'):
    """Visualize a batch of preprocessed images"""
    batch = next(iter(dataloader))
    images = batch['image'][:num_images]
    bmis = batch['bmi'][:num_images]
    
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for idx in range(num_images):
        img = images[idx] * std + mean
        img = torch.clamp(img, 0, 1)
        img = img.permute(1, 2, 0).numpy()
        
        axes[idx].imshow(img)
        
        title = f'BMI: {bmis[idx].item():.2f}'
        
        if 'categorical_features' in batch:
            cat_feats = batch['categorical_features'][idx].tolist()
            title += f'\nCat: {cat_feats[:3]}...'
        
        if 'numerical_features' in batch:
            num_feats = batch['numerical_features'][idx][:3].tolist()
            title += f'\nNum: [{num_feats[0]:.2f}, {num_feats[1]:.2f}, ...]'
        
        axes[idx].set_title(title, fontsize=9, fontweight='bold')
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Batch visualization saved to {save_path}")


# ==================== PYTORCH DATASET CLASS ====================

class BMIDataset(Dataset):
    """PyTorch Dataset for BMI prediction"""
    
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        
        # Identify feature columns
        self.categorical_cols = [col for col in self.df.columns if col.endswith('_encoded')]
        self.numerical_cols = [col for col in self.df.columns if 
                              col.endswith(('_zscore', '_percentile', '_squared', 
                                          '_interaction', '_deviation', '_product',
                                          '_above_mean')) and 
                              self.df[col].dtype in [np.float32, np.float64, np.int32, np.int64] and
                              'height' not in col and 'weight' not in col]
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = row['image_path']
        try:
            with Image.open(img_path) as image:
                image = image.convert('RGB')
                if self.transform:
                    image = self.transform(image)
        except:
            # Return zero tensor if image fails
            image = torch.zeros(3, 224, 224)
        
        # Get BMI target
        bmi = torch.tensor(row['bmi'], dtype=torch.float32)
        
        # Get categorical features
        categorical_features = None
        if len(self.categorical_cols) > 0:
            categorical_features = torch.tensor([row[col] for col in self.categorical_cols], 
                                                dtype=torch.long)
        
        # Get numerical features
        numerical_features = None
        if len(self.numerical_cols) > 0:
            numerical_features = torch.tensor([row[col] for col in self.numerical_cols], 
                                             dtype=torch.float32)
        
        return {
            'image': image,
            'bmi': bmi,
            'categorical_features': categorical_features,
            'numerical_features': numerical_features,
            'image_name': row['name']
        }


def create_transforms(config_dict):
    """Create data transforms"""
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=config_dict['mean'], std=config_dict['std'])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=config_dict['mean'], std=config_dict['std'])
    ])
    
    return train_transform, val_transform


def create_dataloaders(fold_dir, config, batch_size=32, num_workers=4):
    """Create PyTorch DataLoaders for a specific fold"""
    train_df = pd.read_csv(f'{fold_dir}/train.csv')
    val_df = pd.read_csv(f'{fold_dir}/val.csv')
    
    # Get config dict
    if isinstance(config, dict):
        config_dict = config
    else:
        config_dict = {
            'mean': config.mean,
            'std': config.std
        }
    
    train_transform, val_transform = create_transforms(config_dict)
    
    train_dataset = BMIDataset(train_df, transform=train_transform)
    val_dataset = BMIDataset(val_df, transform=val_transform)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    return train_loader, val_loader


# ==================== EXAMPLE USAGE ====================
if __name__ == "__main__":
    print("\n" + "=" * 70)
    print("  LEAKAGE-FREE ILLINOIS DOC PREPROCESSING")
    print("=" * 70 + "\n")
    
    # Set paths for Illinois DOC dataset
    person_csv_path = '/kaggle/input/illinois-doc-labeled-faces-dataset/person.csv'
    image_dirs = [
        '/kaggle/input/illinois-doc-labeled-faces-dataset/front/front',
        '/kaggle/input/illinois-doc-labeled-faces-dataset/inmates/inmates',
        '/kaggle/input/illinois-doc-labeled-faces-dataset/side/side'
    ]
    
    # Run preprocessing
    try:
        output_path, fold_splits = main_preprocessing(
            person_csv_path, 
            image_dirs, 
            output_dir='/kaggle/working/illinois_doc_preprocessed',
            show_plots=False  # Set to True in local environment
        )
        
        # Verify preprocessing
        verify_preprocessing(output_path)
        
        # Test loading first fold
        print("\n=== Testing Fold 1 DataLoader ===")
        config, _ = load_config(output_path)
        train_loader, val_loader = create_dataloaders(
            f'{output_path}/fold_1', 
            config, 
            batch_size=16,
            num_workers=2
        )
        
        # Test one batch
        batch = next(iter(train_loader))
        print(f"\nBatch contents:")
        print(f"  - Images: {batch['image'].shape}")
        print(f"  - BMI: {batch['bmi'].shape}")
        if batch['categorical_features'] is not None:
            print(f"  - Categorical features: {batch['categorical_features'].shape}")
        if batch['numerical_features'] is not None:
            print(f"  - Numerical features: {batch['numerical_features'].shape}")
        
        # Visualize batch
        visualize_batch(train_loader, num_images=8, 
                       save_path='illinois_doc_batch_viz_clean.png')
        
        print("\n" + "=" * 70)
        print("  ALL PREPROCESSING STEPS COMPLETED SUCCESSFULLY!")
        print("=" * 70)
        print(f"\nYou can now use the preprocessed data for training:")
        print(f"  - Data location: {output_path}")
        print(f"  - Number of folds: {config.n_folds}")
        
    except Exception as e:
        print(f"\n✗ Error during preprocessing: {str(e)}")
        import traceback
        traceback.print_exc()


  LEAKAGE-FREE ILLINOIS DOC PREPROCESSING

  ILLINOIS DOC DATASET - LEAKAGE-FREE PREPROCESSING PIPELINE

[STEP 1] Loading data and engineering basic features...
=== Starting Basic Feature Engineering ===
Initial shape: (61110, 22)
[1/3] Engineering demographic features...
[2/3] Cleaning categorical features...
[3/3] Engineering image features...

Removed 395 rows with invalid BMI
Final shape after basic feature engineering: (60715, 28)

=== Basic Feature Engineering Summary ===
Total features created: 28

[STEP 2] Matching images to data...

Matched 60715/60715 records to images (100.0%)

[STEP 3] Validating images...
Fast validating 60715 images...


Quick validation:   0%|          | 0/60715 [00:00<?, ?it/s]


✗ Error during preprocessing: 1288


Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pandas/core/indexes/base.py", line 3805, in get_loc
    return self._engine.get_loc(casted_key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "index.pyx", line 167, in pandas._libs.index.IndexEngine.get_loc
  File "index.pyx", line 196, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 2606, in pandas._libs.hashtable.Int64HashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 2630, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 1288

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_19/1602423171.py", line 1128, in <cell line: 0>
    output_path, fold_splits = main_preprocessing(
                               ^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_19/1602423171.py", line 801, in main_preprocessing
    valid_indices = validate_images_fast(df

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
import os
import sys
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
from sklearn.preprocessing import LabelEncoder
import cv2
from PIL import Image
import warnings
warnings.filterwarnings('ignore')


# ==================== CONFIGURATION ====================
class TrainingConfig:
    """Training configuration with all hyperparameters"""
    def __init__(self):
        # Paths
        self.preprocessed_dir = '/kaggle/working/illinois_doc_preprocessed'
        self.output_dir = '/kaggle/working/vit_bmi_checkpoints'
        self.log_dir = '/kaggle/working/logs'
        
        # Training
        self.n_folds = 5
        self.epochs = 20
        self.batch_size = 96
        self.accumulation_steps = 1
        
        # Optimizer
        self.lr = 3e-5
        self.vit_lr_multiplier = 0.1  # Lower LR for pretrained ViT
        self.weight_decay = 0.01
        self.max_grad_norm = 1.0
        
        # Scheduler
        self.warmup_pct = 0.1
        self.scheduler_type = 'onecycle'
        
        # Model architecture
        self.embed_dim = 32
        self.fusion_dim = 512
        self.unfreeze_last_n_layers = 4
        self.use_adaptive_pooling = False  # Set to True to use complex pooling
        
        # Loss
        self.loss_type = 'huber'  # 'mse', 'mae', 'huber', 'smooth_l1'
        self.huber_delta = 1.0
        
        # Regularization
        self.dropout = 0.2
        
        # Early stopping
        self.patience = 7
        self.min_delta = 0.0001
        
        # Data loading
        self.num_workers = 8
        self.pin_memory = True
        self.prefetch_factor = 4
        
        # Mixed precision
        self.use_amp = True
        
        # Validation
        self.val_every_n_epochs = 1
        
        # Grad-CAM visualization
        self.gradcam_frequency = 999  # Generate Grad-CAM every N epochs
        self.gradcam_samples = 4
        
        # Seed
        self.seed = 42


# ==================== GRAD-CAM FOR VIT (FIXED) ====================
class ViTGradCAM:
    """Improved Grad-CAM for Vision Transformer"""
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.activations = None
        self.handles = []
        
        # Target the last transformer layer
        self.target_layer = model.vit.encoder.layer[-1].output
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks"""
        def forward_hook(module, input, output):
            # output is a tuple (hidden_states,) for ViT
            if isinstance(output, tuple):
                self.activations = output[0].detach()
            else:
                self.activations = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            # grad_output is a tuple
            self.gradients = grad_output[0].detach()
        
        handle_f = self.target_layer.register_forward_hook(forward_hook)
        handle_b = self.target_layer.register_full_backward_hook(backward_hook)
        self.handles = [handle_f, handle_b]
    
    def remove_hooks(self):
        """Remove all hooks"""
        for handle in self.handles:
            handle.remove()
    
    def generate_cam(self, input_image, categorical_features=None, numerical_features=None):
        """Generate Class Activation Map"""
        self.model.eval()
        
        # Get device from model parameters
        device = next(self.model.parameters()).device
        input_image = input_image.to(device)
        
        if categorical_features is not None:
            categorical_features = categorical_features.to(device)
        if numerical_features is not None:
            numerical_features = numerical_features.to(device)
        
        # Enable gradient computation
        with torch.enable_grad():
            input_image = input_image.requires_grad_(True)
            
            # Forward pass
            output = self.model(input_image, categorical_features, numerical_features)
            
            # Backward pass
            self.model.zero_grad()
            output.sum().backward()
        
        # Get gradients and activations from hooks
        gradients = self.gradients  # [B, num_patches+1, hidden_dim]
        activations = self.activations  # [B, num_patches+1, hidden_dim]
        
        # Remove CLS token
        gradients = gradients[:, 1:, :]  # [B, num_patches, hidden_dim]
        activations = activations[:, 1:, :]  # [B, num_patches, hidden_dim]
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=2, keepdim=True)  # [B, num_patches, 1]
        
        # Weight the activations
        cam = (weights * activations).sum(dim=2)  # [B, num_patches]
        
        # Apply ReLU
        cam = F.relu(cam)
        
        # Normalize
        cam = cam - cam.min()
        if cam.max() > 0:
            cam = cam / cam.max()
        
        # Reshape to 2D grid
        batch_size = cam.shape[0]
        num_patches = cam.shape[1]
        grid_size = int(np.sqrt(num_patches))
        
        if grid_size * grid_size == num_patches:
            cam = cam.reshape(batch_size, grid_size, grid_size)
        else:
            # Fallback for non-square patches
            cam = cam.reshape(batch_size, 14, 14)
        
        # Resize to image size
        cam = cam.unsqueeze(1)  # [B, 1, H, W]
        cam = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
        cam = cam.squeeze(1)  # [B, H, W]
        
        return cam[0].cpu().numpy()  # Return first image
    
    def visualize_cam(self, input_image, cam):
        """Create visualization overlaying CAM on original image"""
        # Denormalize image
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        img = input_image.cpu().squeeze() * std + mean
        img = torch.clamp(img, 0, 1)
        img = img.permute(1, 2, 0).numpy()
        img = (img * 255).astype(np.uint8)
        
        # Create heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Overlay
        result = heatmap * 0.4 + img * 0.6
        result = np.clip(result, 0, 255).astype(np.uint8)
        
        return img, heatmap, result


# ==================== ADAPTIVE POOLING (SIMPLIFIED) ====================
class SimplifiedAdaptivePooling(nn.Module):
    """Simplified adaptive pooling with attention weights"""
    def __init__(self, hidden_dim=768):
        super(SimplifiedAdaptivePooling, self).__init__()
        
        # Simple attention scorer
        self.attention_scorer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, patch_embeddings):
        """
        Args:
            patch_embeddings: [batch_size, num_patches, hidden_dim]
        Returns:
            pooled_output: [batch_size, hidden_dim]
            attention_weights: [batch_size, num_patches]
        """
        # Compute attention scores
        attention_scores = self.attention_scorer(patch_embeddings).squeeze(-1)  # [B, num_patches]
        attention_weights = F.softmax(attention_scores, dim=-1)  # [B, num_patches]
        
        # Weighted pooling
        pooled_output = torch.sum(patch_embeddings * attention_weights.unsqueeze(-1), dim=1)
        
        return pooled_output, attention_weights


# ==================== FEATURE FUSION (SIMPLIFIED) ====================
class SimplifiedFeatureFusion(nn.Module):
    """Simplified feature fusion for multi-modal inputs"""
    def __init__(self, visual_dim, categorical_dim, numerical_dim, fusion_dim=512):
        super(SimplifiedFeatureFusion, self).__init__()
        
        self.visual_dim = visual_dim
        self.categorical_dim = categorical_dim
        self.numerical_dim = numerical_dim
        
        # Calculate total input dimension
        total_dim = visual_dim
        if categorical_dim > 0:
            total_dim += categorical_dim
        if numerical_dim > 0:
            total_dim += numerical_dim
        
        # Simple fusion network
        self.fusion = nn.Sequential(
            nn.Linear(total_dim, fusion_dim * 2),
            nn.LayerNorm(fusion_dim * 2),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.GELU(),
            nn.Dropout(0.15)
        )
    
    def forward(self, visual_features, categorical_features=None, numerical_features=None):
        """
        Args:
            visual_features: [batch_size, visual_dim]
            categorical_features: [batch_size, categorical_dim] or None
            numerical_features: [batch_size, numerical_dim] or None
        """
        features_list = [visual_features]
        
        if categorical_features is not None and self.categorical_dim > 0:
            features_list.append(categorical_features)
        
        if numerical_features is not None and self.numerical_dim > 0:
            features_list.append(numerical_features)
        
        # Concatenate all features
        fused_features = torch.cat(features_list, dim=-1)
        
        # Apply fusion network
        output = self.fusion(fused_features)
        
        return output


# ==================== REGRESSION HEAD ====================
class RegressionHead(nn.Module):
    """Enhanced regression head with residual connections"""
    def __init__(self, input_dim=512, hidden_dims=[256, 128], dropout=0.2):
        super(RegressionHead, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        self.layers = nn.Sequential(*layers)
        
        # Final prediction layer
        self.output = nn.Linear(prev_dim, 1)
        
        # Residual connection if dimensions allow
        self.residual = nn.Linear(input_dim, hidden_dims[-1]) if input_dim != hidden_dims[-1] else nn.Identity()
    
    def forward(self, x):
        residual = self.residual(x)
        out = self.layers(x)
        
        # Add residual before final prediction
        if out.shape[-1] == residual.shape[-1]:
            out = out + residual
        
        return self.output(out)


# ==================== MAIN MODEL (FIXED) ====================
class ViTForBMI(nn.Module):
    """Vision Transformer for BMI Prediction with Multi-Modal Features"""
    
    def __init__(self, num_categorical_features=0, categorical_vocab_sizes=None, 
                 num_numerical_features=0, embed_dim=32, fusion_dim=512,
                 freeze_backbone=True, unfreeze_last_n_layers=4,
                 use_adaptive_pooling=False, dropout=0.2):
        super(ViTForBMI, self).__init__()
        
        # Load pretrained ViT
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        # Freeze backbone if specified
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False
            
            # Unfreeze last N transformer layers
            for layer in self.vit.encoder.layer[-unfreeze_last_n_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
        
        # Pooling strategy
        self.use_adaptive_pooling = use_adaptive_pooling
        if use_adaptive_pooling:
            self.adaptive_pooling = SimplifiedAdaptivePooling(hidden_dim=768)
        
        # Categorical embeddings (FIXED)
        self.categorical_embeddings = None
        categorical_dim = 0
        if num_categorical_features > 0 and categorical_vocab_sizes:
            self.categorical_embeddings = nn.ModuleList([
                nn.Embedding(vocab_size + 1, embed_dim, padding_idx=0)  # +1 for padding
                for vocab_size in categorical_vocab_sizes
            ])
            categorical_dim = num_categorical_features * embed_dim
        
        # Numerical feature projection
        self.numerical_proj = None
        numerical_dim = 0
        if num_numerical_features > 0:
            numerical_dim = 128
            self.numerical_proj = nn.Sequential(
                nn.Linear(num_numerical_features, 256),
                nn.LayerNorm(256),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(256, numerical_dim),
                nn.LayerNorm(numerical_dim)
            )
        
        # Feature fusion (simplified)
        self.feature_fusion = SimplifiedFeatureFusion(
            visual_dim=768,
            categorical_dim=categorical_dim,
            numerical_dim=numerical_dim,
            fusion_dim=fusion_dim
        )
        
        # Regression head
        self.regression_head = RegressionHead(
            input_dim=fusion_dim,
            hidden_dims=[256, 128],
            dropout=dropout
        )
        
        # Initialize weights
        self._initialize_weights()
        
        print(f"\n{'='*70}")
        print("MODEL ARCHITECTURE - ViT for BMI Prediction")
        print(f"{'='*70}")
        print(f"Visual Feature Dim:      768 (ViT-Base)")
        print(f"Pooling Strategy:        {'Adaptive' if use_adaptive_pooling else 'CLS Token'}")
        print(f"Categorical Features:    {num_categorical_features} -> {categorical_dim}")
        print(f"Numerical Features:      {num_numerical_features} -> {numerical_dim}")
        print(f"Fusion Dim:              {fusion_dim}")
        print(f"Regression Head:         {fusion_dim} -> 256 -> 128 -> 1")
        print(f"Unfrozen ViT Layers:     Last {unfreeze_last_n_layers}")
        print(f"Dropout:                 {dropout}")
        print(f"{'='*70}\n")
    
    def _initialize_weights(self):
        """Initialize custom layer weights"""
        modules = [self.feature_fusion, self.regression_head]
        if self.use_adaptive_pooling:
            modules.append(self.adaptive_pooling)
        if self.numerical_proj is not None:
            modules.append(self.numerical_proj)
        
        for module in modules:
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, pixel_values, categorical_features=None, numerical_features=None):
        """
        Args:
            pixel_values: [batch_size, 3, 224, 224]
            categorical_features: [batch_size, num_categorical] or None
            numerical_features: [batch_size, num_numerical] or None
        Returns:
            bmi_prediction: [batch_size, 1]
        """
        # Extract ViT features
        outputs = self.vit(pixel_values=pixel_values)
        last_hidden_state = outputs.last_hidden_state  # [B, num_patches+1, 768]
        
        # Get visual features
        if self.use_adaptive_pooling:
            # Remove CLS token, use adaptive pooling on patch tokens
            patch_tokens = last_hidden_state[:, 1:, :]  # [B, num_patches, 768]
            visual_features, _ = self.adaptive_pooling(patch_tokens)  # [B, 768]
        else:
            # Use CLS token (simpler and often works better)
            visual_features = last_hidden_state[:, 0, :]  # [B, 768]
        
        # Process categorical features (FIXED)
        cat_features = None
        if categorical_features is not None and self.categorical_embeddings is not None:
            cat_embeds = []
            for i, embedding_layer in enumerate(self.categorical_embeddings):
                feat = categorical_features[:, i]
                
                # Handle negative values (unknown) -> map to 0 (padding)
                feat = torch.where(feat < 0, torch.zeros_like(feat), feat + 1)
                
                # Clamp to valid range
                max_idx = embedding_layer.num_embeddings - 1
                feat = torch.clamp(feat, min=0, max=max_idx)
                
                emb = embedding_layer(feat)
                cat_embeds.append(emb)
            
            cat_features = torch.cat(cat_embeds, dim=-1)  # [B, categorical_dim]
        
        # Process numerical features
        num_features = None
        if numerical_features is not None and self.numerical_proj is not None:
            # Handle NaN and inf values
            numerical_features = torch.nan_to_num(numerical_features, nan=0.0, posinf=0.0, neginf=0.0)
            num_features = self.numerical_proj(numerical_features)  # [B, numerical_dim]
        
        # Feature fusion
        fused_features = self.feature_fusion(visual_features, cat_features, num_features)
        
        # Regression
        bmi_prediction = self.regression_head(fused_features)
        
        return bmi_prediction
    
    def visualize_predictions_with_gradcam(self, batch, gradcam, save_dir, num_samples=8):
        """Generate Grad-CAM visualizations (FIXED)"""
        self.eval()
        os.makedirs(save_dir, exist_ok=True)
        
        # Get device from model parameters
        device = next(self.parameters()).device
        
        images = batch['image'][:num_samples]
        targets = batch['bmi'][:num_samples]
        image_names = batch['image_name'][:num_samples]
        
        categorical_features = batch.get('categorical_features')
        if categorical_features is not None:
            categorical_features = categorical_features[:num_samples]
        
        numerical_features = batch.get('numerical_features')
        if numerical_features is not None:
            numerical_features = numerical_features[:num_samples]
        
        # Get predictions
        with torch.no_grad():
            predictions = self(
                images.to(device), 
                categorical_features.to(device) if categorical_features is not None else None,
                numerical_features.to(device) if numerical_features is not None else None
            )
        
        # Create figure
        fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
        if num_samples == 1:
            axes = axes.reshape(1, -1)
        
        for idx in range(num_samples):
            img_tensor = images[idx:idx+1]
            cat_feat = categorical_features[idx:idx+1] if categorical_features is not None else None
            num_feat = numerical_features[idx:idx+1] if numerical_features is not None else None
            
            # Generate Grad-CAM
            try:
                cam = gradcam.generate_cam(img_tensor, cat_feat, num_feat)
                original, heatmap, overlay = gradcam.visualize_cam(img_tensor, cam)
            except Exception as e:
                print(f"Warning: Grad-CAM failed for sample {idx}: {e}")
                # Use dummy images
                original = np.zeros((224, 224, 3), dtype=np.uint8)
                heatmap = np.zeros((224, 224, 3), dtype=np.uint8)
                overlay = np.zeros((224, 224, 3), dtype=np.uint8)
            
            pred_bmi = predictions[idx].item()
            true_bmi = targets[idx].item()
            error = abs(pred_bmi - true_bmi)
            
            # Original image
            axes[idx, 0].imshow(original)
            axes[idx, 0].set_title(f'Original\n{image_names[idx][:20]}', fontsize=10)
            axes[idx, 0].axis('off')
            
            # Heatmap
            axes[idx, 1].imshow(heatmap)
            axes[idx, 1].set_title(f'Attention Heatmap', fontsize=10)
            axes[idx, 1].axis('off')
            
            # Overlay
            axes[idx, 2].imshow(overlay)
            axes[idx, 2].set_title(f'Overlay', fontsize=10)
            axes[idx, 2].axis('off')
            
            # Prediction comparison
            axes[idx, 3].bar(['True', 'Pred'], [true_bmi, pred_bmi], 
                           color=['steelblue', 'coral'], alpha=0.7, edgecolor='black')
            axes[idx, 3].set_ylabel('BMI', fontsize=10)
            axes[idx, 3].set_title(f'True: {true_bmi:.2f} | Pred: {pred_bmi:.2f}\nError: {error:.2f}', 
                                  fontsize=10, fontweight='bold')
            axes[idx, 3].grid(True, alpha=0.3, axis='y')
            axes[idx, 3].set_ylim(15, 50)
        
        plt.tight_layout()
        plt.savefig(f'{save_dir}/gradcam_visualization.png', dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"  ✓ Grad-CAM saved to {save_dir}/gradcam_visualization.png")


def count_parameters(model):
    """Count trainable and total parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    print(f"\nModel Parameter Summary:")
    print(f"{'='*50}")
    print(f"Total Parameters:      {total_params:,}")
    print(f"Trainable Parameters:  {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    print(f"Frozen Parameters:     {frozen_params:,} ({frozen_params/total_params*100:.2f}%)")
    print(f"{'='*50}\n")
    
    return total_params, trainable_params


# ==================== DATASET (FIXED) ====================
class BMIDataset(torch.utils.data.Dataset):
    """Dataset for BMI prediction with proper feature detection"""
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        
        # Identify categorical features (encoded columns)
        self.categorical_cols = [col for col in self.df.columns if col.endswith('_encoded')]
        
        # Identify numerical features (FIXED to match preprocessing)
        self.numerical_cols = []
        for col in self.df.columns:
            # Skip non-feature columns
            if col in ['name', 'image_path', 'bmi', 'image_stem', 'has_image']:
                continue
            
            # Check if it's a numerical feature we want
            if self.df[col].dtype in [np.float32, np.float64, np.int32, np.int64]:
                # Include features that end with specific suffixes
                if any(col.endswith(suffix) for suffix in [
                    '_zscore', '_percentile', '_squared', '_deviation', 
                    '_above_mean', '_product', '_interaction'
                ]):
                    self.numerical_cols.append(col)
                # Include specific named features
                elif col in ['bsa', 'ponderal_index', 'height_m', 'weight_kg', 
                           'age', 'age_decade', 'height', 'weight']:
                    self.numerical_cols.append(col)
        
        print(f"  Dataset features: {len(self.categorical_cols)} categorical, {len(self.numerical_cols)} numerical")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = row['image_path']
        try:
            with Image.open(img_path) as image:
                image = image.convert('RGB')
                if self.transform:
                    image = self.transform(image)
        except Exception as e:
            # Return zero tensor if image fails
            print(f"Warning: Failed to load image {img_path}: {e}")
            image = torch.zeros(3, 224, 224)
        
        # Get BMI target
        bmi = torch.tensor(row['bmi'], dtype=torch.float32)
        
        # Get categorical features
        categorical_features = None
        if len(self.categorical_cols) > 0:
            categorical_features = torch.tensor(
                [row[col] for col in self.categorical_cols], 
                dtype=torch.long
            )
        
        # Get numerical features
        numerical_features = None
        if len(self.numerical_cols) > 0:
            numerical_features = torch.tensor(
                [row[col] for col in self.numerical_cols], 
                dtype=torch.float32
            )
        
        return {
            'image': image,
            'bmi': bmi,
            'categorical_features': categorical_features,
            'numerical_features': numerical_features,
            'image_name': row['name']
        }


# ==================== DATA LOADING UTILITIES ====================
def get_transforms(config_dict):
    """Create data transforms with conservative augmentation for faces"""
    from torchvision import transforms
    
    # Conservative augmentation for face images
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.3),  # Reduced from 0.5
        transforms.RandomRotation(degrees=5),     # Reduced from 10
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05),  # Reduced
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),  # Small shifts
        transforms.ToTensor(),
        transforms.Normalize(mean=config_dict['mean'], std=config_dict['std'])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=config_dict['mean'], std=config_dict['std'])
    ])
    
    return train_transform, val_transform


def load_config_and_encoders(preprocessed_dir):
    """Load preprocessing config and feature encoders"""
    with open(f'{preprocessed_dir}/config.json', 'r') as f:
        config_dict = json.load(f)
    
    with open(f'{preprocessed_dir}/feature_encoders.json', 'r') as f:
        encoder_dict = json.load(f)
    
    # Reconstruct encoders
    feature_encoders = {}
    for feature, info in encoder_dict.items():
        le = LabelEncoder()
        le.classes_ = np.array(info['classes'])
        feature_encoders[feature] = le
    
    return config_dict, feature_encoders


def get_categorical_vocab_sizes(feature_encoders):
    """Get vocabulary sizes for categorical features (FIXED)"""
    # Feature encoders are stored with original names ('race', 'sex', etc.)
    # Return in sorted order for consistency
    vocab_sizes = []
    for feat_name in sorted(feature_encoders.keys()):
        vocab_sizes.append(len(feature_encoders[feat_name].classes_))
    return vocab_sizes


def validate_fold_features(train_df, val_df):
    """Validate that train and val have same features"""
    train_features = set(train_df.columns)
    val_features = set(val_df.columns)
    
    if train_features != val_features:
        missing_in_val = train_features - val_features
        missing_in_train = val_features - train_features
        
        error_msg = "Feature mismatch between train and validation!\n"
        if missing_in_val:
            error_msg += f"Missing in val: {missing_in_val}\n"
        if missing_in_train:
            error_msg += f"Missing in train: {missing_in_train}\n"
        
        raise ValueError(error_msg)
    
    print(f"  ✓ Feature validation passed: {len(train_features)} features")


def get_dataloaders(fold, config, preprocess_config):
    """Create train and validation dataloaders with validation"""
    train_transform, val_transform = get_transforms(preprocess_config)
    
    # Load data
    train_df = pd.read_csv(f'{config.preprocessed_dir}/fold_{fold}/train.csv')
    val_df = pd.read_csv(f'{config.preprocessed_dir}/fold_{fold}/val.csv')
    
    print(f"\nFold {fold} dataset sizes:")
    print(f"  Train: {len(train_df):,} samples")
    print(f"  Val:   {len(val_df):,} samples")
    
    # Validate features match
    validate_fold_features(train_df, val_df)
    
    # Create datasets
    train_dataset = BMIDataset(train_df, transform=train_transform)
    val_dataset = BMIDataset(val_df, transform=val_transform)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=True if config.num_workers > 0 else False,
        drop_last=True  # Drop last incomplete batch for stable training
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=True if config.num_workers > 0 else False
    )
    
    return train_loader, val_loader


# ==================== LOSS FUNCTIONS ====================
def get_loss_function(config):
    """Get loss function based on config"""
    if config.loss_type == 'mse':
        return nn.MSELoss()
    elif config.loss_type == 'mae':
        return nn.L1Loss()
    elif config.loss_type == 'huber':
        return nn.HuberLoss(delta=config.huber_delta)
    elif config.loss_type == 'smooth_l1':
        return nn.SmoothL1Loss()
    else:
        raise ValueError(f"Unknown loss type: {config.loss_type}")


# ==================== METRICS ====================
def calculate_metrics(predictions, targets):
    """Calculate regression metrics"""
    predictions = np.array(predictions).flatten()
    targets = np.array(targets).flatten()
    
    mae = np.mean(np.abs(predictions - targets))
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    
    # R² score
    ss_res = np.sum((targets - predictions) ** 2)
    ss_tot = np.sum((targets - np.mean(targets)) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0
    
    return {
        'mae': mae,
        'mse': mse,
        'rmse': rmse,
        'r2': r2
    }


# ==================== UTILITIES ====================
def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_device():
    """Get available device"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"✓ Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        device = torch.device('cpu')
        print("⚠ Using CPU")
    return device


# ==================== EARLY STOPPING ====================
class EarlyStopping:
    """Early stopping handler"""
    def __init__(self, patience=7, min_delta=0.0001, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'min':
            improved = score < (self.best_score - self.min_delta)
        else:
            improved = score > (self.best_score + self.min_delta)
        
        if improved:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False


# ==================== TRAINING ====================
def train_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, 
                device, config, epoch):
    """Train for one epoch (FIXED gradient accumulation)"""
    model.train()
    
    total_loss = 0
    num_batches = len(train_loader)
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}', leave=False)
    
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        images = batch['image'].to(device, non_blocking=True)
        targets = batch['bmi'].to(device, non_blocking=True)
        
        categorical_features = batch.get('categorical_features')
        if categorical_features is not None:
            categorical_features = categorical_features.to(device, non_blocking=True)
        
        numerical_features = batch.get('numerical_features')
        if numerical_features is not None:
            numerical_features = numerical_features.to(device, non_blocking=True)
        
        # Forward pass with mixed precision
        with autocast(enabled=config.use_amp):
            predictions = model(images, categorical_features, numerical_features)
            loss = criterion(predictions.squeeze(), targets)
            
            # Scale loss for gradient accumulation
            loss = loss / config.accumulation_steps
        
        # Backward pass
        scaler.scale(loss).backward()
        
        # Update weights every accumulation_steps OR at the last batch (FIXED)
        is_accumulation_step = (batch_idx + 1) % config.accumulation_steps == 0
        is_last_batch = (batch_idx + 1) == num_batches
        
        if is_accumulation_step or is_last_batch:
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            
            # Optimizer step
            scaler.step(optimizer)
            scaler.update()
            
            # Scheduler step
            if scheduler is not None:
                scheduler.step()
            
            optimizer.zero_grad()
        
        # Track loss
        total_loss += loss.item() * config.accumulation_steps
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{total_loss / (batch_idx + 1):.4f}'})
    
    return total_loss / num_batches


@torch.no_grad()
def validate(model, val_loader, criterion, device, config):
    """Validate model"""
    model.eval()
    
    total_loss = 0
    all_predictions = []
    all_targets = []
    
    pbar = tqdm(val_loader, desc='Validating', leave=False)
    
    for batch in pbar:
        images = batch['image'].to(device, non_blocking=True)
        targets = batch['bmi'].to(device, non_blocking=True)
        
        categorical_features = batch.get('categorical_features')
        if categorical_features is not None:
            categorical_features = categorical_features.to(device, non_blocking=True)
        
        numerical_features = batch.get('numerical_features')
        if numerical_features is not None:
            numerical_features = numerical_features.to(device, non_blocking=True)
        
        # Forward pass
        with autocast(enabled=config.use_amp):
            predictions = model(images, categorical_features, numerical_features)
            loss = criterion(predictions.squeeze(), targets)
        
        total_loss += loss.item()
        
        # Collect predictions and targets
        all_predictions.extend(predictions.cpu().numpy().flatten())
        all_targets.extend(targets.cpu().numpy().flatten())
    
    avg_loss = total_loss / len(val_loader)
    metrics = calculate_metrics(all_predictions, all_targets)
    metrics['loss'] = avg_loss
    
    return metrics


# ==================== CHECKPOINTING ====================
def save_checkpoint(model, optimizer, scheduler, epoch, metrics, fold, config):
    """Save model checkpoint"""
    os.makedirs(config.output_dir, exist_ok=True)
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'metrics': metrics,
        'config': config.__dict__
    }
    
    path = f'{config.output_dir}/fold_{fold}_best.pt'
    torch.save(checkpoint, path)
    
    return path


def load_checkpoint(path, model, optimizer=None, scheduler=None):
    """Load model checkpoint"""
    checkpoint = torch.load(path, map_location='cpu', weights_only=False)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    return checkpoint['epoch'], checkpoint['metrics']


# ==================== TRAINING LOOP ====================
def train_fold(fold, config, device):
    """Train model for one fold"""
    print(f"\n{'='*70}")
    print(f"TRAINING FOLD {fold}/{config.n_folds}")
    print(f"{'='*70}\n")
    
    # Load preprocessing config and encoders
    preprocess_config, feature_encoders = load_config_and_encoders(config.preprocessed_dir)
    
    # Get dataloaders
    train_loader, val_loader = get_dataloaders(fold, config, preprocess_config)
    
    # Get feature dimensions from first batch
    sample_batch = next(iter(train_loader))
    num_categorical = sample_batch['categorical_features'].shape[1] if sample_batch['categorical_features'] is not None else 0
    num_numerical = sample_batch['numerical_features'].shape[1] if sample_batch['numerical_features'] is not None else 0
    categorical_vocab_sizes = get_categorical_vocab_sizes(feature_encoders)
    
    print(f"\nFeature configuration:")
    print(f"  Categorical features: {num_categorical}")
    print(f"  Numerical features:   {num_numerical}")
    print(f"  Vocabulary sizes:     {categorical_vocab_sizes}\n")
    
    # Create model
    model = ViTForBMI(
        num_categorical_features=num_categorical,
        categorical_vocab_sizes=categorical_vocab_sizes,
        num_numerical_features=num_numerical,
        embed_dim=config.embed_dim,
        fusion_dim=config.fusion_dim,
        unfreeze_last_n_layers=config.unfreeze_last_n_layers,
        use_adaptive_pooling=config.use_adaptive_pooling,
        dropout=config.dropout
    ).to(device)
    
    # Count parameters
    count_parameters(model)
    
    # Initialize Grad-CAM
    gradcam = ViTGradCAM(model)
    print("✓ Grad-CAM initialized\n")
    
    # Create visualization directory
    viz_dir = f'{config.output_dir}/gradcam_fold_{fold}'
    os.makedirs(viz_dir, exist_ok=True)
    
    # Get loss function
    criterion = get_loss_function(config)
    print(f"Loss function: {config.loss_type}\n")
    
    # Create optimizer with different learning rates for ViT and new layers
    vit_params = [p for n, p in model.named_parameters() if 'vit' in n and p.requires_grad]
    other_params = [p for n, p in model.named_parameters() if 'vit' not in n and p.requires_grad]
    
    optimizer = AdamW([
        {'params': vit_params, 'lr': config.lr * config.vit_lr_multiplier},
        {'params': other_params, 'lr': config.lr}
    ], weight_decay=config.weight_decay)
    
    # Calculate total training steps (FIXED)
    steps_per_epoch = len(train_loader) // config.accumulation_steps
    if len(train_loader) % config.accumulation_steps != 0:
        steps_per_epoch += 1  # Account for partial accumulation
    
    total_steps = steps_per_epoch * config.epochs
    warmup_steps = int(total_steps * config.warmup_pct)
    
    print(f"Training schedule:")
    print(f"  Steps per epoch: {steps_per_epoch}")
    print(f"  Total steps: {total_steps}")
    print(f"  Warmup steps: {warmup_steps}\n")
    
    # Create scheduler
    if config.scheduler_type == 'onecycle':
        scheduler = OneCycleLR(
            optimizer,
            max_lr=[config.lr * config.vit_lr_multiplier, config.lr],
            total_steps=total_steps,
            pct_start=config.warmup_pct,
            anneal_strategy='cos',
            div_factor=25,
            final_div_factor=10000
        )
    elif config.scheduler_type == 'cosine':
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=steps_per_epoch * 5,
            T_mult=1
        )
    else:
        scheduler = None
    
    # Mixed precision scaler
    scaler = GradScaler(enabled=config.use_amp)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=config.patience, min_delta=config.min_delta)
    
    # Training history
    history = {
        'train_losses': [],
        'val_losses': [],
        'val_maes': [],
        'val_mses': [],
        'val_rmses': [],
        'val_r2s': []
    }
    
    best_mae = float('inf')
    best_metrics = None
    
    print(f"Starting training...")
    print(f"{'='*70}\n")
    
    start_time = time.time()
    
    for epoch in range(1, config.epochs + 1):
        epoch_start = time.time()
        
        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, 
                                scheduler, scaler, device, config, epoch)
        history['train_losses'].append(train_loss)
        
        # Validate
        if epoch % config.val_every_n_epochs == 0 or epoch == config.epochs:
            val_metrics = validate(model, val_loader, criterion, device, config)
            
            history['val_losses'].append(val_metrics['loss'])
            history['val_maes'].append(val_metrics['mae'])
            history['val_mses'].append(val_metrics['mse'])
            history['val_rmses'].append(val_metrics['rmse'])
            history['val_r2s'].append(val_metrics['r2'])
            
            epoch_time = time.time() - epoch_start
            
            print(f"Epoch {epoch:02d}/{config.epochs} | "
                  f"Time: {epoch_time:.1f}s | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_metrics['loss']:.4f} | "
                  f"MAE: {val_metrics['mae']:.4f} | "
                  f"RMSE: {val_metrics['rmse']:.4f} | "
                  f"R²: {val_metrics['r2']:.4f}")
            
            # Generate Grad-CAM visualizations periodically
            if epoch % config.gradcam_frequency == 0 or epoch == 1:
                print("  Generating Grad-CAM visualizations...")
                try:
                    val_batch = next(iter(val_loader))
                    model.visualize_predictions_with_gradcam(
                        val_batch, gradcam, 
                        f'{viz_dir}/epoch_{epoch}',
                        num_samples=config.gradcam_samples
                    )
                except Exception as e:
                    print(f"  ⚠ Grad-CAM generation failed: {str(e)}")
            
            # Save best model
            if val_metrics['mae'] < best_mae:
                best_mae = val_metrics['mae']
                best_metrics = val_metrics.copy()
                checkpoint_path = save_checkpoint(model, optimizer, scheduler, 
                                                 epoch, val_metrics, fold, config)
                print(f"  ✓ New best model saved (MAE: {best_mae:.4f})")
            
            # Early stopping check
            early_stopping(val_metrics['mae'], epoch)
            if early_stopping.early_stop:
                print(f"\n  Early stopping triggered at epoch {epoch}")
                print(f"  Best epoch was {early_stopping.best_epoch}")
                break
        else:
            epoch_time = time.time() - epoch_start
            print(f"Epoch {epoch:02d}/{config.epochs} | "
                  f"Time: {epoch_time:.1f}s | "
                  f"Train Loss: {train_loss:.4f}")
    
    total_time = time.time() - start_time
    
    # Load best model for final evaluation
    checkpoint_path = f'{config.output_dir}/fold_{fold}_best.pt'
    if os.path.exists(checkpoint_path):
        _, _ = load_checkpoint(checkpoint_path, model)
        final_metrics = validate(model, val_loader, criterion, device, config)
    else:
        final_metrics = best_metrics
    
    # Generate final Grad-CAM visualizations
    print("\nGenerating final Grad-CAM visualizations...")
    try:
        val_batch = next(iter(val_loader))
        model.visualize_predictions_with_gradcam(
            val_batch, gradcam,
            f'{viz_dir}/final',
            num_samples=16  # More samples for final visualization
        )
        print(f"✓ Final Grad-CAM visualizations saved")
    except Exception as e:
        print(f"⚠ Final Grad-CAM generation failed: {str(e)}")
    
    # Cleanup hooks
    gradcam.remove_hooks()
    
    print(f"\n{'='*70}")
    print(f"FOLD {fold} RESULTS")
    print(f"{'='*70}")
    print(f"Training time:  {total_time/60:.2f} minutes")
    print(f"MAE:            {final_metrics['mae']:.4f}")
    print(f"MSE:            {final_metrics['mse']:.4f}")
    print(f"RMSE:           {final_metrics['rmse']:.4f}")
    print(f"R²:             {final_metrics['r2']:.4f}")
    print(f"{'='*70}\n")
    
    # Add final metrics to history
    history['final_metrics'] = final_metrics
    history['training_time'] = total_time
    
    return history


# ==================== VISUALIZATION ====================
def plot_training_history(all_histories, config):
    """Plot training history for all folds"""
    os.makedirs(config.log_dir, exist_ok=True)
    
    fig = plt.figure(figsize=(20, 12))
    
    # Training loss
    ax1 = plt.subplot(3, 3, 1)
    for fold_idx, history in enumerate(all_histories):
        epochs = range(1, len(history['train_losses']) + 1)
        ax1.plot(epochs, history['train_losses'], marker='o', 
                label=f'Fold {fold_idx+1}', linewidth=2, markersize=4, alpha=0.7)
    ax1.set_xlabel('Epoch', fontweight='bold')
    ax1.set_ylabel('Loss', fontweight='bold')
    ax1.set_title('Training Loss', fontweight='bold', fontsize=12)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Validation loss
    ax2 = plt.subplot(3, 3, 2)
    for fold_idx, history in enumerate(all_histories):
        val_epochs = np.arange(config.val_every_n_epochs, 
                               len(history['val_losses']) * config.val_every_n_epochs + 1, 
                               config.val_every_n_epochs)
        ax2.plot(val_epochs, history['val_losses'], marker='s', 
                label=f'Fold {fold_idx+1}', linewidth=2, markersize=4, alpha=0.7)
    ax2.set_xlabel('Epoch', fontweight='bold')
    ax2.set_ylabel('Loss', fontweight='bold')
    ax2.set_title('Validation Loss', fontweight='bold', fontsize=12)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # MAE
    ax3 = plt.subplot(3, 3, 3)
    for fold_idx, history in enumerate(all_histories):
        val_epochs = np.arange(config.val_every_n_epochs, 
                               len(history['val_maes']) * config.val_every_n_epochs + 1, 
                               config.val_every_n_epochs)
        ax3.plot(val_epochs, history['val_maes'], marker='^', 
                label=f'Fold {fold_idx+1}', linewidth=2, markersize=4, alpha=0.7)
    ax3.set_xlabel('Epoch', fontweight='bold')
    ax3.set_ylabel('MAE', fontweight='bold')
    ax3.set_title('Validation MAE', fontweight='bold', fontsize=12)
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # RMSE
    ax4 = plt.subplot(3, 3, 4)
    for fold_idx, history in enumerate(all_histories):
        val_epochs = np.arange(config.val_every_n_epochs, 
                               len(history['val_rmses']) * config.val_every_n_epochs + 1, 
                               config.val_every_n_epochs)
        ax4.plot(val_epochs, history['val_rmses'], marker='d', 
                label=f'Fold {fold_idx+1}', linewidth=2, markersize=4, alpha=0.7)
    ax4.set_xlabel('Epoch', fontweight='bold')
    ax4.set_ylabel('RMSE', fontweight='bold')
    ax4.set_title('Validation RMSE', fontweight='bold', fontsize=12)
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # R² Score
    ax5 = plt.subplot(3, 3, 5)
    for fold_idx, history in enumerate(all_histories):
        val_epochs = np.arange(config.val_every_n_epochs, 
                               len(history['val_r2s']) * config.val_every_n_epochs + 1, 
                               config.val_every_n_epochs)
        ax5.plot(val_epochs, history['val_r2s'], marker='*', 
                label=f'Fold {fold_idx+1}', linewidth=2, markersize=6, alpha=0.7)
    ax5.set_xlabel('Epoch', fontweight='bold')
    ax5.set_ylabel('R²', fontweight='bold')
    ax5.set_title('Validation R²', fontweight='bold', fontsize=12)
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    # Final metrics comparison
    fold_labels = [f'Fold {i+1}' for i in range(len(all_histories))]
    
    # Final MAE
    ax6 = plt.subplot(3, 3, 6)
    maes = [h['final_metrics']['mae'] for h in all_histories]
    bars = ax6.bar(fold_labels, maes, color='steelblue', edgecolor='black', alpha=0.7)
    ax6.axhline(np.mean(maes), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {np.mean(maes):.4f}')
    for bar, val in zip(bars, maes):
        ax6.text(bar.get_x() + bar.get_width()/2., bar.get_height(), 
                f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    ax6.set_ylabel('MAE', fontweight='bold')
    ax6.set_title('Final MAE by Fold', fontweight='bold', fontsize=12)
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')
    
    # Final RMSE
    ax7 = plt.subplot(3, 3, 7)
    rmses = [h['final_metrics']['rmse'] for h in all_histories]
    bars = ax7.bar(fold_labels, rmses, color='coral', edgecolor='black', alpha=0.7)
    ax7.axhline(np.mean(rmses), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {np.mean(rmses):.4f}')
    for bar, val in zip(bars, rmses):
        ax7.text(bar.get_x() + bar.get_width()/2., bar.get_height(), 
                f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    ax7.set_ylabel('RMSE', fontweight='bold')
    ax7.set_title('Final RMSE by Fold', fontweight='bold', fontsize=12)
    ax7.legend()
    ax7.grid(True, alpha=0.3, axis='y')
    
    # Final R²
    ax8 = plt.subplot(3, 3, 8)
    r2s = [h['final_metrics']['r2'] for h in all_histories]
    bars = ax8.bar(fold_labels, r2s, color='lightgreen', edgecolor='black', alpha=0.7)
    ax8.axhline(np.mean(r2s), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {np.mean(r2s):.4f}')
    for bar, val in zip(bars, r2s):
        ax8.text(bar.get_x() + bar.get_width()/2., bar.get_height(), 
                f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    ax8.set_ylabel('R²', fontweight='bold')
    ax8.set_title('Final R² by Fold', fontweight='bold', fontsize=12)
    ax8.legend()
    ax8.grid(True, alpha=0.3, axis='y')
    
    # Training time
    ax9 = plt.subplot(3, 3, 9)
    times = [h['training_time']/60 for h in all_histories]
    bars = ax9.bar(fold_labels, times, color='plum', edgecolor='black', alpha=0.7)
    for bar, val in zip(bars, times):
        ax9.text(bar.get_x() + bar.get_width()/2., bar.get_height(), 
                f'{val:.1f}m', ha='center', va='bottom', fontweight='bold')
    ax9.set_ylabel('Time (minutes)', fontweight='bold')
    ax9.set_title('Training Time by Fold', fontweight='bold', fontsize=12)
    ax9.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plot_path = f'{config.log_dir}/training_results.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"✓ Training plots saved to: {plot_path}")


def save_results(all_histories, config):
    """Save training results to JSON"""
    os.makedirs(config.log_dir, exist_ok=True)
    
    def convert_to_native(obj):
        """Convert numpy types to native Python types"""
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: convert_to_native(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_native(item) for item in obj]
        else:
            return obj
    
    results = {
        'config': config.__dict__,
        'folds': []
    }
    
    for fold_idx, history in enumerate(all_histories):
        fold_results = {
            'fold': fold_idx + 1,
            'final_metrics': convert_to_native(history['final_metrics']),
            'training_time': float(history['training_time']),
            'num_epochs': len(history['train_losses'])
        }
        results['folds'].append(fold_results)
    
    # Compute aggregate statistics
    results['aggregate'] = {
        'mae': {
            'mean': float(np.mean([h['final_metrics']['mae'] for h in all_histories])),
            'std': float(np.std([h['final_metrics']['mae'] for h in all_histories])),
            'min': float(np.min([h['final_metrics']['mae'] for h in all_histories])),
            'max': float(np.max([h['final_metrics']['mae'] for h in all_histories]))
        },
        'rmse': {
            'mean': float(np.mean([h['final_metrics']['rmse'] for h in all_histories])),
            'std': float(np.std([h['final_metrics']['rmse'] for h in all_histories])),
            'min': float(np.min([h['final_metrics']['rmse'] for h in all_histories])),
            'max': float(np.max([h['final_metrics']['rmse'] for h in all_histories]))
        },
        'r2': {
            'mean': float(np.mean([h['final_metrics']['r2'] for h in all_histories])),
            'std': float(np.std([h['final_metrics']['r2'] for h in all_histories])),'min': float(np.min([h['final_metrics']['r2'] for h in all_histories])),
            'max': float(np.max([h['final_metrics']['r2'] for h in all_histories]))
        },
        'total_training_time': float(sum([h['training_time'] for h in all_histories]))
    }
    
    results_path = f'{config.log_dir}/results.json'
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"✓ Results saved to: {results_path}")


# ==================== MAIN TRAINING PIPELINE ====================
def main():
    """Main training pipeline"""
    print("\n" + "="*70)
    print("  ViT FOR BMI PREDICTION - FIXED VERSION")
    print("="*70 + "\n")
    
    # Configuration
    config = TrainingConfig()
    
    # Set seed
    set_seed(config.seed)
    
    # Get device
    device = get_device()
    
    print(f"\nConfiguration:")
    print(f"  Folds:              {config.n_folds}")
    print(f"  Epochs:             {config.epochs}")
    print(f"  Batch size:         {config.batch_size}")
    print(f"  Learning rate:      {config.lr}")
    print(f"  ViT LR multiplier:  {config.vit_lr_multiplier}")
    print(f"  Loss function:      {config.loss_type}")
    print(f"  Mixed precision:    {config.use_amp}")
    print(f"  Adaptive pooling:   {config.use_adaptive_pooling}")
    print(f"  Output dir:         {config.output_dir}")
    print(f"  Log dir:            {config.log_dir}\n")
    
    # Train all folds
    all_histories = []
    
    for fold in range(1, config.n_folds + 1):
        try:
            history = train_fold(fold, config, device)
            all_histories.append(history)
        except Exception as e:
            print(f"\n✗ Error training fold {fold}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    if len(all_histories) == 0:
        print("\n✗ No folds completed successfully!")
        return
    
    # Plot results
    print(f"\n{'='*70}")
    print("GENERATING PLOTS AND SAVING RESULTS")
    print(f"{'='*70}\n")
    
    try:
        plot_training_history(all_histories, config)
        save_results(all_histories, config)
    except Exception as e:
        print(f"⚠ Error generating plots/results: {str(e)}")
    
    # Print final summary
    print(f"\n{'='*70}")
    print("FINAL CROSS-VALIDATION RESULTS")
    print(f"{'='*70}")
    print(f"{'Metric':<10} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
    print("-"*70)
    
    maes = [h['final_metrics']['mae'] for h in all_histories]
    rmses = [h['final_metrics']['rmse'] for h in all_histories]
    r2s = [h['final_metrics']['r2'] for h in all_histories]
    
    print(f"{'MAE':<10} {np.mean(maes):<12.4f} {np.std(maes):<12.4f} "
          f"{np.min(maes):<12.4f} {np.max(maes):<12.4f}")
    print(f"{'RMSE':<10} {np.mean(rmses):<12.4f} {np.std(rmses):<12.4f} "
          f"{np.min(rmses):<12.4f} {np.max(rmses):<12.4f}")
    print(f"{'R²':<10} {np.mean(r2s):<12.4f} {np.std(r2s):<12.4f} "
          f"{np.min(r2s):<12.4f} {np.max(r2s):<12.4f}")
    
    total_time = sum([h['training_time'] for h in all_histories])
    print(f"\nTotal training time: {total_time/60:.2f} minutes ({total_time/3600:.2f} hours)")
    print(f"Completed folds: {len(all_histories)}/{config.n_folds}")
    print(f"{'='*70}\n")
    
    # Performance interpretation
    mean_mae = np.mean(maes)
    mean_r2 = np.mean(r2s)
    
    print("Performance Interpretation:")
    print("-" * 70)
    if mean_mae < 2.0:
        print(f"✓ Excellent MAE: {mean_mae:.2f} BMI units")
    elif mean_mae < 3.0:
        print(f"✓ Good MAE: {mean_mae:.2f} BMI units")
    elif mean_mae < 4.0:
        print(f"○ Acceptable MAE: {mean_mae:.2f} BMI units")
    else:
        print(f"⚠ High MAE: {mean_mae:.2f} BMI units - consider model improvements")
    
    if mean_r2 > 0.90:
        print(f"✓ Excellent R²: {mean_r2:.4f}")
    elif mean_r2 > 0.80:
        print(f"✓ Good R²: {mean_r2:.4f}")
    elif mean_r2 > 0.70:
        print(f"○ Acceptable R²: {mean_r2:.4f}")
    else:
        print(f"⚠ Low R²: {mean_r2:.4f} - consider model improvements")
    
    # Check for suspicious results
    if mean_mae < 0.5:
        print(f"\n⚠ WARNING: MAE too low ({mean_mae:.2f}) - possible data leakage!")
    if mean_r2 > 0.98:
        print(f"⚠ WARNING: R² too high ({mean_r2:.4f}) - possible data leakage!")
    
    print(f"{'='*70}\n")
    
    print("✓ Training complete!")
    print(f"✓ Models saved in: {config.output_dir}")
    print(f"✓ Logs saved in: {config.log_dir}")
    print(f"✓ Grad-CAM visualizations saved in: {config.output_dir}/gradcam_fold_*")


# ==================== INFERENCE UTILITIES ====================
def load_trained_model(fold, config, device):
    """Load a trained model from checkpoint"""
    # Load preprocessing config
    preprocess_config, feature_encoders = load_config_and_encoders(config.preprocessed_dir)
    
    # Get feature dimensions from fold data
    train_df = pd.read_csv(f'{config.preprocessed_dir}/fold_{fold}/train.csv')
    
    # Count features
    categorical_cols = [col for col in train_df.columns if col.endswith('_encoded')]
    numerical_cols = []
    for col in train_df.columns:
        if col in ['name', 'image_path', 'bmi', 'image_stem', 'has_image']:
            continue
        if train_df[col].dtype in [np.float32, np.float64, np.int32, np.int64]:
            if any(col.endswith(suffix) for suffix in [
                '_zscore', '_percentile', '_squared', '_deviation', 
                '_above_mean', '_product', '_interaction'
            ]):
                numerical_cols.append(col)
            elif col in ['bsa', 'ponderal_index', 'height_m', 'weight_kg', 
                       'age', 'age_decade', 'height', 'weight']:
                numerical_cols.append(col)
    
    num_categorical = len(categorical_cols)
    num_numerical = len(numerical_cols)
    categorical_vocab_sizes = get_categorical_vocab_sizes(feature_encoders)
    
    # Create model
    model = ViTForBMI(
        num_categorical_features=num_categorical,
        categorical_vocab_sizes=categorical_vocab_sizes,
        num_numerical_features=num_numerical,
        embed_dim=config.embed_dim,
        fusion_dim=config.fusion_dim,
        unfreeze_last_n_layers=config.unfreeze_last_n_layers,
        use_adaptive_pooling=config.use_adaptive_pooling,
        dropout=config.dropout
    ).to(device)
    
    # Load checkpoint
    checkpoint_path = f'{config.output_dir}/fold_{fold}_best.pt'
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    epoch, metrics = load_checkpoint(checkpoint_path, model)
    
    print(f"Loaded model from fold {fold}, epoch {epoch}")
    print(f"Checkpoint metrics: MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}")
    
    return model, metrics


@torch.no_grad()
def predict_single_image(model, image_path, categorical_features, numerical_features, config, device):
    """Predict BMI for a single image"""
    from torchvision import transforms
    
    # Load preprocessing config
    preprocess_config, _ = load_config_and_encoders(config.preprocessed_dir)
    
    # Create transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=preprocess_config['mean'], std=preprocess_config['std'])
    ])
    
    # Load and preprocess image
    try:
        with Image.open(image_path) as image:
            image = image.convert('RGB')
            image_tensor = transform(image).unsqueeze(0).to(device)
    except Exception as e:
        raise ValueError(f"Failed to load image: {e}")
    
    # Prepare features
    if categorical_features is not None:
        categorical_features = torch.tensor(categorical_features, dtype=torch.long).unsqueeze(0).to(device)
    
    if numerical_features is not None:
        numerical_features = torch.tensor(numerical_features, dtype=torch.float32).unsqueeze(0).to(device)
    
    # Predict
    model.eval()
    prediction = model(image_tensor, categorical_features, numerical_features)
    
    return prediction.item()


# ==================== TEST MODEL CREATION ====================
if __name__ == "__main__":
    # Quick test of model creation
    print("\n" + "="*70)
    print("  TESTING MODEL CREATION")
    print("="*70 + "\n")
    
    # Test model
    test_model = ViTForBMI(
        num_categorical_features=4,
        categorical_vocab_sizes=[10, 5, 2, 8],
        num_numerical_features=15,
        embed_dim=32,
        fusion_dim=512,
        use_adaptive_pooling=False,
        dropout=0.2
    )
    
    count_parameters(test_model)
    
    # Test forward pass
    batch_size = 16
    pixel_values = torch.randn(batch_size, 3, 224, 224)
    categorical_features = torch.randint(0, 5, (batch_size, 4))
    numerical_features = torch.randn(batch_size, 15)
    
    output = test_model(pixel_values, categorical_features, numerical_features)
    
    print(f"Input shapes:")
    print(f"  Images: {pixel_values.shape}")
    print(f"  Categorical: {categorical_features.shape}")
    print(f"  Numerical: {numerical_features.shape}")
    print(f"\nOutput shape: {output.shape}")
    print(f"\n✓ Model test passed!\n")
    
    main()

2025-10-19 02:57:54.030749: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760842674.429687      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760842674.550505      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered



  TESTING MODEL CREATION



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]


MODEL ARCHITECTURE - ViT for BMI Prediction
Visual Feature Dim:      768 (ViT-Base)
Pooling Strategy:        CLS Token
Categorical Features:    4 -> 128
Numerical Features:      15 -> 128
Fusion Dim:              512
Regression Head:         512 -> 256 -> 128 -> 1
Unfrozen ViT Layers:     Last 4
Dropout:                 0.2


Model Parameter Summary:
Total Parameters:      88,236,193
Trainable Parameters:  30,198,433 (34.22%)
Frozen Parameters:     58,037,760 (65.78%)

Input shapes:
  Images: torch.Size([16, 3, 224, 224])
  Categorical: torch.Size([16, 4])
  Numerical: torch.Size([16, 15])

Output shape: torch.Size([16, 1])

✓ Model test passed!


  ViT FOR BMI PREDICTION - FIXED VERSION

✓ Using GPU: Tesla T4
  Memory: 15.83 GB

Configuration:
  Folds:              5
  Epochs:             20
  Batch size:         96
  Learning rate:      3e-05
  ViT LR multiplier:  0.1
  Loss function:      huber
  Mixed precision:    True
  Adaptive pooling:   False
  Output dir:         /kaggle/wor

Traceback (most recent call last):
  File "/tmp/ipykernel_19/208876159.py", line 1474, in main
    history = train_fold(fold, config, device)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_19/208876159.py", line 1027, in train_fold
    preprocess_config, feature_encoders = load_config_and_encoders(config.preprocessed_dir)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_19/208876159.py", line 690, in load_config_and_encoders
    with open(f'{preprocessed_dir}/config.json', 'r') as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/working/illinois_doc_preprocessed/config.json'
Traceback (most recent call last):
  File "/tmp/ipykernel_19/208876159.py", line 1474, in main
    history = train_fold(fold, config, device)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_19/208876159.py", line 1027, in train_fold
 