In [None]:
# Install DICOM compression support packages
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])
        return True
    except subprocess.CalledProcessError:
        return False

print("Installing DICOM compression support...")
packages = [
    'pylibjpeg',
    'pylibjpeg-libjpeg', 
    'gdcm'
]

for pkg in packages:
    if install_package(pkg):
        print(f"✅ {pkg} installed")
    else:
        print(f"⚠️ {pkg} installation failed (may already be installed)")

**V2 Testing**

In [None]:
import os
import cv2
import pydicom
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
import random
from tqdm import tqdm 
from datetime import timedelta, datetime
from pathlib import Path
import json
import warnings
import pickle
import glob
from math import sqrt, log

# Image processing
from skimage import measure, morphology, segmentation
from skimage.transform import resize
from scipy.ndimage import binary_dilation, binary_erosion
from skimage.measure import label, regionprops
from sklearn.cluster import KMeans
from skimage.segmentation import clear_border

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler

# Albumentations for medical augmentations
import albumentations as albu
from albumentations.pytorch import ToTensorV2

# Model selection
from sklearn.model_selection import train_test_split

warnings.filterwarnings('ignore')

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(42)

# Configuration
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Pulmonary Fibrosis Progression Analysis - FIXED VERSION")
print(f"Device: {DEVICE}")


In [None]:
# =============================================================================
# PART 1: DATA LOADING AND EDA
# =============================================================================

# Load datasets
train_df = pd.read_csv(DATA_DIR / 'train.csv')
try:
    test_df = pd.read_csv(DATA_DIR / 'test.csv')
    print(f'Train: {train_df.shape[0]} rows, Test: {test_df.shape[0]} rows')
except:
    print(f'Train: {train_df.shape[0]} rows, Test: file not found')
    test_df = None

print("\nTrain data sample:")
print(train_df.head())

print("\nTrain data statistics:")
print(train_df.describe())

# Basic EDA
print(f'\nUnique patients in training data: {train_df["Patient"].nunique()}')
print(f'Total observations: {len(train_df)}')
print(f'Average observations per patient: {len(train_df)/train_df["Patient"].nunique():.2f}')

# Check for missing values
print(f'\nMissing values:')
print(train_df.isnull().sum())

# Basic visualizations
plt.figure(figsize=(15, 10))

# Correlation matrix
plt.subplot(2, 3, 1)
numeric_cols = train_df.select_dtypes(include=[np.number])
sns.heatmap(numeric_cols.corr(), annot=True, cmap='coolwarm', center=0)
plt.title('Correlation Matrix')

# FVC distribution
plt.subplot(2, 3, 2)
plt.hist(train_df['FVC'], bins=30, alpha=0.7)
plt.title('FVC Distribution')
plt.xlabel('FVC')

# Age vs FVC
plt.subplot(2, 3, 3)
plt.scatter(train_df['Age'], train_df['FVC'], alpha=0.6)
plt.xlabel('Age')
plt.ylabel('FVC')
plt.title('Age vs FVC')

# Smoking status distribution
plt.subplot(2, 3, 4)
train_df['SmokingStatus'].value_counts().plot(kind='bar')
plt.title('Smoking Status Distribution')
plt.xticks(rotation=45)

# Sex distribution
plt.subplot(2, 3, 5)
train_df['Sex'].value_counts().plot(kind='bar')
plt.title('Sex Distribution')

# Weeks distribution
plt.subplot(2, 3, 6)
plt.hist(train_df['Weeks'], bins=30, alpha=0.7)
plt.title('Weeks Distribution')
plt.xlabel('Weeks')

plt.tight_layout()
plt.show()

# Sample patient trajectories
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
sample_patients = train_df['Patient'].unique()[:3]

for i, patient in enumerate(sample_patients):
    patient_data = train_df[train_df['Patient'] == patient]
    axes[i].plot(patient_data['Weeks'], patient_data['FVC'], 'o-')
    axes[i].set_title(f"Patient: {patient[:10]}...")
    axes[i].set_xlabel('Weeks')
    axes[i].set_ylabel('FVC')
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# =============================================================================
# PART 2: FEATURE ENGINEERING
# =============================================================================

print("Calculating linear decay coefficients...")

A = {}  # Decay coefficients
TAB = {}  # Tabular features
P = []  # Patient list

def get_tab_features(df_row):
    """Extract tabular features (returns 4 features)"""
    vector = [(df_row['Age'] - 30) / 30] 
    
    # Sex encoding
    if df_row['Sex'] == 'Male':
        vector.append(0)
    else:
        vector.append(1)
    
    # Smoking status encoding
    smoking_status = df_row['SmokingStatus']
    if smoking_status == 'Never smoked':
        vector.extend([0, 0])
    elif smoking_status == 'Ex-smoker':
        vector.extend([1, 1])
    elif smoking_status == 'Currently smokes':
        vector.extend([0, 1])
    else:
        vector.extend([1, 0])
    return np.array(vector, dtype=np.float32)

# Calculate slopes for each patient
decay_rates = []
for patient in tqdm(train_df['Patient'].unique(), desc="Processing patients"):
    sub = train_df[train_df['Patient'] == patient].copy()
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    
    if len(weeks) > 1:
        # Fit linear regression: FVC = a * weeks + b
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, b = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = float(a)  # Ensure float
            TAB[patient] = get_tab_features(sub.iloc[0])
            P.append(patient)
            decay_rates.append(a)
        except Exception as e:
            print(f"Linear fit failed for patient {patient}: {e}")
            # Fallback calculation
            a = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0]) if len(weeks) > 1 else 0.0
            A[patient] = float(a)
            TAB[patient] = get_tab_features(sub.iloc[0])
            P.append(patient)
            decay_rates.append(a)
    else:
        # Single observation
        A[patient] = 0.0
        TAB[patient] = get_tab_features(sub.iloc[0])
        P.append(patient)
        decay_rates.append(0.0)

print(f"Processed {len(P)} patients with decay coefficients")

# Analyze decay rates
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(decay_rates, bins=30, alpha=0.7)
plt.title("Distribution of FVC Decay Rates")
plt.xlabel("Decay Rate (FVC/Week)")
plt.axvline(0, color='red', linestyle='--', alpha=0.7)

plt.subplot(1, 2, 2)
plt.boxplot(decay_rates)
plt.title("Decay Rate Box Plot")
plt.ylabel("Decay Rate (FVC/Week)")

plt.tight_layout()
plt.show()

print(f"Decay rate statistics:")
print(f"  Mean: {np.mean(decay_rates):.3f}")
print(f"  Std: {np.std(decay_rates):.3f}")
print(f"  Range: [{np.min(decay_rates):.3f}, {np.max(decay_rates):.3f}]")

In [None]:
# =============================================================================
# PART 3: ENHANCED DICOM IMAGE PROCESSING WITH COMPRESSION HANDLING
# =============================================================================

# Install required packages for DICOM compression handling
import sys
import subprocess

def install_dicom_dependencies():
    """Install required dependencies for DICOM compression handling"""
    packages = [
        'pydicom[gdcm]',  # GDCM support for DICOM compression
        'pylibjpeg',      # JPEG support
        'pylibjpeg-libjpeg',  # JPEG Lossless support
        'pillow-simd',    # Faster image processing
        'opencv-python'   # OpenCV for image operations
    ]
    
    for package in packages:
        try:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])
            print(f"✅ Installed: {package}")
        except subprocess.CalledProcessError:
            print(f"⚠️ Failed to install: {package} (may already be installed)")

# Try to install dependencies
print("Installing DICOM compression dependencies...")
install_dicom_dependencies()

# Import with compression support
try:
    import gdcm  # GDCM for DICOM compression
    print("✅ GDCM loaded successfully")
    GDCM_AVAILABLE = True
except ImportError:
    print("⚠️ GDCM not available - using fallback methods")
    GDCM_AVAILABLE = False

try:
    import pylibjpeg  # JPEG support
    print("✅ PyLibJPEG loaded successfully")
    PYLIBJPEG_AVAILABLE = True
except ImportError:
    print("⚠️ PyLibJPEG not available - using fallback methods")
    PYLIBJPEG_AVAILABLE = False

def load_and_preprocess_dicom_enhanced(path, verbose=False):
    """Enhanced DICOM loading with compression handling and multiple fallback methods"""
    try:
        # Method 1: Try standard pydicom with auto decompression
        try:
            dcm = pydicom.dcmread(str(path), force=True)
            
            # Check if decompression is needed and handle "already uncompressed" case
            if hasattr(dcm, 'decompress'):
                try:
                    dcm.decompress()
                except RuntimeError as e:
                    # "The dataset is already uncompressed" is actually a success case
                    if "already uncompressed" not in str(e):
                        raise e
                    # Continue - the file is fine
            
            if hasattr(dcm, 'pixel_array'):
                img = dcm.pixel_array.astype(np.float32)
                return process_dicom_image(img, path, verbose=verbose)
        except Exception as e:
            if verbose:
                print(f"Method 1 failed for {path}: {e}")
        
        # Method 2: Try with specific transfer syntax handling
        try:
            dcm = pydicom.dcmread(str(path), force=True, defer_size="1 KB")
            
            # Handle different transfer syntaxes
            if hasattr(dcm, 'file_meta') and hasattr(dcm.file_meta, 'TransferSyntaxUID'):
                ts = dcm.file_meta.TransferSyntaxUID
                if verbose:
                    print(f"Transfer Syntax: {ts}")
                
                # Force pixel array extraction
                img = dcm.pixel_array.astype(np.float32)
                return process_dicom_image(img, path, verbose=verbose)
        except Exception as e:
            if verbose:
                print(f"Method 2 failed for {path}: {e}")
        
        # Method 3: Try raw pixel data extraction
        try:
            dcm = pydicom.dcmread(str(path), force=True, stop_before_pixels=False)
            
            if hasattr(dcm, 'PixelData') and dcm.PixelData is not None:
                # Try to extract raw pixel data
                rows = dcm.Rows
                cols = dcm.Columns
                
                # Attempt different pixel data interpretations
                pixel_bytes = dcm.PixelData
                
                # Try as uint16 (most common for CT)
                if len(pixel_bytes) >= rows * cols * 2:
                    img = np.frombuffer(pixel_bytes[:rows*cols*2], dtype=np.uint16)
                    img = img.reshape((rows, cols)).astype(np.float32)
                    return process_dicom_image(img, path, verbose=verbose)
                
                # Try as uint8
                elif len(pixel_bytes) >= rows * cols:
                    img = np.frombuffer(pixel_bytes[:rows*cols], dtype=np.uint8)
                    img = img.reshape((rows, cols)).astype(np.float32)
                    return process_dicom_image(img, path, verbose=verbose)
        except Exception as e:
            if verbose:
                print(f"Method 3 failed for {path}: {e}")
        
        # Method 4: Try with PIL/OpenCV as fallback
        try:
            from PIL import Image
            
            # Try to open as regular image file (some DICOM can be read this way)
            with Image.open(path) as pil_img:
                img = np.array(pil_img.convert('L')).astype(np.float32)
                return process_dicom_image(img, path, verbose=verbose)
        except Exception as e:
            if verbose:
                print(f"Method 4 failed for {path}: {e}")
        
        if verbose:
            print(f"All methods failed for {path} - returning dummy image")
        return create_dummy_image()
        
    except Exception as e:
        if verbose:
            print(f"Critical error loading {path}: {e}")
        return create_dummy_image()

def process_dicom_image(img, path, verbose=False):
    """Process DICOM image array into standardized format"""
    try:
        # Handle different image dimensions
        if len(img.shape) == 3:
            img = img[img.shape[0]//2]  # Take middle slice if 3D
        elif len(img.shape) == 4:
            img = img[0, img.shape[1]//2]  # Handle 4D case
        
        # Ensure 2D
        if len(img.shape) != 2:
            if verbose:
                print(f"Unexpected image shape {img.shape} for {path}")
            return create_dummy_image()
        
        # Resize to target size
        img = cv2.resize(img, (512, 512))
        
        # Handle Hounsfield Units (HU) for CT scans
        # Typical HU range: -1000 (air) to +3000 (bone)
        # Clip extreme values
        img = np.clip(img, -1000, 3000)
        
        # Normalize to 0-255 with better handling
        img_min, img_max = img.min(), img.max()
        if img_max > img_min:
            img = (img - img_min) / (img_max - img_min) * 255
        else:
            # Handle constant images
            img = np.full_like(img, 128)  # Gray instead of black
        
        img = np.clip(img, 0, 255)
        
        # Convert to 3-channel RGB
        img = np.stack([img, img, img], axis=2).astype(np.uint8)
        
        return img
        
    except Exception as e:
        if verbose:
            print(f"Error processing image from {path}: {e}")
        return create_dummy_image()

def create_dummy_image():
    """Create a dummy image when DICOM loading fails"""
    # Create a gray image with some noise to simulate medical imaging
    dummy = np.random.normal(128, 20, (512, 512))
    dummy = np.clip(dummy, 0, 255).astype(np.uint8)
    dummy = np.stack([dummy, dummy, dummy], axis=2)
    return dummy

# Legacy functions for backward compatibility
def load_scan_fixed(path):
    """Load DICOM slices from a folder - ENHANCED with compression handling"""
    try:
        files = [f for f in os.listdir(path) if f.endswith('.dcm')]
        slices = []
        
        for f in files:
            try:
                dcm_path = os.path.join(path, f)
                dcm = pydicom.dcmread(dcm_path, force=True)
                
                # Try to decompress if compressed
                if hasattr(dcm, 'decompress'):
                    try:
                        dcm.decompress()
                    except RuntimeError as e:
                        # "Already uncompressed" is fine
                        if "already uncompressed" not in str(e):
                            raise e
                
                slices.append(dcm)
            except Exception as e:
                print(f"Warning: Could not read {f}: {e}")
                continue
        
        if not slices:
            return None
            
        slices.sort(key=lambda x: int(x.InstanceNumber) if hasattr(x, 'InstanceNumber') else 0)
        
        # Set slice thickness
        try:
            if len(slices) > 1:
                slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
            else:
                slice_thickness = 1.0
        except:
            try:
                slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
            except:
                slice_thickness = 1.0  # Default
        
        for s in slices:
            s.SliceThickness = slice_thickness
            
        return slices
    except Exception as e:
        print(f"Error loading scan from {path}: {e}")
        return None

def get_pixels_hu_fixed(scans):
    """Convert DICOM to Hounsfield Units (HU) - Enhanced with compression handling"""
    try:
        valid_slices = []
        
        for s in scans:
            try:
                # Try to access pixel array
                if hasattr(s, 'pixel_array'):
                    pixel_array = s.pixel_array
                    valid_slices.append(s)
                elif hasattr(s, 'PixelData'):
                    # Try raw pixel data extraction
                    rows, cols = s.Rows, s.Columns
                    pixel_bytes = s.PixelData
                    
                    if len(pixel_bytes) >= rows * cols * 2:
                        pixel_array = np.frombuffer(pixel_bytes[:rows*cols*2], dtype=np.uint16)
                        pixel_array = pixel_array.reshape((rows, cols))
                        s.pixel_array = pixel_array  # Add it for consistency
                        valid_slices.append(s)
            except Exception as e:
                print(f"Could not extract pixels from slice: {e}")
                continue
        
        if not valid_slices:
            return None
            
        image = np.stack([s.pixel_array for s in valid_slices])
        image = image.astype(np.int16)
        
        # Set out-of-scan pixels to 0
        image[image == -2000] = 0
        
        # Convert to HU
        first_slice = valid_slices[0]
        intercept = getattr(first_slice, 'RescaleIntercept', 0)
        slope = getattr(first_slice, 'RescaleSlope', 1)
        
        if slope != 1:
            image = slope * image.astype(np.float64)
            image = image.astype(np.int16)
            
        image += np.int16(intercept)
        
        return np.array(image, dtype=np.int16)
    except Exception as e:
        print(f"Error converting to HU: {e}")
        return None

# Update the main loading function
def load_and_preprocess_dicom_fixed(path):
    """Main DICOM loading function - now uses enhanced loading with reduced verbosity"""
    return load_and_preprocess_dicom_enhanced(path, verbose=False)  # Turn off verbose by default

# Test DICOM loading on a sample
if TRAIN_DIR.exists():
    print("\nTesting enhanced DICOM loading...")
    sample_patient = P[0] if P else None
    if sample_patient:
        sample_path = TRAIN_DIR / sample_patient
        if sample_path.exists():
            dcm_files = list(sample_path.glob('*.dcm'))
            if dcm_files:
                print(f"Testing with {len(dcm_files)} DICOM files from {sample_patient}")
                
                # Test first file with verbose output
                test_img = load_and_preprocess_dicom_enhanced(dcm_files[0], verbose=True)
                print(f"Loaded image shape: {test_img.shape}, dtype: {test_img.dtype}")
                print(f"Pixel value range: [{test_img.min()}, {test_img.max()}]")
                
                # Test a few more files without verbose output
                success_count = 0
                for i, dcm_file in enumerate(dcm_files[:5]):
                    try:
                        img = load_and_preprocess_dicom_fixed(dcm_file)
                        if img.shape == (512, 512, 3):
                            success_count += 1
                    except Exception as e:
                        print(f"❌ File {i+1}: {dcm_file.name} - Error: {e}")
                
                print(f"✅ Successfully loaded {success_count}/{min(5, len(dcm_files))} test files")
            else:
                print("No DICOM files found for sample patient")
        else:
            print(f"Sample patient directory not found: {sample_path}")
    else:
        print("No patients available for testing")
else:
    print("Training directory not found - using dummy data for testing")
    dummy_img = create_dummy_image()
    print(f"Created dummy image: {dummy_img.shape}, dtype: {dummy_img.dtype}")

In [None]:
# =============================================================================
# PART 4: MEDICAL AUGMENTATIONS
# =============================================================================

class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                # Geometric augmentations - conservative for medical images
                albu.Rotate(limit=10, p=0.5),  # Reduced rotation
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10, p=0.5),
                
                # Medical-specific augmentations
                albu.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                albu.GaussNoise(var_limit=(5.0, 25.0), p=0.3),
                albu.RandomGamma(gamma_limit=(90, 110), p=0.3),
                
                # Light distortions
                albu.GridDistortion(num_steps=3, distort_limit=0.1, p=0.2),
                albu.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=0.2),
                
                # Dropout
                albu.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.2),
                
                # Normalization
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
    
    def __call__(self, image):
        return self.transform(image=image)['image']

In [None]:
# =============================================================================
# PART 5: FIXED DATASET CLASS
# =============================================================================

class OSICDenseNetDataset(Dataset):
    """FIXED dataset with proper error handling and debugging"""
    
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True, verbose=True):
        self.patients = patients
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augment = augment
        self.augmentor = MedicalAugmentation(augment=augment)
        
        # Comprehensive patient validation
        self.patient_images = {}
        missing_dirs = []
        no_images = []
        valid_patients = []
        
        for patient in self.patients:
            patient_dir = self.data_dir / patient
            
            if not patient_dir.exists():
                missing_dirs.append(patient)
                continue
                
            image_files = [f for f in patient_dir.iterdir() if f.suffix.lower() == '.dcm']
            
            if not image_files:
                no_images.append(patient)
                continue
                
            # Test load first image to ensure it's valid
            try:
                test_img = load_and_preprocess_dicom_fixed(image_files[0])
                if test_img is not None and test_img.shape == (512, 512, 3):
                    self.patient_images[patient] = image_files
                    valid_patients.append(patient)
                else:
                    no_images.append(patient)
            except Exception as e:
                if verbose:
                    print(f"Failed to load test image for {patient}: {e}")
                no_images.append(patient)
        
        self.valid_patients = valid_patients
        
        if verbose:
            print(f"\nDataset {split} validation:")
            print(f"  Input patients: {len(self.patients)}")
            print(f"  Missing directories: {len(missing_dirs)}")
            print(f"  No valid images: {len(no_images)}")
            print(f"  Valid patients: {len(self.valid_patients)}")
            
            if missing_dirs and len(missing_dirs) <= 5:
                print(f"  Missing dirs: {missing_dirs}")
            elif missing_dirs:
                print(f"  Missing dirs: {missing_dirs[:3]}... and {len(missing_dirs)-3} more")
    
    def __len__(self):
        if self.split == 'train':
            return len(self.valid_patients) * 4  # Reduced multiplier
        else:
            return len(self.valid_patients)
    
    def __getitem__(self, idx):
        try:
            if self.split == 'train':
                patient_idx = idx % len(self.valid_patients)
            else:
                patient_idx = idx
                
            patient = self.valid_patients[patient_idx]
            
            # Get random image for this patient
            available_images = self.patient_images[patient]
            if len(available_images) > 1:
                selected_image = np.random.choice(available_images)
            else:
                selected_image = available_images[0]
            
            # Load and preprocess image
            img = load_and_preprocess_dicom_fixed(selected_image)
            if img is None:
                raise ValueError(f"Failed to load image for {patient}")
            
            # Apply augmentations
            img_tensor = self.augmentor(img)
            
            # Get tabular features
            if patient not in self.TAB_dict:
                raise ValueError(f"No tabular features for {patient}")
            tab_features = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
            
            # Get target (decay coefficient)
            if patient not in self.A_dict:
                raise ValueError(f"No decay coefficient for {patient}")
            target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
            
            return img_tensor, tab_features, target, patient
            
        except Exception as e:
            print(f"Error loading sample {idx} (patient {patient if 'patient' in locals() else 'unknown'}): {e}")
            # Return dummy data to prevent training crash
            dummy_img = torch.zeros((3, 512, 512), dtype=torch.float32)
            dummy_tab = torch.zeros(4, dtype=torch.float32)
            dummy_target = torch.tensor(0.0, dtype=torch.float32)
            return dummy_img, dummy_tab, dummy_target, "dummy_patient"

In [None]:
# =============================================================================
# PART 6: FIXED MODEL ARCHITECTURE
# =============================================================================

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv1(x_cat)
        return x * self.sigmoid(x_cat)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class WorkingDenseNetModel(nn.Module):
    """
    COMPLETELY FIXED DenseNet model with proper cross-modal attention
    """
    
    def __init__(self, tabular_dim=4, dropout_rate=0.3):
        super(WorkingDenseNetModel, self).__init__()
        
        # DenseNet121 backbone
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        
        # Get the number of features from DenseNet
        self.num_image_features = 1024  # DenseNet121 output features
        
        # Attention mechanisms
        self.spatial_attention = SpatialAttention()
        self.channel_attention = ChannelAttention(self.num_image_features)
        
        # Enhanced tabular processing
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU()
        )
        
        # FIX 2: MAJOR BUG FIX - Trainable cross-modal projection
        self.tab_to_img_projection = nn.Linear(512, self.num_image_features)
        self.img_to_tab_projection = nn.Linear(self.num_image_features, 512)
        
        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.num_image_features, 
            num_heads=8, 
            dropout=0.1, 
            batch_first=True
        )
        
        # Multi-modal fusion with proper dimensions
        fusion_input_dim = self.num_image_features + 512  # 1024 + 512 = 1536
        self.fusion_layer = nn.Sequential(
            nn.Linear(fusion_input_dim, 768),
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(768, 384),
            nn.BatchNorm1d(384),
            nn.ReLU(),
            nn.Dropout(dropout_rate/2),
            nn.Linear(384, 256),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        
        # Regression heads for uncertainty quantification
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        # Log variance head (for uncertainty)
        self.log_var_head = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
        # Initialize weights properly
        self._init_weights()
        
    def _init_weights(self):
        """Proper weight initialization"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, images, tabular):
        batch_size = images.size(0)
        
        # Extract image features
        img_features = self.features(images)  # [B, 1024, H, W]
        
        # Apply attention mechanisms
        img_features = self.channel_attention(img_features)
        img_features = self.spatial_attention(img_features)
        
        # Global average pooling
        img_features = F.adaptive_avg_pool2d(img_features, (1, 1))
        img_features = img_features.view(batch_size, -1)  # [B, 1024]
        
        # Process tabular data
        tab_features = self.tabular_processor(tabular)  # [B, 512]
        
        # FIX 2: CRITICAL FIX - Proper cross-modal attention
        # Prepare for attention
        img_expanded = img_features.unsqueeze(1)  # [B, 1, 1024]
        
        # Project tabular to same dimension as image features (TRAINABLE!)
        tab_proj = self.tab_to_img_projection(tab_features)  # [B, 512] -> [B, 1024]
        tab_expanded = tab_proj.unsqueeze(1)  # [B, 1, 1024]
        
        # Cross-modal attention: let image attend to tabular
        attended_img, attention_weights = self.cross_attention(
            img_expanded,  # Query
            tab_expanded,  # Key
            tab_expanded   # Value
        )
        attended_img = attended_img.squeeze(1)  # [B, 1024]
        
        # Also project image to tabular space for bidirectional fusion
        img_to_tab = self.img_to_tab_projection(img_features)  # [B, 1024] -> [B, 512]
        enhanced_tab = tab_features + img_to_tab  # Residual connection
        
        # Multi-modal fusion
        combined_features = torch.cat([attended_img, enhanced_tab], dim=1)  # [B, 1536]
        fused_features = self.fusion_layer(combined_features)  # [B, 256]
        
        # Predictions
        mean_pred = self.mean_head(fused_features).squeeze(-1)  # [B]
        log_var = self.log_var_head(fused_features).squeeze(-1)  # [B]
        
        # Clamp log_var to reasonable range to prevent numerical issues
        log_var = torch.clamp(log_var, min=-10, max=10)
        
        return mean_pred, log_var

# FIX 6: Properly defined ModelWithConfidence
class ModelWithConfidence(nn.Module):
    """Wrapper to convert log_var to confidence for evaluation/submission"""
    
    def __init__(self, base_model):
        super(ModelWithConfidence, self).__init__()
        self.base_model = base_model
    
    def forward(self, images, tabular):
        mean_pred, log_var = self.base_model(images, tabular)
        # FIX 3: Consistent sigma handling - no arbitrary multipliers
        confidence = torch.exp(log_var / 2.0)  # Convert log_var to standard deviation
        return mean_pred, confidence

In [None]:
# Initialize the model
model = WorkingDenseNetModel(tabular_dim=4).to(DEVICE)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Part 7: Data Split and Loaders
# -----------------------------

# Split patients into train and validation
from sklearn.model_selection import train_test_split

patients_list = list(P)
train_patients, val_patients = train_test_split(
    patients_list, 
    test_size=0.2, 
    random_state=42,
    shuffle=True
)

print(f"Train patients: {len(train_patients)}")
print(f"Validation patients: {len(val_patients)}")

# Create datasets
train_dataset = OSICDenseNetDataset(
    patients=train_patients,
    A_dict=A,
    TAB_dict=TAB,
    data_dir=TRAIN_DIR,
    split='train',
    augment=True
)

val_dataset = OSICDenseNetDataset(
    patients=val_patients,
    A_dict=A,
    TAB_dict=TAB,
    data_dir=TRAIN_DIR,
    split='val',
    augment=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)


print(f"Data loaders created: {len(train_loader)} training batches, {len(val_loader)} validation batches")


In [None]:
# Part 8: Training with Uncertainty
# --------------------------------

# FIX 3: Corrected trainer with consistent sigma handling
class CorrectedSimpleTrainer:
    """
    Trainer with FIXED sigma consistency
    """
    
    def __init__(self, model, device, lr=1e-4):
        self.model = model
        self.device = device
        self.lr = lr
        self.best_val_mae = float('inf')
        self.best_val_lll = float('-inf')
        
    def uncertainty_loss(self, mean_pred, log_var, targets, reduction='mean'):
        """FIXED uncertainty loss with consistent sigma handling"""
        # FIX 3: Use natural units - no artificial scaling during training
        var = torch.exp(log_var)
        mse_loss = (mean_pred - targets) ** 2
        
        # Uncertainty loss: 0.5 * (MSE/var + log(var))
        loss = 0.5 * (mse_loss / var + log_var)
        
        if reduction == 'mean':
            return loss.mean()
        return loss.sum()
        
    def laplace_log_likelihood(self, y_true, y_pred, log_var, sigma_floor=70.0):
        """
        FIXED Laplace Log Likelihood with proper sigma floor handling
        """
        # Convert log variance to standard deviation
        sigma = torch.exp(log_var / 2.0)
        
        # FIX 3: Apply sigma floor only for evaluation, not training scaling
        sigma = torch.clamp(sigma, min=sigma_floor)  # Use contest floor here
        
        abs_errors = torch.abs(y_true - y_pred)
        
        # Actual log-likelihood: -log(√2 * σ) - |y-μ|/σ
        log_likelihood = -torch.log(np.sqrt(2.0) * sigma) - abs_errors / sigma
        
        return torch.mean(log_likelihood)
        
    def train(self, train_loader, val_loader, epochs=30, patience=8):
        optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=5e-5,
            weight_decay=1e-5
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=4, verbose=True
        )
        
        patience_counter = 0
        
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0
            train_mae = 0.0
            train_lll = 0.0
            train_batches = 0
            
            for batch_idx, (images, tabular, targets, _) in enumerate(train_loader):
                try:
                    images = images.to(self.device)
                    tabular = tabular.to(self.device) 
                    targets = targets.to(self.device)
                    
                    optimizer.zero_grad()
                    
                    # Forward pass
                    mean_pred, log_var = self.model(images, tabular)
                    
                    # Calculate losses and metrics
                    loss = self.uncertainty_loss(mean_pred, log_var, targets)
                    mae = F.l1_loss(mean_pred, targets)
                    lll = self.laplace_log_likelihood(targets, mean_pred, log_var)
                    
                    # Backward pass
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    optimizer.step()
                    
                    train_loss += loss.item()
                    train_mae += mae.item()
                    train_lll += lll.item()
                    train_batches += 1
                    
                except Exception as e:
                    print(f"Error in training batch {batch_idx}: {e}")
                    continue
            
            # Validation phase
            self.model.eval()
            val_loss = 0.0
            val_mae = 0.0
            val_lll = 0.0
            val_predictions = []
            val_targets = []
            val_log_vars = []
            
            with torch.no_grad():
                for batch_idx, (images, tabular, targets, _) in enumerate(val_loader):
                    try:
                        images = images.to(self.device)
                        tabular = tabular.to(self.device)
                        targets = targets.to(self.device)
                        
                        mean_pred, log_var = self.model(images, tabular)
                        
                        loss = self.uncertainty_loss(mean_pred, log_var, targets)
                        mae = F.l1_loss(mean_pred, targets)
                        lll = self.laplace_log_likelihood(targets, mean_pred, log_var)
                        
                        val_loss += loss.item()
                        val_mae += mae.item()
                        val_lll += lll.item()
                        
                        val_predictions.extend(mean_pred.cpu().numpy())
                        val_targets.extend(targets.cpu().numpy())
                        val_log_vars.extend(log_var.cpu().numpy())
                        
                    except Exception as e:
                        print(f"Error in validation batch {batch_idx}: {e}")
                        continue
            
            # Calculate metrics
            if train_batches > 0 and len(val_predictions) > 0:
                avg_train_loss = train_loss / train_batches
                avg_train_mae = train_mae / train_batches
                avg_train_lll = train_lll / train_batches
                
                avg_val_loss = val_loss / len(val_loader)
                avg_val_mae = val_mae / len(val_loader)
                avg_val_lll = val_lll / len(val_loader)
                
                # Convert to numpy for additional metrics
                val_predictions = np.array(val_predictions)
                val_targets = np.array(val_targets)
                val_log_vars = np.array(val_log_vars)
                
                # FIX 3: Show raw sigma values (no arbitrary scaling)
                val_sigmas = np.sqrt(np.exp(val_log_vars))
                val_rmse = np.sqrt(np.mean((val_predictions - val_targets) ** 2))
                
                # Calculate R²
                ss_res = np.sum((val_targets - val_predictions) ** 2)
                ss_tot = np.sum((val_targets - np.mean(val_targets)) ** 2)
                r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else -float('inf')
                
                print(f"Epoch {epoch+1}/{epochs}")
                print(f"Train Loss: {avg_train_loss:.6f} | Train LLL: {avg_train_lll:.6f} | Train MAE: {avg_train_mae:.6f}")
                print(f"Val Loss: {avg_val_loss:.6f} | Val LLL: {avg_val_lll:.6f} | MAE: {avg_val_mae:.6f} | RMSE: {val_rmse:.6f} | R²: {r2:.6f}")
                print(f"Raw Sigma Stats: Avg={np.mean(val_sigmas):.2f}, Range=[{np.min(val_sigmas):.2f}, {np.max(val_sigmas):.2f}]")
                
                scheduler.step(avg_val_lll)
                
                if avg_val_lll > self.best_val_lll:
                    self.best_val_lll = avg_val_lll
                    self.best_val_mae = avg_val_mae
                    torch.save(self.model.state_dict(), 'best_model_fixed.pth')
                    print("✅ New best model saved!")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
                    
                print("-" * 70)
        
        return self.best_val_mae

In [None]:
# Part 9: Submission Generation with Confidence
# -------------------------------------------

class TTAPredictor:
    """Test-time augmentation for more robust predictions"""
    def __init__(self, model, num_augmentations=5):
        self.model = model
        self.num_augmentations = num_augmentations
        self.augmentor = MedicalAugmentation(augment=True)
        self.model.eval()
    
    def predict(self, image, tabular):
        # Original prediction
        with torch.no_grad():
            mean_pred, log_var = self.model(image.unsqueeze(0), tabular.unsqueeze(0))
            mean_preds = [mean_pred.item()]
            log_vars = [log_var.item()]
        
        # Augmented predictions
        for _ in range(self.num_augmentations):
            try:
                # Apply augmentation
                aug_img = self.augmentor(image.permute(1, 2, 0).numpy().astype(np.uint8))
                aug_img = aug_img.to(image.device)
                
                # Predict
                with torch.no_grad():
                    mean_pred, log_var = self.model(aug_img.unsqueeze(0), tabular.unsqueeze(0))
                    mean_preds.append(mean_pred.item())
                    log_vars.append(log_var.item())
                    
            except Exception as e:
                print(f"Error in TTA: {e}")
                continue
        
        # Ensemble predictions
        mean_ensemble = np.median(mean_preds)
        log_var_ensemble = np.median(log_vars)
        
        # Calculate uncertainty (standard deviation)
        std = np.sqrt(np.exp(log_var_ensemble))
        
        return mean_ensemble, std

def create_submission_with_confidence(model, test_dir, output_file='submission.csv'):
    """Create submission with confidence intervals (no artificial clipping)."""
    print(f"📝 Creating submission with confidence intervals...")

    # Load test data
    test_df = pd.read_csv(DATA_DIR / 'test.csv')
    print(f"✅ Loaded test data: {len(test_df)} samples")

    submissions = []
    model.eval()

    # Create augmentor for test time augmentation
    test_augmentor = MedicalAugmentation(augment=False)

    print("🔄 Processing test patients...")

    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing"):
        patient_id = row['Patient']
        weeks = row['Weeks']

        try:
            # Load patient image dir
            patient_dir = Path(test_dir) / patient_id

            # Default fallback predictions
            fvc_pred = 2000.0
            confidence_val = 200.0

            if patient_dir.exists():
                image_files = list(patient_dir.glob('*.dcm'))
                if image_files:
                    # Load and preprocess image
                    img = load_and_preprocess_dicom_fixed(image_files[0])
                    img_tensor = test_augmentor(img).unsqueeze(0).to(DEVICE)

                    # Prepare tabular features
                    tab_features = get_tab_features(row)
                    tab_tensor = torch.tensor(tab_features).float().unsqueeze(0).to(DEVICE)

                    # Predict with uncertainty
                    with torch.no_grad():
                        mean_pred, log_var = model(img_tensor, tab_tensor)
                        fvc_pred = mean_pred.item()
                        confidence_val = max(torch.exp(log_var / 2).item() * 70, 70)

            # Create submission rows for required weeks
            for week in range(-12, 134):
                patient_week = f"{patient_id}_{week}"

                # Linear progression adjustment
                if patient_id in A:
                    time_adjusted_fvc = fvc_pred + (week - weeks) * A[patient_id]
                else:
                    time_adjusted_fvc = fvc_pred + (week - weeks) * (-7)

                # ⚠️ No clipping — keep raw predictions
                submissions.append({
                    'Patient_Week': patient_week,
                    'FVC': time_adjusted_fvc,
                    'Confidence': confidence_val
                })

        except Exception as e:
            print(f"⚠️ Error processing patient {patient_id}: {e}")
            # Fallback rows for failed patient
            for week in range(-12, 134):
                patient_week = f"{patient_id}_{week}"
                submissions.append({
                    'Patient_Week': patient_week,
                    'FVC': 2000.0,
                    'Confidence': 200.0
                })

    # Build dataframe
    submission_df = pd.DataFrame(submissions)
    submission_df.to_csv(output_file, index=False)

    print(f"✅ Submission saved to {output_file}")
    print(f"📊 Submission stats:")
    print(f"   Total rows: {len(submission_df)}")
    print(f"   FVC raw range: {submission_df['FVC'].min():.1f} - {submission_df['FVC'].max():.1f}")
    print(f"   Confidence range: {submission_df['Confidence'].min():.1f} - {submission_df['Confidence'].max():.1f}")

    return submission_df


# Helper function for DICOM loading (for submission)
def load_and_preprocess_dicom(path):
    """Simplified DICOM loading for submission"""
    try:
        dcm = pydicom.dcmread(str(path))
        img = dcm.pixel_array.astype(np.float32)
        
        if len(img.shape) == 3:
            img = img[img.shape[0]//2]
        
        img = cv2.resize(img, (512, 512))
        
        # Normalize to 0-255
        img_min, img_max = img.min(), img.max()
        if img_max > img_min:
            img = (img - img_min) / (img_max - img_min) * 255
        else:
            img = np.zeros_like(img)
        
        # Convert to 3-channel
        img = np.stack([img, img, img], axis=2).astype(np.uint8)
        return img
        
    except Exception as e:
        # Return black image as fallback
        return np.zeros((512, 512, 3), dtype=np.uint8)

In [None]:
# Part 10: Execute Training with Enhanced Error Handling
# -----------------------------------------------------

print("🚀 Starting training...")
print("📊 Dataset Summary:")
print(f"   Total patients: {len(P)}")
print(f"   Training patients: {len(train_dataset.valid_patients)}")
print(f"   Validation patients: {len(val_dataset.valid_patients)}")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

# Test one batch to ensure everything works
print("\n🧪 Testing data loading...")
try:
    test_batch = next(iter(train_loader))
    images, tabular, targets, patient_ids = test_batch
    print(f"✅ Successfully loaded test batch:")
    print(f"   Images shape: {images.shape}")
    print(f"   Tabular shape: {tabular.shape}")
    print(f"   Targets shape: {targets.shape}")
    print(f"   Sample patient: {patient_ids[0]}")
    
    # Test model forward pass
    with torch.no_grad():
        test_pred, test_log_var = model(images.to(DEVICE), tabular.to(DEVICE))
        print(f"   Model output shapes: pred={test_pred.shape}, log_var={test_log_var.shape}")
        print(f"   Sample prediction: {test_pred[0].item():.3f} ± {torch.exp(test_log_var[0]/2).item():.3f}")
    
    print("✅ All systems ready for training!")
    
except Exception as e:
    print(f"❌ Error in data loading test: {e}")
    import traceback
    traceback.print_exc()
    print("\n🛠️ Please check data paths and DICOM loading issues above.")

# Start actual training
try:
    trainer = CorrectedSimpleTrainer(model, DEVICE, lr=5e-5)
    print(f"\n🎯 Training started with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    best_val_mae = trainer.train(train_loader, val_loader, epochs=30, patience=8)
    print(f"🎯 Training completed! Best validation MAE: {best_val_mae:.6f}")
    
    # Generate submission (skipping as per user's request)
    print("📝 Skipping submission generation as requested...")
    # final_submission = create_submission_with_confidence(model, TEST_DIR, 'enhanced_submission.csv')
    # print("✅ Submission ready!")
    
except KeyboardInterrupt:
    print("\n⏹️ Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    import traceback
    traceback.print_exc()
    print("\n💡 Try running individual cells to debug the issue")

In [None]:
# -------------------------
# LLL evaluation utilities
# -------------------------
import numpy as np
import pandas as pd
import torch
import os
from math import sqrt, log
from tqdm import tqdm

SQRT2 = np.sqrt(2.0)

def laplace_score_per_sample(y_true, y_pred, sigma, sigma_floor=70.0):
    """
    Per-sample Laplace Log-Likelihood (as used in OSIC).
    Inputs are numpy arrays or scalars.
    sigma is clipped to a minimum of sigma_floor.
    Returns per-sample score (not averaged).
    """
    sigma = np.maximum(sigma, sigma_floor)
    delta = np.abs(y_true - y_pred)
    term1 = - (SQRT2 * delta) / sigma
    term2 = - np.log(SQRT2 * sigma)
    return term1 + term2

def mean_laplace_score(y_true, y_pred, sigma, sigma_floor=70.0):
    arr = laplace_score_per_sample(np.array(y_true), np.array(y_pred), np.array(sigma), sigma_floor=sigma_floor)
    return float(np.mean(arr))

# -------------------------
# Evaluate on a DataLoader
# -------------------------
def evaluate_lll_from_loader(model, loader, device, mode='log_var', tta_predictor=None, sigma_floor=70.0, save_csv=True, out_dir=None):
    """
    mode:
      - 'log_var' : model(images, tabular) -> (mean_pred, log_var). sigma = sqrt(exp(log_var))
      - 'confidence' : model(images, tabular) -> (mean_pred, confidence). confidence used as sigma directly
      - 'tta' : use tta_predictor.predict(image, tabular) -> (mean, std). (std used directly as sigma)
    tta_predictor: instance of TTAPredictor if mode == 'tta'
    Returns: (mean_lll, df) and writes CSV if save_csv True.
    CSV columns: ['patient'(if available), 'y_true', 'y_pred', 'sigma', 'lll']
    """
    model.eval()
    preds = []
    trues = []
    sigmas = []
    patients = []

    device = torch.device(device)

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating LLL"):
            images, tabular, targets, patient_ids = batch
            batch_size = images.shape[0]

            images = images.to(device)
            tabular = tabular.to(device)
            targets = targets.to(device)

            if mode == 'log_var':
                mean_pred, log_var = model(images, tabular)
                mean_np = mean_pred.detach().cpu().numpy().astype(float)
                log_var_np = log_var.detach().cpu().numpy().astype(float)
                sigma_np = np.sqrt(np.exp(log_var_np))
            elif mode == 'confidence':
                mean_pred, confidence = model(images, tabular)
                mean_np = mean_pred.detach().cpu().numpy().astype(float)
                sigma_np = confidence.detach().cpu().numpy().astype(float)
            elif mode == 'tta':
                mean_list = []
                sigma_list = []
                for i in range(batch_size):
                    img = images[i].cpu()
                    tab = tabular[i].cpu()
                    mean_i, sigma_i = tta_predictor.predict(img, tab)
                    mean_list.append(float(mean_i))
                    sigma_list.append(float(sigma_i))
                mean_np = np.array(mean_list, dtype=float)
                sigma_np = np.array(sigma_list, dtype=float)
            else:
                raise ValueError("Unknown mode for evaluate_lll_from_loader")

            targets_np = targets.detach().cpu().numpy().astype(float)

            preds.extend(mean_np.tolist())
            trues.extend(targets_np.tolist())
            sigmas.extend(sigma_np.tolist())
            patients.extend([p if isinstance(p, str) else (p.item().decode('utf-8') if hasattr(p, 'item') else str(p)) for p in patient_ids])

    preds = np.array(preds, dtype=float)
    trues = np.array(trues, dtype=float)
    sigmas = np.array(sigmas, dtype=float)
    lll_per_sample = laplace_score_per_sample(trues, preds, sigmas, sigma_floor=sigma_floor)
    mean_lll = float(np.mean(lll_per_sample))

    df = pd.DataFrame({
        'patient': patients,
        'y_true': trues,
        'y_pred': preds,
        'sigma': sigmas,
        'lll': lll_per_sample
    })

    if save_csv:
        if out_dir is None:
            out_dir = globals().get('auto_save_dir', '.')
        os.makedirs(out_dir, exist_ok=True)
        outpath = os.path.join(out_dir, 'lll_predictions.csv')
        df.to_csv(outpath, index=False)
        print(f"Saved per-sample predictions + lll to: {outpath}")

    print(f"Mean Laplace Log-Likelihood (LLL): {mean_lll:.6f}")
    return mean_lll, df

# -------------------------
# Helper: Convert slope -> FVC predictions and compute LLL per patient-week
# -------------------------
def compute_lll_from_slope_predictions(slope_df, cur_fvc_map, cur_week_map, weeks_to_predict=None, sigma_floor=70.0, save_csv=True, out_dir=None):
    if weeks_to_predict is None:
        weeks_to_predict = np.arange(-12, 134)

    rows = []
    for idx, r in slope_df.iterrows():
        patient = r['Patient']
        slope = float(r['pred_slope'])
        sigma_slope = float(r.get('sigma_slope', 0.0))
        if patient not in cur_fvc_map or patient not in cur_week_map:
            continue
        cur_fvc = float(cur_fvc_map[patient])
        cur_week = int(cur_week_map[patient])

        intercept = cur_fvc - slope * cur_week

        for w in weeks_to_predict:
            pred_fvc = intercept + slope * w
            sigma_fvc = max(1e-6, sigma_slope * abs(w - cur_week))
            rows.append({
                'Patient': patient,
                'Week': w,
                'y_true': None,
                'y_pred': pred_fvc,
                'sigma': sigma_fvc
            })

    df_expanded = pd.DataFrame(rows)
    if save_csv:
        out_dir = out_dir or globals().get('auto_save_dir', '.')
        os.makedirs(out_dir, exist_ok=True)
        df_expanded.to_csv(os.path.join(out_dir, 'slope_to_fvc_expanded.csv'), index=False)
        print(f"Saved expanded slope->FVC predictions to {out_dir}/slope_to_fvc_expanded.csv")
    return df_expanded

# -------------------------
# Example usage after training (run these cells)
# -------------------------
# 1) If your trained model returns (mean_pred, log_var) -> use mode='log_var'
auto_save_dir = "./auto_save_data"
import os
os.makedirs(auto_save_dir, exist_ok=True)
mean_lll, df = evaluate_lll_from_loader(model, val_loader, DEVICE, mode='log_var', out_dir=auto_save_dir)
#
# 2) If you wrapped your model with ModelWithConfidence and it returns (mean_pred, confidence) -> mode='confidence'
wrapped = ModelWithConfidence(model)  # load weights as needed
wrapped.load_state_dict(torch.load('model_with_confidence.pth'))  # if saved
wrapped.to(DEVICE).eval()
mean_lll, df = evaluate_lll_from_loader(wrapped, val_loader, DEVICE, mode='confidence', out_dir=auto_save_dir)
#
# 3) If you want to do TTA (slower) using TTAPredictor:
tta = TTAPredictor(model, num_augmentations=5)
mean_lll, df = evaluate_lll_from_loader(model, val_loader, DEVICE, mode='tta', tta_predictor=tta, out_dir=auto_save_dir)
#
# 4) If your model predicts slope, and you have anchor cur_fvc & cur_week (per-patient), create a slope_df:
slope_df = pd.DataFrame([
    {'Patient': p, 'pred_slope': s, 'sigma_slope': ssize}
    for p, s, ssize in zip(patient_list, slope_list, sigma_list)
])
expanded = compute_lll_from_slope_predictions(slope_df, cur_fvc_map, cur_week_map, weeks_to_predict=np.arange(-12,134))
#    # fill expanded['y_true'] with true FVCs if you have them, then compute mean_laplace_score:
mean_lll = mean_laplace_score(expanded['y_true'], expanded['y_pred'], expanded['sigma'], sigma_floor=70.0)


In [None]:
# Install required packages
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', package, '--quiet'])
        return True
    except subprocess.CalledProcessError:
        return False

print("Installing required packages...")
packages = [
    'pydicom',
    'pylibjpeg',
    'pylibjpeg-libjpeg', 
    'gdcm',
    'opencv-python-headless',
    'scikit-learn',
    'albumentations',
    'tqdm',
    'seaborn',
    'torch_xla',
    'cloud-tpu-client'
]

for pkg in packages:
    if install_package(pkg):
        print(f"✅ {pkg} installed")
    else:
        print(f"⚠️ {pkg} installation failed (may already be installed)")

# Import libraries
import os
import cv2
import pydicom
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns
import random
from tqdm import tqdm 
from datetime import timedelta, datetime
from pathlib import Path
import json
import warnings
import pickle
import glob
from math import sqrt, log
from sklearn.metrics import mean_squared_error, r2_score

# Image processing
from skimage import measure, morphology, segmentation
from skimage.transform import resize
from scipy.ndimage import binary_dilation, binary_erosion
from skimage.measure import label, regionprops
from sklearn.cluster import KMeans
from skimage.segmentation import clear_border

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler

# Albumentations for medical augmentations
import albumentations as albu
from albumentations.pytorch import ToTensorV2

# Model selection
from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
import lightgbm as lgb

# Check for TPU/GPU acceleration
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    HAS_TPU = True
except ImportError:
    HAS_TPU = False

warnings.filterwarnings('ignore')

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(42)

# Configuration - Detect and use best available accelerator
if HAS_TPU:
    DEVICE = xm.xla_device()
    print("Using TPU accelerator")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using GPU accelerator")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"

print("Pulmonary Fibrosis Progression Analysis - OPTIMIZED VERSION")
print(f"Device: {DEVICE}")

# =============================================================================
# PART 1: DATA LOADING AND EDA
# =============================================================================

# Load datasets
train_df = pd.read_csv(DATA_DIR / 'train.csv')
try:
    test_df = pd.read_csv(DATA_DIR / 'test.csv')
    print(f'Train: {train_df.shape[0]} rows, Test: {test_df.shape[0]} rows')
except:
    print(f'Train: {train_df.shape[0]} rows, Test: file not found')
    test_df = None

print("\nTrain data sample:")
print(train_df.head())

print("\nTrain data statistics:")
print(train_df.describe())

# Basic EDA
print(f'\nUnique patients in training data: {train_df["Patient"].nunique()}')
print(f'Total observations: {len(train_df)}')
print(f'Average observations per patient: {len(train_df)/train_df["Patient"].nunique():.2f}')

# Check for missing values
print(f'\nMissing values:')
print(train_df.isnull().sum())

# =============================================================================
# PART 2: FEATURE ENGINEERING
# =============================================================================

print("Processing tabular features...")

# Create baseline features for each patient
baseline_features = {}
patient_slopes = {}
patient_intercepts = {}

for patient in train_df['Patient'].unique():
    patient_data = train_df[train_df['Patient'] == patient].copy().sort_values('Weeks')
    
    # Get baseline measurement (first visit)
    baseline = patient_data.iloc[0]
    baseline_features[patient] = {
        'Age': baseline['Age'],
        'Sex': baseline['Sex'],
        'SmokingStatus': baseline['SmokingStatus'],
        'BaselineFVC': baseline['FVC'],
        'BaselineWeeks': baseline['Weeks'],
        'Percent': baseline['Percent'] if 'Percent' in baseline else 50.0  # Fallback
    }
    
    # Calculate slope and intercept if multiple measurements
    if len(patient_data) > 1:
        weeks = patient_data['Weeks'].values
        fvc = patient_data['FVC'].values
        
        # Linear regression: FVC = slope * weeks + intercept
        A = np.vstack([weeks, np.ones(len(weeks))]).T
        slope, intercept = np.linalg.lstsq(A, fvc, rcond=None)[0]
        
        patient_slopes[patient] = slope
        patient_intercepts[patient] = intercept
    else:
        patient_slopes[patient] = 0.0
        patient_intercepts[patient] = baseline['FVC']

# Create enhanced tabular features
def get_enhanced_tabular_features(patient_id, row=None):
    """Get enhanced tabular features with proper encoding"""
    features = baseline_features[patient_id].copy()
    
    # Standardize age
    features['Age'] = (features['Age'] - 50) / 20  # Standardize around mean
    
    # Encode sex (0 for Male, 1 for Female)
    features['Sex'] = 1 if features['Sex'] == 'Female' else 0
    
    # One-hot encode smoking status
    smoking_status = features['SmokingStatus']
    features['Smoking_Never'] = 1 if smoking_status == 'Never smoked' else 0
    features['Smoking_Ex'] = 1 if smoking_status == 'Ex-smoker' else 0
    features['Smoking_Current'] = 1 if smoking_status == 'Currently smokes' else 0
    
    # Standardize baseline FVC
    features['BaselineFVC'] = (features['BaselineFVC'] - 2500) / 1000
    
    # Standardize baseline weeks
    features['BaselineWeeks'] = features['BaselineWeeks'] / 100
    
    # Standardize Percent
    features['Percent'] = (features['Percent'] - 50) / 20
    
    # Add slope and intercept
    features['Slope'] = patient_slopes[patient_id] / 10  # Scale slope
    features['Intercept'] = (patient_intercepts[patient_id] - 2500) / 1000
    
    # Remove the original smoking status
    del features['SmokingStatus']
    
    # If we have a row with current week, add week delta
    if row is not None:
        week_delta = (row['Weeks'] - features['BaselineWeeks'] * 100) / 50
        features['WeekDelta'] = week_delta
    
    return np.array(list(features.values()), dtype=np.float32)

# Calculate slopes for each patient
A = patient_slopes
TAB = {patient: get_enhanced_tabular_features(patient) for patient in baseline_features.keys()}
P = list(baseline_features.keys())

print(f"Processed {len(P)} patients with enhanced features")

# =============================================================================
# PART 3: ENHANCED DICOM PROCESSING WITH HU PRESERVATION
# =============================================================================

def get_pixels_hu(dcm):
    """Convert DICOM pixel array to Hounsfield Units"""
    try:
        # Get pixel array
        pixel_array = dcm.pixel_array.astype(np.float32)
        
        # Apply rescale intercept and slope if available
        intercept = getattr(dcm, 'RescaleIntercept', 0)
        slope = getattr(dcm, 'RescaleSlope', 1)
        
        pixel_array = pixel_array * slope + intercept
        
        return pixel_array
    except:
        return np.zeros((512, 512), dtype=np.float32)

def load_dicom_with_hu(path):
    """Load DICOM and preserve Hounsfield Units"""
    try:
        dcm = pydicom.dcmread(str(path), force=True)
        hu_image = get_pixels_hu(dcm)
        
        # Window to lung range [-1200, 600]
        hu_image = np.clip(hu_image, -1200, 600)
        
        # Normalize to [-1, 1]
        hu_image = (hu_image + 300) / 900  # Center around -300, scale by 900
        
        return hu_image
    except Exception as e:
        return np.zeros((512, 512), dtype=np.float32) - 1  # Return -1 filled array

def load_three_slices(patient_dir, slice_idx=None, target_size=(256, 256)):
    """Load three adjacent slices for 2.5D input with consistent sizing"""
    try:
        # Get all DICOM files for patient
        dicom_files = sorted(list(patient_dir.glob('*.dcm')))
        if not dicom_files:
            return None
            
        # Use middle slice if not specified
        if slice_idx is None:
            slice_idx = len(dicom_files) // 2
            
        # Get three slices around the index
        slices = []
        for i in range(max(0, slice_idx-1), min(len(dicom_files), slice_idx+2)):
            slice_img = load_dicom_with_hu(dicom_files[i])
            
            # Resize to target size
            if slice_img.shape != target_size:
                slice_img = cv2.resize(slice_img, target_size, interpolation=cv2.INTER_AREA)
                
            slices.append(slice_img)
            
        # If we couldn't get 3 slices, duplicate existing ones
        while len(slices) < 3:
            slices.append(slices[-1] if slices else np.zeros(target_size, dtype=np.float32) - 1)
            
        # Stack slices as channels
        stacked = np.stack(slices, axis=-1)
        return stacked
        
    except Exception as e:
        # Return dummy 3-slice image with correct size
        dummy_slice = np.zeros(target_size, dtype=np.float32) - 1
        return np.stack([dummy_slice] * 3, axis=-1)

# =============================================================================
# PART 4: MEDICAL AUGMENTATIONS FOR HU IMAGES
# =============================================================================

class MedicalAugmentation:
    def __init__(self, augment=True, target_size=(256, 256)):
        self.target_size = target_size
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=5, p=0.3),  # Reduced rotation for medical images
                albu.HorizontalFlip(p=0.3),
                albu.ShiftScaleRotate(shift_limit=0.03, scale_limit=0.05, rotate_limit=5, p=0.3),
                albu.GaussNoise(var_limit=(0.01, 0.05), p=0.2),  # Reduced noise for HU
                albu.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                ToTensorV2()
            ])
    
    def __call__(self, image):
        # Ensure image is the correct size
        if image.shape[:2] != self.target_size:
            image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_AREA)
        return self.transform(image=image)['image']

# =============================================================================
# PART 5: DATASET CLASS WITH 2.5D INPUT
# =============================================================================

class OSICDenseNetDataset(Dataset):
    """Dataset class with 2.5D input and robust error handling"""
    
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True, target_size=(256, 256)):
        self.patients = patients
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augment = augment
        self.target_size = target_size
        self.augmentor = MedicalAugmentation(augment=augment, target_size=target_size)
        
        # Preload patient directories and file lists
        self.patient_dirs = {}
        valid_patients = []
        
        for patient in self.patients:
            patient_dir = self.data_dir / patient
            
            if not patient_dir.exists():
                continue
                
            dicom_files = list(patient_dir.glob('*.dcm'))
            
            if dicom_files:
                self.patient_dirs[patient] = patient_dir
                valid_patients.append(patient)
        
        self.valid_patients = valid_patients
        print(f"Dataset {split}: {len(self.valid_patients)} valid patients")
    
    def __len__(self):
        return len(self.valid_patients) * (4 if self.split == 'train' else 1)
    
    def __getitem__(self, idx):
        try:
            patient_idx = idx % len(self.valid_patients)
            patient = self.valid_patients[patient_idx]
            
            # Load 2.5D image (3 slices)
            img = load_three_slices(self.patient_dirs[patient], target_size=self.target_size)
            
            # Apply augmentations
            img_tensor = self.augmentor(img)
            
            # Get features
            tab_features = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
            target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
            
            return img_tensor, tab_features, target, patient
            
        except Exception as e:
            # Return dummy data with consistent sizes
            dummy_img = torch.zeros((3, self.target_size[0], self.target_size[1]), dtype=torch.float32)
            dummy_tab = torch.zeros(len(self.TAB_dict[patient]), dtype=torch.float32)
            dummy_target = torch.tensor(0.0, dtype=torch.float32)
            return dummy_img, dummy_tab, dummy_target, "dummy_patient"

# =============================================================================
# PART 6: IMPROVED MODEL ARCHITECTURE
# =============================================================================

class EfficientNetModel(nn.Module):
    """More efficient model using EfficientNet backbone"""
    
    def __init__(self, tabular_dim=10):
        super(EfficientNetModel, self).__init__()
        
        # EfficientNet backbone
        self.backbone = models.efficientnet_b0(pretrained=True)
        
        # Modify first convolution to accept 3 channels properly
        original_first_conv = self.backbone.features[0][0]
        self.backbone.features[0][0] = nn.Conv2d(
            3, original_first_conv.out_channels, 
            kernel_size=original_first_conv.kernel_size,
            stride=original_first_conv.stride,
            padding=original_first_conv.padding,
            bias=original_first_conv.bias
        )
        
        # Initialize with pretrained weights (average across RGB channels)
        with torch.no_grad():
            self.backbone.features[0][0].weight[:, :3] = original_first_conv.weight.clone()
            if original_first_conv.bias is not None:
                self.backbone.features[0][0].bias = original_first_conv.bias.clone()
        
        # Get number of features from backbone
        self.num_image_features = self.backbone.classifier[1].in_features
        
        # Tabular processor
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Fusion and prediction
        self.fusion = nn.Sequential(
            nn.Linear(self.num_image_features + 128, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.mean_head = nn.Linear(128, 1)
        self.log_var_head = nn.Linear(128, 1)
        
    def forward(self, images, tabular):
        # Image features
        img_features = self.backbone.features(images)
        img_features = F.adaptive_avg_pool2d(img_features, (1, 1))
        img_features = img_features.view(img_features.size(0), -1)
        
        # Tabular features
        tab_features = self.tabular_processor(tabular)
        
        # Fusion
        combined = torch.cat([img_features, tab_features], dim=1)
        fused = self.fusion(combined)
        
        # Predictions
        mean_pred = self.mean_head(fused).squeeze(-1)
        log_var = self.log_var_head(fused).squeeze(-1)
        
        return mean_pred, log_var

# Initialize model
tabular_dim = len(TAB[list(TAB.keys())[0]])
model = EfficientNetModel(tabular_dim=tabular_dim).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# =============================================================================
# PART 7: DATA SPLIT AND LOADERS
# =============================================================================

# Split patients using GroupKFold for better validation
patients_list = list(P)
kfold = GroupKFold(n_splits=5)

# For simplicity, use the first fold
train_idx, val_idx = next(kfold.split(patients_list, groups=patients_list))
train_patients = [patients_list[i] for i in train_idx]
val_patients = [patients_list[i] for i in val_idx]

print(f"Train patients: {len(train_patients)}, Validation patients: {len(val_patients)}")

# Create datasets with consistent target size (smaller for faster training)
TARGET_SIZE = (256, 256)
train_dataset = OSICDenseNetDataset(
    patients=train_patients, A_dict=A, TAB_dict=TAB, 
    data_dir=TRAIN_DIR, split='train', augment=True, target_size=TARGET_SIZE
)

val_dataset = OSICDenseNetDataset(
    patients=val_patients, A_dict=A, TAB_dict=TAB,
    data_dir=TRAIN_DIR, split='val', augment=False, target_size=TARGET_SIZE
)

# Create data loaders with appropriate batch size
BATCH_SIZE = 16 if HAS_TPU or torch.cuda.is_available() else 4
NUM_WORKERS = 4 if HAS_TPU or torch.cuda.is_available() else 2

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=NUM_WORKERS, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=NUM_WORKERS, pin_memory=True
)

print(f"Data loaders created: {len(train_loader)} train, {len(val_loader)} val batches")
print(f"Using batch size: {BATCH_SIZE}, Workers: {NUM_WORKERS}")

# =============================================================================
# PART 8: TRAINING WITH IMPROVED LOSS AND METRICS
# =============================================================================

def laplace_log_likelihood(y_true, y_pred, sigma, sigma_min=70):
    """Compute Laplace Log Likelihood with sigma clipping"""
    sigma = np.maximum(sigma, sigma_min)
    delta = np.abs(y_true - y_pred)
    return -np.sqrt(2) * delta / sigma - np.log(np.sqrt(2) * sigma)

class ImprovedTrainer:
    def __init__(self, model, device, lr=1e-4):
        self.model = model
        self.device = device
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=5, factor=0.5, verbose=True
        )
        self.best_loss = float('inf')
        self.scaler = GradScaler()
        
        # For TPU
        self.is_tpu = hasattr(device, 'type') and device.type == 'xla'
    
    def gaussian_nll_loss(self, mean_pred, log_var, targets):
        """Gaussian negative log likelihood loss"""
        return 0.5 * (torch.mean(torch.exp(-log_var) * (mean_pred - targets)**2 + log_var))
    
    def compute_metrics(self, mean_pred, log_var, targets):
        """Compute various metrics for evaluation"""
        # Convert to numpy for metric calculation
        mean_pred_np = mean_pred.detach().cpu().numpy()
        log_var_np = log_var.detach().cpu().numpy()
        targets_np = targets.detach().cpu().numpy()
        
        # Calculate sigma
        sigma_np = np.sqrt(np.exp(log_var_np))
        
        # Metrics
        mse = mean_squared_error(targets_np, mean_pred_np)
        rmse = np.sqrt(mse)
        r2 = r2_score(targets_np, mean_pred_np)
        
        # Laplace Log Likelihood
        lll = np.mean(laplace_log_likelihood(targets_np, mean_pred_np, sigma_np))
        
        return {
            'mse': mse,
            'rmse': rmse,
            'r2': r2,
            'lll': lll
        }
    
    def train_epoch(self, loader):
        self.model.train()
        total_loss = 0
        all_metrics = {'mse': 0, 'rmse': 0, 'r2': 0, 'lll': 0}
        
        for images, tabular, targets, _ in tqdm(loader, desc="Training"):
            if self.is_tpu:
                images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
            else:
                images, tabular, targets = images.to(self.device, non_blocking=True), \
                                         tabular.to(self.device, non_blocking=True), \
                                         targets.to(self.device, non_blocking=True)
            
            self.optimizer.zero_grad()
            
            if self.is_tpu:
                # TPU doesn't support AMP, use regular training
                mean_pred, log_var = self.model(images, tabular)
                loss = self.gaussian_nll_loss(mean_pred, log_var, targets)
                loss.backward()
                xm.optimizer_step(self.optimizer)
            else:
                # Use AMP for GPU
                with autocast():
                    mean_pred, log_var = self.model(images, tabular)
                    loss = self.gaussian_nll_loss(mean_pred, log_var, targets)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            
            total_loss += loss.item()
            
            # Compute metrics
            metrics = self.compute_metrics(mean_pred, log_var, targets)
            for k in all_metrics:
                all_metrics[k] += metrics[k]
        
        # Average metrics
        for k in all_metrics:
            all_metrics[k] /= len(loader)
        
        return total_loss / len(loader), all_metrics
    
    def validate(self, loader):
        self.model.eval()
        total_loss = 0
        all_metrics = {'mse': 0, 'rmse': 0, 'r2': 0, 'lll': 0}
        
        with torch.no_grad():
            for images, tabular, targets, _ in tqdm(loader, desc="Validation"):
                if self.is_tpu:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                else:
                    images, tabular, targets = images.to(self.device, non_blocking=True), \
                                             tabular.to(self.device, non_blocking=True), \
                                             targets.to(self.device, non_blocking=True)
                
                mean_pred, log_var = self.model(images, tabular)
                loss = self.gaussian_nll_loss(mean_pred, log_var, targets)
                total_loss += loss.item()
                
                # Compute metrics
                metrics = self.compute_metrics(mean_pred, log_var, targets)
                for k in all_metrics:
                    all_metrics[k] += metrics[k]
        
        # Average metrics
        for k in all_metrics:
            all_metrics[k] /= len(loader)
        
        return total_loss / len(loader), all_metrics
    
    def train(self, train_loader, val_loader, epochs=20):
        for epoch in range(epochs):
            train_loss, train_metrics = self.train_epoch(train_loader)
            val_loss, val_metrics = self.validate(val_loader)
            
            self.scheduler.step(val_loss)
            
            print(f"\nEpoch {epoch+1}/{epochs}:")
            print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            print("Train Metrics - MSE: {mse:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}, LLL: {lll:.4f}".format(**train_metrics))
            print("Val Metrics   - MSE: {mse:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}, LLL: {lll:.4f}".format(**val_metrics))
            
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                if self.is_tpu:
                    xm.save(model.state_dict(), 'best_model.pth')
                else:
                    torch.save(model.state_dict(), 'best_model.pth')
                print("✅ New best model saved!")

# Start training
print("🚀 Starting training...")
trainer = ImprovedTrainer(model, DEVICE, lr=1e-4)
trainer.train(train_loader, val_loader, epochs=15)

print("🎯 Training completed!")

# =============================================================================
# PART 9: BASELINE MODEL (LightGBM)
# =============================================================================

print("Training LightGBM baseline for comparison...")

# Prepare data for LightGBM
X = []
y = []
groups = []

for patient in P:
    features = TAB[patient]
    X.append(features)
    y.append(A[patient])
    groups.append(patient)

X = np.array(X)
y = np.array(y)

# Train LightGBM model
lgb_model = lgb.LGBMRegressor(n_estimators=100, random_state=42)
lgb_model.fit(X, y)

# Evaluate
lgb_preds = lgb_model.predict(X)
lgb_mae = np.mean(np.abs(lgb_preds - y))
lgb_mse = mean_squared_error(y, lgb_preds)
lgb_rmse = np.sqrt(lgb_mse)
lgb_r2 = r2_score(y, lgb_preds)

print(f"LightGBM Baseline - MAE: {lgb_mae:.4f}, MSE: {lgb_mse:.4f}, RMSE: {lgb_rmse:.4f}, R²: {lgb_r2:.4f}")

# =============================================================================
# PART 10: IMPROVED SUBMISSION GENERATION
# =============================================================================

def create_improved_submission(model, test_dir, output_file='submission.csv'):
    """Create improved submission file with proper uncertainty handling"""
    print("Creating improved submission...")
    
    # Load test data
    test_df = pd.read_csv(DATA_DIR / 'test.csv')
    submissions = []
    
    model.eval()
    augmentor = MedicalAugmentation(augment=False, target_size=TARGET_SIZE)
    
    for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
        patient_id = row['Patient']
        weeks = row['Weeks']
        
        try:
            patient_dir = Path(test_dir) / patient_id
            
            # Load 2.5D image
            img = load_three_slices(patient_dir, target_size=TARGET_SIZE)
            if img is None:
                raise ValueError("No DICOM files found")
                
            img_tensor = augmentor(img).unsqueeze(0).to(DEVICE)
            
            # Get tabular features
            tab_features = get_enhanced_tabular_features(patient_id, row)
            tab_tensor = torch.tensor(tab_features).float().unsqueeze(0).to(DEVICE)
            
            with torch.no_grad():
                mean_pred, log_var = model(img_tensor, tab_tensor)
                fvc_pred = mean_pred.item()
                sigma = np.sqrt(np.exp(log_var.item()))
                
                # Apply competition sigma floor (70) only at submission time
                confidence = max(sigma, 70.0)
            
            # For each required week in the test set
            patient_week = f"{patient_id}_{weeks}"
            
            # Use the model's prediction directly (no slope adjustment)
            submissions.append({
                'Patient_Week': patient_week,
                'FVC': fvc_pred,
                'Confidence': confidence
            })
                
        except Exception as e:
            # Fallback to baseline prediction
            if patient_id in TAB:
                tab_features = TAB[patient_id]
                fvc_pred = lgb_model.predict(tab_features.reshape(1, -1))[0]
            else:
                fvc_pred = 2500  # Average FVC
                
            submissions.append({
                'Patient_Week': f"{patient_id}_{weeks}",
                'FVC': fvc_pred,
                'Confidence': 200.0  # Conservative uncertainty
            })
    
    # Create submission file
    submission_df = pd.DataFrame(submissions)
    submission_df.to_csv(output_file, index=False)
    print(f"Submission saved to {output_file}")
    return submission_df

# Generate submission
if test_df is not None and TEST_DIR.exists():
    submission = create_improved_submission(model, TEST_DIR, 'submission.csv')
    print("✅ Submission ready!")
    print(submission.head())
else:
    print("No test data found - skipping submission")

print("🎉 All done!")