In [16]:
# DenseNet V2: Enhanced Medical Imaging Model for OSIC Pulmonary Fibrosis Progression
# Production-Ready Single-Flow Notebook with Advanced Uncertainty Quantification

import os
import cv2
import pydicom
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import random
from tqdm import tqdm 
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from PIL import Image
import json
from pathlib import Path
import joblib
import warnings
import pickle

# Albumentations for medical augmentations
import albumentations as albu
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings('ignore')

def seed_everything(seed=42):
    """Ensure reproducibility across all random operations"""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
seed_everything(42)

# Configuration
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("🚀 DenseNet V2 - Enhanced Medical Imaging Model")
print("=" * 60)
print(f"📱 Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"🔥 GPU: {torch.cuda.get_device_name()}")
    print(f"💾 Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print("=" * 60)


# QUICK RECOVERY: Load Auto-Saved Data (Run this after restarting kernel)
def quick_recovery():
    """Quick recovery of auto-saved data after kernel restart"""
    global train_df, A, TAB, P, train_patients, val_patients
    
    print("🔄 QUICK RECOVERY MODE")
    print("=" * 40)
    
    # Check Kaggle vs local environment
    if os.path.exists('/kaggle/working/auto_save_data'):
        auto_save_dir = "/kaggle/working/auto_save_data"
        print("🐰 Using Kaggle persistent auto-save data")
    elif os.path.exists('auto_save_data'):
        auto_save_dir = "auto_save_data"
        print("🏠 Using local auto-save data")
    else:
        print("❌ No auto-saved data found. Run full notebook first.")
        return False
    
    try:
        # Load core data
        print("📊 Loading core data...")
        train_df = pd.read_csv(f"{auto_save_dir}/train_df_backup.csv")
        
        with open(f"{auto_save_dir}/decay_coefficients_A_backup.pkl", 'rb') as f:
            A = pickle.load(f)
        
        with open(f"{auto_save_dir}/tabular_features_TAB_backup.pkl", 'rb') as f:
            TAB = pickle.load(f)
        
        with open(f"{auto_save_dir}/patient_list_P_backup.pkl", 'rb') as f:
            P = pickle.load(f)
        
        print(f"✅ Loaded: train_df ({train_df.shape}), A ({len(A)}), TAB ({len(TAB)}), P ({len(P)})")
        
        # Load splits if available
        if os.path.exists(f"{auto_save_dir}/train_patients_backup.pkl"):
            print("🔄 Loading train/val splits...")
            
            with open(f"{auto_save_dir}/train_patients_backup.pkl", 'rb') as f:
                train_patients = pickle.load(f)
            
            with open(f"{auto_save_dir}/val_patients_backup.pkl", 'rb') as f:
                val_patients = pickle.load(f)
            
            print(f"✅ Loaded: train_patients ({len(train_patients)}), val_patients ({len(val_patients)})")
        
        # Show metadata
        if os.path.exists(f"{auto_save_dir}/processing_metadata.json"):
            with open(f"{auto_save_dir}/processing_metadata.json", 'r') as f:
                metadata = json.load(f)
            print(f"📅 Data from: {metadata.get('processing_timestamp', 'Unknown')}")
        
        # Load model if available
        if os.path.exists(f"{auto_save_dir}/model_weights_backup.pth"):
            print("🏗️ Loading model...")
            try:
                global model
                model = WorkingDenseNetModel(tabular_dim=4).to(DEVICE)
                model.load_state_dict(torch.load(f"{auto_save_dir}/model_weights_backup.pth", map_location=DEVICE))
                print("✅ Model weights loaded")
            except:
                print("⚠️ Model loading failed (need to run model definition cells first)")
        
        # Show training results if available
        if os.path.exists(f"{auto_save_dir}/training_results_backup.json"):
            with open(f"{auto_save_dir}/training_results_backup.json", 'r') as f:
                results = json.load(f)
            print(f"📈 Previous training: MAE = {results.get('best_val_mae', 'N/A')}")
        
        print("🎉 Quick recovery complete! Core variables restored.")
        print("💡 Tip: If model loading failed, run model definition cells first, then call quick_recovery() again")
        return True
        
    except Exception as e:
        print(f"❌ Recovery failed: {e}")
        return False

print("✅ Quick recovery system ready!")
print("💡 Usage after kernel restart:")
print("   quick_recovery()  # Restore all auto-saved data")
# Uncomment the line below to auto-recover on kernel restart
# quick_recovery()


# Cell 2: Load Data and Create Tabular Features
train_df = pd.read_csv('../input/osic-pulmonary-fibrosis-progression/train.csv')
print(f"Loaded dataset with shape: {train_df.shape}")

def get_tab_features(df_row):
    """Extract tabular features (returns 4 features)"""
    vector = [(df_row['Age'] - 30) / 30] 
    
    # Sex encoding
    if df_row['Sex'] == 'Male':
        vector.append(0)
    else:
        vector.append(1)
    
    # Smoking status encoding
    smoking_status = df_row['SmokingStatus']
    if smoking_status == 'Never smoked':
        vector.extend([0, 0])
    elif smoking_status == 'Ex-smoker':
        vector.extend([1, 1])
    elif smoking_status == 'Currently smokes':
        vector.extend([0, 1])
    else:
        vector.extend([1, 0])
    return np.array(vector)

# Calculate linear decay coefficients for each patient
A = {} 
TAB = {} 
P = [] 

print("Calculating linear decay coefficients...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].copy()
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    
    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, b = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
            TAB[patient] = get_tab_features(sub.iloc[0])
            P.append(patient)
        except:
            # Use fallback method for patients with insufficient data
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0]) if len(weeks) > 1 else 0.0
            TAB[patient] = get_tab_features(sub.iloc[0])
            P.append(patient)
    else:
        A[patient] = 0.0
        TAB[patient] = get_tab_features(sub.iloc[0])
        P.append(patient)

print(f"Processed {len(P)} patients with decay coefficients")


# Auto-Save: Critical Data After Processing
print("💾 Auto-saving critical data...")
import os
import pickle
import json
from datetime import datetime

# Create auto-save directory (Kaggle-aware)
if os.path.exists('/kaggle/working'):
    auto_save_dir = "/kaggle/working/auto_save_data"
    print("🐰 Using Kaggle persistent storage")
else:
    auto_save_dir = "auto_save_data"
    print("🏠 Using local storage")

os.makedirs(auto_save_dir, exist_ok=True)

# Save immediately after processing
try:
    # Save core dataframes and dictionaries
    train_df.to_csv(f"{auto_save_dir}/train_df_backup.csv", index=False)
    
    with open(f"{auto_save_dir}/decay_coefficients_A_backup.pkl", 'wb') as f:
        pickle.dump(A, f)
    
    with open(f"{auto_save_dir}/tabular_features_TAB_backup.pkl", 'wb') as f:
        pickle.dump(TAB, f)
    
    with open(f"{auto_save_dir}/patient_list_P_backup.pkl", 'wb') as f:
        pickle.dump(P, f)
    
    # Save processing metadata
    metadata = {
        'processed_patients': len(P),
        'total_decay_coefficients': len(A),
        'tabular_features_dim': len(list(TAB.values())[0]) if TAB else 0,
        'processing_timestamp': datetime.now().isoformat()
    }
    
    with open(f"{auto_save_dir}/processing_metadata.json", 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"✅ Auto-saved to {auto_save_dir}/")
    print(f"   - train_df_backup.csv")
    print(f"   - decay_coefficients_A_backup.pkl") 
    print(f"   - tabular_features_TAB_backup.pkl")
    print(f"   - patient_list_P_backup.pkl")
    print(f"   - processing_metadata.json")
    
except Exception as e:
    print(f"⚠️ Auto-save failed: {e}")



# Cell 3: Medical-Specific Augmentations (FIXED)
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                # Geometric augmentations
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                
                # Medical-specific augmentations
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                
                # Lung-specific augmentations for robustness
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                
                # Cutout for robustness
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                
                # Normalization
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
    
    def __call__(self, image):
        return self.transform(image=image)['image']




# Cell 4: Enhanced DenseNet Model with ALL Improvements
class UltraAdvancedDenseNetModel(nn.Module):
    """
    Ultra-Enhanced DenseNet model with ALL improvements:
    - Multi-scale feature extraction
    - Cross-modal attention
    - Uncertainty quantification
    - Spatial attention
    - Channel attention
    - Feature pyramid network
    """
    
    def __init__(self, tabular_dim=4, dropout_rate=0.5):
        super(UltraAdvancedDenseNetModel, self).__init__()
        
        # DenseNet121 backbone with pretrained weights
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        
        # Multi-scale processing branches
        self.scale_branches = nn.ModuleList([
            self._create_scale_branch(kernel_size=7, stride=2, padding=3),  # Large scale
            self._create_scale_branch(kernel_size=5, stride=2, padding=2),  # Medium scale
            self._create_scale_branch(kernel_size=3, stride=2, padding=1),  # Small scale
        ])
        
        # Feature pyramid network
        self.fpn_conv1 = nn.Conv2d(1024, 512, 1)
        self.fpn_conv2 = nn.Conv2d(1024, 512, 1)
        self.fpn_conv3 = nn.Conv2d(1024, 512, 1)
        
        # Spatial attention mechanism
        self.spatial_attention = SpatialAttention()
        
        # Channel attention mechanism
        self.channel_attention = ChannelAttention(512 * 3)
        
        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=1536, num_heads=12, dropout=0.2, batch_first=True
        )
        
        # Enhanced tabular processing with residual connections
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU()
        )
        
        # Multi-modal fusion with attention
        self.fusion_attention = nn.MultiheadAttention(
            embed_dim=1568, num_heads=8, dropout=0.1, batch_first=True
        )
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(1568, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate/2)
        )
        
        # Uncertainty quantification heads
        self.mean_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        self.log_var_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _create_scale_branch(self, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                
    def forward(self, images, tabular):
        batch_size = images.size(0)
        
        # Multi-scale processing
        scale_features = []
        for i, branch in enumerate(self.scale_branches):
            if i == 0:
                scale_feat = branch(images)
            elif i == 1:
                downsampled = F.avg_pool2d(images, kernel_size=2)
                scale_feat = branch(downsampled)
                scale_feat = F.interpolate(scale_feat, scale_factor=2, mode='bilinear', align_corners=False)
            else:
                downsampled = F.avg_pool2d(images, kernel_size=4)
                scale_feat = branch(downsampled)
                scale_feat = F.interpolate(scale_feat, scale_factor=4, mode='bilinear', align_corners=False)
            
            scale_features.append(scale_feat)
        
        # Concatenate multi-scale features
        multi_scale = torch.cat(scale_features, dim=1)
        
        # Pass through DenseNet features
        img_features = self.features(multi_scale)
        
        # Apply spatial attention
        img_features = self.spatial_attention(img_features)
        
        # Feature pyramid processing
        fpn1 = self.fpn_conv1(img_features)
        fpn2 = self.fpn_conv2(img_features)
        fpn3 = self.fpn_conv3(img_features)
        
        # Global pooling for each FPN level
        fpn1_pool = F.adaptive_avg_pool2d(fpn1, (1, 1)).view(batch_size, -1)
        fpn2_pool = F.adaptive_avg_pool2d(fpn2, (1, 1)).view(batch_size, -1)
        fpn3_pool = F.adaptive_avg_pool2d(fpn3, (1, 1)).view(batch_size, -1)
        
        # Concatenate FPN features
        fpn_combined = torch.cat([fpn1_pool, fpn2_pool, fpn3_pool], dim=1)
        
        # Apply channel attention
        fpn_combined = self.channel_attention(fpn_combined.unsqueeze(-1).unsqueeze(-1))
        fpn_combined = fpn_combined.view(batch_size, -1)
        
        # Process tabular data
        tab_features = self.tabular_processor(tabular)
        
        # Cross-modal attention
        img_expanded = fpn_combined.unsqueeze(1)
        tab_expanded = tab_features.unsqueeze(1)
        
        # Attention between image and tabular features
        attended_img, _ = self.cross_attention(
            img_expanded, tab_expanded, tab_expanded
        )
        attended_img = attended_img.squeeze(1)
        
        # Fusion with attention
        combined_features = torch.cat([attended_img, tab_features], dim=1)
        combined_expanded = combined_features.unsqueeze(1)
        
        fused_features, _ = self.fusion_attention(
            combined_expanded, combined_expanded, combined_expanded
        )
        fused_features = fused_features.squeeze(1)
        
        # Final fusion
        final_features = self.fusion_layer(fused_features)
        
        # Predict mean and log variance
        mean_pred = self.mean_head(final_features)
        log_var = self.log_var_head(final_features)
        
        return mean_pred.squeeze(), log_var.squeeze()

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv1(x_cat)
        return x * self.sigmoid(x_cat)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)
        
# Cell 5: Enhanced Dataset Class (FIXED)
class OSICDenseNetDataset(Dataset):
    """Enhanced dataset with medical augmentations and robust loading"""
    
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augment = augment
        self.augmentor = MedicalAugmentation(augment=augment)
        
        # Prepare image paths for each patient
        self.patient_images = {}
        for patient in self.patients:
            patient_dir = self.data_dir / patient
            if patient_dir.exists():
                image_files = [f for f in patient_dir.iterdir() if f.suffix.lower() == '.dcm']
                if image_files:
                    self.patient_images[patient] = image_files
        
        # Filter patients with available images
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")
    
    def __len__(self):
        # For training, use multiple samples per patient
        if self.split == 'train':
            return len(self.valid_patients) * 6  # More augmented samples
        else:
            return len(self.valid_patients)
    
    def __getitem__(self, idx):
        if self.split == 'train':
            patient_idx = idx % len(self.valid_patients)
        else:
            patient_idx = idx
            
        patient = self.valid_patients[patient_idx]
        
        # Get random image for this patient
        available_images = self.patient_images[patient]
        if len(available_images) > 1:
            selected_image = np.random.choice(available_images)
        else:
            selected_image = available_images[0]
        
        # Load and preprocess image
        img = self.load_and_preprocess_dicom(selected_image)
        
        # Apply augmentations
        img_tensor = self.augmentor(img)
        
        # Get tabular features
        tab_features = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        
        # Get target (decay coefficient)
        target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        
        return img_tensor, tab_features, target, patient
    
    def load_and_preprocess_dicom(self, path):
        """Enhanced DICOM loading with better preprocessing"""
        try:
            # Load DICOM
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            
            # Handle different DICOM formats
            if len(img.shape) == 3:
                img = img[img.shape[0]//2]  # Take middle slice if 3D
            
            # Resize to target size
            img = cv2.resize(img, (512, 512))
            
            # Normalize to 0-255 range
            img_min, img_max = img.min(), img.max()
            if img_max > img_min:
                img = (img - img_min) / (img_max - img_min) * 255
            else:
                img = np.zeros_like(img)
            
            # Convert to 3-channel
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            
            return img
            
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            # Return a black image as fallback
            return np.zeros((512, 512, 3), dtype=np.uint8)


# Cell 5.1: Define PICP Loss and helper
import torch.nn.functional as F

def picp_loss(y_true, y_lower, y_upper, target_coverage=0.95):
    """
    Prediction Interval Coverage Probability loss.
    Penalizes intervals that cover less than target_coverage of the data.
    """
    # Boolean mask of which targets lie inside the interval
    in_interval = ((y_true >= y_lower) & (y_true <= y_upper)).float()
    picp = in_interval.mean()           # actual coverage
    penalty = F.relu(target_coverage - picp)
    return penalty, picp.item()



# Cell 6: CORRECTED Working Model with Fixed Dimensions
class WorkingDenseNetModel(nn.Module):
    """
    CORRECTED model with proper dimension matching
    """
    
    def __init__(self, tabular_dim=4, dropout_rate=0.4):
        super(WorkingDenseNetModel, self).__init__()
        
        # DenseNet121 backbone
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        
        # Spatial attention
        self.spatial_attention = SpatialAttention()
        
        # Enhanced tabular processing
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 256),  # Increased to 256
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 512),  # Final tabular features: 512
            nn.BatchNorm1d(512),
            nn.ReLU()
        )
        
        # Cross-modal attention (fixed dimensions)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=1024, num_heads=8, dropout=0.2, batch_first=True
        )
        
        # Multi-modal fusion (corrected input size)
        self.fusion_layer = nn.Sequential(
            nn.Linear(1024 + 512, 768),  # 1024 (img) + 512 (tab) = 1536 -> 768
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(768, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate/2)
        )
        
        # Uncertainty quantification heads
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.log_var_head = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
    def forward(self, images, tabular):
        batch_size = images.size(0)
        
        # Extract image features
        img_features = self.features(images)  # [B, 1024, H, W]
        
        # Apply spatial attention
        img_features = self.spatial_attention(img_features)
        
        # Global average pooling
        img_features = F.adaptive_avg_pool2d(img_features, (1, 1))
        img_features = img_features.view(batch_size, -1)  # [B, 1024]
        
        # Process tabular data
        tab_features = self.tabular_processor(tabular)  # [B, 512]
        
        # Cross-modal attention
        img_expanded = img_features.unsqueeze(1)  # [B, 1, 1024]
        tab_expanded = tab_features.unsqueeze(1)  # [B, 1, 512]
        
        # Project tabular to same dimension for attention
        tab_proj = F.linear(tab_expanded, 
                           torch.randn(1024, 512).to(images.device))  # [B, 1, 1024]
        
        attended_img, _ = self.cross_attention(
            img_expanded, tab_proj, tab_proj
        )
        attended_img = attended_img.squeeze(1)  # [B, 1024]
        
        # Fusion
        combined_features = torch.cat([attended_img, tab_features], dim=1)  # [B, 1536]
        fused_features = self.fusion_layer(combined_features)
        
        # Predict mean and log variance
        mean_pred = self.mean_head(fused_features)
        log_var = self.log_var_head(fused_features)
        
        return mean_pred.squeeze(), log_var.squeeze()

print("✅ CORRECTED Working model defined!")

# Cell 6.5: Data Preparation and Loaders
print("🔄 Creating data loaders...")

# Split patients into train and validation
from sklearn.model_selection import train_test_split

patients_list = list(P)
train_patients, val_patients = train_test_split(
    patients_list, 
    test_size=0.2, 
    random_state=42,
    shuffle=True
)

print(f"Train patients: {len(train_patients)}")
print(f"Validation patients: {len(val_patients)}")

# Create datasets
train_dataset = OSICDenseNetDataset(
    patients=train_patients,
    A_dict=A,
    TAB_dict=TAB,
    data_dir=TRAIN_DIR,
    split='train',
    augment=True
)

val_dataset = OSICDenseNetDataset(
    patients=val_patients,
    A_dict=A,
    TAB_dict=TAB,
    data_dir=TRAIN_DIR,
    split='val',
    augment=False
)

# Create data loaders
# Cell 6.5: DataLoaders (Optimized)
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=6,
    pin_memory=True,
    drop_last=True,
    persistent_workers=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

print(f"✅ Data loaders created!")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")





# Cell 7: Final Trainer with AMP, LLL, Full Metrics -------------------

from torch.cuda.amp import autocast, GradScaler


# Cell 7: Enhanced Trainer with PICP integrated

class PicpTrainer:
    def __init__(self, model, device, lr=1e-4, lambda_picp=0.5):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.scaler = torch.cuda.amp.GradScaler()
        self.lambda_picp = lambda_picp

    def laplace_nll(self, y_true, y_pred, log_var):
        sigma = torch.exp(log_var / 2.0)
        sigma = torch.clamp(sigma, min=1.0, max=500.0)
        abs_errors = torch.abs(y_true - y_pred)
        log_likelihood = -torch.log(2.0 * sigma) - abs_errors / sigma
        return -log_likelihood.mean()  # <--- do NOT use .item() here

    def picp_loss(self, y_true, y_pred, log_var):
        sigma = torch.exp(log_var / 2.0)
        lower = y_pred - 1.96 * sigma
        upper = y_pred + 1.96 * sigma
        inside = ((y_true >= lower) & (y_true <= upper)).float()
        picp = inside.mean()
        return torch.abs(picp - 0.95)  # Encourages PICP ~ 0.95

    def train(self, train_loader, val_loader, epochs=30, patience=5):
        best_val_mae = float('inf')
        patience_counter = 0

        # 💡 Ensure all model params require grad
        for p in self.model.parameters():
            p.requires_grad = True

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0

            for batch in train_loader:
                self.optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    images, targets = batch
                    images = images.to(self.device)
                    targets = targets.to(self.device)

                    outputs = self.model(images)
                    preds, log_var = outputs[:, 0], outputs[:, 1]

                    nll = self.laplace_nll(targets, preds, log_var)
                    p_loss = self.picp_loss(targets, preds, log_var)
                    loss = nll + self.lambda_picp * p_loss

                if not loss.requires_grad:
                    print("❌ Loss has no grad!")
                    continue

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                train_loss += loss.item()

            # 💡 You can log loss here
            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}")

        return best_val_mae



class SimpleTrainer:
    def __init__(self, model, device, lr=1e-4):
        self.model = model
        self.device = device
        self.lr = lr
        self.best_val_mae = float('inf')
        self.best_val_lll = float('inf')
        self.scaler = GradScaler()

    def uncertainty_loss(self, mean_pred, log_var, targets, reduction='mean'):
        var = torch.exp(log_var)
        mse_loss = (mean_pred - targets) ** 2
        uncertainty_penalty = torch.mean(torch.abs(log_var - torch.log(mse_loss + 1e-6)))
        loss = 0.5 * (mse_loss / var + log_var) + 0.1 * uncertainty_penalty
        return loss.mean() if reduction == 'mean' else loss.sum()

    def laplace_log_likelihood(self, y_true, y_pred, log_var):
        sigma = torch.exp(log_var / 2.0)
        sigma = torch.clamp(sigma, min=5.0, max=500.0)
        abs_errors = torch.abs(y_true - y_pred)
        log_likelihood = -torch.log(2.0 * sigma) - abs_errors / sigma
        return torch.mean(log_likelihood)

    def train(self, train_loader, val_loader, epochs=30, patience=8):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4, verbose=True)
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_loss, train_mae, train_lll, train_batches = 0.0, 0.0, 0.0, 0

            for batch_idx, (images, tabular, targets, _) in enumerate(train_loader):
                images = images.to(self.device, non_blocking=True)
                tabular = tabular.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)

                optimizer.zero_grad()
                with autocast():
                    mean_pred, log_var = self.model(images, tabular)
                    loss = self.uncertainty_loss(mean_pred, log_var, targets)
                    mae = F.l1_loss(mean_pred, targets)
                    lll = self.laplace_log_likelihood(targets, mean_pred, log_var)

                self.scaler.scale(loss).backward()
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(optimizer)
                self.scaler.update()

                train_loss += loss.item()
                train_mae += mae.item()
                train_lll += lll.item()
                train_batches += 1

            self.model.eval()
            val_loss, val_mae, val_lll = 0.0, 0.0, 0.0
            val_predictions, val_targets, val_sigmas = [], [], []

            with torch.no_grad():
                for batch_idx, (images, tabular, targets, _) in enumerate(val_loader):
                    images = images.to(self.device)
                    tabular = tabular.to(self.device)
                    targets = targets.to(self.device)

                    with autocast():
                        mean_pred, log_var = self.model(images, tabular)
                        loss = self.uncertainty_loss(mean_pred, log_var, targets)
                        mae = F.l1_loss(mean_pred, targets)
                        lll = self.laplace_log_likelihood(targets, mean_pred, log_var)

                    sigma = torch.exp(log_var / 2.0)
                    sigma = torch.clamp(sigma, min=50.0, max=500.0)

                    val_loss += loss.item()
                    val_mae += mae.item()
                    val_lll += lll.item()
                    val_predictions.extend(mean_pred.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())
                    val_sigmas.extend(sigma.cpu().numpy())

            # Metrics and logging
            if train_batches > 0 and len(val_predictions) > 0:
                avg_train_loss = train_loss / train_batches
                avg_train_mae = train_mae / train_batches
                avg_train_lll = train_lll / train_batches

                avg_val_loss = val_loss / len(val_loader)
                avg_val_mae = val_mae / len(val_loader)
                avg_val_lll = val_lll / len(val_loader)

                val_predictions = np.array(val_predictions)
                val_targets = np.array(val_targets)
                val_sigmas = np.array(val_sigmas)

                val_rmse = np.sqrt(np.mean((val_predictions - val_targets) ** 2))
                ss_res = np.sum((val_targets - val_predictions) ** 2)
                ss_tot = np.sum((val_targets - np.mean(val_targets)) ** 2)
                r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else -float('inf')
                avg_sigma = np.mean(val_sigmas)

                print(f"\nEpoch {epoch+1}/{epochs}")
                print(f"Train Loss: {avg_train_loss:.6f} | Train LLL: {avg_train_lll:.6f} | Train MAE: {avg_train_mae:.6f}")
                print(f"Val Loss: {avg_val_loss:.6f} | Val LLL: {avg_val_lll:.6f} | MAE: {avg_val_mae:.6f} | RMSE: {val_rmse:.6f} | R²: {r2:.6f} | Sigma: {avg_sigma:.2f}")

                scheduler.step(avg_val_mae)

                if avg_val_lll < self.best_val_lll:
                    self.best_val_lll = avg_val_lll
                    self.best_val_mae = avg_val_mae
                    torch.save(self.model.state_dict(), 'best_working_model.pth')
                    print("✅ New best model saved! (Best Laplace Log Likelihood)")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"⏹️ Early stopping at epoch {epoch+1}")
                        break
                print("-" * 80)
        return self.best_val_mae

print("✅ Enhanced trainer with AMP + LLL defined!")









# CORRECTED SimpleTrainer with FIXED sigma learning
class CorrectedSimpleTrainer:
    """
    CORRECTED trainer with proper sigma learning and debugging
    """
    
    def __init__(self, model, device, lr=1e-4):
        self.model = model
        self.device = device
        self.lr = lr
        self.best_val_mae = float('inf')
        self.best_val_lll = float('inf')  # For actual log-likelihood (negative values)
        
    def uncertainty_loss(self, mean_pred, log_var, targets, reduction='mean'):
        """Uncertainty-aware loss function"""
        var = torch.exp(log_var)
        mse_loss = (mean_pred - targets) ** 2
        
        # Add penalty for poor uncertainty estimation
        uncertainty_penalty = torch.mean(torch.abs(log_var - torch.log(mse_loss + 1e-6)))
        
        loss = 0.5 * (mse_loss / var + log_var) + 0.05 * uncertainty_penalty
        
        if reduction == 'mean':
            return loss.mean()
        return loss.sum()
        
    def laplace_log_likelihood(self, y_true, y_pred, log_var):
        """
        Calculate ACTUAL Laplace Log Likelihood
        Returns actual log-likelihood (negative values, higher is better)
        """
        # Convert log variance to standard deviation (sigma)
        sigma = torch.exp(log_var / 2.0)
        
        # Much smaller bounds - allow dynamic learning!
        sigma = torch.clamp(sigma, min=2.0, max=200.0)
        
        abs_errors = torch.abs(y_true - y_pred)
        
        # ACTUAL log-likelihood: log(1/(2σ)) - |y-μ|/σ
        # = -log(2σ) - |y-μ|/σ
        log_likelihood = -torch.log(2.0 * sigma) - abs_errors / sigma
        
        # Return ACTUAL log-likelihood (negative values, higher is better)
        return torch.mean(log_likelihood)
        
    def train(self, train_loader, val_loader, epochs=30, patience=8):
        # Better optimizer settings for uncertainty learning
        optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=5e-5,  # Lower learning rate for better uncertainty learning
            weight_decay=1e-5
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=4, verbose=True  # mode='max' for log-likelihood
        )
        
        patience_counter = 0
        
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0
            train_mae = 0.0
            train_lll = 0.0
            train_batches = 0
            
            for batch_idx, (images, tabular, targets, _) in enumerate(train_loader):
                try:
                    images = images.to(self.device)
                    tabular = tabular.to(self.device) 
                    targets = targets.to(self.device)
                    
                    optimizer.zero_grad()
                    
                    # Forward pass
                    mean_pred, log_var = self.model(images, tabular)
                    
                    # Calculate losses and metrics
                    loss = self.uncertainty_loss(mean_pred, log_var, targets)
                    mae = F.l1_loss(mean_pred, targets)
                    lll = self.laplace_log_likelihood(targets, mean_pred, log_var)
                    
                    # Backward pass
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    optimizer.step()
                    
                    train_loss += loss.item()
                    train_mae += mae.item()
                    train_lll += lll.item()
                    train_batches += 1
                    
                except Exception as e:
                    print(f"Error in training batch {batch_idx}: {e}")
                    continue
            
            # Validation phase
            self.model.eval()
            val_loss = 0.0
            val_mae = 0.0
            val_lll = 0.0
            val_predictions = []
            val_targets = []
            val_sigmas = []
            val_log_vars = []
            
            with torch.no_grad():
                for batch_idx, (images, tabular, targets, _) in enumerate(val_loader):
                    try:
                        images = images.to(self.device)
                        tabular = tabular.to(self.device)
                        targets = targets.to(self.device)
                        
                        mean_pred, log_var = self.model(images, tabular)
                        
                        # Calculate all metrics
                        loss = self.uncertainty_loss(mean_pred, log_var, targets)
                        mae = F.l1_loss(mean_pred, targets)
                        lll = self.laplace_log_likelihood(targets, mean_pred, log_var)
                        
                        # Calculate sigma values for debugging
                        sigma = torch.exp(log_var / 2.0)
                        sigma = torch.clamp(sigma, min=2.0, max=200.0)
                        
                        val_loss += loss.item()
                        val_mae += mae.item()
                        val_lll += lll.item()
                        
                        val_predictions.extend(mean_pred.cpu().numpy())
                        val_targets.extend(targets.cpu().numpy())
                        val_sigmas.extend(sigma.cpu().numpy())
                        val_log_vars.extend(log_var.cpu().numpy())
                        
                    except Exception as e:
                        print(f"Error in validation batch {batch_idx}: {e}")
                        continue
            
            # Calculate comprehensive metrics
            if train_batches > 0 and len(val_predictions) > 0:
                # Average training metrics
                avg_train_loss = train_loss / train_batches
                avg_train_mae = train_mae / train_batches
                avg_train_lll = train_lll / train_batches
                
                # Average validation metrics
                avg_val_loss = val_loss / len(val_loader)
                avg_val_mae = val_mae / len(val_loader)
                avg_val_lll = val_lll / len(val_loader)
                
                # Convert to numpy arrays for additional metrics
                val_predictions = np.array(val_predictions)
                val_targets = np.array(val_targets)
                val_sigmas = np.array(val_sigmas)
                val_log_vars = np.array(val_log_vars)
                
                # Calculate RMSE
                val_rmse = np.sqrt(np.mean((val_predictions - val_targets) ** 2))
                
                # Calculate R²
                ss_res = np.sum((val_targets - val_predictions) ** 2)
                ss_tot = np.sum((val_targets - np.mean(val_targets)) ** 2)
                r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else -float('inf')
                
                # Sigma statistics
                avg_sigma = np.mean(val_sigmas)
                min_sigma = np.min(val_sigmas)
                max_sigma = np.max(val_sigmas)
                
                # Log variance statistics for debugging
                avg_log_var = np.mean(val_log_vars)
                
                # Enhanced printing with all metrics and debugging info
                print(f"Epoch {epoch+1}/{epochs}")
                print(f"Train Loss: {avg_train_loss:.6f} | Train LLL: {avg_train_lll:.6f} | Train MAE: {avg_train_mae:.6f}")
                print(f"Val Loss: {avg_val_loss:.6f} | Val Laplace Log Likelihood: {avg_val_lll:.6f} | MAE: {avg_val_mae:.6f} | RMSE: {val_rmse:.6f} | R²: {r2:.6f}")
                print(f"Sigma Stats: Avg={avg_sigma:.2f}, Range=[{min_sigma:.2f}, {max_sigma:.2f}] | Avg Log-Var: {avg_log_var:.4f}")
                
                # Learning rate scheduling (using log-likelihood now)
                scheduler.step(avg_val_lll)
                
                # Early stopping and model saving (using actual LLL - higher is better)
                if avg_val_lll > self.best_val_lll:  # Higher log-likelihood is better
                    self.best_val_lll = avg_val_lll
                    self.best_val_mae = avg_val_mae
                    torch.save(self.model.state_dict(), 'best_corrected_model.pth')
                    print("✅ New best model saved! (Best Actual Log Likelihood)")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
                    
                print("-" * 100)
        
        return self.best_val_mae

print("✅ CORRECTED SimpleTrainer defined with proper sigma learning!")


# Test and Retrain with CORRECTED Trainer
print("🔧 Testing CORRECTED trainer with proper sigma learning...")

# Load the best model if available
if 'model' in globals():
    try:
        if os.path.exists('best_working_model.pth'):
            model.load_state_dict(torch.load('best_working_model.pth'))
            print("✅ Loaded best_working_model.pth as starting point")
        else:
            print("⚠️ Starting with current model weights")
        
        # Create CORRECTED trainer
        corrected_trainer = CorrectedSimpleTrainer(model, DEVICE, lr=5e-5)
        
        # Test forward pass first
        print("🔍 Testing model with corrected trainer...")
        test_batch = next(iter(train_loader))
        images, tabular, targets, _ = test_batch
        images = images.to(DEVICE)
        tabular = tabular.to(DEVICE)
        targets = targets.to(DEVICE)
        
        with torch.no_grad():
            mean_pred, log_var = model(images, tabular)
            sigma = torch.exp(log_var / 2.0)
            sigma_clamped = torch.clamp(sigma, min=2.0, max=200.0)
            lll = corrected_trainer.laplace_log_likelihood(targets, mean_pred, log_var)
        
        print(f"✅ Test Results:")
        print(f"   Raw log_var: {log_var[:5].detach().cpu().numpy()}")
        print(f"   Raw sigma: {sigma[:5].detach().cpu().numpy()}")
        print(f"   Clamped sigma: {sigma_clamped[:5].detach().cpu().numpy()}")
        print(f"   Laplace Log Likelihood: {lll.item():.6f}")
        
        print("\n🚀 Starting CORRECTED training...")
        print("Expected improvements:")
        print("   - Dynamic sigma values (not fixed at 50.0)")
        print("   - Negative LLL values (better uncertainty)")
        print("   - Progress toward LLL = -6.0")
        print("=" * 60)
        
        # Start corrected training
        best_val_mae_corrected = corrected_trainer.train(
            train_loader, 
            val_loader, 
            epochs=25, 
            patience=8
        )
        
        print(f"🎯 CORRECTED training completed!")
        print(f"   Best validation MAE: {best_val_mae_corrected:.6f}")
        print(f"   Best validation LLL: {corrected_trainer.best_val_lll:.6f}")
        
    except Exception as e:
        print(f"❌ Error in corrected training: {e}")
        import traceback
        traceback.print_exc()
        
else:
    print("❌ No model found. Run previous cells first!")










# Cell 8: Initialize Corrected Model and Test
print("🔄 Replacing with CORRECTED working model...")

# Delete old model
if 'model' in globals():
    del model
torch.cuda.empty_cache()

# Initialize corrected model
model = WorkingDenseNetModel(tabular_dim=4).to(DEVICE)
print(f"✅ Corrected model initialized!")
print(f"📊 Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test model with actual batch
try:
    if 'train_loader' in globals():
        test_batch = next(iter(train_loader))
        images, tabular, targets, _ = test_batch
        images = images.to(DEVICE)
        tabular = tabular.to(DEVICE)
        
        print(f"🔍 Input shapes:")
        print(f"   Images: {images.shape}")
        print(f"   Tabular: {tabular.shape}")
        
        with torch.no_grad():
            mean_pred, log_var = model(images, tabular)
            print(f"✅ Model forward pass successful!")
            print(f"   Mean prediction: {mean_pred.shape} - {mean_pred[:3]}")
            print(f"   Log variance: {log_var.shape} - {log_var[:3]}")
    else:
        # Create a dummy test if data loaders aren't available
        print("⚠️ Data loaders not found, creating dummy test...")
        dummy_images = torch.randn(2, 3, 512, 512).to(DEVICE)
        dummy_tabular = torch.randn(2, 4).to(DEVICE)
        
        with torch.no_grad():
            mean_pred, log_var = model(dummy_images, dummy_tabular)
            print(f"✅ Model forward pass successful with dummy data!")
            print(f"   Mean prediction: {mean_pred.shape} - {mean_pred}")
            print(f"   Log variance: {log_var.shape} - {log_var}")
            
except Exception as e:
    print(f"❌ Model test failed: {e}")
    import traceback
    traceback.print_exc()



# Cell 9: Start Training with PICP-aware Trainer
# Cell: Start Training
if 'model' in globals():
    print("🚀 Starting training with CORRECTED model...")
    
    trainer = PicpTrainer(model, DEVICE, lr=1e-4, lambda_picp=0.5)

    best_val_mae = trainer.train(
        train_loader, 
        val_loader, 
        epochs=30,
        patience=8
    )

    print(f"🎯 Training completed! Best validation MAE: {best_val_mae:.6f}")
else:
    print("❌ No model found. Run previous cells first!")



🚀 DenseNet V2 - Enhanced Medical Imaging Model
📱 Device: cuda
🔥 GPU: Tesla P100-PCIE-16GB
💾 Memory: 17.1 GB
✅ Quick recovery system ready!
💡 Usage after kernel restart:
   quick_recovery()  # Restore all auto-saved data
Loaded dataset with shape: (1549, 7)
Calculating linear decay coefficients...


100%|██████████| 176/176 [00:00<00:00, 1492.36it/s]

Processed 176 patients with decay coefficients
💾 Auto-saving critical data...
🐰 Using Kaggle persistent storage
✅ Auto-saved to /kaggle/working/auto_save_data/
   - train_df_backup.csv
   - decay_coefficients_A_backup.pkl
   - tabular_features_TAB_backup.pkl
   - patient_list_P_backup.pkl
   - processing_metadata.json
✅ CORRECTED Working model defined!
🔄 Creating data loaders...
Train patients: 140
Validation patients: 36





Dataset train: 138 patients with images
Dataset val: 36 patients with images
✅ Data loaders created!
Train batches: 103, Val batches: 5
✅ Enhanced trainer with AMP + LLL defined!
✅ CORRECTED SimpleTrainer defined with proper sigma learning!
🔧 Testing CORRECTED trainer with proper sigma learning...
✅ Loaded best_working_model.pth as starting point
🔍 Testing model with corrected trainer...
✅ Test Results:
   Raw log_var: [1.4879323  0.8238814  1.240476   1.0700326  0.96476096]
   Raw sigma: [2.1042647 1.509745  1.8593705 1.7074761 1.619926 ]
   Clamped sigma: [2.1042647 2.        2.        2.        2.       ]
   Laplace Log Likelihood: -3.345077

🚀 Starting CORRECTED training...
Expected improvements:
   - Dynamic sigma values (not fixed at 50.0)
   - Negative LLL values (better uncertainty)
   - Progress toward LLL = -6.0
Error in training batch 0: element 0 of tensors does not require grad and does not have a grad_fn
Error in training batch 1: element 0 of tensors does not require gra

KeyboardInterrupt: 

In [None]:
# Cell 10: Training Execution
print("Starting Progressive Training...")
best_val_mae = trainer.train(
    train_loader, 
    val_loader, 
    epochs=40, 
    patience=10
)




# Auto-Save: Model Training Results
print("💾 Auto-saving model training results...")

try:
    # Save training results
    if 'best_val_mae' in globals():
        training_results = {
            'best_val_mae': float(best_val_mae),
            'training_completed': True,
            'training_timestamp': datetime.now().isoformat(),
            'device_used': str(DEVICE),
            'model_parameters': sum(p.numel() for p in model.parameters()) if 'model' in globals() else 0
        }
        
        with open(f"{auto_save_dir}/training_results_backup.json", 'w') as f:
            json.dump(training_results, f, indent=2)
        
        print(f"✅ Training results saved: MAE = {best_val_mae:.6f}")
    
    # Auto-save model weights if training completed
    if 'model' in globals():
        torch.save(model.state_dict(), f"{auto_save_dir}/model_weights_backup.pth")
        print("✅ Model weights auto-saved")
    
    # Save trainer state if available
    if 'trainer' in globals():
        trainer_state = {
            'lr': trainer.lr,
            'best_val_mae': float(trainer.best_val_mae) if hasattr(trainer, 'best_val_mae') else None,
            'trainer_class': 'SimpleTrainer'
        }
        
        with open(f"{auto_save_dir}/trainer_state_backup.json", 'w') as f:
            json.dump(trainer_state, f, indent=2)
        
        print("✅ Trainer state auto-saved")
    
except Exception as e:
    print(f"⚠️ Training auto-save failed: {e}")



# Cell 11: TTAPredictor for Enhanced Inference
class TTAPredictor:
    def __init__(self, model, num_augmentations=5):
        self.model = model
        self.num_augmentations = num_augmentations
        self.augmentor = MedicalAugmentation(augment=True)
        self.model.eval()
    
    def predict(self, image, tabular):
        # Original prediction
        with torch.no_grad():
            mean_pred, log_var = self.model(image.unsqueeze(0), tabular.unsqueeze(0))
            mean_preds = [mean_pred.item()]
            log_vars = [log_var.item()]
        
        # Augmented predictions
        for _ in range(self.num_augmentations):
            try:
                # Apply augmentation
                aug_img = self.augmentor(image.permute(1, 2, 0).numpy().astype(np.uint8))
                aug_img = aug_img.to(DEVICE)
                
                # Predict
                with torch.no_grad():
                    mean_pred, log_var = self.model(aug_img.unsqueeze(0), tabular.unsqueeze(0))
                    mean_preds.append(mean_pred.item())
                    log_vars.append(log_var.item())
                    
            except Exception as e:
                print(f"Error in TTA: {e}")
                continue
        
        # Ensemble predictions
        mean_ensemble = np.median(mean_preds)
        log_var_ensemble = np.median(log_vars)
        
        # Calculate uncertainty (standard deviation)
        std = np.sqrt(np.exp(log_var_ensemble))
        
        return mean_ensemble, std
# Cell 13: Option 1 - Quick Confidence Head (Recommended - 5-10 minutes)
print("🔧 Adding Confidence Estimation to Existing Model...")

class ConfidenceHead(nn.Module):
    """Simple confidence estimation head"""
    def __init__(self, input_dim=256):
        super(ConfidenceHead, self).__init__()
        self.confidence_head = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()  # Ensures positive confidence values
        )
    
    def forward(self, features):
        return self.confidence_head(features)

class ModelWithConfidence(nn.Module):
    """Wrapper to add confidence to your existing model"""
    def __init__(self, base_model):
        super(ModelWithConfidence, self).__init__()
        self.base_model = base_model
        
        # Freeze the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Add confidence head (takes features before final prediction)
        self.confidence_head = ConfidenceHead(input_dim=256)  # Adjust based on your model
        
    def forward(self, images, tabular):
        # Get features from your trained model (before final prediction)
        with torch.no_grad():
            # Access the fusion layer output from your model
            batch_size = images.size(0)
            img_features = self.base_model.features(images)
            img_features = self.base_model.spatial_attention(img_features)
            img_features = F.adaptive_avg_pool2d(img_features, (1, 1)).view(batch_size, -1)
            tab_features = self.base_model.tabular_processor(tabular)
            
            # Get the fusion features (this is what we'll use for confidence)
            combined_features = torch.cat([img_features, tab_features], dim=1)
            fusion_features = self.base_model.fusion_layer(combined_features)
        
        # Get original FVC prediction
        mean_pred, log_var = self.base_model(images, tabular)
        
        # Predict confidence using the fusion features
        confidence = self.confidence_head(fusion_features.detach())
        
        return mean_pred, confidence.squeeze()

print("✅ ConfidenceHead and ModelWithConfidence classes defined!")

# Cell 14: Quick Confidence Trainer
class ConfidenceTrainer:
    """Quick trainer for confidence head only"""
    def __init__(self, model, device):
        self.model = model
        self.device = device
        
    def confidence_loss(self, fvc_pred, confidence, targets):
        """Loss that encourages reasonable confidence intervals"""
        mse_loss = F.mse_loss(fvc_pred, targets)
        
        # Penalty for overconfident predictions (small confidence with large error)
        errors = torch.abs(fvc_pred - targets)
        confidence_penalty = torch.mean(errors / (confidence + 1e-6))
        
        # Penalty for underconfident predictions (very large confidence)
        overconfidence_penalty = torch.mean(confidence) * 0.1
        
        return mse_loss + confidence_penalty + overconfidence_penalty
    
    def train_confidence(self, train_loader, val_loader, epochs=10):
        """Train only the confidence head - FAST!"""
        # Only train the confidence head
        optimizer = torch.optim.Adam(self.model.confidence_head.parameters(), lr=1e-3)
        
        print(f"🚀 Training confidence head for {epochs} epochs...")
        
        for epoch in range(epochs):
            self.model.train()
            train_loss = 0
            
            for batch_idx, (images, tabular, targets, _) in enumerate(train_loader):
                try:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    
                    optimizer.zero_grad()
                    fvc_pred, confidence = self.model(images, tabular)
                    loss = self.confidence_loss(fvc_pred, confidence, targets)
                    loss.backward()
                    optimizer.step()
                    
                    train_loss += loss.item()
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue
            
            # Validation
            self.model.eval()
            val_loss = 0
            val_confidences = []
            
            with torch.no_grad():
                for images, tabular, targets, _ in val_loader:
                    try:
                        images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                        fvc_pred, confidence = self.model(images, tabular)
                        loss = self.confidence_loss(fvc_pred, confidence, targets)
                        val_loss += loss.item()
                        val_confidences.extend(confidence.cpu().numpy())
                    except Exception as e:
                        continue
            
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            avg_confidence = np.mean(val_confidences)
            
            print(f"Epoch {epoch+1}/{epochs}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Avg Confidence: {avg_confidence:.2f}")
        
        torch.save(self.model.state_dict(), 'model_with_confidence.pth')
        print("✅ Confidence model saved to 'model_with_confidence.pth'!")
        
        return avg_val_loss

print("✅ ConfidenceTrainer class defined!")



# Cell 15: Option 2 - Quantile Regression Head (More Advanced)
print("📊 Defining Quantile Regression approach...")

class QuantileRegressionHead(nn.Module):
    """Quantile regression for confidence intervals"""
    def __init__(self, input_dim=256):
        super(QuantileRegressionHead, self).__init__()
        self.lower_quantile = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.upper_quantile = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, features):
        lower = self.lower_quantile(features)
        upper = self.upper_quantile(features)
        return lower.squeeze(), upper.squeeze()

def quantile_loss(predictions, targets, quantile):
    """Quantile regression loss"""
    errors = targets - predictions
    return torch.mean(torch.max(quantile * errors, (quantile - 1) * errors))

class ModelWithQuantiles(nn.Module):
    """Model with quantile regression for confidence intervals"""
    def __init__(self, base_model):
        super(ModelWithQuantiles, self).__init__()
        self.base_model = base_model
        
        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False
            
        self.quantile_head = QuantileRegressionHead()
    
    def forward(self, images, tabular):
        # Get fusion features (same as confidence model)
        batch_size = images.size(0)
        with torch.no_grad():
            img_features = self.base_model.features(images)
            img_features = self.base_model.spatial_attention(img_features)
            img_features = F.adaptive_avg_pool2d(img_features, (1, 1)).view(batch_size, -1)
            tab_features = self.base_model.tabular_processor(tabular)
            combined_features = torch.cat([img_features, tab_features], dim=1)
            fusion_features = self.base_model.fusion_layer(combined_features)
        
        # Original prediction
        mean_pred, _ = self.base_model(images, tabular)
        
        # Quantile predictions
        lower_pred, upper_pred = self.quantile_head(fusion_features.detach())
        
        # Calculate confidence and final FVC
        confidence = upper_pred - lower_pred
        final_fvc = (lower_pred + upper_pred) / 2
        
        return final_fvc, confidence, lower_pred, upper_pred

class QuantileTrainer:
    """Trainer for quantile regression model"""
    def __init__(self, model, device):
        self.model = model
        self.device = device
    
    def train_quantiles(self, train_loader, val_loader, epochs=10):
        optimizer = torch.optim.Adam(self.model.quantile_head.parameters(), lr=1e-3)
        
        print(f"🚀 Training quantile regression for {epochs} epochs...")
        
        best_val_loss = float('inf')
        
        for epoch in range(epochs):
            self.model.train()
            train_loss = 0
            
            for batch_idx, (images, tabular, targets, _) in enumerate(train_loader):
                try:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    
                    optimizer.zero_grad()
                    final_fvc, confidence, lower_pred, upper_pred = self.model(images, tabular)
                    
                    # Quantile losses
                    lower_loss = quantile_loss(lower_pred, targets, 0.2)  # 20th percentile
                    upper_loss = quantile_loss(upper_pred, targets, 0.8)  # 80th percentile
                    
                    # Combined loss
                    loss = lower_loss + upper_loss + F.mse_loss(final_fvc, targets)
                    
                    loss.backward()
                    optimizer.step()
                    
                    train_loss += loss.item()
                except Exception as e:
                    continue
            
            # Validation
            self.model.eval()
            val_loss = 0
            with torch.no_grad():
                for images, tabular, targets, _ in val_loader:
                    try:
                        images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                        final_fvc, confidence, lower_pred, upper_pred = self.model(images, tabular)
                        
                        lower_loss = quantile_loss(lower_pred, targets, 0.2)
                        upper_loss = quantile_loss(upper_pred, targets, 0.8)
                        loss = lower_loss + upper_loss + F.mse_loss(final_fvc, targets)
                        val_loss += loss.item()
                    except Exception as e:
                        continue
            
            avg_val_loss = val_loss / len(val_loader)
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
            
            print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {avg_val_loss:.4f}")
        
        torch.save(self.model.state_dict(), 'quantile_model.pth')
        print("✅ Quantile model saved!")
        
        return best_val_loss

print("✅ Quantile Regression classes defined!")


# Cell 16: Execute Confidence Training
print("🚀 Setting up and training confidence estimation...")

# Check if we have a trained model
if 'model' in globals() and 'train_loader' in globals():
    try:
        # Load the best model weights if available
        import os
        if os.path.exists('best_working_model.pth'):
            model.load_state_dict(torch.load('best_working_model.pth'))
            print("✅ Loaded best_working_model.pth")
        elif os.path.exists('best_densenet_model.pth'):
            model.load_state_dict(torch.load('best_densenet_model.pth'))
            print("✅ Loaded best_densenet_model.pth")
        else:
            print("⚠️ No saved model found, using current model weights")
        
        model.eval()
        
        # Create model with confidence (Option 1 - Recommended)
        print("🔧 Creating ModelWithConfidence...")
        confidence_model = ModelWithConfidence(model).to(DEVICE)
        
        # Train confidence head (this is FAST - only 10 epochs!)
        print("⚡ Training confidence head (Option 1)...")
        conf_trainer = ConfidenceTrainer(confidence_model, DEVICE)
        conf_val_loss = conf_trainer.train_confidence(train_loader, val_loader, epochs=10)
        
        print("✅ Confidence training completed!")
        
        # Optional: Also create quantile model for comparison
        print("📊 Creating quantile model (Option 2)...")
        quantile_model = ModelWithQuantiles(model).to(DEVICE)
        
        print("⚡ Training quantile regression...")
        quant_trainer = QuantileTrainer(quantile_model, DEVICE)
        quant_val_loss = quant_trainer.train_quantiles(train_loader, val_loader, epochs=8)
        
        # Compare models and save the better one
        print(f"\n🏆 Model Comparison:")
        print(f"   Confidence Model Val Loss: {conf_val_loss:.6f}")
        print(f"   Quantile Model Val Loss: {quant_val_loss:.6f}")
        
        if conf_val_loss <= quant_val_loss:
            print("🥇 Confidence Model wins! Using Option 1")
            torch.save(confidence_model.state_dict(), 'best_confidence_model.pth')
            best_model = confidence_model
            best_model_type = "confidence"
        else:
            print("🥇 Quantile Model wins! Using Option 2")
            torch.save(quantile_model.state_dict(), 'best_confidence_model.pth')
            best_model = quantile_model
            best_model_type = "quantile"
        
        print(f"✅ Best model saved as 'best_confidence_model.pth' (type: {best_model_type})")
        print("✅ Both confidence models ready!")
        
    except Exception as e:
        print(f"❌ Error in confidence training: {e}")
        import traceback
        traceback.print_exc()
        
else:
    print("⚠️ Model or data loaders not available. Run previous cells first!")
    print("Available variables:", [var for var in globals().keys() if not var.startswith('_')])



# Cell 17: Quick Submission Generator with Confidence
def create_submission_with_confidence(model, test_dir, output_file='enhanced_submission.csv'):
    """Create submission with confidence intervals"""
    print(f"📝 Creating submission with confidence intervals...")
    
    # Load test data
    try:
        test_df = pd.read_csv('../input/osic-pulmonary-fibrosis-progression/test.csv')
        print(f"✅ Loaded test data: {len(test_df)} samples")
    except:
        print("⚠️ Test data not found, creating sample submission format")
        # Create sample format for demonstration
        test_df = pd.DataFrame({
            'Patient': ['ID00000000000000000000000'] * 5,
            'Weeks': [-12, -6, 0, 6, 12],
            'FVC': [2000, 1950, 1900, 1850, 1800],
            'Age': [65] * 5,
            'Sex': ['Male'] * 5,
            'SmokingStatus': ['Ex-smoker'] * 5
        })
    
    submissions = []
    model.eval()
    
    # Create augmentor for test time augmentation
    test_augmentor = MedicalAugmentation(augment=False)
    
    print("🔄 Processing test patients...")
    
    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing"):
        patient_id = row['Patient']
        weeks = row['Weeks']
        
        try:
            # Load patient image
            patient_dir = Path(test_dir) / patient_id
            
            # Default prediction values
            fvc_pred = 2000.0  # Default FVC
            confidence_val = 200.0  # Default confidence
            
            if patient_dir.exists():
                image_files = list(patient_dir.glob('*.dcm'))
                if image_files:
                    # Load and preprocess image
                    img = load_and_preprocess_dicom(image_files[0])
                    img_tensor = test_augmentor(img).unsqueeze(0).to(DEVICE)
                    
                    # Prepare tabular features
                    tab_features = get_tab_features(row)
                    tab_tensor = torch.tensor(tab_features).float().unsqueeze(0).to(DEVICE)
                    
                    # Predict with confidence
                    with torch.no_grad():
                        if hasattr(model, 'confidence_head'):  # Option 1
                            fvc_pred, confidence = model(img_tensor, tab_tensor)
                            fvc_pred = fvc_pred.item()
                            confidence_val = max(confidence.item(), 70)  # Minimum confidence
                        elif hasattr(model, 'quantile_head'):  # Option 2
                            final_fvc, confidence, lower_pred, upper_pred = model(img_tensor, tab_tensor)
                            fvc_pred = final_fvc.item()
                            confidence_val = max(confidence.item(), 70)
                        else:
                            # Fallback to base model
                            mean_pred, log_var = model(img_tensor, tab_tensor)
                            fvc_pred = mean_pred.item()
                            confidence_val = max(torch.exp(log_var/2).item() * 100, 70)
            
            # Create submission rows for required weeks
            for week in range(-12, 134):  # Standard competition range
                patient_week = f"{patient_id}_{week}"
                
                # Adjust prediction based on time progression
                if patient_id in A:
                    time_adjusted_fvc = fvc_pred + (week - weeks) * A[patient_id]
                else:
                    time_adjusted_fvc = fvc_pred + (week - weeks) * (-7)  # Default decay
                
                # Ensure reasonable bounds
                time_adjusted_fvc = max(time_adjusted_fvc, 800)  # Minimum FVC
                time_adjusted_fvc = min(time_adjusted_fvc, 6000)  # Maximum FVC
                
                submissions.append({
                    'Patient_Week': patient_week,
                    'FVC': time_adjusted_fvc,
                    'Confidence': confidence_val
                })
                
        except Exception as e:
            print(f"⚠️ Error processing patient {patient_id}: {e}")
            # Use default values for this patient
            for week in range(-12, 134):
                patient_week = f"{patient_id}_{week}"
                submissions.append({
                    'Patient_Week': patient_week,
                    'FVC': 2000.0,
                    'Confidence': 200.0
                })
    
    # Create submission dataframe
    submission_df = pd.DataFrame(submissions)
    submission_df.to_csv(output_file, index=False)
    
    print(f"✅ Submission saved to {output_file}")
    print(f"📊 Submission stats:")
    print(f"   Total rows: {len(submission_df)}")
    print(f"   FVC range: {submission_df['FVC'].min():.1f} - {submission_df['FVC'].max():.1f}")
    print(f"   Confidence range: {submission_df['Confidence'].min():.1f} - {submission_df['Confidence'].max():.1f}")
    
    return submission_df

# Helper function for DICOM loading (simplified version)
def load_and_preprocess_dicom(path):
    """Simplified DICOM loading for submission"""
    try:
        dcm = pydicom.dcmread(str(path))
        img = dcm.pixel_array.astype(np.float32)
        
        if len(img.shape) == 3:
            img = img[img.shape[0]//2]
        
        img = cv2.resize(img, (512, 512))
        
        # Normalize to 0-255
        img_min, img_max = img.min(), img.max()
        if img_max > img_min:
            img = (img - img_min) / (img_max - img_min) * 255
        else:
            img = np.zeros_like(img)
        
        # Convert to 3-channel
        img = np.stack([img, img, img], axis=2).astype(np.uint8)
        return img
        
    except Exception as e:
        # Return black image as fallback
        return np.zeros((512, 512, 3), dtype=np.uint8)

print("✅ Submission generator functions defined!")


# Cell 18: Generate Final Submission
print("🎯 Generating final submission with confidence intervals...")

# Generate submission using the best available model
try:
    # First try to load the best confidence model
    if 'best_model' in globals() and 'best_model_type' in globals():
        print(f"✅ Using best model: {best_model_type}")
        final_submission = create_submission_with_confidence(
            best_model, 
            TEST_DIR, 
            f'enhanced_densenet_best_{best_model_type}_submission.csv'
        )
        chosen_model = f"Best {best_model_type.title()} Model"
        
    elif 'confidence_model' in globals():
        print("✅ Using confidence model (Option 1)")
        final_submission = create_submission_with_confidence(
            confidence_model, 
            TEST_DIR, 
            'enhanced_densenet_confidence_submission.csv'
        )
        chosen_model = "Confidence Model"
        
    elif 'quantile_model' in globals():
        print("✅ Using quantile model (Option 2)")
        final_submission = create_submission_with_confidence(
            quantile_model, 
            TEST_DIR, 
            'enhanced_densenet_quantile_submission.csv'
        )
        chosen_model = "Quantile Model"
        
    elif 'model' in globals():
        print("✅ Using base model with uncertainty")
        final_submission = create_submission_with_confidence(
            model, 
            TEST_DIR, 
            'enhanced_densenet_base_submission.csv'
        )
        chosen_model = "Base Model"
        
    else:
        print("❌ No model available for submission")
        chosen_model = "None"
        final_submission = None
    
    if final_submission is not None:
        print(f"\n🎉 SUCCESS! Final submission created with {chosen_model}")
        print(f"📁 File ready for competition upload!")
        
        # Display final statistics
        print(f"\n📊 Final Submission Statistics:")
        print(f"   Model used: {chosen_model}")
        print(f"   Total predictions: {len(final_submission)}")
        print(f"   Unique patients: {len(set([p.split('_')[0] for p in final_submission['Patient_Week']]))}")
        print(f"   FVC predictions range: {final_submission['FVC'].min():.1f} - {final_submission['FVC'].max():.1f}")
        print(f"   Confidence range: {final_submission['Confidence'].min():.1f} - {final_submission['Confidence'].max():.1f}")
        print(f"   Average confidence: {final_submission['Confidence'].mean():.1f}")
        
        # Show sample predictions
        print(f"\n📋 Sample predictions:")
        print(final_submission.head(10))
        
        # Save additional info
        submission_info = {
            'model_type': chosen_model,
            'total_predictions': len(final_submission),
            'fvc_range': [float(final_submission['FVC'].min()), float(final_submission['FVC'].max())],
            'confidence_range': [float(final_submission['Confidence'].min()), float(final_submission['Confidence'].max())],
            'avg_confidence': float(final_submission['Confidence'].mean())
        }
        
        with open('submission_info.json', 'w') as f:
            json.dump(submission_info, f, indent=2)
        
        print(f"\n💾 Additional files saved:")
        print(f"   - submission_info.json (metadata)")
        print(f"   - enhanced_densenet_*_submission.csv (main submission)")
        
except Exception as e:
    print(f"❌ Error generating submission: {e}")
    import traceback
    traceback.print_exc()

