# Oct 24

üîç What's Missing From Our Original Plan (OSIC-Only)
Looking back at our discussion, here's what you haven't done yet that we agreed was feasible:
* .Proper data splitting - You only did simple train/val split. We discussed 5-fold cross-validation OR holding out a proper test set (20%) that you evaluate ONCE at the end. Right now you have no true held-out evaluation.\n
* .Clinical baseline comparison - You need to train a clinical-only model (Age, Sex, Smoking, baseline FVC) and show that adding CT scans improves predictions. This proves CT adds value beyond simple demographics.\n
* .Comprehensive evaluation metrics - You only reported R¬≤, MAE, LLL. Missing: classification metrics (accuracy, sensitivity, specificity for fast/moderate/slow progressors), confusion matrix, calibration plots, and prediction scatter plots.\n
* .Interpretability/Visualization - No attention maps showing which lung regions drive predictions, no feature importance analysis, no training curves saved as figures.\n
* .Error analysis - Which patients did the model fail on? Are there patterns in failures? This builds trust and shows limitations honestly.\n
Do these 5 things with your current code and OSIC dataset, and you'll have a complete, publishable story showing that baseline CT scans can predict fibrosis progression! üéØ



In [7]:
import os
import cv2
import pydicom
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
import random
from tqdm import tqdm 
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import r2_score, mean_squared_error, confusion_matrix, classification_report
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from pathlib import Path
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import warnings
from PIL import Image
import torchvision.transforms as transforms

warnings.filterwarnings('ignore')

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed_everything(42)

# Configuration
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("üöÄ PUBLICATION-READY OSIC Model - Comprehensive Analysis")
print("=" * 70)
print(f"üì± Device: {DEVICE}")

# Load Data
train_df = pd.read_csv('../input/osic-pulmonary-fibrosis-progression/train.csv')
print(f"Loaded dataset with shape: {train_df.shape}")

def get_optimized_tab_features(df_row):
    """Optimized tabular features - simpler but more effective"""
    vector = []
    
    # Basic but effective features
    age = df_row['Age']
    vector.extend([
        (age - 50) / 30,  # Centered age
        age / 100,  # Scaled age
    ])
    
    # Simple sex encoding
    if df_row['Sex'] == 'Male':
        vector.append(1.0)
    else:
        vector.append(0.0)
    
    # Simple smoking status
    smoking_status = df_row['SmokingStatus']
    if smoking_status == 'Never smoked':
        vector.extend([1, 0, 0])
    elif smoking_status == 'Ex-smoker':
        vector.extend([0, 1, 0])
    elif smoking_status == 'Currently smokes':
        vector.extend([0, 0, 1])
    else:
        vector.extend([0, 0, 0])
    
    # FVC features
    if 'FVC' in df_row:
        fvc = df_row['FVC']
        vector.extend([
            fvc / 3000,  # Normalized FVC
            (fvc - 2500) / 1000,  # Centered FVC
        ])
    
    # Percent predicted (approximate)
    if 'FVC' in df_row and 'Age' in df_row:
        fvc = df_row['FVC']
        age = df_row['Age']
        sex = df_row['Sex']
        
        # Approximate percent predicted FVC
        if sex == 'Male':
            pp_fvc = fvc / (27.63 - 0.112 * age) if age > 0 else 0.8
        else:
            pp_fvc = fvc / (21.78 - 0.101 * age) if age > 0 else 0.8
            
        vector.append(min(pp_fvc, 2.0))  # Cap at 200%
    
    return np.array(vector)

def calculate_lll(actual, predicted, sigma):
    """Calculate Log Laplace Likelihood"""
    sigma = np.maximum(sigma, 1e-6)  # Avoid division by zero
    delta = np.abs(actual - predicted)
    return -np.sqrt(2) * delta / sigma - np.log(sigma * np.sqrt(2))

# Improved coefficient calculation
A = {} 
TAB = {} 
P = []

print("Calculating optimized linear decay coefficients...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].copy().sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    
    if len(weeks) >= 2:
        try:
            # Simple robust slope calculation
            if len(weeks) == 2:
                slope = (fvc[1] - fvc[0]) / (weeks[1] - weeks[0])
            else:
                # Use Theil-Sen estimator for robustness
                slopes = []
                for i in range(len(weeks)):
                    for j in range(i+1, len(weeks)):
                        if weeks[j] != weeks[i]:
                            slope = (fvc[j] - fvc[i]) / (weeks[j] - weeks[i])
                            slopes.append(slope)
                slope = np.median(slopes) if slopes else 0.0
            
            A[patient] = slope
        except:
            A[patient] = 0.0
    else:
        A[patient] = 0.0
    
    TAB[patient] = get_optimized_tab_features(sub.iloc[0])
    P.append(patient)

print(f"Processed {len(P)} patients with optimized features")

# Analyze target distribution
decay_values = np.array(list(A.values()))
print(f"Target statistics: mean={decay_values.mean():.4f}, std={decay_values.std():.4f}")
print(f"Target range: [{decay_values.min():.4f}, {decay_values.max():.4f}]")

# ============================================================================
# 1. PROPER DATA SPLITTING - 5-Fold Cross Validation
# ============================================================================

def create_progressor_categories(decay_values, n_bins=3):
    """Categorize patients into progressor types based on slope"""
    bin_edges = np.percentile(decay_values, [33, 66])
    categories = np.digitize(decay_values, bin_edges)
    return categories

print("\nüìä Setting up 5-Fold Cross Validation...")
progressor_categories = create_progressor_categories(decay_values)
patients_array = np.array(P)

# Store fold results
fold_results = []

# ============================================================================
# 2. CLINICAL BASELINE MODEL
# ============================================================================

class ClinicalBaselineModel:
    """Clinical-only baseline model using traditional ML"""
    def __init__(self):
        self.models = {
            'random_forest': RandomForestRegressor(n_estimators=100, random_state=42),
            'linear': LinearRegression()
        }
        self.best_model = None
        
    def train(self, X_train, y_train):
        # Train both models
        for name, model in self.models.items():
            model.fit(X_train, y_train)
        
        # Select best based on training performance (simple approach)
        self.best_model = self.models['random_forest']
        return self.best_model
    
    def predict(self, X):
        return self.best_model.predict(X)
    
    def get_feature_importance(self, feature_names):
        if hasattr(self.best_model, 'feature_importances_'):
            return dict(zip(feature_names, self.best_model.feature_importances_))
        return {}

# ============================================================================
# ENHANCED DEEP LEARNING MODEL (Your existing model with improvements)
# ============================================================================

class OptimizedAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=10, p=0.5),
                albu.HorizontalFlip(p=0.4),
                albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, p=0.6),
                albu.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.4),
                albu.GaussNoise(var_limit=(5.0, 20.0), p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
    
    def __call__(self, image):
        return self.transform(image=image)['image']

class OptimizedDenseNetModel(nn.Module):
    def __init__(self, tabular_dim=10, dropout_rate=0.2):
        super(OptimizedDenseNetModel, self).__init__()
        
        # DenseNet121 backbone
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        
        # Freeze early layers, unfreeze later layers
        for i, param in enumerate(self.features.parameters()):
            param.requires_grad = i > 100  # Only unfreeze later layers
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Simple but effective tabular processor
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        
        # Feature fusion
        self.fusion_layer = nn.Sequential(
            nn.Linear(1024 + 256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        
        # Output heads
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.log_var_head = nn.Sequential(
            nn.Linear(256, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Tanh()  # Constrain output
        )
        
        # Grad-CAM attributes
        self.gradients = None
        self.activations = None
        
        # Initialize output layers for better convergence
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in [self.mean_head, self.log_var_head]:
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def activations_hook(self, grad):
        self.gradients = grad
    
    def forward(self, images, tabular):
        batch_size = images.size(0)
        
        # Extract image features
        img_features = self.features(images)
        
        # Register hook for Grad-CAM
        if img_features.requires_grad:
            h = img_features.register_hook(self.activations_hook)
        self.activations = img_features.detach()
        
        img_features = self.global_pool(img_features).view(batch_size, -1)
        
        # Process tabular data
        tab_features = self.tabular_processor(tabular)
        
        # Feature fusion
        combined_features = torch.cat([img_features, tab_features], dim=1)
        fused_features = self.fusion_layer(combined_features)
        
        # Predict mean and log variance
        mean_pred = self.mean_head(fused_features)
        log_var = self.log_var_head(fused_features)
        
        return mean_pred.squeeze(), log_var.squeeze()
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.features(x)

class OptimizedOSICDataset(Dataset):
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train'):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = OptimizedAugmentation(augment=(split=='train'))
        
        # Prepare image paths
        self.patient_images = {}
        for patient in self.patients:
            patient_dir = self.data_dir / patient
            if patient_dir.exists():
                image_files = [f for f in patient_dir.iterdir() if f.suffix.lower() == '.dcm']
                if image_files:
                    self.patient_images[patient] = image_files
        
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")
    
    def __len__(self):
        if self.split == 'train':
            return len(self.valid_patients) * 8
        else:
            return len(self.valid_patients)
    
    def __getitem__(self, idx):
        if self.split == 'train':
            patient_idx = idx % len(self.valid_patients)
        else:
            patient_idx = idx
            
        patient = self.valid_patients[patient_idx]
        
        # Get random image
        available_images = self.patient_images[patient]
        selected_image = random.choice(available_images) if available_images else available_images[0]
        
        # Load and preprocess image
        img = self.load_dicom(selected_image)
        img_tensor = self.augmentor(img)
        
        # Get tabular features
        tab_features = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        
        # Get target (clipped to reasonable range)
        target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        
        return img_tensor, tab_features, target, patient, str(selected_image)
    
    def load_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            
            if len(img.shape) == 3:
                img = img[img.shape[0]//2]
            
            img = cv2.resize(img, (384, 384))
            
            # Normalize
            img_min, img_max = img.min(), img.max()
            if img_max > img_min:
                img = (img - img_min) / (img_max - img_min) * 255
            else:
                img = np.zeros_like(img)
            
            # Apply CLAHE
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            img = clahe.apply(img.astype(np.uint8))
            
            # Convert to 3-channel
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            
            return img
            
        except Exception as e:
            print(f"Error loading {path}: {e}")
            return np.zeros((384, 384, 3), dtype=np.uint8)

class OptimizedTrainer:
    def __init__(self, model, device, lr=1e-4, model_type='deep_learning'):
        self.model = model
        self.device = device
        self.lr = lr
        self.model_type = model_type
        self.best_val_r2 = -float('inf')
        self.best_val_mae = float('inf')
        self.best_val_lll = -float('inf')
        self.train_losses = []
        self.val_losses = []
        self.train_r2_scores = []
        self.val_r2_scores = []
        
    def uncertainty_loss(self, mean_pred, log_var, targets):
        var = torch.exp(log_var)
        mse_loss = (mean_pred - targets) ** 2
        return 0.5 * (mse_loss / var + log_var).mean()
    
    def train(self, train_loader, val_loader, epochs=50):
        if self.model_type == 'deep_learning':
            return self._train_deep_learning(train_loader, val_loader, epochs)
        else:
            return self._train_clinical_baseline(train_loader, val_loader)
    
    def _train_deep_learning(self, train_loader, val_loader, epochs=50):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, verbose=True
        )
        
        patience_counter = 0
        
        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0.0
            train_batches = 0
            train_predictions, train_targets = [], []
            
            for images, tabular, targets, _, _ in train_loader:
                images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                
                optimizer.zero_grad()
                mean_pred, log_var = self.model(images, tabular)
                
                # Combined loss
                mse_loss = F.mse_loss(mean_pred, targets)
                uncertainty_loss = self.uncertainty_loss(mean_pred, log_var, targets)
                
                # Start with more MSE focus, transition to uncertainty
                if epoch < 20:
                    loss = 0.7 * mse_loss + 0.3 * uncertainty_loss
                else:
                    loss = 0.3 * mse_loss + 0.7 * uncertainty_loss
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                
                train_loss += loss.item()
                train_batches += 1
                
                # Store predictions for R¬≤ calculation
                train_predictions.extend(mean_pred.detach().cpu().numpy())
                train_targets.extend(targets.cpu().numpy())
            
            # Calculate training R¬≤
            train_r2 = r2_score(train_targets, train_predictions) if len(train_targets) > 0 else 0
            self.train_r2_scores.append(train_r2)
            self.train_losses.append(train_loss / train_batches)
            
            # Validation
            self.model.eval()
            val_predictions, val_targets, val_log_vars = [], [], []
            
            with torch.no_grad():
                for images, tabular, targets, _, _ in val_loader:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    mean_pred, log_var = self.model(images, tabular)
                    
                    # Convert to numpy properly
                    mean_pred_np = mean_pred.cpu().numpy()
                    log_var_np = log_var.cpu().numpy()
                    targets_np = targets.cpu().numpy()
                    
                    if mean_pred_np.ndim == 0:
                        val_predictions.append(mean_pred_np.item())
                        val_log_vars.append(log_var_np.item())
                        val_targets.append(targets_np.item())
                    else:
                        val_predictions.extend(mean_pred_np.tolist())
                        val_log_vars.extend(log_var_np.tolist())
                        val_targets.extend(targets_np.tolist())
            
            if len(val_predictions) > 0:
                val_pred_np = np.array(val_predictions)
                val_target_np = np.array(val_targets)
                val_log_var_np = np.array(val_log_vars)
                val_sigma_np = np.exp(val_log_var_np / 2)
                
                # Calculate metrics
                r2 = r2_score(val_target_np, val_pred_np)
                mae = np.mean(np.abs(val_pred_np - val_target_np))
                rmse = np.sqrt(mean_squared_error(val_target_np, val_pred_np))
                lll_values = calculate_lll(val_target_np, val_pred_np, val_sigma_np)
                avg_lll = np.mean(lll_values)
                
                avg_train_loss = train_loss / train_batches if train_batches > 0 else 0
                current_lr = optimizer.param_groups[0]['lr']
                
                self.val_r2_scores.append(r2)
                self.val_losses.append(avg_train_loss)  # Simplified for plotting
                
                print(f"Epoch {epoch+1}: LR={current_lr:.2e}, Loss={avg_train_loss:.4f}")
                print(f"          Train R¬≤={train_r2:.4f}, Val R¬≤={r2:.4f}")
                print(f"          Val MAE={mae:.4f}, RMSE={rmse:.4f}, LLL={avg_lll:.4f}")
                
                # Update scheduler
                scheduler.step(r2)
                
                # Save best model
                if r2 > self.best_val_r2:
                    self.best_val_r2 = r2
                    self.best_val_mae = mae
                    self.best_val_lll = avg_lll
                    torch.save(self.model.state_dict(), f'best_model_fold_{len(fold_results)}.pth')
                    print(f"üéØ NEW BEST! R¬≤: {r2:.4f}")
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                if patience_counter >= 10:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
                
                print("-" * 60)
        
        return self.best_val_r2, self.best_val_mae, self.best_val_lll
    
    def _train_clinical_baseline(self, train_loader, val_loader):
        # Extract features and targets
        X_train, y_train = [], []
        for _, tabular, targets, _, _ in train_loader:
            X_train.extend(tabular.numpy())
            y_train.extend(targets.numpy())
        
        X_val, y_val = [], []
        for _, tabular, targets, _, _ in val_loader:
            X_val.extend(tabular.numpy())
            y_val.extend(targets.numpy())
        
        # Train clinical model
        self.model.train(np.array(X_train), np.array(y_train))
        
        # Predict
        y_pred = self.model.predict(np.array(X_val))
        
        # Calculate metrics
        r2 = r2_score(y_val, y_pred)
        mae = np.mean(np.abs(y_pred - y_val))
        rmse = np.sqrt(mean_squared_error(y_val, y_pred))
        
        # For clinical model, use fixed sigma for LLL calculation
        sigma = np.std(y_pred - y_val)
        lll_values = calculate_lll(y_val, y_pred, np.full_like(y_val, sigma))
        avg_lll = np.mean(lll_values)
        
        print(f"Clinical Baseline - R¬≤: {r2:.4f}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, LLL: {avg_lll:.4f}")
        
        return r2, mae, avg_lll

# ============================================================================
# 3. COMPREHENSIVE EVALUATION METRICS
# ============================================================================

class ComprehensiveEvaluator:
    def __init__(self):
        self.metrics = {}
        
    def calculate_classification_metrics(self, true_slopes, pred_slopes):
        """Calculate progressor classification metrics"""
        # Define progressor categories based on tertiles
        slow_threshold = np.percentile(true_slopes, 33)
        fast_threshold = np.percentile(true_slopes, 66)
        
        true_categories = np.zeros_like(true_slopes)
        true_categories[true_slopes <= slow_threshold] = 0  # Slow
        true_categories[(true_slopes > slow_threshold) & (true_slopes <= fast_threshold)] = 1  # Moderate
        true_categories[true_slopes > fast_threshold] = 2  # Fast
        
        pred_categories = np.zeros_like(pred_slopes)
        pred_categories[pred_slopes <= slow_threshold] = 0
        pred_categories[(pred_categories > slow_threshold) & (pred_slopes <= fast_threshold)] = 1
        pred_categories[pred_slopes > fast_threshold] = 2
        
        # Calculate metrics
        accuracy = np.mean(true_categories == pred_categories)
        cm = confusion_matrix(true_categories, pred_categories)
        report = classification_report(true_categories, pred_categories, 
                                     target_names=['Slow', 'Moderate', 'Fast'], output_dict=True)
        
        return {
            'accuracy': accuracy,
            'confusion_matrix': cm,
            'classification_report': report,
            'true_categories': true_categories,
            'pred_categories': pred_categories
        }
    
    def create_calibration_plot(self, true_slopes, pred_slopes, save_path='calibration_plot.png'):
        """Create calibration plot for regression"""
        plt.figure(figsize=(10, 6))
        
        # Bin predictions
        bins = np.linspace(min(pred_slopes), max(pred_slopes), 10)
        bin_centers = []
        true_means = []
        
        for i in range(len(bins)-1):
            mask = (pred_slopes >= bins[i]) & (pred_slopes < bins[i+1])
            if np.sum(mask) > 0:
                bin_centers.append((bins[i] + bins[i+1]) / 2)
                true_means.append(np.mean(true_slopes[mask]))
        
        plt.plot(bin_centers, true_means, 'o-', label='Calibration')
        plt.plot([min(pred_slopes), max(pred_slopes)], [min(pred_slopes), max(pred_slopes)], '--', color='gray', label='Perfect calibration')
        plt.xlabel('Predicted Slope')
        plt.ylabel('Actual Slope')
        plt.title('Calibration Plot')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    def create_scatter_plot(self, true_slopes, pred_slopes, save_path='scatter_plot.png'):
        """Create prediction vs actual scatter plot"""
        plt.figure(figsize=(8, 6))
        plt.scatter(true_slopes, pred_slopes, alpha=0.6, s=50)
        plt.plot([min(true_slopes), max(true_slopes)], [min(true_slopes), max(true_slopes)], 'r--', alpha=0.8)
        plt.xlabel('Actual Slope')
        plt.ylabel('Predicted Slope')
        plt.title('Predicted vs Actual Slope')
        plt.grid(True, alpha=0.3)
        
        # Add R¬≤ to plot
        r2 = r2_score(true_slopes, pred_slopes)
        plt.text(0.05, 0.95, f'R¬≤ = {r2:.3f}', transform=plt.gca().transAxes, 
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_training_curves(self, trainer, save_path='training_curves.png'):
        """Plot training and validation curves"""
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(trainer.train_losses, label='Training Loss')
        plt.plot(trainer.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss')
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        plt.plot(trainer.train_r2_scores, label='Training R¬≤')
        plt.plot(trainer.val_r2_scores, label='Validation R¬≤')
        plt.xlabel('Epoch')
        plt.ylabel('R¬≤ Score')
        plt.legend()
        plt.title('Training and Validation R¬≤')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

# ============================================================================
# 4. INTERPRETABILITY - GRAD-CAM (CORRECTED)
# ============================================================================

class GradCAM:
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.model.features[-1].register_forward_hook(self.save_activations)
        self.model.features[-1].register_backward_hook(self.save_gradients)
    
    def save_activations(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def __call__(self, x):
        # Forward pass
        self.model.zero_grad()
        
        # Get predictions
        mean_pred, _ = self.model(x, torch.zeros(x.size(0), 9).to(x.device))  # dummy tabular
        
        # Backward pass
        target = mean_pred.sum()  # Use sum of predictions for Grad-CAM
        target.backward()
        
        # Check if we have gradients and activations
        if self.gradients is None or self.activations is None:
            raise RuntimeError("Gradients or activations not captured. Check hook registration.")
        
        # Pool gradients across spatial dimensions
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        
        # Weight activations by gradients
        weighted_activations = torch.zeros_like(self.activations)
        for i in range(pooled_gradients.size(0)):
            weighted_activations[:, i, :, :] = self.activations[:, i, :, :] * pooled_gradients[i]
        
        # Heatmap
        heatmap = torch.mean(weighted_activations, dim=1).squeeze()
        heatmap = F.relu(heatmap)
        
        # Normalize
        if torch.max(heatmap) > 0:
            heatmap /= torch.max(heatmap)
        
        return heatmap.detach().cpu().numpy()

def create_attention_map(model, image_tensor, original_image, save_path):
    """Create and save attention map using Grad-CAM"""
    model.eval()
    
    # Move to device and enable gradients
    image_tensor = image_tensor.unsqueeze(0).to(DEVICE).requires_grad_(True)
    
    try:
        # Get Grad-CAM
        gradcam = GradCAM(model)
        heatmap = gradcam(image_tensor)
        
        # Resize heatmap to match original image
        heatmap = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        
        # Convert original image to RGB if needed
        if len(original_image.shape) == 2:
            original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
        elif original_image.shape[2] == 1:
            original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2RGB)
        
        # Ensure both images have same data type and range
        original_image = original_image.astype(np.float32)
        heatmap = heatmap.astype(np.float32)
        
        # Superimpose heatmap on original image
        superimposed_img = heatmap * 0.4 + original_image * 0.6
        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
        
        # Save result
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 3, 1)
        plt.imshow(original_image.astype(np.uint8))
        plt.title('Original CT')
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(heatmap.astype(np.uint8))
        plt.title('Attention Map')
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(superimposed_img)
        plt.title('Overlay')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"‚úÖ Attention map saved: {save_path}")
        
    except Exception as e:
        print(f"‚ùå Error creating attention map: {e}")
        # Create a placeholder image to avoid breaking the pipeline
        plt.figure(figsize=(8, 6))
        plt.text(0.5, 0.5, f"Grad-CAM Failed\n{str(e)}", 
                ha='center', va='center', transform=plt.gca().transAxes)
        plt.axis('off')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

# ============================================================================
# 5. ERROR ANALYSIS (CORRECTED)
# ============================================================================

def perform_error_analysis(true_slopes, pred_slopes, patients, tabular_data, save_path='error_analysis.png'):
    """Analyze patterns in prediction errors"""
    errors = np.abs(pred_slopes - true_slopes)
    
    # Identify worst predictions (handle empty arrays)
    if len(errors) > 0:
        worst_indices = np.argsort(errors)[-min(10, len(errors)):]  # Top 10 worst or all if less
    else:
        worst_indices = []
    
    print("\nüîç ERROR ANALYSIS - Worst Predictions:")
    print("=" * 60)
    for i, idx in enumerate(worst_indices):
        if idx < len(patients) and idx < len(true_slopes):
            print(f"{i+1}. Patient {patients[idx]}: True={true_slopes[idx]:.2f}, Pred={pred_slopes[idx]:.2f}, Error={errors[idx]:.2f}")
    
    # Plot error distribution
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    if len(errors) > 0:
        plt.hist(errors, bins=min(20, len(errors)), alpha=0.7, edgecolor='black')
    plt.xlabel('Absolute Error')
    plt.ylabel('Frequency')
    plt.title('Error Distribution')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    if len(true_slopes) > 0 and len(errors) > 0:
        plt.scatter(true_slopes, errors, alpha=0.6)
    plt.xlabel('True Slope')
    plt.ylabel('Absolute Error')
    plt.title('Error vs True Value')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    # Analyze errors by feature (using first tabular feature as example)
    if len(tabular_data) > 0 and len(errors) > 0:
        feature_errors = []
        for i in range(min(5, tabular_data.shape[1])):  # Limit to first 5 features
            if len(tabular_data) > i:
                corr = np.corrcoef(tabular_data[:, i], errors)[0, 1] if len(errors) > 1 else 0
                feature_errors.append((i, abs(corr)))
        
        # Plot top correlated features with errors
        if feature_errors:
            feature_errors.sort(key=lambda x: x[1], reverse=True)
            features, correlations = zip(*feature_errors[:5])
            
            plt.bar(range(len(features)), correlations)
            plt.xlabel('Feature Index')
            plt.ylabel('|Correlation with Error|')
            plt.title('Top Features Correlated with Error')
            plt.xticks(range(len(features)), [f'F{i}' for i in features])
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return worst_indices, errors

# ============================================================================
# MAIN EXECUTION - CORRECTED CALLS
# ============================================================================

def run_comprehensive_analysis():
    print("\nüéØ STARTING COMPREHENSIVE 5-FOLD ANALYSIS")
    print("=" * 70)
    
    # Initialize evaluator
    evaluator = ComprehensiveEvaluator()
    
    # 5-Fold Cross Validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(patients_array, progressor_categories)):
        print(f"\nüìä FOLD {fold + 1}/5")
        print("-" * 50)
        
        # Get patients for this fold
        train_patients = patients_array[train_idx].tolist()
        val_patients = patients_array[val_idx].tolist()
        
        print(f"Train: {len(train_patients)} patients, Val: {len(val_patients)} patients")
        
        # Get tabular dimension
        if train_patients and train_patients[0] in TAB:
            tabular_dim = len(TAB[train_patients[0]])
        else:
            tabular_dim = 9  # fallback
            
        # Create datasets
        train_dataset = OptimizedOSICDataset(train_patients, A, TAB, TRAIN_DIR, 'train')
        val_dataset = OptimizedOSICDataset(val_patients, A, TAB, TRAIN_DIR, 'val')
        
        # Skip fold if no valid patients
        if len(train_dataset.valid_patients) == 0 or len(val_dataset.valid_patients) == 0:
            print(f"‚ö†Ô∏è  Skipping fold {fold+1} - no valid patients with images")
            continue
            
        # Data loaders
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
        
        # ====================================================================
        # 2. CLINICAL BASELINE COMPARISON
        # ====================================================================
        print("\nü©∫ TRAINING CLINICAL BASELINE MODEL...")
        clinical_model = ClinicalBaselineModel()
        clinical_trainer = OptimizedTrainer(clinical_model, DEVICE, model_type='clinical')
        clinical_r2, clinical_mae, clinical_lll = clinical_trainer.train(train_loader, val_loader)
        
        # ====================================================================
        # DEEP LEARNING MODEL
        # ====================================================================
        print("\nüß† TRAINING DEEP LEARNING MODEL...")
        dl_model = OptimizedDenseNetModel(tabular_dim=tabular_dim).to(DEVICE)
        dl_trainer = OptimizedTrainer(dl_model, DEVICE, lr=1e-4, model_type='deep_learning')
        dl_r2, dl_mae, dl_lll = dl_trainer.train(train_loader, val_loader, epochs=30)
        
        # ====================================================================
        # COMPREHENSIVE EVALUATION
        # ====================================================================
        print("\nüìà PERFORMING COMPREHENSIVE EVALUATION...")
        
        # Get predictions for comprehensive evaluation
        dl_model.eval()
        all_preds, all_targets, all_patients = [], [], []
        all_tabular = []
        
        with torch.no_grad():
            for images, tabular, targets, patients, _ in val_loader:
                images, tabular = images.to(DEVICE), tabular.to(DEVICE)
                mean_pred, _ = dl_model(images, tabular)
                
                # Handle both scalar and tensor cases
                pred_np = mean_pred.cpu().numpy()
                if pred_np.ndim == 0:
                    all_preds.append(pred_np.item())
                else:
                    all_preds.extend(pred_np.tolist())
                    
                all_targets.extend(targets.numpy())
                all_patients.extend(patients)
                all_tabular.extend(tabular.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_targets = np.array(all_targets)
        all_tabular = np.array(all_tabular)
        
        # Skip if no predictions
        if len(all_preds) == 0:
            print("‚ö†Ô∏è  No predictions generated, skipping evaluation")
            continue
            
        # 3.1 Classification Metrics
        classification_results = evaluator.calculate_classification_metrics(all_targets, all_preds)
        print(f"Classification Accuracy: {classification_results['accuracy']:.4f}")
        
        # 3.2 Calibration Plot
        evaluator.create_calibration_plot(all_targets, all_preds, f'calibration_fold_{fold+1}.png')
        
        # 3.3 Scatter Plot
        evaluator.create_scatter_plot(all_targets, all_preds, f'scatter_fold_{fold+1}.png')
        
        # 3.4 Training Curves
        evaluator.plot_training_curves(dl_trainer, f'training_curves_fold_{fold+1}.png')
        
        # ====================================================================
        # 4. INTERPRETABILITY - GRAD-CAM
        # ====================================================================
        print("\nüëÅÔ∏è  GENERATING ATTENTION MAPS...")
        # Get one sample for visualization
        if len(val_dataset) > 0:
            try:
                sample_idx = 0
                sample_image, sample_tabular, sample_target, sample_patient, sample_path = val_dataset[sample_idx]
                
                # Load original image properly
                original_img = val_dataset.load_dicom(Path(sample_path))
                
                create_attention_map(
                    dl_model, 
                    sample_image, 
                    original_img, 
                    f'attention_map_fold_{fold+1}.png'
                )
            except Exception as e:
                print(f"‚ö†Ô∏è  Could not generate attention map: {e}")
        else:
            print("‚ö†Ô∏è  No validation samples for attention map")
        
        # ====================================================================
        # 5. ERROR ANALYSIS
        # ====================================================================
        print("\nüîç PERFORMING ERROR ANALYSIS...")
        worst_indices, errors = perform_error_analysis(
            all_targets, all_preds, all_patients, all_tabular, f'error_analysis_fold_{fold+1}.png'
        )
        
        # Store fold results
        fold_results.append({
            'fold': fold + 1,
            'clinical_r2': clinical_r2,
            'clinical_mae': clinical_mae,
            'clinical_lll': clinical_lll,
            'dl_r2': dl_r2,
            'dl_mae': dl_mae,
            'dl_lll': dl_lll,
            'classification_accuracy': classification_results['accuracy'],
            'mean_error': np.mean(errors) if len(errors) > 0 else 0,
            'std_error': np.std(errors) if len(errors) > 0 else 0
        })
        
        print(f"\n‚úÖ FOLD {fold + 1} COMPLETED")
        print(f"Clinical R¬≤: {clinical_r2:.4f}, DL R¬≤: {dl_r2:.4f}")
        print(f"Improvement: {(dl_r2 - clinical_r2):.4f}")
    
    # ========================================================================
    # FINAL SUMMARY AND COMPARISON
    # ========================================================================
    print("\n" + "="*70)
    print("üéØ FINAL COMPREHENSIVE RESULTS")
    print("="*70)
    
    if not fold_results:
        print("‚ùå No results generated from any fold")
        return pd.DataFrame()
        
    # Convert to DataFrame for easy analysis
    results_df = pd.DataFrame(fold_results)
    
    print("\nüìä 5-FOLD CROSS VALIDATION RESULTS:")
    print(results_df.round(4))
    
    print("\nüìà AVERAGE PERFORMANCE ACROSS FOLDS:")
    avg_results = results_df.mean()
    print(avg_results.round(4))
    
    print(f"\nüí™ DEEP LEARNING IMPROVEMENT OVER CLINICAL BASELINE:")
    print(f"R¬≤ Improvement: {avg_results['dl_r2'] - avg_results['clinical_r2']:.4f}")
    print(f"MAE Improvement: {avg_results['clinical_mae'] - avg_results['dl_mae']:.4f}")
    print(f"LLL Improvement: {avg_results['dl_lll'] - avg_results['clinical_lll']:.4f}")
    
    # Create final comparison plot
    plt.figure(figsize=(10, 6))
    x_pos = np.arange(len(results_df))
    width = 0.35
    
    plt.bar(x_pos - width/2, results_df['clinical_r2'], width, label='Clinical Baseline', alpha=0.7)
    plt.bar(x_pos + width/2, results_df['dl_r2'], width, label='DL Model', alpha=0.7)
    
    plt.xlabel('Fold')
    plt.ylabel('R¬≤ Score')
    plt.title('Clinical vs Deep Learning Model Performance (5-Fold CV)')
    plt.xticks(x_pos, [f'Fold {i+1}' for i in range(len(results_df))])
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('final_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\n‚úÖ COMPREHENSIVE ANALYSIS COMPLETED!")
    print("üìÅ Generated files:")
    print("   - calibration_fold_*.png (Calibration plots)")
    print("   - scatter_fold_*.png (Prediction scatter plots)")
    print("   - training_curves_fold_*.png (Training history)")
    print("   - attention_map_fold_*.png (Grad-CAM visualizations)")
    print("   - error_analysis_fold_*.png (Error analysis)")
    print("   - final_comparison.png (Model comparison)")
    
    return results_df

if __name__ == "__main__":
    final_results = run_comprehensive_analysis()

üöÄ PUBLICATION-READY OSIC Model - Comprehensive Analysis
üì± Device: cuda
Loaded dataset with shape: (1549, 7)
Calculating optimized linear decay coefficients...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 176/176 [00:00<00:00, 1151.90it/s]

Processed 176 patients with optimized features
Target statistics: mean=-4.8107, std=6.7150
Target range: [-39.0741, 11.1389]

üìä Setting up 5-Fold Cross Validation...

üéØ STARTING COMPREHENSIVE 5-FOLD ANALYSIS

üìä FOLD 1/5
--------------------------------------------------
Train: 140 patients, Val: 36 patients





Dataset train: 139 patients with images
Dataset val: 35 patients with images

ü©∫ TRAINING CLINICAL BASELINE MODEL...
Clinical Baseline - R¬≤: -0.3138, MAE: 4.8450, RMSE: 6.0912, LLL: -3.2784

üß† TRAINING DEEP LEARNING MODEL...
Epoch 1: LR=1.00e-04, Loss=47.9357
          Train R¬≤=-0.2942, Val R¬≤=-0.2428
          Val MAE=4.2587, RMSE=5.9245, LLL=-4.7148
üéØ NEW BEST! R¬≤: -0.2428
------------------------------------------------------------
Epoch 2: LR=1.00e-04, Loss=34.1029
          Train R¬≤=0.0616, Val R¬≤=-0.0150
          Val MAE=3.8338, RMSE=5.3541, LLL=-4.2742
üéØ NEW BEST! R¬≤: -0.0150
------------------------------------------------------------
Epoch 3: LR=1.00e-04, Loss=31.5362
          Train R¬≤=0.1306, Val R¬≤=-0.2412
          Val MAE=4.4853, RMSE=5.9206, LLL=-4.8216
------------------------------------------------------------
Epoch 4: LR=1.00e-04, Loss=31.2274
          Train R¬≤=0.1356, Val R¬≤=-0.2128
          Val MAE=4.5661, RMSE=5.8524, LLL=-4.8216
---------