# Face2BMI: Complete End-to-End Pipeline
### From Data Preprocessing to Model Training, Validation & Ablation Study

This notebook provides a comprehensive pipeline for BMI prediction from facial images.

**Pipeline Overview:**
1. **Data Preprocessing**: Background removal and face segmentation using MediaPipe
2. **Feature Extraction**: 468 facial landmarks detection + 7 geometric features
3. **Data Merging**: Combine geometric features with BMI labels
4. **Model Training**: Multiple architectures (CNN, Hybrid, GNN)
5. **K-Fold Cross-Validation**: Robust evaluation with 5 folds
6. **Comprehensive Evaluation**: R¬≤, MAE, RMSE, MAPE metrics
7. **Ablation Study**: Component-wise model analysis

---

## üîß Part 1: Data Preprocessing & Feature Extraction

In [None]:
!pip install mediapipe -q

In [9]:
import cv2
import numpy as np
import os
from glob import glob
from tqdm import tqdm
import mediapipe as mp

# ==== Paths ====
input_folder = "/kaggle/input/morph/Dataset/Images/Train"
output_folder = "/kaggle/working/ROI"
os.makedirs(output_folder, exist_ok=True)

# ==== Collect image paths ====
image_paths = glob(os.path.join(input_folder, "*.jpg")) 

if len(image_paths) == 0:
    print(f"‚ö†Ô∏è No images found in {input_folder}")
else:
    # ==== Initialize MediaPipe ====
    mp_selfie_segmentation = mp.solutions.selfie_segmentation
    selfie_segmentation = mp_selfie_segmentation.SelfieSegmentation(model_selection=1)

    for path in tqdm(image_paths, desc="Removing background", ncols=80):
        try:
            img = cv2.imread(path)
            if img is None or img.size == 0:
                continue

            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            results = selfie_segmentation.process(img_rgb)

            # Create mask and apply
            mask = results.segmentation_mask
            condition = mask > 0.5  # keep only face region
            bg = np.zeros_like(img, dtype=np.uint8)  # black background
            output = np.where(condition[..., None], img, bg)

            # Optional: Crop to tight face region
            gray = cv2.cvtColor(output, cv2.COLOR_BGR2GRAY)
            coords = cv2.findNonZero(gray)
            if coords is not None:
                x, y, w, h = cv2.boundingRect(coords)
                output = output[y:y+h, x:x+w]

            # Save
            save_path = os.path.join(output_folder, os.path.basename(path))
            cv2.imwrite(save_path, output)

        except:
            continue


2025-10-27 03:17:55.193890: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761535075.372003      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761535075.426444      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


‚ö†Ô∏è No images found in /kaggle/working/upscaling_image


In [None]:
import cv2
import matplotlib.pyplot as plt
import os
from glob import glob

# ==== Path ====
roi_folder = "/kaggle/working/ROI"

# ==== Collect first 10 images ====
image_paths = glob(os.path.join(roi_folder, "*.jpg")) 

image_paths = image_paths[:10]

# ==== Display ====
plt.figure(figsize=(15, 6))
for i, path in enumerate(image_paths):
    img = cv2.imread(path)
    if img is not None:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.subplot(2, 5, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(os.path.basename(path))
plt.tight_layout()
plt.show()


import cv2
import mediapipe as mp
import os
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

# ==== Paths ====
input_folder = "/kaggle/working/ROI"
output_folder = "/kaggle/working/face_mesh_output"
os.makedirs(output_folder, exist_ok=True)

# ==== Collect images ====
image_paths = glob(os.path.join(input_folder, "*.jpg")) 

print(f"üìÅ Total images found: {len(image_paths)}")

if len(image_paths) == 0:
    raise ValueError("‚ùå No images found in the directory. Please check your input path.")

# ==== Mediapipe Setup ====
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True,
                                  max_num_faces=1,
                                  refine_landmarks=True,
                                  min_detection_confidence=0.5)

# ==== Processing ====
processed_images = []
for path in tqdm(image_paths, desc="Drawing landmarks", ncols=80):
    img = cv2.imread(path)
    if img is None:
        continue
    
    h, w, _ = img.shape
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(img_rgb)
    
    if not results.multi_face_landmarks:
        continue

    for face_landmarks in results.multi_face_landmarks:
        for idx, lm in enumerate(face_landmarks.landmark):
            x, y = int(lm.x * w), int(lm.y * h)
           
            cv2.circle(img, (x, y), 1, (0, 255, 0), -1)
            # Draw index number in green
            cv2.putText(img, str(idx), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 
                        0.25, (0, 0, 250), 1, cv2.LINE_AA)
    
    # Save output image
    save_path = os.path.join(output_folder, os.path.basename(path))
    cv2.imwrite(save_path, img)
    
    # Collect few images for display
    if len(processed_images) < 10:
        processed_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

print(f"\n‚úÖ FaceMesh landmarks drawn and saved in: {output_folder}")

# ==== Show sample outputs ====
if processed_images:
    plt.figure(figsize=(15, 6))
    for i, img in enumerate(processed_images):
        plt.subplot(2, 5, i+1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Image {i+1}")
    plt.tight_layout()
    plt.show()
else:
    print("‚ö†Ô∏è No landmarks detected in any image.")


In [None]:
import cv2
import mediapipe as mp
import os
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
from scipy.spatial import distance
import numpy as np

# ==== Paths ====
input_folder = "/kaggle/working/ROI"
output_folder = "/kaggle/working/face_mesh_output"
csv_path = "/kaggle/working/features.csv"

os.makedirs(output_folder, exist_ok=True)

# ==== Collect images ====
image_paths = glob(os.path.join(input_folder, "*.jpg")) + \
              glob(os.path.join(input_folder, "*.jpeg")) + \
              glob(os.path.join(input_folder, "*.png"))

print(f"üìÅ Total images found: {len(image_paths)}")

if len(image_paths) == 0:
    raise ValueError("‚ùå No images found in the directory. Please check your input path.")

# ==== Mediapipe Setup ====
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True,
                                  max_num_faces=1,
                                  refine_landmarks=True,
                                  min_detection_confidence=0.5)

# ==== Function to Calculate Features ====
def calculate_geometric_features(landmarks):
    if landmarks is None or len(landmarks) < 468:
        return None

    try:
        features = {}
        # Key landmarks
        chin = landmarks[152]
        forehead_center = landmarks[10]
        left_cheek = landmarks[234]
        right_cheek = landmarks[454]
        left_jaw = landmarks[172]
        right_jaw = landmarks[397]
        left_eye_left = landmarks[263]      
        left_eye_right = landmarks[362]     
        right_eye_left = landmarks[33]      
        right_eye_right = landmarks[133]    
        nose_tip = landmarks[1]
        nose_bridge = landmarks[6]
        left_eyebrow_inner = landmarks[70]
        left_eyebrow_outer = landmarks[107] 
        right_eyebrow_inner = landmarks[300]
        right_eyebrow_outer = landmarks[336] 

        # 1Ô∏è‚É£ CWJWR - Cheekbone Width to Jaw Width Ratio
        cheekbone_width = distance.euclidean(left_cheek, right_cheek)
        jaw_width = distance.euclidean(left_jaw, right_jaw)
        features['cwjwr'] = cheekbone_width / (jaw_width + 1e-6)

        # 2Ô∏è‚É£ CWUFHR - Cheekbone Width to Upper Face Height Ratio
        upper_face_height = distance.euclidean(forehead_center, nose_tip)
        features['cwufhr'] = cheekbone_width / (upper_face_height + 1e-6)

        # 3Ô∏è‚É£ PAR - Perimeter to Area Ratio
        face_contour_points = [left_jaw, left_cheek, forehead_center, right_cheek, right_jaw, chin]
        perimeter = sum(
            distance.euclidean(face_contour_points[i], face_contour_points[(i+1) % len(face_contour_points)])
            for i in range(len(face_contour_points))
        )
        area = 0.5 * abs(sum(
            face_contour_points[i][0] * face_contour_points[(i+1) % len(face_contour_points)][1] -
            face_contour_points[(i+1) % len(face_contour_points)][0] * face_contour_points[i][1]
            for i in range(len(face_contour_points))
        ))
        features['par'] = perimeter / (area + 1e-6)

        # 4Ô∏è‚É£ ASoE - Average Size of Eyes
        left_eye_width = distance.euclidean(left_eye_left, left_eye_right)
        right_eye_width = distance.euclidean(right_eye_left, right_eye_right)
        features['asoe'] = (left_eye_width + right_eye_width) / 2

        # 5Ô∏è‚É£ FHLFHR - Face Height to Lower Face Height Ratio
        face_height = distance.euclidean(forehead_center, chin)
        lower_face_height = distance.euclidean(nose_tip, chin)
        features['fhlfhr'] = face_height / (lower_face_height + 1e-6)

        # 6Ô∏è‚É£ FWLFHR - Face Width to Lower Face Height Ratio
        face_width = distance.euclidean(left_jaw, right_jaw)
        features['fwlfhr'] = face_width / (lower_face_height + 1e-6)

        # 7Ô∏è‚É£ MEH - Mean Eyebrow Height
        left_eyebrow_height = distance.euclidean(
            (left_eyebrow_inner + left_eyebrow_outer) / 2, left_eye_left
        )
        right_eyebrow_height = distance.euclidean(
            (right_eyebrow_inner + right_eyebrow_outer) / 2, right_eye_left
        )
        features['meh'] = (left_eyebrow_height + right_eyebrow_height) / 2

        return features
    except Exception as e:
        print(f"‚ö†Ô∏è Error calculating features: {e}")
        return None

# ==== Processing ====
data = []
processed_images = []

for path in tqdm(image_paths, desc="Processing images", ncols=80):
    img = cv2.imread(path)
    if img is None:
        continue

    h, w, _ = img.shape
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(img_rgb)

    if not results.multi_face_landmarks:
        continue

    for face_landmarks in results.multi_face_landmarks:
        coords = np.array([[lm.x * w, lm.y * h] for lm in face_landmarks.landmark])
        feats = calculate_geometric_features(coords)
        if feats:
            feats['filename'] = os.path.basename(path)
            data.append(feats)

        # Draw landmarks
        for idx, (x, y) in enumerate(coords.astype(int)):
            cv2.circle(img, (x, y), 1, (0,255, 0), -1)
            if idx % 25 == 0:  # fewer labels for readability
                cv2.putText(img, str(idx), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0,0,255), 1)

    save_path = os.path.join(output_folder, os.path.basename(path))
    cv2.imwrite(save_path, img)

    if len(processed_images) < 10:
        processed_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

# ==== Save CSV ====
df = pd.DataFrame(data)
df.to_csv(csv_path, index=False)
print(f"\n‚úÖ Saved {len(df)} entries to {csv_path}")

# ==== Show sample outputs ====
if processed_images:
    plt.figure(figsize=(30, 20))
    for i, img in enumerate(processed_images):
        plt.subplot(2, 5, i+1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Image {i+1}")
    plt.tight_layout()
    plt.show()
else:
    print("‚ö†Ô∏è No landmarks detected in any image.")


In [None]:
import pandas as pd
import os

# ==== 1Ô∏è‚É£ Paths ====
person_csv = "/kaggle/input/person.cs"          # adjust if it's in a subfolder
features_csv = "/kaggle/working/features.csv"    # output file from your previous step
output_csv = "/kaggle/working/merged_features.csv"

# ==== 2Ô∏è‚É£ Load the CSV files ====
df_person = pd.read_csv(person_csv)
df_features = pd.read_csv(features_csv)

print("‚úÖ Loaded files:")
print(f"person.csv ‚Üí {df_person.shape}")
print(f"features.csv ‚Üí {df_features.shape}")

# ==== 3Ô∏è‚É£ Check column names ====
print("\nüß© Columns in person.csv:", df_person.columns.tolist())
print("üß© Columns in features.csv:", df_features.columns.tolist())

# ==== 4Ô∏è‚É£ Merge using 'ID' (change to correct column if needed) ====
# Common column name can be 'ID', 'id', or 'person_id' ‚Äî adjust if necessary
merge_key = "id"
if merge_key not in df_person.columns or merge_key not in df_features.columns:
    # Try to auto-detect possible matching column
    possible_keys = set(df_person.columns) & set(df_features.columns)
    if len(possible_keys) > 0:
        merge_key = list(possible_keys)[0]
        print(f"‚öôÔ∏è Auto-detected merge key: {merge_key}")
    else:
        raise ValueError("‚ùå No common column found for merging!")

merged_df = pd.merge(df_person, df_features, on=merge_key, how="inner")


merged_df.to_csv(output_csv, index=False)

print(f"\n‚úÖ Merged successfully! Saved to: {output_csv}")
print("üßæ Final shape:", merged_df.shape)


merged_df.head()


---
## üß† Part 2: Model Training, Validation & Ablation Study
### Complete Training Pipeline with Multiple Model Architectures

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import warnings
import os
import gc
warnings.filterwarnings('ignore')

# ============================================================================
# GPU & SEED SETUP
# ============================================================================

def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True
        if torch.cuda.get_device_capability()[0] >= 8:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        return device
    print("Using CPU")
    return torch.device('cpu')

def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# ============================================================================
# GNN SETUP (Optional)
# ============================================================================

try:
    from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, BatchNorm
    GNN_AVAILABLE = True
except ImportError:
    GNN_AVAILABLE = False
    try:
        os.system('pip install torch-geometric -q')
        from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, BatchNorm
        GNN_AVAILABLE = True
    except:
        pass

# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_DIR = '/kaggle/input/bmi-dataset'
OUTPUT_DIR = '/kaggle/working'
TRAIN_CSV = f'{DATA_DIR}/train.csv'
TEST_CSV = f'{DATA_DIR}/test.csv'
IMAGE_DIR = f'{DATA_DIR}/ROI'

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================================
# DATA LOADING
# ============================================================================

def load_raw_data():
    """Load data WITHOUT any NaN filling - to prevent data leakage"""
    train_df = pd.read_csv(TRAIN_CSV)
    test_df = pd.read_csv(TEST_CSV)
    print(f"Train: {len(train_df)}, Test: {len(test_df)}")
    return train_df, test_df

def fill_nan_from_reference(df, reference_df):
    """Fill NaN values in df using statistics from reference_df (training data only)"""
    df = df.copy()
    for col in df.select_dtypes(include=[np.number]).columns:
        if df[col].isna().any():
            fill_value = reference_df[col].median() if col in reference_df.columns else 0
            df[col] = df[col].fillna(fill_value)
    for col in df.select_dtypes(include=['object']).columns:
        if df[col].isna().any():
            if col in reference_df.columns and not reference_df[col].mode().empty:
                df[col] = df[col].fillna(reference_df[col].mode()[0])
    return df

def get_landmark_info(df):
    """Extract landmark column information"""
    x_cols = sorted([c for c in df.columns if c.endswith('_x')])
    landmarks = []
    for x_col in x_cols:
        prefix = x_col[:-2]
        y_col = f'{prefix}_y'
        z_col = f'{prefix}_z'
        if y_col in df.columns:
            landmarks.append((prefix, x_col, y_col, z_col if z_col in df.columns else None))
    return landmarks

def get_base_feature_cols(df):
    """Get list of base features that exist in the dataframe"""
    basic = ['face_height', 'face_width_cheeks', 'face_width_jaw', 'face_ratio_height_width',
             'jaw_cheek_ratio', 'left_eye_width', 'right_eye_width', 'eye_width_ratio',
             'nose_length', 'nose_width', 'nose_ratio', 'mouth_width', 'face_oval_area',
             'face_oval_perimeter', 'face_compactness', 'age', 'sex_encoded',
             'race_Asian', 'race_Black', 'race_Hispanic', 'race_White']
    return [f for f in basic if f in df.columns]

def engineer_features_for_fold(train_df, val_df=None, test_df=None):
    """Engineer features for a fold without data leakage"""
    train_df = fill_nan_from_reference(train_df, train_df)
    if val_df is not None:
        val_df = fill_nan_from_reference(val_df, train_df)
    if test_df is not None:
        test_df = fill_nan_from_reference(test_df, train_df)

    dfs = {'train': train_df.copy()}
    if val_df is not None:
        dfs['val'] = val_df.copy()
    if test_df is not None:
        dfs['test'] = test_df.copy()

    features = get_base_feature_cols(train_df)

    for name, df in dfs.items():
        if 'face_height' in df.columns and 'face_width_cheeks' in df.columns:
            df['fwhr'] = df['face_width_cheeks'] / df['face_height'].clip(lower=1e-8)
            df['face_area'] = df['face_width_cheeks'] * df['face_height']

        if 'left_eye_width' in df.columns and 'right_eye_width' in df.columns:
            sum_eyes = df['left_eye_width'] + df['right_eye_width']
            df['eye_symmetry'] = 1 - abs(df['left_eye_width'] - df['right_eye_width']) / sum_eyes.clip(lower=1e-8)

        if 'nose_length' in df.columns and 'nose_width' in df.columns:
            df['nose_compactness'] = df['nose_width'] / df['nose_length'].clip(lower=1e-8)

        if 'age' in df.columns:
            df['age_squared'] = df['age'] ** 2
            df['age_log'] = np.log1p(df['age'])

        dfs[name] = df

    derived = ['fwhr', 'face_area', 'eye_symmetry', 'nose_compactness', 'age_squared', 'age_log']
    features.extend([f for f in derived if f in dfs['train'].columns])

    for name, df in dfs.items():
        for suffix, coord_name in [('_x', 'x'), ('_y', 'y'), ('_z', 'z')]:
            cols = [c for c in df.columns if c.endswith(suffix)]
            if cols:
                df[f'coord_mean_{coord_name}'] = df[cols].mean(axis=1)
                df[f'coord_std_{coord_name}'] = df[cols].std(axis=1)
                df[f'coord_range_{coord_name}'] = df[cols].max(axis=1) - df[cols].min(axis=1)
        dfs[name] = df

    for coord_name in ['x', 'y', 'z']:
        for stat in ['mean', 'std', 'range']:
            feat_name = f'coord_{stat}_{coord_name}'
            if feat_name in dfs['train'].columns:
                features.append(feat_name)

    features = list(set(features))

    for name, df in dfs.items():
        for f in features:
            if f in df.columns:
                df[f] = df[f].fillna(0)
        dfs[name] = df

    results = [dfs['train']]
    if val_df is not None:
        results.append(dfs['val'])
    if test_df is not None:
        results.append(dfs['test'])
    results.append(features)

    return tuple(results)

def create_graphs(num_nodes):
    """Create multi-scale graph structures"""
    if num_nodes <= 0:
        return None

    graphs = {}

    if num_nodes <= 10:
        edges = [[i, j] for i in range(num_nodes) for j in range(num_nodes) if i != j]
        edge_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous() if edges else torch.zeros(2, 0, dtype=torch.long)
        graphs['local'] = graphs['regional'] = graphs['global'] = edge_tensor
        return graphs

    local_edges = []
    for i in range(num_nodes):
        for j in range(1, min(4, num_nodes)):
            if i + j < num_nodes:
                local_edges.extend([[i, i + j], [i + j, i]])

    regional_edges = []
    for i in range(num_nodes):
        for j in range(1, min(7, num_nodes)):
            if i + j < num_nodes:
                regional_edges.extend([[i, i + j], [i + j, i]])
    for i in range(0, num_nodes - 5, 5):
        for j in range(i + 5, min(i + 15, num_nodes), 5):
            regional_edges.extend([[i, j], [j, i]])

    global_edges = []
    for i in range(num_nodes):
        for j in range(1, min(13, num_nodes)):
            if i + j < num_nodes:
                global_edges.extend([[i, i + j], [i + j, i]])
    mid = num_nodes // 2
    for i in range(min(mid, num_nodes - mid)):
        if mid + i < num_nodes:
            global_edges.extend([[i, mid + i], [mid + i, i]])
    step = max(num_nodes // 8, 1)
    for i in range(0, num_nodes, step):
        for j in range(i + step, num_nodes, step):
            global_edges.extend([[i, j], [j, i]])

    graphs['local'] = torch.tensor(local_edges, dtype=torch.long).t().contiguous()
    graphs['regional'] = torch.tensor(regional_edges, dtype=torch.long).t().contiguous()
    graphs['global'] = torch.tensor(global_edges, dtype=torch.long).t().contiguous()

    return graphs

# ============================================================================
# DATASET
# ============================================================================

class BMIDataset(Dataset):
    def __init__(self, df, feature_cols, landmark_info, image_dir,
                 is_training=True, augment_mult=1,
                 feature_scaler=None, landmark_scaler=None,
                 fit_scalers=False):

        self.df = df.reset_index(drop=True)
        self.feature_cols = feature_cols
        self.landmark_info = landmark_info
        self.image_dir = image_dir
        self.is_training = is_training
        self.augment_mult = augment_mult if is_training else 1

        raw_features = np.nan_to_num(self.df[feature_cols].values.astype(np.float32))

        if fit_scalers or feature_scaler is None:
            self.feature_scaler = RobustScaler()
            scaled_features = self.feature_scaler.fit_transform(raw_features)
        else:
            self.feature_scaler = feature_scaler
            scaled_features = self.feature_scaler.transform(raw_features)

        self.features = torch.from_numpy(scaled_features.astype(np.float32))

        self.targets = torch.from_numpy(
            np.nan_to_num(self.df['BMI'].values.astype(np.float32), nan=self.df['BMI'].median())
        )

        if 'age' in self.df.columns:
            self.age_targets = torch.from_numpy(
                np.nan_to_num(self.df['age'].values.astype(np.float32), nan=self.df['age'].median())
            )
        else:
            self.age_targets = torch.zeros(len(self.df))

        if 'sex_encoded' in self.df.columns:
            sex = np.clip(np.nan_to_num(self.df['sex_encoded'].values, nan=0).astype(np.int64), 0, 1)
            self.sex_targets = torch.from_numpy(sex)
        else:
            self.sex_targets = torch.zeros(len(self.df), dtype=torch.long)

        bmi_vals = np.nan_to_num(self.df['BMI'].values, nan=self.df['BMI'].median())
        bins = [0, 18.5, 25, 30, 100]
        cat = pd.cut(bmi_vals, bins=bins, labels=[0, 1, 2, 3], include_lowest=True)
        cat = pd.Series(cat).cat.codes.fillna(1).astype(np.int64).clip(0, 3)
        self.bmi_categories = torch.from_numpy(cat.values)

        self.filenames = self.df['image_filename'].tolist()

        self.landmarks, self.landmark_scaler = self._extract_landmarks(
            landmark_scaler, fit_scalers
        )

        if is_training:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.75, 1.0), ratio=(0.9, 1.1)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.2, hue=0.08),
                transforms.RandomRotation(degrees=12),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])

        self._placeholder = None

    def _extract_landmarks(self, scaler, fit_scaler):
        coord_cols = [c for c in self.df.columns if c.endswith(('_x', '_y', '_z'))]
        if not coord_cols:
            return None, None

        prefixes = sorted(set(c.rsplit('_', 1)[0] for c in coord_cols))
        has_z = any(c.endswith('_z') for c in coord_cols)
        dim = 3 if has_z else 2

        data = np.zeros((len(self.df), len(prefixes), dim), dtype=np.float32)
        for i, row in self.df.iterrows():
            for j, prefix in enumerate(prefixes):
                data[i, j, 0] = row.get(f'{prefix}_x', 0) or 0
                data[i, j, 1] = row.get(f'{prefix}_y', 0) or 0
                if has_z:
                    data[i, j, 2] = row.get(f'{prefix}_z', 0) or 0

        data = np.nan_to_num(data)
        flat = data.reshape(-1, dim)

        if fit_scaler or scaler is None:
            fitted_scaler = StandardScaler()
            flat_scaled = fitted_scaler.fit_transform(flat)
        else:
            fitted_scaler = scaler
            flat_scaled = fitted_scaler.transform(flat)

        return torch.from_numpy(flat_scaled.reshape(data.shape).astype(np.float32)), fitted_scaler

    def _get_placeholder(self):
        if self._placeholder is None:
            self._placeholder = torch.zeros(3, 224, 224)
            self._placeholder[0] = 0.485
            self._placeholder[1] = 0.456
            self._placeholder[2] = 0.406
        return self._placeholder.clone()

    def __len__(self):
        return len(self.df) * self.augment_mult

    def __getitem__(self, idx):
        orig_idx = idx % len(self.df)
        is_aug = idx >= len(self.df)

        feat = self.features[orig_idx].clone()
        if self.is_training and is_aug:
            noise_std = 0.02
            feat = feat + torch.randn_like(feat) * noise_std

        if self.landmarks is not None:
            graph_feat = self.landmarks[orig_idx].clone()
            if self.is_training and is_aug:
                graph_feat = graph_feat + torch.randn_like(graph_feat) * 0.01
        else:
            graph_feat = torch.zeros(1, 3)

        img_path = os.path.join(self.image_dir, self.filenames[orig_idx])
        try:
            with Image.open(img_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                image = self.transform(img)
        except Exception as e:
            image = self._get_placeholder()

        return {
            'features': feat,
            'image': image,
            'graph_features': graph_feat,
            'target': self.targets[orig_idx],
            'age_target': self.age_targets[orig_idx],
            'sex_target': self.sex_targets[orig_idx],
            'bmi_category': self.bmi_categories[orig_idx],
        }

def collate_fn(batch):
    return {
        'features': torch.stack([b['features'] for b in batch]),
        'image': torch.stack([b['image'] for b in batch]),
        'graph_features': torch.stack([b['graph_features'] for b in batch]),
        'target': torch.stack([b['target'] for b in batch]),
        'age_target': torch.stack([b['age_target'] for b in batch]),
        'sex_target': torch.stack([b['sex_target'] for b in batch]),
        'bmi_category': torch.stack([b['bmi_category'] for b in batch]),
    }

# ============================================================================
# MODEL COMPONENTS
# ============================================================================

class ChannelAttention(nn.Module):
    def __init__(self, dim, reduction=4):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(dim // reduction, dim, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, batch=None):
        if batch is None:
            avg = x.mean(0, keepdim=True)
            mx = x.max(0, keepdim=True)[0]
            return x * self.sigmoid(self.mlp(avg) + self.mlp(mx))

        bs = batch.max().item() + 1
        avgs, mxs = [], []
        for i in range(bs):
            mask = batch == i
            if mask.any():
                avgs.append(x[mask].mean(0))
                mxs.append(x[mask].max(0)[0])
        if not avgs:
            return x
        avg = torch.stack(avgs)
        mx = torch.stack(mxs)
        return x * self.sigmoid(self.mlp(avg) + self.mlp(mx))[batch]

class SpatialAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Linear(2, dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(dim // 4, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, batch=None):
        avg = x.mean(-1, keepdim=True)
        mx = x.max(-1, keepdim=True)[0]
        att = self.sigmoid(self.conv(torch.cat([avg, mx], -1)))
        return x * att

class CBAM(nn.Module):
    def __init__(self, dim, reduction=4):
        super().__init__()
        self.channel = ChannelAttention(dim, reduction)
        self.spatial = SpatialAttention(dim)

    def forward(self, x, batch=None):
        x = self.channel(x, batch)
        x = self.spatial(x, batch)
        return x

if GNN_AVAILABLE:
    class MultiScaleGCN(nn.Module):
        def __init__(self, in_dim=3, hidden=128, out_dim=256, layers=3, dropout=0.3):
            super().__init__()
            self.input_proj = nn.Sequential(
                nn.Linear(in_dim, hidden), BatchNorm(hidden), nn.ReLU(), nn.Dropout(dropout)
            )

            self.local_convs = nn.ModuleList([GCNConv(hidden, hidden, improved=True) for _ in range(layers)])
            self.regional_convs = nn.ModuleList([GCNConv(hidden, hidden, improved=True) for _ in range(layers)])
            self.global_convs = nn.ModuleList([GCNConv(hidden, hidden, improved=True) for _ in range(layers)])
            self.bns = nn.ModuleList([BatchNorm(hidden) for _ in range(layers)])
            self.cbams = nn.ModuleList([CBAM(hidden) for _ in range(layers)])

            self.cross_attn = nn.MultiheadAttention(hidden, 4, dropout=dropout, batch_first=True)
            self.norm = nn.LayerNorm(hidden)
            self.node_attn = nn.Sequential(nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Linear(hidden // 2, 1))

            self.output = nn.Sequential(
                nn.Linear(hidden * 3, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(), nn.Dropout(dropout),
                nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
            )
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, edges, batch):
            bs = batch.max().item() + 1
            x = self.input_proj(x)
            x_l, x_r, x_g = x.clone(), x.clone(), x.clone()

            for lc, rc, gc, bn, cbam in zip(self.local_convs, self.regional_convs, self.global_convs, self.bns, self.cbams):
                x_l = self.dropout(cbam(torch.relu(bn(lc(x_l, edges['local']))), batch)) + x_l
                x_r = self.dropout(cbam(torch.relu(bn(rc(x_r, edges['regional']))), batch)) + x_r
                x_g = self.dropout(cbam(torch.relu(bn(gc(x_g, edges['global']))), batch)) + x_g

            fused = []
            for i in range(bs):
                mask = batch == i
                if mask.any():
                    scales = torch.stack([x_l[mask], x_r[mask], x_g[mask]], 1)
                    attn_out, _ = self.cross_attn(scales, scales, scales)
                    fused.append(self.norm(attn_out + scales).mean(1))

            if not fused:
                return torch.zeros(bs, self.output[0].out_features, device=x.device)

            x_fused = torch.cat(fused, 0)
            x_mean = global_mean_pool(x_fused, batch)
            x_max = global_max_pool(x_fused, batch)
            attn_w = torch.softmax(self.node_attn(x_fused), 0)
            x_attn = global_mean_pool(x_fused * attn_w, batch)

            return self.output(torch.cat([x_mean, x_max, x_attn], 1))
else:
    class MultiScaleGCN(nn.Module):
        def __init__(self, in_dim=3, hidden=128, out_dim=256, layers=3, dropout=0.3):
            super().__init__()
            self.input_proj = nn.Linear(in_dim, hidden)
            self.attns = nn.ModuleList([
                nn.MultiheadAttention(hidden, 4, dropout=dropout, batch_first=True) for _ in range(layers)
            ])
            self.norms = nn.ModuleList([nn.LayerNorm(hidden) for _ in range(layers)])
            self.output = nn.Sequential(
                nn.Linear(hidden, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(), nn.Dropout(dropout),
                nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()
            )
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, edges=None, batch=None):
            if x.dim() == 2:
                x = x.unsqueeze(0)
            x = torch.relu(self.input_proj(x))
            for attn, norm in zip(self.attns, self.norms):
                res = x
                x, _ = attn(x, x, x)
                x = norm(self.dropout(x) + res)
            return self.output(x.mean(1))

# ============================================================================
# MAIN MODEL
# ============================================================================

class HybridModel(nn.Module):
    def __init__(self, num_features, num_landmarks=0, landmark_dim=3, dropout=0.3, use_gcn=True):
        super().__init__()
        self.use_gcn = use_gcn and num_landmarks > 0

        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        for p in list(resnet.parameters())[:-20]:
            p.requires_grad = False
        self.img_backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.img_proj = nn.Sequential(
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout)
        )

        self.tab_net = nn.Sequential(
            nn.Linear(num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout * 0.7),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout * 0.5)
        )

        gcn_dim = 256 if self.use_gcn else 0
        if self.use_gcn:
            self.gcn = MultiScaleGCN(landmark_dim, 128, gcn_dim, 3, dropout)

        combined = 512 + 128 + gcn_dim
        self.fusion_attn = nn.MultiheadAttention(combined, 8, dropout=dropout, batch_first=True)
        self.fusion_norm = nn.LayerNorm(combined)

        self.shared = nn.Sequential(
            nn.Linear(combined, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout * 0.7)
        )

        self.bmi_head = nn.Sequential(
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout * 0.5),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, 1)
        )
        self.age_head = nn.Sequential(
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(128, 1)
        )
        self.sex_head = nn.Sequential(
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(128, 2)
        )
        self.cat_head = nn.Sequential(
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout * 0.5), nn.Linear(128, 4)
        )

        self.log_vars = nn.Parameter(torch.zeros(4))
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, features, image, graph_features=None, edges=None):
        bs = features.size(0)

        img_feat = self.img_backbone(image).view(bs, -1)
        img_feat = self.img_proj(img_feat)

        tab_feat = self.tab_net(features)

        if self.use_gcn and graph_features is not None:
            if GNN_AVAILABLE and edges is not None:
                num_nodes = graph_features.size(1)
                x = graph_features.view(-1, graph_features.size(-1))
                batch_idx = torch.arange(bs, device=x.device).repeat_interleave(num_nodes)

                edges_batch = {}
                for scale in ['local', 'regional', 'global']:
                    edges_batch[scale] = torch.cat([
                        edges[scale] + i * num_nodes for i in range(bs)
                    ], dim=1).to(x.device)

                gcn_feat = self.gcn(x, edges_batch, batch_idx)
            else:
                gcn_feat = self.gcn(graph_features)
            combined = torch.cat([img_feat, tab_feat, gcn_feat], 1)
        else:
            combined = torch.cat([img_feat, tab_feat], 1)

        combined_u = combined.unsqueeze(1)
        attn_out, _ = self.fusion_attn(combined_u, combined_u, combined_u)
        fused = self.fusion_norm(attn_out.squeeze(1) + combined)

        shared = self.shared(fused)

        return {
            'bmi': self.bmi_head(shared).squeeze(-1),
            'age': self.age_head(shared).squeeze(-1),
            'sex': self.sex_head(shared),
            'bmi_category': self.cat_head(shared)
        }

# ============================================================================
# LOSS & TRAINING UTILITIES
# ============================================================================

class MultiTaskLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.smooth_l1 = nn.SmoothL1Loss()
        self.ce = nn.CrossEntropyLoss()

    def forward(self, preds, targets, log_vars):
        bmi_loss = self.smooth_l1(preds['bmi'], targets['bmi'])
        age_loss = self.smooth_l1(preds['age'], targets['age'])
        sex_loss = self.ce(preds['sex'], targets['sex'])
        cat_loss = self.ce(preds['bmi_category'], targets['bmi_category'])

        total = (torch.exp(-log_vars[0]) * bmi_loss + log_vars[0] +
                 torch.exp(-log_vars[1]) * age_loss + log_vars[1] +
                 torch.exp(-log_vars[2]) * sex_loss + log_vars[2] +
                 torch.exp(-log_vars[3]) * cat_loss + log_vars[3])

        return total, {'bmi': bmi_loss.item(), 'age': age_loss.item(),
                       'sex': sex_loss.item(), 'cat': cat_loss.item(), 'total': total.item()}

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_state = None

    def __call__(self, score, model):
        if self.best_score is None or score < self.best_score - self.min_delta:
            self.best_score = score
            self.best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_epoch(model, loader, criterion, optimizer, device, edges=None):
    model.train()
    total_loss, total_bmi = 0, 0
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    for batch in loader:
        feat = batch['features'].to(device, non_blocking=True)
        img = batch['image'].to(device, non_blocking=True)
        graph = batch['graph_features'].to(device, non_blocking=True)
        targets = {'bmi': batch['target'].to(device, non_blocking=True),
                   'age': batch['age_target'].to(device, non_blocking=True),
                   'sex': batch['sex_target'].to(device, non_blocking=True),
                   'bmi_category': batch['bmi_category'].to(device, non_blocking=True)}

        optimizer.zero_grad(set_to_none=True)

        if scaler:
            with torch.cuda.amp.autocast():
                preds = model(feat, img, graph, edges)
                loss, losses = criterion(preds, targets, model.log_vars)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            preds = model(feat, img, graph, edges)
            loss, losses = criterion(preds, targets, model.log_vars)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        total_loss += losses['total']
        total_bmi += losses['bmi']

    return total_loss / len(loader), total_bmi / len(loader)

@torch.no_grad()
def validate_epoch(model, loader, criterion, device, edges=None):
    model.eval()
    total_loss, total_bmi = 0, 0
    all_preds, all_targets = [], []

    for batch in loader:
        feat = batch['features'].to(device, non_blocking=True)
        img = batch['image'].to(device, non_blocking=True)
        graph = batch['graph_features'].to(device, non_blocking=True)
        targets = {'bmi': batch['target'].to(device, non_blocking=True),
                   'age': batch['age_target'].to(device, non_blocking=True),
                   'sex': batch['sex_target'].to(device, non_blocking=True),
                   'bmi_category': batch['bmi_category'].to(device, non_blocking=True)}

        if device.type == 'cuda':
            with torch.cuda.amp.autocast():
                preds = model(feat, img, graph, edges)
                loss, losses = criterion(preds, targets, model.log_vars)
        else:
            preds = model(feat, img, graph, edges)
            loss, losses = criterion(preds, targets, model.log_vars)

        total_loss += losses['total']
        total_bmi += losses['bmi']
        all_preds.extend(preds['bmi'].cpu().numpy())
        all_targets.extend(targets['bmi'].cpu().numpy())

    n = len(loader)
    r2 = r2_score(all_targets, all_preds)
    mae = mean_absolute_error(all_targets, all_preds)
    return total_loss / n, total_bmi / n, r2, mae, all_preds, all_targets

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    device = get_device()

    batch_size = 128 if device.type == 'cuda' else 32
    num_workers = 4 if device.type == 'cuda' else 2
    
    num_epochs = 15
    patience = 8

    print("\n" + "="*60)
    print("Loading Raw Data")
    print("="*60)
    train_df_raw, test_df_raw = load_raw_data()

    landmark_info = get_landmark_info(train_df_raw)
    num_landmarks = len(landmark_info)
    landmark_dim = 3 if any(l[3] is not None for l in landmark_info) else 2
    print(f"Landmarks: {num_landmarks}, Dim: {landmark_dim}")

    edges = create_graphs(num_landmarks)
    if edges:
        print(f"Graph edges - L:{edges['local'].shape[1]} R:{edges['regional'].shape[1]} G:{edges['global'].shape[1]}")

    train_df_raw['bmi_bin'] = pd.qcut(train_df_raw['BMI'], q=5, labels=False, duplicates='drop')
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    print("\n" + "="*60)
    print("5-Fold Cross Validation (No Data Leakage)")
    print("="*60)

    cv_results = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(train_df_raw, train_df_raw['bmi_bin']), 1):
        print(f"\n{'='*60}")
        print(f"FOLD {fold}/5")
        print(f"{'='*60}")

        fold_train_raw = train_df_raw.iloc[train_idx].copy()
        fold_val_raw = train_df_raw.iloc[val_idx].copy()
        print(f"Train samples: {len(fold_train_raw)}, Validation samples: {len(fold_val_raw)}")

        fold_train, fold_val, feature_cols = engineer_features_for_fold(
            fold_train_raw, fold_val_raw
        )

        print(f"Number of features: {len(feature_cols)}")

        train_ds = BMIDataset(
            fold_train, feature_cols, landmark_info, IMAGE_DIR,
            is_training=True, augment_mult=3,
            fit_scalers=True
        )

        val_ds = BMIDataset(
            fold_val, feature_cols, landmark_info, IMAGE_DIR,
            is_training=False, augment_mult=1,
            feature_scaler=train_ds.feature_scaler,
            landmark_scaler=train_ds.landmark_scaler,
            fit_scalers=False
        )

        train_loader = DataLoader(
            train_ds, batch_size=batch_size, shuffle=True,
            num_workers=num_workers, pin_memory=device.type=='cuda',
            collate_fn=collate_fn, persistent_workers=False
        )
        val_loader = DataLoader(
            val_ds, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=device.type=='cuda',
            collate_fn=collate_fn, persistent_workers=False
        )

        model = HybridModel(len(feature_cols), num_landmarks, landmark_dim, 0.3, True).to(device)
        edges_device = {k: v.to(device) for k, v in edges.items()} if edges else None

        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=2e-3, weight_decay=1e-4
        )
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=5, T_mult=2, eta_min=1e-6
        )
        criterion = MultiTaskLoss()
        early_stop = EarlyStopping(patience=patience)

        best_epoch = 0
        print(f"\nStarting training for {num_epochs} epochs...")
        print("-" * 60)

        for epoch in range(num_epochs):
            train_loss, train_bmi = train_epoch(model, train_loader, criterion, optimizer, device, edges_device)
            val_loss, val_bmi, val_r2, val_mae, _, _ = validate_epoch(model, val_loader, criterion, device, edges_device)
            scheduler.step()

            is_best = ""
            if early_stop.best_score is None or val_bmi < early_stop.best_score - early_stop.min_delta:
                best_epoch = epoch + 1
                is_best = " *"

            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {train_loss:.4f}, BMI Loss: {train_bmi:.4f}")
            print(f"Validation Loss: {val_loss:.4f}, BMI Loss: {val_bmi:.4f}, R2: {val_r2:.4f}, MAE: {val_mae:.2f}{is_best}")
            print("-" * 60)

            early_stop(val_bmi, model)
            if early_stop.early_stop:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        model.load_state_dict(early_stop.best_state)
        _, _, val_r2, val_mae, preds, targets = validate_epoch(model, val_loader, criterion, device, edges_device)
        val_rmse = np.sqrt(mean_squared_error(targets, preds))

        print(f"\nFold {fold} Best Results (Epoch {best_epoch}):")
        print(f"  R2: {val_r2:.4f}")
        print(f"  MAE: {val_mae:.2f}")
        print(f"  RMSE: {val_rmse:.2f}")

        cv_results.append({
            'fold': fold, 'r2': val_r2, 'mae': val_mae, 'rmse': val_rmse,
            'best_epoch': best_epoch, 'state': early_stop.best_state,
            'feature_scaler': train_ds.feature_scaler,
            'landmark_scaler': train_ds.landmark_scaler,
            'feature_cols': feature_cols
        })

        del model, train_ds, val_ds, train_loader, val_loader
        clear_memory()

    print("\n" + "="*60)
    print("CROSS-VALIDATION SUMMARY")
    print("="*60)
    avg_r2 = np.mean([r['r2'] for r in cv_results])
    avg_mae = np.mean([r['mae'] for r in cv_results])
    avg_rmse = np.mean([r['rmse'] for r in cv_results])
    std_r2 = np.std([r['r2'] for r in cv_results])
    std_mae = np.std([r['mae'] for r in cv_results])
    std_rmse = np.std([r['rmse'] for r in cv_results])

    print(f"\nMetrics across all folds:")
    print(f"  R2:   {avg_r2:.4f} ¬± {std_r2:.4f}")
    print(f"  MAE:  {avg_mae:.2f} ¬± {std_mae:.2f}")
    print(f"  RMSE: {avg_rmse:.2f} ¬± {std_rmse:.2f}")

    print(f"\nPer-fold results:")
    for r in cv_results:
        print(f"  Fold {r['fold']}: R2={r['r2']:.4f}, MAE={r['mae']:.2f}, RMSE={r['rmse']:.2f} (Best Epoch: {r['best_epoch']})")

    best_fold_result = max(cv_results, key=lambda x: x['r2'])
    print(f"\nBest fold: {best_fold_result['fold']} (R2={best_fold_result['r2']:.4f})")

    print("\n" + "="*60)
    print("TEST SET EVALUATION")
    print("="*60)

    train_full, test_full, final_feature_cols = engineer_features_for_fold(
        train_df_raw.drop('bmi_bin', axis=1), test_df=test_df_raw
    )

    final_model = HybridModel(
        len(final_feature_cols), num_landmarks, landmark_dim, 0.3, True
    ).to(device)
    final_model.load_state_dict(best_fold_result['state'])

    test_ds = BMIDataset(
        test_full, final_feature_cols, landmark_info, IMAGE_DIR,
        is_training=False, augment_mult=1,
        feature_scaler=best_fold_result['feature_scaler'],
        landmark_scaler=best_fold_result['landmark_scaler'],
        fit_scalers=False
    )

    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=device.type=='cuda',
        collate_fn=collate_fn, persistent_workers=False
    )

    edges_device = {k: v.to(device) for k, v in edges.items()} if edges else None
    criterion = MultiTaskLoss()

    _, _, test_r2, test_mae, test_preds, test_targets = validate_epoch(
        final_model, test_loader, criterion, device, edges_device
    )
    test_rmse = np.sqrt(mean_squared_error(test_targets, test_preds))
    residuals = np.array(test_targets) - np.array(test_preds)

    print(f"\nTest Results:")
    print(f"  R2:   {test_r2:.4f}")
    print(f"  MAE:  {test_mae:.2f}")
    print(f"  RMSE: {test_rmse:.2f}")
    print(f"  Mean Residual: {residuals.mean():.2f} ¬± {residuals.std():.2f}")

    print(f"\nPrediction Accuracy:")
    for t in [1, 2, 3, 5]:
        pct = np.mean(np.abs(residuals) < t) * 100
        print(f"  Within ¬±{t} BMI: {pct:.1f}%")

    print("\n" + "="*60)
    print("SAVING RESULTS")
    print("="*60)

    save_path = os.path.join(OUTPUT_DIR, 'hybrid_model_v2.pth')
    torch.save({
        'model_state': best_fold_result['state'],
        'feature_scaler': best_fold_result['feature_scaler'],
        'landmark_scaler': best_fold_result['landmark_scaler'],
        'feature_cols': final_feature_cols,
        'num_landmarks': num_landmarks,
        'landmark_dim': landmark_dim,
        'graph_edges': edges,
        'cv_results': [{k: v for k, v in r.items() if k not in ['state', 'feature_scaler', 'landmark_scaler']}
                       for r in cv_results],
        'test_results': {'r2': test_r2, 'mae': test_mae, 'rmse': test_rmse}
    }, save_path)
    print(f"‚úì Model saved: {save_path}")

    pred_df = pd.DataFrame({
        'image_filename': test_df_raw['image_filename'].values,
        'actual_bmi': test_targets,
        'predicted_bmi': test_preds,
        'residual': residuals
    })
    pred_path = os.path.join(OUTPUT_DIR, 'predictions_v2.csv')
    pred_df.to_csv(pred_path, index=False)
    print(f"‚úì Predictions saved: {pred_path}")

    edges_path = os.path.join(OUTPUT_DIR, 'graph_edges.pth')
    torch.save(edges, edges_path)
    print(f"‚úì Graph edges saved: {edges_path}")

    results_dict = {
        'cv_results': cv_results,
        'test_targets': test_targets,
        'test_preds': test_preds,
        'test_r2': test_r2,
        'test_mae': test_mae,
        'test_rmse': test_rmse,
        'residuals': residuals
    }
    
    import pickle
    with open(os.path.join(OUTPUT_DIR, 'results_for_viz.pkl'), 'wb') as f:
        pickle.dump(results_dict, f)
    
    print(f"‚úì Results saved for visualization")

    del final_model, test_ds, test_loader
    clear_memory()

    return results_dict

if __name__ == '__main__':
    results = main()

In [None]:
# Step 1: Install required libraries
!pip install pillow -q

# Step 2: Import necessary modules
import os
from PIL import Image
import numpy as np

# Step 3: Create output directory
# Use the actual Kaggle input path (this should already exist as a dataset)
input_dir = '/kaggle/input/vip-attribute/data/data'
output_dir = '/kaggle/working/processed'
os.makedirs(output_dir, exist_ok=True)

# Step 4: Check if input directory exists
if not os.path.exists(input_dir):
    print(f"ERROR: Input directory '{input_dir}' does not exist!")
    print("Please check your dataset path.")
else:
    print(f"Input directory found: {input_dir}")
    
    # Step 5: Process all images in the input folder
    target_size = (224, 224)
    min_size = 200
    
    processed_count = 0
    skipped_count = 0
    error_count = 0
    
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(image_files)} image files")
    
    for filename in image_files:
        input_path = os.path.join(input_dir, filename)
        try:
            img = Image.open(input_path).convert("RGB")
            width, height = img.size
            
            # Skip if either dimension is below 200
            if width < min_size or height < min_size:
                
                skipped_count += 1
                continue
            
            # Resize the image while preserving aspect ratio
            img.thumbnail(target_size, Image.LANCZOS)
            
            # Create a solid black background
            final_img = Image.new("RGB", target_size, (0, 0, 0))
            
            # Center the image on the black background
            offset = ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2)
            final_img.paste(img, offset)
            
            # Save as JPEG
            output_filename = os.path.splitext(filename)[0] + '.jpg'
            output_path = os.path.join(output_dir, output_filename)
            final_img.save(output_path, "JPEG", quality=95)
            
            processed_count += 1
            
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            error_count += 1
    
    print(f"\n{'='*50}")
    print(f"Processing complete!")
    print(f"Processed: {processed_count} images")
    print(f"Skipped: {skipped_count} images (too small)")
    print(f"Errors: {error_count} images")
    print(f"Output directory: {output_dir}")
    print(f"{'='*50}")

In [None]:
import pandas as pd

# Load the dataset (change the filename if needed)
df = pd.read_csv("/kaggle/input/bmi-dataset/test.csv")

# Drop the height and weight columns
df = df.drop(columns=["height", "weight"])

# Save the updated dataset to Kaggle working directory
output_path = "/kaggle/working/cleaned_dataset.csv"
df.to_csv(output_path, index=False)
# 
# print(f"File saved to {output_path}")


In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set memory-efficient configurations
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_PATH = '/kaggle/working/cleaned_dataset.csv'
IMAGE_DIR = '/kaggle/input/bmi-dataset/ROI'
OUTPUT_DIR = './shap_analysis_lightweight'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================================
# DATA LOADING AND PREPARATION
# ============================================================================

print("="*80)
print("LIGHTWEIGHT SHAP ANALYSIS FOR BMI PREDICTION")
print("="*80)

# Load the dataset
print(f"\n1. Loading dataset from: {DATA_PATH}")
df = pd.read_csv(DATA_PATH)
print(f"   ‚úì Loaded {len(df)} samples with {len(df.columns)} columns")

if 'BMI' in df.columns:
    print(f"   ‚úì BMI column found: {df['BMI'].min():.1f} to {df['BMI'].max():.1f}")
else:
    print("   ‚ö† BMI column not found, creating synthetic BMI")
    df['BMI'] = np.random.uniform(18, 35, len(df))

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

print("\n3. Preparing features...")
exclude_cols = ['id', 'image_filename', 'image', 'BMI', 'ID', 'filename']
feature_cols = [col for col in df.columns 
                if col not in exclude_cols 
                and pd.api.types.is_numeric_dtype(df[col])]

print(f"   ‚úì Selected {len(feature_cols)} numeric features")

MAX_SAMPLES = 100
if len(df) > MAX_SAMPLES:
    df_sample = df.sample(MAX_SAMPLES, random_state=42)
    print(f"   ‚ö† Using subset of {MAX_SAMPLES} samples")
else:
    df_sample = df

X = df_sample[feature_cols].fillna(0).astype(np.float32)
y = df_sample['BMI'].values

print(f"   ‚úì X shape: {X.shape}, y range: {y.min():.1f} to {y.max():.1f}")

# ============================================================================
# TRAIN MODEL
# ============================================================================

print("\n4. Training model...")
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

model = RandomForestRegressor(n_estimators=50, max_depth=10, min_samples_split=5, random_state=42, n_jobs=-1)
model.fit(X_train_scaled, y_train)

y_pred = model.predict(X_test_scaled)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"   ‚úì Model - MAE: {mae:.4f}, R¬≤: {r2:.4f}")

# ============================================================================
# FIXED SHAP ANALYSIS
# ============================================================================

print("\n5. SHAP analysis with FIXED plots...")

try:
    import shap
    
    explainer = shap.TreeExplainer(model)
    shap_sample_size = min(50, len(X_test_scaled))
    X_sample = X_test_scaled[:shap_sample_size]
    shap_values = explainer.shap_values(X_sample)
    
    print(f"   ‚úì SHAP values for {shap_sample_size} samples")
    
    # FIXED SUMMARY PLOT
    plt.figure(figsize=(14, 10))
    shap.summary_plot(shap_values, X_sample, feature_names=feature_cols, 
                     show=False, max_display=15)
    plt.title('SHAP Feature Importance for BMI Prediction', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'shap_summary.png'), dpi=150, bbox_inches='tight')
    plt.close()
    
    # FIXED BAR PLOT - COMPLETE X-AXIS LABEL
    plt.figure(figsize=(16, 12))  # Larger size
    shap.summary_plot(shap_values, X_sample, feature_names=feature_cols,
                     plot_type="bar", show=False, max_display=20)
    
    # Customize to prevent cutoff
    plt.title('Mean Absolute SHAP Values (Impact on Model Output)', 
              fontsize=18, fontweight='bold', pad=25)
    plt.xlabel('Mean |SHAP value| (average impact on model output)', fontsize=14)
    plt.ylabel('Features', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=11)
    
    # CRITICAL: Adjust margins to prevent x-axis label cutoff
    plt.subplots_adjust(left=0.18, right=0.98, top=0.90, bottom=0.15, hspace=0.3)
    
    plt.savefig(os.path.join(OUTPUT_DIR, 'shap_bar.png'), 
                dpi=200, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print("   ‚úì FIXED bar plot saved - full x-axis label visible!")
    
    # Feature importance table
    mean_abs_shap = np.abs(shap_values).mean(axis=0)
    importance_df = pd.DataFrame({
        'Feature': feature_cols,
        'Mean_abs_SHAP': mean_abs_shap,
        'Rank': np.argsort(-mean_abs_shap) + 1
    }).sort_values('Mean_abs_SHAP', ascending=False)
    
    importance_df.to_csv(os.path.join(OUTPUT_DIR, 'feature_importance.csv'), index=False)
    
    # TOP FEATURES HORIZONTAL BAR CHART
    plt.figure(figsize=(12, 10))
    top_n = min(15, len(importance_df))
    top_features = importance_df.head(top_n)
    
    bars = plt.barh(range(top_n), top_features['Mean_abs_SHAP'][::-1], 
                    color='steelblue', alpha=0.8, edgecolor='navy', linewidth=0.5)
    plt.yticks(range(top_n), top_features['Feature'][::-1], fontsize=11)
    plt.xlabel('Mean Absolute SHAP Value', fontsize=14, fontweight='bold')
    plt.title(f'Top {top_n} Most Important Features for BMI Prediction', 
              fontsize=16, fontweight='bold')
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'top_features.png'), dpi=200, bbox_inches='tight')
    plt.close()
    
    print(f"\n‚úÖ SHAP analysis COMPLETE!")
    print(f"üìÅ Files saved to: {OUTPUT_DIR}")
    print("\nüèÜ TOP 5 FEATURES:")
    for i, row in enumerate(importance_df.head(5).itertuples(), 1):
        print(f"{i:2d}. {row.Feature:35s} | SHAP: {row.Mean_abs_SHAP:.4f}")

except Exception as e:
    print(f"‚ö† SHAP failed: {e}")
    print("Using model feature importance...")
    
    feature_importance = pd.DataFrame({
        'Feature': feature_cols,
        'Importance': model.feature_importances_
    }).sort_values('Importance', ascending=False)
    
    feature_importance.to_csv(os.path.join(OUTPUT_DIR, 'feature_importance.csv'), index=False)
    
    plt.figure(figsize=(12, 10))
    top_n = 15
    plt.barh(range(top_n), feature_importance['Importance'].head(top_n)[::-1], 
             color='steelblue', alpha=0.8)
    plt.yticks(range(top_n), feature_importance['Feature'].head(top_n)[::-1])
    plt.xlabel('Feature Importance')
    plt.title('Top Features (Model Built-in Importance)')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'feature_importance.png'), dpi=200, bbox_inches='tight')
    plt.close()

print(f"\nüéâ Analysis complete! Check '{OUTPUT_DIR}' for all plots and CSV files.")


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
from sklearn.ensemble import RandomForestRegressor
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import warnings
import os
warnings.filterwarnings('ignore')

# ============================================================================
# FAST ABLATION CONFIGURATION
# ============================================================================

# üöÄ SPEED OPTIMIZATIONS
FAST_MODE = True  # Set to False for full training
SAMPLE_SIZE = 5000  # Use only 500 samples
NUM_EPOCHS = 10    # Reduced from 15
PATIENCE = 3      # Reduced from 5

print("="*80)
print("‚ö° FAST ABLATION MODE ENABLED")
print("="*80)
print(f"üìä Sample Size: {SAMPLE_SIZE} (vs full dataset)")
print(f"üîÑ Max Epochs: {NUM_EPOCHS} (vs 15)")
print(f"‚è±Ô∏è  Patience: {PATIENCE} (vs 5)")
print(f"üéØ Expected Time: ~5-15 minutes (vs 12 hours)")
print("="*80 + "\n")

# ============================================================================
# ABLATION STUDY CONFIGURATION
# ============================================================================

ABLATION_CONFIGS = {
    'FULL_MODEL': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'Full multimodal model with all features'
    },
    'NO_IMAGES': {
        'use_images': False,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'Tabular features only (no images)'
    },
    'NO_TABULAR': {
        'use_images': True,
        'use_tabular': False,
        'use_augmentation': True,
        'use_derived_features': False,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.0,
        'description': 'Images only (no tabular features)'
    },
    'NO_AUGMENTATION': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': False,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.0,
        'description': 'No data augmentation'
    },
    'NO_DERIVED_FEATURES': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': False,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 2,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'No engineered/derived features'
    },
    'NO_DROPOUT': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': False,
        'dropout_rate': 0.0,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'No dropout regularization'
    },
    'HIGH_DROPOUT': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.5,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'High dropout (0.5 instead of 0.3)'
    },
    'NO_BATCH_NORM': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': False,
        'use_pretrained': True,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'No batch normalization'
    },
    'NO_PRETRAINED': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': False,
        'augment_multiplier': 3,  # Changed from 3
        'feature_noise': 0.02,
        'description': 'Random initialization (no pretrained weights)'
    },
    'MINIMAL_MODEL': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': False,
        'use_derived_features': False,
        'use_dropout': False,
        'dropout_rate': 0.0,
        'use_batch_norm': False,
        'use_pretrained': False,
        'augment_multiplier': 3,
        'feature_noise': 0.0,
        'description': 'Minimal model (basic features, no regularization)'
    }
}

# ============================================================================
# SETUP
# ============================================================================

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Paths (adjust these to your environment)
DATA_DIR = '/kaggle/input/bmi-dataset'
TRAIN_CSV = '/kaggle/input/bmi-dataset/train.csv'
TEST_CSV = '/kaggle/input/bmi-dataset/test.csv'
IMAGE_DIR = '/kaggle/input/bmi-dataset/ROI'

print("üî¨ ABLATION STUDY - BMI PREDICTION")
print("="*80)

# Load data
try:
    train_df = pd.read_csv(TRAIN_CSV)
    test_df = pd.read_csv(TEST_CSV)
    print(f"‚úÖ Data loaded: Train={train_df.shape}, Test={test_df.shape}")
except:
    print("‚ùå Error loading data. Using synthetic data for demonstration.")
    # Create synthetic data
    n_train, n_test = 1000, 200
    train_df = pd.DataFrame({
        'BMI': np.random.normal(25, 5, n_train),
        'age': np.random.randint(18, 80, n_train),
        'sex_encoded': np.random.randint(0, 2, n_train),
        'face_height': np.random.normal(100, 10, n_train),
        'face_width_cheeks': np.random.normal(80, 8, n_train),
        'face_width_jaw': np.random.normal(75, 8, n_train),
        'left_eye_width': np.random.normal(20, 2, n_train),
        'right_eye_width': np.random.normal(20, 2, n_train),
        'nose_length': np.random.normal(30, 3, n_train),
        'nose_width': np.random.normal(25, 2.5, n_train),
        'image_filename': [f'img_{i}.jpg' for i in range(n_train)]
    })
    test_df = train_df.copy().sample(n_test)
    IMAGE_DIR = None

# üöÄ SAMPLE DATA FOR SPEED
if FAST_MODE and len(train_df) > SAMPLE_SIZE:
    print(f"\n‚ö° Fast Mode: Sampling {SAMPLE_SIZE} rows from {len(train_df)}")
    train_df = train_df.sample(SAMPLE_SIZE, random_state=42).reset_index(drop=True)
    print(f"‚úÖ Using {len(train_df)} samples for ablation study")

# Drop height/weight if present
for col in ['height', 'weight']:
    if col in train_df.columns:
        train_df = train_df.drop(columns=[col])
        test_df = test_df.drop(columns=[col])

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

def prepare_features(df, include_derived=True):
    """Prepare features with optional derived features."""
    df_eng = df.copy()
    
    # Basic features
    basic_features = ['face_height', 'face_width_cheeks', 'face_width_jaw',
                      'left_eye_width', 'right_eye_width', 'nose_length', 
                      'nose_width', 'age', 'sex_encoded']
    
    available_features = [f for f in basic_features if f in df_eng.columns]
    
    if include_derived:
        # Derived features
        if 'face_height' in df_eng.columns and 'face_width_cheeks' in df_eng.columns:
            df_eng['fwhr'] = (df_eng['face_width_cheeks'] / 
                             df_eng['face_height'].clip(lower=1e-8)).fillna(0)
            available_features.append('fwhr')
        
        if 'left_eye_width' in df_eng.columns and 'right_eye_width' in df_eng.columns:
            sum_eyes = df_eng['left_eye_width'] + df_eng['right_eye_width']
            df_eng['eye_symmetry'] = 1 - abs(df_eng['left_eye_width'] - 
                                             df_eng['right_eye_width']) / sum_eyes.clip(lower=1e-8)
            df_eng['eye_symmetry'] = df_eng['eye_symmetry'].fillna(0)
            available_features.append('eye_symmetry')
        
        if 'face_width_cheeks' in df_eng.columns and 'face_height' in df_eng.columns:
            df_eng['face_area'] = (df_eng['face_width_cheeks'] * 
                                  df_eng['face_height']).fillna(0)
            available_features.append('face_area')
    
    return df_eng, available_features

# ============================================================================
# DATASET CLASS
# ============================================================================

class AblationDataset(Dataset):
    """Dataset for ablation study."""
    def __init__(self, df, feature_cols, image_dir, config, is_training=True):
        self.df = df.reset_index(drop=True)
        self.feature_cols = feature_cols
        self.image_dir = image_dir
        self.config = config
        self.is_training = is_training
        
        if config['use_tabular']:
            self.features = df[feature_cols].values.astype(np.float32)
        self.targets = df['BMI'].values.astype(np.float32)
        
        if 'image_filename' in df.columns:
            self.filenames = df['image_filename'].tolist()
        else:
            self.filenames = [f'img_{i}.jpg' for i in range(len(df))]
        
        # Setup transforms
        if config['use_images']:
            if is_training and config['use_augmentation']:
                self.transform = transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
    
    def __len__(self):
        if self.is_training and self.config['augment_multiplier'] > 1:
            return len(self.df) * self.config['augment_multiplier']
        return len(self.df)
    
    def __getitem__(self, idx):
        original_idx = idx % len(self.df)
        is_augmented = idx >= len(self.df)
        
        target = torch.tensor(self.targets[original_idx], dtype=torch.float32)
        
        # Tabular features
        if self.config['use_tabular']:
            features = torch.tensor(self.features[original_idx], dtype=torch.float32)
            if self.is_training and is_augmented and self.config['feature_noise'] > 0:
                noise = torch.randn_like(features) * self.config['feature_noise']
                features = features + noise
        else:
            features = torch.zeros(1)  # Dummy
        
        # Image
        if self.config['use_images'] and self.image_dir is not None:
            filename = self.filenames[original_idx]
            img_path = os.path.join(self.image_dir, filename)
            try:
                with Image.open(img_path) as img:
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    image = self.transform(img)
            except:
                image = torch.zeros(3, 224, 224)
        else:
            # Dummy image
            image = torch.randn(3, 224, 224) * 0.1
        
        return {
            'features': features,
            'image': image,
            'target': target
        }

# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================

class AblationModel(nn.Module):
    """Model for ablation study."""
    def __init__(self, num_features, config):
        super().__init__()
        self.config = config
        
        # Image branch
        if config['use_images']:
            if config['use_pretrained']:
                resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            else:
                resnet = models.resnet18(weights=None)
            
            for param in resnet.parameters():
                param.requires_grad = False
            self.image_features = nn.Sequential(*list(resnet.children())[:-1])
            image_dim = 512
        else:
            image_dim = 0
        
        # Tabular branch
        if config['use_tabular']:
            layers = []
            layers.append(nn.Linear(num_features, 128))
            if config['use_batch_norm']:
                layers.append(nn.BatchNorm1d(128))
            layers.append(nn.ReLU())
            if config['use_dropout']:
                layers.append(nn.Dropout(config['dropout_rate']))
            
            layers.append(nn.Linear(128, 64))
            if config['use_batch_norm']:
                layers.append(nn.BatchNorm1d(64))
            layers.append(nn.ReLU())
            
            self.tabular_features = nn.Sequential(*layers)
            tabular_dim = 64
        else:
            tabular_dim = 0
        
        # Fusion
        combined_dim = image_dim + tabular_dim
        
        if combined_dim == 0:
            raise ValueError("Must use at least images or tabular features!")
        
        fusion_layers = []
        fusion_layers.append(nn.Linear(combined_dim, 128))
        if config['use_batch_norm']:
            fusion_layers.append(nn.BatchNorm1d(128))
        fusion_layers.append(nn.ReLU())
        if config['use_dropout']:
            fusion_layers.append(nn.Dropout(config['dropout_rate']))
        
        fusion_layers.append(nn.Linear(128, 64))
        if config['use_batch_norm']:
            fusion_layers.append(nn.BatchNorm1d(64))
        fusion_layers.append(nn.ReLU())
        
        self.fusion = nn.Sequential(*fusion_layers)
        self.output = nn.Linear(64, 1)
    
    def forward(self, features, image):
        batch_size = features.size(0)
        feat_list = []
        
        if self.config['use_images']:
            img_feat = self.image_features(image)
            img_feat = img_feat.view(batch_size, -1)
            feat_list.append(img_feat)
        
        if self.config['use_tabular']:
            tab_feat = self.tabular_features(features)
            feat_list.append(tab_feat)
        
        combined = torch.cat(feat_list, dim=1)
        fused = self.fusion(combined)
        output = self.output(fused)
        
        return output.squeeze()

# ============================================================================
# TRAINING FUNCTION
# ============================================================================

def train_model(config, train_df, val_df, feature_cols, device, 
                num_epochs=NUM_EPOCHS, patience=PATIENCE):
    """Train model with given configuration."""
    
    # Prepare features
    if config['use_derived_features']:
        train_eng, train_features = prepare_features(train_df, include_derived=True)
        val_eng, _ = prepare_features(val_df, include_derived=True)
    else:
        train_eng, train_features = prepare_features(train_df, include_derived=False)
        val_eng, _ = prepare_features(val_df, include_derived=False)
    
    # Use provided feature_cols or computed ones
    if config['use_tabular']:
        use_features = [f for f in train_features if f in train_eng.columns]
    else:
        use_features = ['age']  # Dummy
    
    # Scale features
    if config['use_tabular']:
        scaler = RobustScaler()
        train_eng[use_features] = scaler.fit_transform(
            train_eng[use_features].fillna(0).values)
        val_eng[use_features] = scaler.transform(
            val_eng[use_features].fillna(0).values)
    else:
        scaler = None
    
    # Create datasets
    train_dataset = AblationDataset(train_eng, use_features, IMAGE_DIR, 
                                    config, is_training=True)
    val_dataset = AblationDataset(val_eng, use_features, IMAGE_DIR, 
                                  config, is_training=False)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, 
                             num_workers=2)  # Changed from 0 to 2
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, 
                           num_workers=2)  # Changed from 0 to 2
    
    # Create model
    model = AblationModel(len(use_features) if config['use_tabular'] else 1, 
                         config).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                           lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                     factor=0.5, patience=2)  # Reduced patience
    
    history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Train
        model.train()
        train_loss = 0
        for batch in train_loader:
            features = batch['features'].to(device)
            images = batch['image'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            outputs = model(features, images)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validate
        model.eval()
        val_loss = 0
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for batch in val_loader:
                features = batch['features'].to(device)
                images = batch['image'].to(device)
                targets = batch['target'].to(device)
                
                outputs = model(features, images)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())
        
        avg_val_loss = val_loss / len(val_loader)
        val_r2 = r2_score(val_targets, val_preds)
        
        scheduler.step(avg_val_loss)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_r2'].append(val_r2)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"   Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Final validation
    model.eval()
    val_preds, val_targets = [], []
    with torch.no_grad():
        for batch in val_loader:
            features = batch['features'].to(device)
            images = batch['image'].to(device)
            targets = batch['target'].to(device)
            outputs = model(features, images)
            val_preds.extend(outputs.cpu().numpy())
            val_targets.extend(targets.cpu().numpy())
    
    final_r2 = r2_score(val_targets, val_preds)
    final_mae = mean_absolute_error(val_targets, val_preds)
    final_rmse = np.sqrt(mean_squared_error(val_targets, val_preds))
    
    return {
        'model': model,
        'scaler': scaler,
        'r2': final_r2,
        'mae': final_mae,
        'rmse': final_rmse,
        'history': history,
        'use_features': use_features
    }

# ============================================================================
# RUN ABLATION STUDY
# ============================================================================

print("\n" + "="*80)
print("üöÄ RUNNING FAST ABLATION STUDY")
print("="*80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

# Use simple train/val split for speed
train_idx, val_idx = train_test_split(range(len(train_df)), test_size=0.2, 
                                       random_state=42)
train_split = train_df.iloc[train_idx].copy()
val_split = train_df.iloc[val_idx].copy()

print(f"Train samples: {len(train_split)}, Val samples: {len(val_split)}\n")

ablation_results = {}
import time

total_start = time.time()

for i, (config_name, config) in enumerate(ABLATION_CONFIGS.items(), 1):
    print(f"\n{'='*60}")
    print(f"üß™ [{i}/{len(ABLATION_CONFIGS)}] {config_name}")
    print(f"üìù {config['description']}")
    print(f"{'='*60}")
    
    config_start = time.time()
    
    try:
        result = train_model(
            config, train_split, val_split, 
            feature_cols=None,  # Will be computed inside
            device=device,
            num_epochs=NUM_EPOCHS,
            patience=PATIENCE
        )
        
        config_time = time.time() - config_start
        
        ablation_results[config_name] = result
        
        print(f"‚úÖ Results (trained in {config_time:.1f}s):")
        print(f"   R¬≤: {result['r2']:.4f}")
        print(f"   MAE: {result['mae']:.2f}")
        print(f"   RMSE: {result['rmse']:.2f}")
        
    except Exception as e:
        print(f"‚ùå Failed: {str(e)}")
        ablation_results[config_name] = None

total_time = time.time() - total_start
print(f"\n‚è±Ô∏è  Total training time: {total_time/60:.1f} minutes")

# ============================================================================
# VISUALIZE RESULTS
# ============================================================================

print("\n" + "="*80)
print("üìä CREATING ABLATION VISUALIZATIONS")
print("="*80)

# Filter successful results
successful_results = {k: v for k, v in ablation_results.items() if v is not None}

if len(successful_results) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. R¬≤ Comparison
    ax1 = axes[0, 0]
    names = list(successful_results.keys())
    r2_scores = [successful_results[n]['r2'] for n in names]
    colors = ['red' if n == 'FULL_MODEL' else 'skyblue' for n in names]
    bars = ax1.barh(names, r2_scores, color=colors, edgecolor='black')
    ax1.set_xlabel('R¬≤ Score', fontsize=12)
    ax1.set_title('R¬≤ Score by Configuration', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='x')
    for bar, score in zip(bars, r2_scores):
        width = bar.get_width()
        ax1.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
                f'{score:.3f}', va='center', fontsize=10)
    
    # 2. MAE Comparison
    ax2 = axes[0, 1]
    mae_scores = [successful_results[n]['mae'] for n in names]
    colors = ['red' if n == 'FULL_MODEL' else 'coral' for n in names]
    bars = ax2.barh(names, mae_scores, color=colors, edgecolor='black')
    ax2.set_xlabel('MAE (kg/m¬≤)', fontsize=12)
    ax2.set_title('MAE by Configuration', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='x')
    for bar, score in zip(bars, mae_scores):
        width = bar.get_width()
        ax2.text(width + 0.05, bar.get_y() + bar.get_height()/2, 
                f'{score:.2f}', va='center', fontsize=10)
    
    # 3. Performance Degradation from Full Model
    ax3 = axes[1, 0]
    full_r2 = successful_results['FULL_MODEL']['r2']
    degradation = [(full_r2 - successful_results[n]['r2']) * 100 for n in names]
    colors = ['green' if d <= 0 else 'red' for d in degradation]
    bars = ax3.barh(names, degradation, color=colors, edgecolor='black', alpha=0.7)
    ax3.set_xlabel('R¬≤ Degradation (%)', fontsize=12)
    ax3.set_title('Performance Drop vs Full Model', fontsize=14, fontweight='bold')
    ax3.axvline(x=0, color='black', linestyle='--', linewidth=2)
    ax3.grid(True, alpha=0.3, axis='x')
    
    # 4. Training History (Full Model)
    ax4 = axes[1, 1]
    if 'FULL_MODEL' in successful_results:
        history = successful_results['FULL_MODEL']['history']
        epochs = range(1, len(history['train_loss']) + 1)
        ax4.plot(epochs, history['train_loss'], label='Train Loss', linewidth=2)
        ax4.plot(epochs, history['val_loss'], label='Val Loss', linewidth=2)
        ax4_twin = ax4.twinx()
        ax4_twin.plot(epochs, history['val_r2'], label='Val R¬≤', 
                     color='green', linewidth=2, linestyle='--')
        ax4.set_xlabel('Epoch', fontsize=12)
        ax4.set_ylabel('Loss', fontsize=12)
        ax4_twin.set_ylabel('R¬≤ Score', fontsize=12)
        ax4.set_title('Full Model Training History', fontsize=14, fontweight='bold')
        ax4.legend(loc='upper left')
        ax4_twin.legend(loc='upper right')
        ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('ablation_study_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("‚úÖ Visualization saved as 'ablation_study_results.png'")

# ============================================================================
# SUMMARY TABLE
# ============================================================================

print("\n" + "="*80)
print("üìã ABLATION STUDY SUMMARY")
print("="*80)

summary_data = []
for name, result in successful_results.items():
    config = ABLATION_CONFIGS[name]
    summary_data.append({
        'Configuration': name,
        'Description': config['description'],
        'R¬≤': result['r2'],
        'MAE': result['mae'],
        'RMSE': result['rmse'],
        'R¬≤_Drop_%': (successful_results['FULL_MODEL']['r2'] - result['r2']) * 100
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('R¬≤', ascending=False)
summary_df.to_csv('ablation_study_summary.csv', index=False)

print("\n" + summary_df.to_string(index=False))

print("\n" + "="*80)
print("üéØ KEY FINDINGS")
print("="*80)

# Find most impactful components
if len(successful_results) > 1:
    full_r2 = successful_results['FULL_MODEL']['r2']
    
    impact = {}
    for name, result in successful_results.items():
        if name != 'FULL_MODEL':
            impact[name] = full_r2 - result['r2']
    
    sorted_impact = sorted(impact.items(), key=lambda x: x[1], reverse=True)
    
    print("\nüîç Most Impactful Components (by R¬≤ drop):")
    for i, (name, drop) in enumerate(sorted_impact[:5], 1):
        percentage = (drop / full_r2) * 100 if full_r2 > 0 else 0
        print(f"   {i}. {name}: -{drop:.4f} ({percentage:.1f}% degradation)")
        print(f"      ‚Üí {ABLATION_CONFIGS[name]['description']}")
    
    print("\nüí° Least Impactful (most redundant):")
    for i, (name, drop) in enumerate(sorted_impact[-3:], 1):
        percentage = (drop / full_r2) * 100 if full_r2 > 0 else 0
        print(f"   {i}. {name}: -{drop:.4f} ({percentage:.1f}% degradation)")
        print(f"      ‚Üí {ABLATION_CONFIGS[name]['description']}")

print("\n‚úÖ Ablation study complete!")
print(f"   Tested {len(successful_results)} configurations")
print(f"   Results saved to 'ablation_study_summary.csv'")
print(f"   Visualizations saved to 'ablation_study_results.png'")

# ============================================================================
# COMPONENT CONTRIBUTION ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("üßÆ COMPONENT CONTRIBUTION ANALYSIS")
print("="*80)

if 'FULL_MODEL' in successful_results:
    full_r2 = successful_results['FULL_MODEL']['r2']
    full_mae = successful_results['FULL_MODEL']['mae']
    
    contributions = {
        'Images': {
            'r2_contribution': full_r2 - successful_results.get('NO_IMAGES', {}).get('r2', 0) if 'NO_IMAGES' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_IMAGES', {}).get('mae', 0) - full_mae if 'NO_IMAGES' in successful_results else 0
        },
        'Tabular Features': {
            'r2_contribution': full_r2 - successful_results.get('NO_TABULAR', {}).get('r2', 0) if 'NO_TABULAR' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_TABULAR', {}).get('mae', 0) - full_mae if 'NO_TABULAR' in successful_results else 0
        },
        'Data Augmentation': {
            'r2_contribution': full_r2 - successful_results.get('NO_AUGMENTATION', {}).get('r2', 0) if 'NO_AUGMENTATION' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_AUGMENTATION', {}).get('mae', 0) - full_mae if 'NO_AUGMENTATION' in successful_results else 0
        },
        'Derived Features': {
            'r2_contribution': full_r2 - successful_results.get('NO_DERIVED_FEATURES', {}).get('r2', 0) if 'NO_DERIVED_FEATURES' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_DERIVED_FEATURES', {}).get('mae', 0) - full_mae if 'NO_DERIVED_FEATURES' in successful_results else 0
        },
        'Dropout': {
            'r2_contribution': full_r2 - successful_results.get('NO_DROPOUT', {}).get('r2', 0) if 'NO_DROPOUT' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_DROPOUT', {}).get('mae', 0) - full_mae if 'NO_DROPOUT' in successful_results else 0
        },
        'Batch Normalization': {
            'r2_contribution': full_r2 - successful_results.get('NO_BATCH_NORM', {}).get('r2', 0) if 'NO_BATCH_NORM' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_BATCH_NORM', {}).get('mae', 0) - full_mae if 'NO_BATCH_NORM' in successful_results else 0
        },
        'Pretrained Weights': {
            'r2_contribution': full_r2 - successful_results.get('NO_PRETRAINED', {}).get('r2', 0) if 'NO_PRETRAINED' in successful_results else 0,
            'mae_improvement': successful_results.get('NO_PRETRAINED', {}).get('mae', 0) - full_mae if 'NO_PRETRAINED' in successful_results else 0
        }
    }
    
    print("\nüìä Individual Component Contributions:")
    contrib_df = pd.DataFrame(contributions).T
    contrib_df = contrib_df.sort_values('r2_contribution', ascending=False)
    
    for component, values in contrib_df.iterrows():
        r2_contrib = values['r2_contribution']
        mae_improve = values['mae_improvement']
        percentage = (r2_contrib / full_r2) * 100 if full_r2 > 0 else 0
        
        print(f"\n   {component}:")
        print(f"      R¬≤ Contribution: {r2_contrib:+.4f} ({percentage:+.1f}%)")
        print(f"      MAE Improvement: {mae_improve:+.2f} kg/m¬≤")
        
        # Rating
        if abs(r2_contrib) > 0.05:
            rating = "üî¥ CRITICAL"
        elif abs(r2_contrib) > 0.02:
            rating = "üü° IMPORTANT"
        else:
            rating = "üü¢ MINOR"
        print(f"      Impact: {rating}")
    
    contrib_df.to_csv('component_contributions.csv')
    print("\n‚úÖ Component contributions saved to 'component_contributions.csv'")

# ============================================================================
# DROPOUT RATE ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("üéõÔ∏è DROPOUT RATE ANALYSIS")
print("="*80)

dropout_configs = ['NO_DROPOUT', 'FULL_MODEL', 'HIGH_DROPOUT']
dropout_available = [c for c in dropout_configs if c in successful_results]

if len(dropout_available) > 1:
    print("\nüìä Dropout Rate vs Performance:")
    dropout_rates = {
        'NO_DROPOUT': 0.0,
        'FULL_MODEL': 0.3,
        'HIGH_DROPOUT': 0.5
    }
    
    for config in dropout_available:
        rate = dropout_rates[config]
        result = successful_results[config]
        print(f"   Dropout {rate:.1f}: R¬≤={result['r2']:.4f}, MAE={result['mae']:.2f}")
    
    # Find optimal
    best_dropout_config = max(dropout_available, 
                             key=lambda x: successful_results[x]['r2'])
    optimal_rate = dropout_rates[best_dropout_config]
    print(f"\nüéØ Optimal Dropout Rate: {optimal_rate:.1f}")

# ============================================================================
# MODALITY COMPARISON
# ============================================================================

print("\n" + "="*80)
print("üîÄ MODALITY COMPARISON")
print("="*80)

modality_configs = {
    'Images Only': 'NO_TABULAR',
    'Tabular Only': 'NO_IMAGES',
    'Multimodal (Both)': 'FULL_MODEL'
}

available_modalities = {k: v for k, v in modality_configs.items() 
                        if v in successful_results}

if len(available_modalities) >= 2:
    print("\nüìä Performance by Modality:")
    
    modality_results = []
    for name, config in available_modalities.items():
        result = successful_results[config]
        modality_results.append({
            'Modality': name,
            'R¬≤': result['r2'],
            'MAE': result['mae'],
            'RMSE': result['rmse']
        })
        print(f"   {name}:")
        print(f"      R¬≤: {result['r2']:.4f}")
        print(f"      MAE: {result['mae']:.2f} kg/m¬≤")
    
    # Calculate synergy
    if 'FULL_MODEL' in successful_results and 'NO_IMAGES' in successful_results and 'NO_TABULAR' in successful_results:
        full_r2 = successful_results['FULL_MODEL']['r2']
        images_only_r2 = successful_results['NO_TABULAR']['r2']
        tabular_only_r2 = successful_results['NO_IMAGES']['r2']
        
        expected_combined = (images_only_r2 + tabular_only_r2) / 2
        actual_combined = full_r2
        synergy = actual_combined - expected_combined
        
        print(f"\nüî¨ Multimodal Synergy Analysis:")
        print(f"   Images Only R¬≤: {images_only_r2:.4f}")
        print(f"   Tabular Only R¬≤: {tabular_only_r2:.4f}")
        print(f"   Expected Combined (avg): {expected_combined:.4f}")
        print(f"   Actual Multimodal R¬≤: {actual_combined:.4f}")
        print(f"   Synergy Effect: {synergy:+.4f}")
        
        if synergy > 0.01:
            print(f"   ‚Üí ‚ú® POSITIVE SYNERGY: Modalities complement each other!")
        elif synergy < -0.01:
            print(f"   ‚Üí ‚ö†Ô∏è NEGATIVE SYNERGY: Potential interference between modalities")
        else:
            print(f"   ‚Üí ‚ûñ NEUTRAL: Modalities are independent")

# ============================================================================
# REGULARIZATION ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("üõ°Ô∏è REGULARIZATION ANALYSIS")
print("="*80)

reg_configs = {
    'No Regularization': 'MINIMAL_MODEL',
    'Dropout Only': None,
    'Batch Norm Only': None,
    'Full Regularization': 'FULL_MODEL'
}

print("\nüìä Regularization Impact:")
if 'MINIMAL_MODEL' in successful_results and 'FULL_MODEL' in successful_results:
    minimal_r2 = successful_results['MINIMAL_MODEL']['r2']
    full_r2 = successful_results['FULL_MODEL']['r2']
    
    improvement = full_r2 - minimal_r2
    percentage = (improvement / minimal_r2) * 100 if minimal_r2 > 0 else 0
    
    print(f"   Minimal Model R¬≤: {minimal_r2:.4f}")
    print(f"   Full Regularization R¬≤: {full_r2:.4f}")
    print(f"   Improvement: {improvement:+.4f} ({percentage:+.1f}%)")
    
    if improvement > 0.02:
        print(f"   ‚Üí ‚úÖ Regularization provides significant benefit")
    else:
        print(f"   ‚Üí ‚ö†Ô∏è Regularization has minimal impact")

# ============================================================================
# TRAINING EFFICIENCY ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("‚ö° TRAINING EFFICIENCY ANALYSIS")
print("="*80)

print("\nüìä Epochs to Convergence:")
for name, result in successful_results.items():
    if result and 'history' in result:
        epochs_trained = len(result['history']['train_loss'])
        final_r2 = result['r2']
        
        # Find epoch where 95% of final R¬≤ was reached
        if 'val_r2' in result['history']:
            val_r2_history = result['history']['val_r2']
            target_r2 = final_r2 * 0.95
            
            epochs_to_95 = epochs_trained
            for i, r2 in enumerate(val_r2_history):
                if r2 >= target_r2:
                    epochs_to_95 = i + 1
                    break
            
            print(f"   {name}:")
            print(f"      Total epochs: {epochs_trained}")
            print(f"      Epochs to 95% performance: {epochs_to_95}")
            efficiency = (epochs_to_95/epochs_trained)*100 if epochs_trained > 0 else 0
            print(f"      Efficiency: {efficiency:.1f}%")

# ============================================================================
# RECOMMENDATIONS
# ============================================================================

print("\n" + "="*80)
print("üí° RECOMMENDATIONS")
print("="*80)

if len(successful_results) > 0:
    full_r2 = successful_results.get('FULL_MODEL', {}).get('r2', 0)
    
    recommendations = []
    
    # Check each component
    if 'NO_IMAGES' in successful_results:
        images_drop = full_r2 - successful_results['NO_IMAGES']['r2']
        if images_drop > 0.05:
            recommendations.append("‚úÖ KEEP: Images are critical (R¬≤ drop: {:.3f})".format(images_drop))
        elif images_drop < 0.01:
            recommendations.append("‚ùå CONSIDER REMOVING: Images add minimal value (R¬≤ drop: {:.3f})".format(images_drop))
    
    if 'NO_TABULAR' in successful_results:
        tabular_drop = full_r2 - successful_results['NO_TABULAR']['r2']
        if tabular_drop > 0.05:
            recommendations.append("‚úÖ KEEP: Tabular features are critical (R¬≤ drop: {:.3f})".format(tabular_drop))
        elif tabular_drop < 0.01:
            recommendations.append("‚ùå CONSIDER REMOVING: Tabular features add minimal value (R¬≤ drop: {:.3f})".format(tabular_drop))
    
    if 'NO_AUGMENTATION' in successful_results:
        aug_drop = full_r2 - successful_results['NO_AUGMENTATION']['r2']
        if aug_drop > 0.02:
            recommendations.append("‚úÖ KEEP: Data augmentation is important (R¬≤ drop: {:.3f})".format(aug_drop))
        elif aug_drop < 0.005:
            recommendations.append("‚ö†Ô∏è OPTIONAL: Data augmentation has minimal impact (R¬≤ drop: {:.3f})".format(aug_drop))
    
    if 'NO_DERIVED_FEATURES' in successful_results:
        derived_drop = full_r2 - successful_results['NO_DERIVED_FEATURES']['r2']
        if derived_drop > 0.02:
            recommendations.append("‚úÖ KEEP: Derived features are valuable (R¬≤ drop: {:.3f})".format(derived_drop))
        elif derived_drop < 0.005:
            recommendations.append("‚ö†Ô∏è SIMPLIFY: Derived features add minimal value (R¬≤ drop: {:.3f})".format(derived_drop))
    
    if 'NO_DROPOUT' in successful_results:
        dropout_drop = full_r2 - successful_results['NO_DROPOUT']['r2']
        if dropout_drop > 0.01:
            recommendations.append("‚úÖ KEEP: Dropout prevents overfitting (R¬≤ drop: {:.3f})".format(dropout_drop))
    
    if 'NO_PRETRAINED' in successful_results:
        pretrain_drop = full_r2 - successful_results['NO_PRETRAINED']['r2']
        if pretrain_drop > 0.03:
            recommendations.append("‚úÖ KEEP: Pretrained weights are crucial (R¬≤ drop: {:.3f})".format(pretrain_drop))
        elif pretrain_drop < 0.01:
            recommendations.append("‚ö†Ô∏è OPTIONAL: Pretrained weights have minimal impact (R¬≤ drop: {:.3f})".format(pretrain_drop))
    
    print("\nüéØ Based on the ablation study:")
    for i, rec in enumerate(recommendations, 1):
        print(f"   {i}. {rec}")
    
    # Overall strategy
    print("\nüìã Suggested Model Strategy:")
    
    if len(recommendations) > 0:
        # Determine best configuration
        critical_components = [r for r in recommendations if '‚úÖ KEEP' in r and 'critical' in r.lower()]
        
        if len(critical_components) >= 2:
            print("   ‚Üí Use FULL MULTIMODAL approach")
            print("   ‚Üí Both modalities provide significant value")
        elif 'Images are critical' in str(recommendations):
            print("   ‚Üí Focus on IMAGE-BASED approach with light tabular features")
        elif 'Tabular features are critical' in str(recommendations):
            print("   ‚Üí Focus on TABULAR approach with light image features")
        else:
            print("   ‚Üí Use BALANCED approach with moderate complexity")
    
    # Complexity vs Performance tradeoff
    print("\n‚öñÔ∏è Complexity vs Performance Tradeoff:")
    if 'MINIMAL_MODEL' in successful_results and 'FULL_MODEL' in successful_results:
        minimal = successful_results['MINIMAL_MODEL']
        full = successful_results['FULL_MODEL']
        
        complexity_cost = "High"
        performance_gain = full['r2'] - minimal['r2']
        
        print(f"   Minimal Model: R¬≤={minimal['r2']:.4f} (Low complexity)")
        print(f"   Full Model: R¬≤={full['r2']:.4f} (High complexity)")
        print(f"   Performance Gain: {performance_gain:+.4f}")
        
        if performance_gain > 0.05:
            print("   ‚Üí ‚úÖ Additional complexity is JUSTIFIED")
        elif performance_gain > 0.02:
            print("   ‚Üí ‚ö†Ô∏è Additional complexity shows MODERATE benefit")
        else:
            print("   ‚Üí ‚ùå Additional complexity NOT justified - use simpler model")

# ============================================================================
# EXPORT DETAILED RESULTS
# ============================================================================

print("\n" + "="*80)
print("üíæ EXPORTING DETAILED RESULTS")
print("="*80)

# Create comprehensive report
report_data = []
for name, result in successful_results.items():
    config = ABLATION_CONFIGS[name]
    
    report_entry = {
        'Configuration': name,
        'Description': config['description'],
        'Use_Images': config['use_images'],
        'Use_Tabular': config['use_tabular'],
        'Augmentation': config['use_augmentation'],
        'Derived_Features': config['use_derived_features'],
        'Dropout': config['dropout_rate'],
        'Batch_Norm': config['use_batch_norm'],
        'Pretrained': config['use_pretrained'],
        'R¬≤': result['r2'],
        'MAE': result['mae'],
        'RMSE': result['rmse'],
        'Epochs_Trained': len(result['history']['train_loss'])
    }
    
    # Add relative performance
    if 'FULL_MODEL' in successful_results:
        full_r2 = successful_results['FULL_MODEL']['r2']
        report_entry['R¬≤_vs_Full'] = result['r2'] - full_r2
        report_entry['R¬≤_vs_Full_%'] = ((result['r2'] - full_r2) / full_r2) * 100 if full_r2 > 0 else 0
    
    report_data.append(report_entry)

detailed_report_df = pd.DataFrame(report_data)
detailed_report_df = detailed_report_df.sort_values('R¬≤', ascending=False)
detailed_report_df.to_csv('ablation_detailed_report.csv', index=False)

print("‚úÖ Detailed report saved to 'ablation_detailed_report.csv'")

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*80)
print("üéì ABLATION STUDY COMPLETE")
print("="*80)

print(f"\nüìä Summary Statistics:")
print(f"   Configurations tested: {len(successful_results)}/{len(ABLATION_CONFIGS)}")
if len(successful_results) > 0:
    print(f"   Best R¬≤: {max([r['r2'] for r in successful_results.values()]):.4f}")
    print(f"   Worst R¬≤: {min([r['r2'] for r in successful_results.values()]):.4f}")
    print(f"   R¬≤ Range: {max([r['r2'] for r in successful_results.values()]) - min([r['r2'] for r in successful_results.values()]):.4f}")

if 'FULL_MODEL' in successful_results:
    full_model = successful_results['FULL_MODEL']
    print(f"\nüèÜ Full Model Performance:")
    print(f"   R¬≤: {full_model['r2']:.4f}")
    print(f"   MAE: {full_model['mae']:.2f} kg/m¬≤")
    print(f"   RMSE: {full_model['rmse']:.2f} kg/m¬≤")

print(f"\nüìÅ Generated Files:")
print(f"   ‚Ä¢ ablation_study_summary.csv")
print(f"   ‚Ä¢ ablation_detailed_report.csv")
print(f"   ‚Ä¢ component_contributions.csv")
print(f"   ‚Ä¢ ablation_study_results.png")

print(f"\n‚è±Ô∏è  Total Time: {total_time/60:.1f} minutes")

if FAST_MODE:
    print(f"\n‚ö° FAST MODE was enabled:")
    print(f"   ‚Ä¢ Used {SAMPLE_SIZE} samples instead of full dataset")
    print(f"   ‚Ä¢ Trained for {NUM_EPOCHS} epochs instead of 15")
    print(f"   ‚Ä¢ Set FAST_MODE=False and re-run for full training")

print("\n" + "="*80)
print("‚ú® Thank you for using the Fast Ablation Study Framework!")
print("="*80)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
from sklearn.ensemble import RandomForestRegressor
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import warnings
import os
import time
from scipy import stats
warnings.filterwarnings('ignore')

# ============================================================================
# FAST ABLATION CONFIGURATION
# ============================================================================

FAST_MODE = True
SAMPLE_SIZE = 5000
NUM_EPOCHS = 30
PATIENCE = 3

print("="*80)
print("‚ö° ENHANCED ABLATION STUDY WITH COMPREHENSIVE METRICS")
print("="*80)
print(f"üìä Sample Size: {SAMPLE_SIZE}")
print(f"üîÑ Max Epochs: {NUM_EPOCHS}")
print(f"‚è±Ô∏è  Patience: {PATIENCE}")
print("="*80 + "\n")

# ============================================================================
# ABLATION STUDY CONFIGURATION
# ============================================================================

ABLATION_CONFIGS = {
    'FULL_MODEL': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'Full multimodal model with all features'
    },
    'NO_IMAGES': {
        'use_images': False,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'Tabular features only (no images)'
    },
    'NO_TABULAR': {
        'use_images': True,
        'use_tabular': False,
        'use_augmentation': True,
        'use_derived_features': False,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.0,
        'description': 'Images only (no tabular features)'
    },
    'NO_AUGMENTATION': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': False,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.0,
        'description': 'No data augmentation'
    },
    'NO_DERIVED_FEATURES': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': False,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 2,
        'feature_noise': 0.02,
        'description': 'No engineered/derived features'
    },
    'NO_DROPOUT': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': False,
        'dropout_rate': 0.0,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'No dropout regularization'
    },
    'HIGH_DROPOUT': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.5,
        'use_batch_norm': True,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'High dropout (0.5 instead of 0.3)'
    },
    'NO_BATCH_NORM': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': False,
        'use_pretrained': True,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'No batch normalization'
    },
    'NO_PRETRAINED': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': True,
        'use_derived_features': True,
        'use_dropout': True,
        'dropout_rate': 0.3,
        'use_batch_norm': True,
        'use_pretrained': False,
        'augment_multiplier': 3,
        'feature_noise': 0.02,
        'description': 'Random initialization (no pretrained weights)'
    },
    'MINIMAL_MODEL': {
        'use_images': True,
        'use_tabular': True,
        'use_augmentation': False,
        'use_derived_features': False,
        'use_dropout': False,
        'dropout_rate': 0.0,
        'use_batch_norm': False,
        'use_pretrained': False,
        'augment_multiplier': 3,
        'feature_noise': 0.0,
        'description': 'Minimal model (basic features, no regularization)'
    }
}

# ============================================================================
# SETUP
# ============================================================================

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# Paths
DATA_DIR = '/kaggle/input/bmi-dataset'
TRAIN_CSV = '/kaggle/input/bmi-dataset/train.csv'
TEST_CSV = '/kaggle/input/bmi-dataset/test.csv'
IMAGE_DIR = '/kaggle/input/bmi-dataset/ROI'

print("üî¨ ENHANCED ABLATION STUDY - BMI PREDICTION")
print("="*80)

# Load data
try:
    train_df = pd.read_csv(TRAIN_CSV)
    test_df = pd.read_csv(TEST_CSV)
    print(f"‚úÖ Data loaded: Train={train_df.shape}, Test={test_df.shape}")
except:
    print("‚ùå Error loading data. Using synthetic data for demonstration.")
    n_train, n_test = 1000, 200
    train_df = pd.DataFrame({
        'BMI': np.random.normal(25, 5, n_train),
        'age': np.random.randint(18, 80, n_train),
        'sex_encoded': np.random.randint(0, 2, n_train),
        'face_height': np.random.normal(100, 10, n_train),
        'face_width_cheeks': np.random.normal(80, 8, n_train),
        'face_width_jaw': np.random.normal(75, 8, n_train),
        'left_eye_width': np.random.normal(20, 2, n_train),
        'right_eye_width': np.random.normal(20, 2, n_train),
        'nose_length': np.random.normal(30, 3, n_train),
        'nose_width': np.random.normal(25, 2.5, n_train),
        'image_filename': [f'img_{i}.jpg' for i in range(n_train)]
    })
    test_df = train_df.copy().sample(n_test)
    IMAGE_DIR = None

# Sample data
if FAST_MODE and len(train_df) > SAMPLE_SIZE:
    print(f"\n‚ö° Fast Mode: Sampling {SAMPLE_SIZE} rows from {len(train_df)}")
    train_df = train_df.sample(SAMPLE_SIZE, random_state=42).reset_index(drop=True)
    print(f"‚úÖ Using {len(train_df)} samples for ablation study")

# Drop height/weight
for col in ['height', 'weight']:
    if col in train_df.columns:
        train_df = train_df.drop(columns=[col])
        test_df = test_df.drop(columns=[col])

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

def prepare_features(df, include_derived=True):
    df_eng = df.copy()
    basic_features = ['face_height', 'face_width_cheeks', 'face_width_jaw',
                      'left_eye_width', 'right_eye_width', 'nose_length', 
                      'nose_width', 'age', 'sex_encoded']
    available_features = [f for f in basic_features if f in df_eng.columns]
    
    if include_derived:
        if 'face_height' in df_eng.columns and 'face_width_cheeks' in df_eng.columns:
            df_eng['fwhr'] = (df_eng['face_width_cheeks'] / 
                             df_eng['face_height'].clip(lower=1e-8)).fillna(0)
            available_features.append('fwhr')
        
        if 'left_eye_width' in df_eng.columns and 'right_eye_width' in df_eng.columns:
            sum_eyes = df_eng['left_eye_width'] + df_eng['right_eye_width']
            df_eng['eye_symmetry'] = 1 - abs(df_eng['left_eye_width'] - 
                                             df_eng['right_eye_width']) / sum_eyes.clip(lower=1e-8)
            df_eng['eye_symmetry'] = df_eng['eye_symmetry'].fillna(0)
            available_features.append('eye_symmetry')
        
        if 'face_width_cheeks' in df_eng.columns and 'face_height' in df_eng.columns:
            df_eng['face_area'] = (df_eng['face_width_cheeks'] * 
                                  df_eng['face_height']).fillna(0)
            available_features.append('face_area')
    
    return df_eng, available_features

# ============================================================================
# DATASET CLASS
# ============================================================================

class AblationDataset(Dataset):
    def __init__(self, df, feature_cols, image_dir, config, is_training=True):
        self.df = df.reset_index(drop=True)
        self.feature_cols = feature_cols
        self.image_dir = image_dir
        self.config = config
        self.is_training = is_training
        
        if config['use_tabular']:
            self.features = df[feature_cols].values.astype(np.float32)
        self.targets = df['BMI'].values.astype(np.float32)
        
        if 'image_filename' in df.columns:
            self.filenames = df['image_filename'].tolist()
        else:
            self.filenames = [f'img_{i}.jpg' for i in range(len(df))]
        
        if config['use_images']:
            if is_training and config['use_augmentation']:
                self.transform = transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
    
    def __len__(self):
        if self.is_training and self.config['augment_multiplier'] > 1:
            return len(self.df) * self.config['augment_multiplier']
        return len(self.df)
    
    def __getitem__(self, idx):
        original_idx = idx % len(self.df)
        is_augmented = idx >= len(self.df)
        target = torch.tensor(self.targets[original_idx], dtype=torch.float32)
        
        if self.config['use_tabular']:
            features = torch.tensor(self.features[original_idx], dtype=torch.float32)
            if self.is_training and is_augmented and self.config['feature_noise'] > 0:
                noise = torch.randn_like(features) * self.config['feature_noise']
                features = features + noise
        else:
            features = torch.zeros(1)
        
        if self.config['use_images'] and self.image_dir is not None:
            filename = self.filenames[original_idx]
            img_path = os.path.join(self.image_dir, filename)
            try:
                with Image.open(img_path) as img:
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    image = self.transform(img)
            except:
                image = torch.zeros(3, 224, 224)
        else:
            image = torch.randn(3, 224, 224) * 0.1
        
        return {'features': features, 'image': image, 'target': target}

# ============================================================================
# MODEL ARCHITECTURE
# ============================================================================

class AblationModel(nn.Module):
    def __init__(self, num_features, config):
        super().__init__()
        self.config = config
        
        if config['use_images']:
            if config['use_pretrained']:
                resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            else:
                resnet = models.resnet18(weights=None)
            
            for param in resnet.parameters():
                param.requires_grad = False
            self.image_features = nn.Sequential(*list(resnet.children())[:-1])
            image_dim = 512
        else:
            image_dim = 0
        
        if config['use_tabular']:
            layers = []
            layers.append(nn.Linear(num_features, 128))
            if config['use_batch_norm']:
                layers.append(nn.BatchNorm1d(128))
            layers.append(nn.ReLU())
            if config['use_dropout']:
                layers.append(nn.Dropout(config['dropout_rate']))
            
            layers.append(nn.Linear(128, 64))
            if config['use_batch_norm']:
                layers.append(nn.BatchNorm1d(64))
            layers.append(nn.ReLU())
            
            self.tabular_features = nn.Sequential(*layers)
            tabular_dim = 64
        else:
            tabular_dim = 0
        
        combined_dim = image_dim + tabular_dim
        
        if combined_dim == 0:
            raise ValueError("Must use at least images or tabular features!")
        
        fusion_layers = []
        fusion_layers.append(nn.Linear(combined_dim, 128))
        if config['use_batch_norm']:
            fusion_layers.append(nn.BatchNorm1d(128))
        fusion_layers.append(nn.ReLU())
        if config['use_dropout']:
            fusion_layers.append(nn.Dropout(config['dropout_rate']))
        
        fusion_layers.append(nn.Linear(128, 64))
        if config['use_batch_norm']:
            fusion_layers.append(nn.BatchNorm1d(64))
        fusion_layers.append(nn.ReLU())
        
        self.fusion = nn.Sequential(*fusion_layers)
        self.output = nn.Linear(64, 1)
    
    def forward(self, features, image):
        batch_size = features.size(0)
        feat_list = []
        
        if self.config['use_images']:
            img_feat = self.image_features(image)
            img_feat = img_feat.view(batch_size, -1)
            feat_list.append(img_feat)
        
        if self.config['use_tabular']:
            tab_feat = self.tabular_features(features)
            feat_list.append(tab_feat)
        
        combined = torch.cat(feat_list, dim=1)
        fused = self.fusion(combined)
        output = self.output(fused)
        
        return output.squeeze()

# ============================================================================
# UTILITY FUNCTIONS FOR METRICS
# ============================================================================

def count_parameters(model):
    """Count total and trainable parameters"""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def get_model_size(model):
    """Get model size in MB"""
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_mb = (param_size + buffer_size) / 1024**2
    return size_mb

def measure_inference_time(model, loader, device, num_batches=10):
    """Measure average inference time per sample"""
    model.eval()
    times = []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= num_batches:
                break
            
            features = batch['features'].to(device)
            images = batch['image'].to(device)
            batch_size = features.size(0)
            
            start = time.time()
            _ = model(features, images)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.time()
            
            times.append((end - start) / batch_size)
    
    return np.mean(times) * 1000  # Convert to ms

def calculate_confidence_interval(values, confidence=0.95):
    """Calculate confidence interval"""
    n = len(values)
    mean = np.mean(values)
    std_err = stats.sem(values)
    margin = std_err * stats.t.ppf((1 + confidence) / 2, n - 1)
    return mean, mean - margin, mean + margin

# ============================================================================
# TRAINING FUNCTION WITH MULTIPLE RUNS
# ============================================================================

def train_model_multiple_runs(config, train_df, val_df, feature_cols, device, 
                               num_runs=3, num_epochs=NUM_EPOCHS, patience=PATIENCE):
    """Train model multiple times to get statistics"""
    
    run_results = []
    
    for run in range(num_runs):
        print(f"   Run {run+1}/{num_runs}...")
        set_seed(42 + run)  # Different seed for each run
        
        # Prepare features
        if config['use_derived_features']:
            train_eng, train_features = prepare_features(train_df, include_derived=True)
            val_eng, _ = prepare_features(val_df, include_derived=True)
        else:
            train_eng, train_features = prepare_features(train_df, include_derived=False)
            val_eng, _ = prepare_features(val_df, include_derived=False)
        
        if config['use_tabular']:
            use_features = [f for f in train_features if f in train_eng.columns]
        else:
            use_features = ['age']
        
        # Scale features
        if config['use_tabular']:
            scaler = RobustScaler()
            train_eng[use_features] = scaler.fit_transform(
                train_eng[use_features].fillna(0).values)
            val_eng[use_features] = scaler.transform(
                val_eng[use_features].fillna(0).values)
        else:
            scaler = None
        
        # Create datasets
        train_dataset = AblationDataset(train_eng, use_features, IMAGE_DIR, 
                                        config, is_training=True)
        val_dataset = AblationDataset(val_eng, use_features, IMAGE_DIR, 
                                      config, is_training=False)
        
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, 
                                 num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, 
                               num_workers=2)
        
        # Create model
        model = AblationModel(len(use_features) if config['use_tabular'] else 1, 
                             config).to(device)
        
        criterion = nn.MSELoss()
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=1e-3, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                         factor=0.5, patience=2)
        
        history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
        best_val_loss = float('inf')
        best_model_state = None
        patience_counter = 0
        
        training_start = time.time()
        
        for epoch in range(num_epochs):
            # Train
            model.train()
            train_loss = 0
            for batch in train_loader:
                features = batch['features'].to(device)
                images = batch['image'].to(device)
                targets = batch['target'].to(device)
                
                optimizer.zero_grad()
                outputs = model(features, images)
                loss = criterion(outputs, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                train_loss += loss.item()
            
            avg_train_loss = train_loss / len(train_loader)
            
            # Validate
            model.eval()
            val_loss = 0
            val_preds, val_targets = [], []
            
            with torch.no_grad():
                for batch in val_loader:
                    features = batch['features'].to(device)
                    images = batch['image'].to(device)
                    targets = batch['target'].to(device)
                    
                    outputs = model(features, images)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()
                    
                    val_preds.extend(outputs.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())
            
            avg_val_loss = val_loss / len(val_loader)
            val_r2 = r2_score(val_targets, val_preds)
            
            scheduler.step(avg_val_loss)
            
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(avg_val_loss)
            history['val_r2'].append(val_r2)
            
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_state = model.state_dict().copy()
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= patience:
                break
        
        training_time = time.time() - training_start
        
        # Load best model
        model.load_state_dict(best_model_state)
        
        # Final validation
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for batch in val_loader:
                features = batch['features'].to(device)
                images = batch['image'].to(device)
                targets = batch['target'].to(device)
                outputs = model(features, images)
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())
        
        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)
        
        final_r2 = r2_score(val_targets, val_preds)
        final_mae = mean_absolute_error(val_targets, val_preds)
        final_rmse = np.sqrt(mean_squared_error(val_targets, val_preds))
        final_mape = np.mean(np.abs((val_targets - val_preds) / val_targets)) * 100
        
        # Calculate additional metrics
        residuals = val_targets - val_preds
        mean_residual = np.mean(residuals)
        std_residual = np.std(residuals)
        
        # Within threshold accuracy
        within_1 = np.mean(np.abs(residuals) < 1) * 100
        within_2 = np.mean(np.abs(residuals) < 2) * 100
        within_3 = np.mean(np.abs(residuals) < 3) * 100
        
        # Model statistics
        total_params, trainable_params = count_parameters(model)
        model_size = get_model_size(model)
        inference_time = measure_inference_time(model, val_loader, device)
        
        run_results.append({
            'r2': final_r2,
            'mae': final_mae,
            'rmse': final_rmse,
            'mape': final_mape,
            'mean_residual': mean_residual,
            'std_residual': std_residual,
            'within_1': within_1,
            'within_2': within_2,
            'within_3': within_3,
            'training_time': training_time,
            'inference_time': inference_time,
            'total_params': total_params,
            'trainable_params': trainable_params,
            'model_size': model_size,
            'epochs_trained': len(history['train_loss']),
            'history': history,
            'model': model,
            'scaler': scaler,
            'use_features': use_features
        })
    
    # Aggregate results
    aggregated = {
        'r2_mean': np.mean([r['r2'] for r in run_results]),
        'r2_std': np.std([r['r2'] for r in run_results]),
        'r2_ci': calculate_confidence_interval([r['r2'] for r in run_results]),
        
        'mae_mean': np.mean([r['mae'] for r in run_results]),
        'mae_std': np.std([r['mae'] for r in run_results]),
        'mae_ci': calculate_confidence_interval([r['mae'] for r in run_results]),
        
        'rmse_mean': np.mean([r['rmse'] for r in run_results]),
        'rmse_std': np.std([r['rmse'] for r in run_results]),
        'rmse_ci': calculate_confidence_interval([r['rmse'] for r in run_results]),
        
        'mape_mean': np.mean([r['mape'] for r in run_results]),
        'mape_std': np.std([r['mape'] for r in run_results]),
        
        'within_1_mean': np.mean([r['within_1'] for r in run_results]),
        'within_2_mean': np.mean([r['within_2'] for r in run_results]),
        'within_3_mean': np.mean([r['within_3'] for r in run_results]),
        
        'training_time_mean': np.mean([r['training_time'] for r in run_results]),
        'training_time_std': np.std([r['training_time'] for r in run_results]),
        
        'inference_time_mean': np.mean([r['inference_time'] for r in run_results]),
        'inference_time_std': np.std([r['inference_time'] for r in run_results]),
        
        'total_params': run_results[0]['total_params'],
        'trainable_params': run_results[0]['trainable_params'],
        'model_size': run_results[0]['model_size'],
        
        'epochs_trained_mean': np.mean([r['epochs_trained'] for r in run_results]),
        
        'best_run': max(run_results, key=lambda x: x['r2']),
        'all_runs': run_results
    }
    
    return aggregated

# ============================================================================
# RUN ABLATION STUDY
# ============================================================================

print("\n" + "="*80)
print("üöÄ RUNNING ENHANCED ABLATION STUDY")
print("="*80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

# Use simple train/val split
train_idx, val_idx = train_test_split(range(len(train_df)), test_size=0.2, 
                                       random_state=42)
train_split = train_df.iloc[train_idx].copy()
val_split = train_df.iloc[val_idx].copy()

print(f"Train samples: {len(train_split)}, Val samples: {len(val_split)}\n")

ablation_results = {}
total_start = time.time()

NUM_RUNS_PER_CONFIG = 3  # Number of runs for statistical significance

for i, (config_name, config) in enumerate(ABLATION_CONFIGS.items(), 1):
    print(f"\n{'='*60}")
    print(f"üß™ [{i}/{len(ABLATION_CONFIGS)}] {config_name}")
    print(f"üìù {config['description']}")
    print(f"{'='*60}")
    
    config_start = time.time()
    
    try:
        result = train_model_multiple_runs(
            config, train_split, val_split, 
            feature_cols=None,
            device=device,
            num_runs=NUM_RUNS_PER_CONFIG,
            num_epochs=NUM_EPOCHS,
            patience=PATIENCE
        )
        
        config_time = time.time() - config_start
        ablation_results[config_name] = result
        
        print(f"‚úÖ Results (trained in {config_time:.1f}s):")
        print(f"   R¬≤: {result['r2_mean']:.4f} ¬± {result['r2_std']:.4f}")
        print(f"   MAE: {result['mae_mean']:.2f} ¬± {result['mae_std']:.2f}")
        print(f"   RMSE: {result['rmse_mean']:.2f} ¬± {result['rmse_std']:.2f}")
        print(f"   Inference Time: {result['inference_time_mean']:.2f}ms ¬± {result['inference_time_std']:.2f}ms")
        print(f"   Total Params: {result['total_params']:,}")
        print(f"   Model Size: {result['model_size']:.2f} MB")
        
    except Exception as e:
        print(f"‚ùå Failed: {str(e)}")
        import traceback
        traceback.print_exc()
        ablation_results[config_name] = None

total_time = time.time() - total_start
print(f"\n‚è±Ô∏è  Total training time: {total_time/60:.1f} minutes")

# ============================================================================
# CREATE COMPREHENSIVE ABLATION TABLE
# ============================================================================

print("\n" + "="*80)
print("üìä CREATING COMPREHENSIVE ABLATION TABLE")
print("="*80)

successful_results = {k: v for k, v in ablation_results.items() if v is not None}

if len(successful_results) > 0:
    # Prepare data for table
    table_data = []
    
    for config_name, result in successful_results.items():
        config = ABLATION_CONFIGS[config_name]
        
        # Create configuration summary
        config_summary = []
        if config['use_images']:
            config_summary.append("IMG")
        if config['use_tabular']:
            config_summary.append("TAB")
        if config['use_augmentation']:
            config_summary.append("AUG")
        if config['use_derived_features']:
            config_summary.append("DER")
        if config['use_dropout']:
            config_summary.append(f"DROP({config['dropout_rate']})")
        if config['use_batch_norm']:
            config_summary.append("BN")
        if config['use_pretrained']:
            config_summary.append("PRE")
        
        config_str = "+".join(config_summary)
        
        r2_mean, r2_lower, r2_upper = result['r2_ci']
        mae_mean, mae_lower, mae_upper = result['mae_ci']
        rmse_mean, rmse_lower, rmse_upper = result['rmse_ci']
        
        table_data.append({
            'Configuration': config_name,
            'Components': config_str,
            'Total_Params': result['total_params'],
            'Trainable_Params': result['trainable_params'],
            'Model_Size_MB': result['model_size'],
            'Inference_Time_ms': result['inference_time_mean'],
            'Inference_Time_Std': result['inference_time_std'],
            'Training_Time_s': result['training_time_mean'],
            'Training_Time_Std': result['training_time_std'],
            'Epochs': result['epochs_trained_mean'],
            'R2_Mean': result['r2_mean'],
            'R2_Std': result['r2_std'],
            'R2_CI_Lower': r2_lower,
            'R2_CI_Upper': r2_upper,
            'MAE_Mean': result['mae_mean'],
            'MAE_Std': result['mae_std'],
            'MAE_CI_Lower': mae_lower,
            'MAE_CI_Upper': mae_upper,
            'RMSE_Mean': result['rmse_mean'],
            'RMSE_Std': result['rmse_std'],
            'RMSE_CI_Lower': rmse_lower,
            'RMSE_CI_Upper': rmse_upper,
            'MAPE_Mean': result['mape_mean'],
            'MAPE_Std': result['mape_std'],
            'Within_1_BMI_%': result['within_1_mean'],
            'Within_2_BMI_%': result['within_2_mean'],
            'Within_3_BMI_%': result['within_3_mean'],
        })
    
    # Create DataFrame
    comprehensive_table = pd.DataFrame(table_data)
    comprehensive_table = comprehensive_table.sort_values('R2_Mean', ascending=False)
    
    # Save full table
    comprehensive_table.to_csv('ablation_comprehensive_table.csv', index=False)
    print("‚úÖ Comprehensive table saved to 'ablation_comprehensive_table.csv'")
    
    # Create formatted display table (similar to the image)
    display_table = pd.DataFrame({
        'Setting': [row['Configuration'] for _, row in comprehensive_table.iterrows()],
        'Components': [row['Components'] for _, row in comprehensive_table.iterrows()],
        'Params (M)': [f"{row['Total_Params']/1e6:.2f}" for _, row in comprehensive_table.iterrows()],
        'Size (MB)': [f"{row['Model_Size_MB']:.1f}" for _, row in comprehensive_table.iterrows()],
        'Inf. Time (ms)': [f"{row['Inference_Time_ms']:.2f}¬±{row['Inference_Time_Std']:.2f}" 
                          for _, row in comprehensive_table.iterrows()],
        'R¬≤ Score': [f"{row['R2_Mean']:.4f}¬±{row['R2_Std']:.4f}" 
                    for _, row in comprehensive_table.iterrows()],
        'R¬≤ CI (95%)': [f"[{row['R2_CI_Lower']:.4f}, {row['R2_CI_Upper']:.4f}]" 
                       for _, row in comprehensive_table.iterrows()],
        'MAE': [f"{row['MAE_Mean']:.2f}¬±{row['MAE_Std']:.2f}" 
               for _, row in comprehensive_table.iterrows()],
        'RMSE': [f"{row['RMSE_Mean']:.2f}¬±{row['RMSE_Std']:.2f}" 
                for _, row in comprehensive_table.iterrows()],
        'MAPE (%)': [f"{row['MAPE_Mean']:.2f}¬±{row['MAPE_Std']:.2f}" 
                    for _, row in comprehensive_table.iterrows()],
        'Acc@1': [f"{row['Within_1_BMI_%']:.1f}" for _, row in comprehensive_table.iterrows()],
        'Acc@2': [f"{row['Within_2_BMI_%']:.1f}" for _, row in comprehensive_table.iterrows()],
        'Acc@3': [f"{row['Within_3_BMI_%']:.1f}" for _, row in comprehensive_table.iterrows()],
    })
    
    display_table.to_csv('ablation_display_table.csv', index=False)
    print("‚úÖ Display table saved to 'ablation_display_table.csv'")
    
    # Print formatted table
    print("\n" + "="*80)
    print("üìã ABLATION STUDY TABLE")
    print("="*80)
    print(display_table.to_string(index=False))
    
    # Create visualization
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # 1. R¬≤ Comparison with error bars
    ax1 = fig.add_subplot(gs[0, :2])
    configs = comprehensive_table['Configuration'].values
    r2_means = comprehensive_table['R2_Mean'].values
    r2_stds = comprehensive_table['R2_Std'].values
    colors = ['red' if c == 'FULL_MODEL' else 'skyblue' for c in configs]
    
    bars = ax1.barh(range(len(configs)), r2_means, xerr=r2_stds, 
                     color=colors, edgecolor='black', capsize=5)
    ax1.set_yticks(range(len(configs)))
    ax1.set_yticklabels(configs)
    ax1.set_xlabel('R¬≤ Score', fontsize=12, fontweight='bold')
    ax1.set_title('R¬≤ Score by Configuration (with std dev)', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='x')
    
    for i, (mean, std) in enumerate(zip(r2_means, r2_stds)):
        ax1.text(mean + std + 0.01, i, f'{mean:.4f}¬±{std:.4f}', 
                va='center', fontsize=9)
    
    # 2. Inference Time Comparison
    ax2 = fig.add_subplot(gs[0, 2])
    inf_times = comprehensive_table['Inference_Time_ms'].values
    inf_stds = comprehensive_table['Inference_Time_Std'].values
    
    bars = ax2.barh(range(len(configs)), inf_times, xerr=inf_stds,
                     color='coral', edgecolor='black', capsize=5)
    ax2.set_yticks(range(len(configs)))
    ax2.set_yticklabels(configs)
    ax2.set_xlabel('Inference Time (ms)', fontsize=10, fontweight='bold')
    ax2.set_title('Inference Speed', fontsize=12, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='x')
    
    # 3. Model Size Comparison
    ax3 = fig.add_subplot(gs[1, 0])
    sizes = comprehensive_table['Model_Size_MB'].values
    bars = ax3.bar(range(len(configs)), sizes, color='lightgreen', edgecolor='black')
    ax3.set_xticks(range(len(configs)))
    ax3.set_xticklabels(configs, rotation=45, ha='right')
    ax3.set_ylabel('Size (MB)', fontsize=10, fontweight='bold')
    ax3.set_title('Model Size', fontsize=12, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')
    
    for i, size in enumerate(sizes):
        ax3.text(i, size + 0.5, f'{size:.1f}', ha='center', fontsize=8)
    
    # 4. Parameter Count
    ax4 = fig.add_subplot(gs[1, 1])
    params = comprehensive_table['Total_Params'].values / 1e6  # In millions
    bars = ax4.bar(range(len(configs)), params, color='plum', edgecolor='black')
    ax4.set_xticks(range(len(configs)))
    ax4.set_xticklabels(configs, rotation=45, ha='right')
    ax4.set_ylabel('Parameters (M)', fontsize=10, fontweight='bold')
    ax4.set_title('Total Parameters', fontsize=12, fontweight='bold')
    ax4.grid(True, alpha=0.3, axis='y')
    
    for i, param in enumerate(params):
        ax4.text(i, param + 0.2, f'{param:.2f}M', ha='center', fontsize=8)
    
    # 5. MAE with confidence intervals
    ax5 = fig.add_subplot(gs[1, 2])
    mae_means = comprehensive_table['MAE_Mean'].values
    mae_lower = comprehensive_table['MAE_CI_Lower'].values
    mae_upper = comprehensive_table['MAE_CI_Upper'].values
    mae_err = [mae_means - mae_lower, mae_upper - mae_means]
    
    ax5.barh(range(len(configs)), mae_means, xerr=mae_err,
             color='wheat', edgecolor='black', capsize=5)
    ax5.set_yticks(range(len(configs)))
    ax5.set_yticklabels(configs)
    ax5.set_xlabel('MAE (kg/m¬≤)', fontsize=10, fontweight='bold')
    ax5.set_title('MAE with 95% CI', fontsize=12, fontweight='bold')
    ax5.grid(True, alpha=0.3, axis='x')
    
    # 6. Within threshold accuracy
    ax6 = fig.add_subplot(gs[2, :])
    within_1 = comprehensive_table['Within_1_BMI_%'].values
    within_2 = comprehensive_table['Within_2_BMI_%'].values
    within_3 = comprehensive_table['Within_3_BMI_%'].values
    
    x = np.arange(len(configs))
    width = 0.25
    
    ax6.bar(x - width, within_1, width, label='Within ¬±1 BMI', color='lightcoral')
    ax6.bar(x, within_2, width, label='Within ¬±2 BMI', color='lightskyblue')
    ax6.bar(x + width, within_3, width, label='Within ¬±3 BMI', color='lightgreen')
    
    ax6.set_xticks(x)
    ax6.set_xticklabels(configs, rotation=45, ha='right')
    ax6.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    ax6.set_title('Prediction Accuracy within Thresholds', fontsize=14, fontweight='bold')
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('ablation_comprehensive_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("‚úÖ Comprehensive visualization saved")

# ============================================================================
# STATISTICAL ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("üìä STATISTICAL ANALYSIS")
print("="*80)

if len(successful_results) > 1:
    print("\nüîç Performance Comparison (vs FULL_MODEL):")
    print("-" * 80)
    
    if 'FULL_MODEL' in successful_results:
        full_model = successful_results['FULL_MODEL']
        
        for config_name, result in successful_results.items():
            if config_name != 'FULL_MODEL':
                r2_diff = full_model['r2_mean'] - result['r2_mean']
                mae_diff = result['mae_mean'] - full_model['mae_mean']
                speed_ratio = result['inference_time_mean'] / full_model['inference_time_mean']
                param_ratio = result['total_params'] / full_model['total_params']
                
                print(f"\n{config_name}:")
                print(f"  R¬≤ Œî: {r2_diff:+.4f} ({r2_diff/full_model['r2_mean']*100:+.1f}%)")
                print(f"  MAE Œî: {mae_diff:+.2f} kg/m¬≤")
                print(f"  Speed: {speed_ratio:.2f}x {'faster' if speed_ratio < 1 else 'slower'}")
                print(f"  Params: {param_ratio:.2f}x {'fewer' if param_ratio < 1 else 'more'}")

print("\n‚úÖ Enhanced ablation study complete!")
print(f"   Generated files:")
print(f"   ‚Ä¢ ablation_comprehensive_table.csv")
print(f"   ‚Ä¢ ablation_display_table.csv")
print(f"   ‚Ä¢ ablation_comprehensive_results.png")