In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

osic_pulmonary_fibrosis_progression_path = kagglehub.competition_download('osic-pulmonary-fibrosis-progression')

print('Data source import complete.')


ClaudAI  -- >   Tabular only

In [None]:
# SECTION 1: IMPORTS AND CONFIGURATION
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed=42):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    SEED = 42
    TRAIN_SPLIT = 0.7
    VAL_SPLIT = 0.15
    TEST_SPLIT = 0.15
    BATCH_SIZE = 64
    MODEL_ARCHITECTURE = "ResNet"  # Options: "ResNet", "DenseNet", "EfficientNet"
    LEARNING_RATE = 1e-3
    NUM_EPOCHS = 100
    WEIGHT_DECAY = 1e-4
    DROPOUT_RATE = 0.3
    PATIENCE = 15
    MIN_DELTA = 0.001
    HIDDEN_DIMS = [512, 256, 128, 64]
    USE_BATCH_NORM = True

cfg = Config()
set_seed(cfg.SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# SECTION 2: DATA LOADING AND PREPROCESSING

def load_and_prepare_data():
    """
    Load data from train.csv and create proper train/val/test splits
    Modify this function based on your actual dataset structure
    """
    # Load your actual dataset - modify path as needed
    try:
        # Try to load the actual dataset
        df = pd.read_csv('/kaggle/input/your-dataset/train.csv')
        print(f"Loaded dataset with shape: {df.shape}")
    except:
        # Fallback: Create realistic synthetic data for demonstration
        print("Creating synthetic dataset for demonstration...")
        np.random.seed(cfg.SEED)
        n_samples = 2000
        n_features = 20

        # Create correlated features with some noise
        X = np.random.randn(n_samples, n_features)
        # Create a meaningful target with some non-linear relationships
        y = (3 * X[:, 0] + 2 * X[:, 1] - 1.5 * X[:, 2] +
             0.5 * X[:, 0] * X[:, 1] + 0.3 * X[:, 2]**2 +
             np.random.randn(n_samples) * 0.1)

        # Create DataFrame
        feature_names = [f'feature_{i}' for i in range(n_features)]
        df = pd.DataFrame(X, columns=feature_names)
        df['target'] = y

    # Separate features and target
    if 'target' in df.columns:
        target_col = 'target'
    else:
        # Modify this to match your actual target column name
        target_col = df.columns[-1]  # Assume last column is target

    X = df.drop(columns=[target_col])
    y = df[target_col].values

    print(f"Features shape: {X.shape}")
    print(f"Target shape: {y.shape}")
    print(f"Target statistics: mean={y.mean():.4f}, std={y.std():.4f}")

    # Create train/val/test splits
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=(cfg.VAL_SPLIT + cfg.TEST_SPLIT),
        random_state=cfg.SEED, shuffle=True
    )

    val_size = cfg.VAL_SPLIT / (cfg.VAL_SPLIT + cfg.TEST_SPLIT)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=(1-val_size),
        random_state=cfg.SEED, shuffle=True
    )

    print(f"Split sizes - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

    return X_train, X_val, X_test, y_train, y_val, y_test

def create_preprocessing_pipeline(X_train):
    """Create sklearn preprocessing pipeline"""
    numeric_features = X_train.select_dtypes(include=[np.number]).columns
    categorical_features = X_train.select_dtypes(include=['object']).columns

    preprocessors = []

    if len(numeric_features) > 0:
        numeric_transformer = Pipeline([
            ('scaler', StandardScaler())
        ])
        preprocessors.append(('num', numeric_transformer, numeric_features))

    if len(categorical_features) > 0:
        from sklearn.preprocessing import OneHotEncoder
        categorical_transformer = Pipeline([
            ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
        ])
        preprocessors.append(('cat', categorical_transformer, categorical_features))

    if preprocessors:
        preprocessing_pipeline = ColumnTransformer(
            transformers=preprocessors,
            remainder='passthrough'
        )
    else:
        preprocessing_pipeline = StandardScaler()

    return preprocessing_pipeline

# SECTION 3: ADVANCED MODEL ARCHITECTURES

class ResNetBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(out_dim) if cfg.USE_BATCH_NORM else nn.Identity()
        self.bn2 = nn.BatchNorm1d(out_dim) if cfg.USE_BATCH_NORM else nn.Identity()
        self.dropout = nn.Dropout(dropout_rate)
        self.shortcut = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = F.relu(self.bn1(self.fc1(x)))
        out = self.dropout(out)
        out = self.bn2(self.fc2(out))
        out += shortcut
        return F.relu(out)

class DenseBlock(nn.Module):
    def __init__(self, in_dim, growth_rate=32, num_layers=3, dropout_rate=0.1):
        super().__init__()
        self.layers = nn.ModuleList()
        current_dim = in_dim

        for i in range(num_layers):
            layer = nn.Sequential(
                nn.BatchNorm1d(current_dim) if cfg.USE_BATCH_NORM else nn.Identity(),
                nn.ReLU(),
                nn.Linear(current_dim, growth_rate),
                nn.Dropout(dropout_rate)
            )
            self.layers.append(layer)
            current_dim += growth_rate

        self.output_dim = current_dim

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_feature = layer(torch.cat(features, dim=1))
            features.append(new_feature)
        return torch.cat(features, dim=1)

class ResNetRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout_rate=0.3):
        super().__init__()
        self.input_bn = nn.BatchNorm1d(input_dim) if cfg.USE_BATCH_NORM else nn.Identity()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(ResNetBlock(prev_dim, hidden_dim, dropout_rate))
            prev_dim = hidden_dim

        self.layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(prev_dim, 1)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.input_bn(x)
        x = self.layers(x)
        x = self.dropout(x)
        return self.output_layer(x).squeeze()

class DenseNetRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout_rate=0.3):
        super().__init__()
        self.input_bn = nn.BatchNorm1d(input_dim) if cfg.USE_BATCH_NORM else nn.Identity()

        blocks = []
        current_dim = input_dim

        for hidden_dim in hidden_dims[:-1]:
            block = DenseBlock(current_dim, growth_rate=hidden_dim//4, num_layers=3, dropout_rate=dropout_rate)
            blocks.append(block)
            current_dim = block.output_dim

            # Transition layer
            transition = nn.Sequential(
                nn.BatchNorm1d(current_dim) if cfg.USE_BATCH_NORM else nn.Identity(),
                nn.ReLU(),
                nn.Linear(current_dim, hidden_dims[-1]),
                nn.Dropout(dropout_rate)
            )
            blocks.append(transition)
            current_dim = hidden_dims[-1]

        self.features = nn.Sequential(*blocks)
        self.output_layer = nn.Linear(current_dim, 1)

    def forward(self, x):
        x = self.input_bn(x)
        x = self.features(x)
        return self.output_layer(x).squeeze()

class EfficientNetRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout_rate=0.3):
        super().__init__()
        self.input_bn = nn.BatchNorm1d(input_dim) if cfg.USE_BATCH_NORM else nn.Identity()

        # Efficient scaling: gradually reduce dimensions
        layers = []
        prev_dim = input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            # Add squeeze-and-excitation-like attention
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim) if cfg.USE_BATCH_NORM else nn.Identity(),
                nn.SiLU(),  # Swish activation
                nn.Dropout(dropout_rate * (0.5 + 0.5 * i / len(hidden_dims)))  # Progressive dropout
            ])
            prev_dim = hidden_dim

        self.layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(prev_dim, 1)

    def forward(self, x):
        x = self.input_bn(x)
        x = self.layers(x)
        return self.output_layer(x).squeeze()

def create_model(model_name, input_dim):
    """Factory function to create models"""
    models = {
        "ResNet": ResNetRegressor,
        "DenseNet": DenseNetRegressor,
        "EfficientNet": EfficientNetRegressor
    }

    if model_name not in models:
        raise ValueError(f"Unsupported model: {model_name}. Choose from {list(models.keys())}")

    model = models[model_name](
        input_dim=input_dim,
        hidden_dims=cfg.HIDDEN_DIMS,
        dropout_rate=cfg.DROPOUT_RATE
    ).to(DEVICE)

    # Initialize weights
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    model.apply(init_weights)
    return model

# SECTION 4: TRAINING AND EVALUATION

def laplace_log_likelihood(actual, predicted, confidence=1.0):
    """Robust Laplace Log Likelihood metric"""
    actual, predicted = np.array(actual), np.array(predicted)
    confidence = np.maximum(confidence, 1e-6)  # Avoid division by zero
    delta = np.abs(actual - predicted)
    metric = -np.sqrt(2) * delta / confidence - np.log(np.sqrt(2) * confidence)
    return np.mean(metric)

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.restore_best_weights and self.best_weights:
                    model.load_state_dict(self.best_weights)
                print(f"\nEarly stopping triggered after {self.counter} epochs without improvement.")

def train_epoch(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    all_preds, all_targets = [], []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        total_loss += loss.item()
        all_preds.extend(output.detach().cpu().numpy())
        all_targets.extend(target.detach().cpu().numpy())

    avg_loss = total_loss / len(train_loader)
    r2 = r2_score(all_targets, all_preds)
    return avg_loss, r2

def validate_epoch(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_targets = [], []

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            all_preds.extend(output.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    avg_loss = total_loss / len(val_loader)
    r2 = r2_score(all_targets, all_preds)
    lll = laplace_log_likelihood(all_targets, all_preds, confidence=np.std(all_preds))
    return avg_loss, r2, lll

def train_and_validate(model, train_loader, val_loader, optimizer, criterion, scheduler, early_stopper):
    train_losses, val_losses, train_r2s, val_r2s = [], [], [], []

    for epoch in range(cfg.NUM_EPOCHS):
        # Training
        train_loss, train_r2 = train_epoch(model, train_loader, optimizer, criterion)
        train_losses.append(train_loss)
        train_r2s.append(train_r2)

        # Validation
        val_loss, val_r2, val_lll = validate_epoch(model, val_loader, criterion)
        val_losses.append(val_loss)
        val_r2s.append(val_r2)

        # Learning rate scheduling
        if scheduler:
            scheduler.step(val_loss)

        # Print progress
        if (epoch + 1) % 5 == 0 or epoch < 10:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch+1:3d}: Train Loss={train_loss:.4f} R²={train_r2:.4f} | "
                  f"Val Loss={val_loss:.4f} R²={val_r2:.4f} LLL={val_lll:.4f} | LR={current_lr:.2e}")

        # Early stopping
        early_stopper(val_loss, model)
        if early_stopper.early_stop:
            break

    return model, (train_losses, val_losses, train_r2s, val_r2s)

# SECTION 5: MAIN EXECUTION AND EVALUATION

def main():
    print("="*60)
    print("ROBUST ML PIPELINE - STARTING EXECUTION")
    print("="*60)

    # Load and prepare data
    print("\n1. Loading and preparing data...")
    X_train, X_val, X_test, y_train, y_val, y_test = load_and_prepare_data()

    # Create preprocessing pipeline
    print("\n2. Creating preprocessing pipeline...")
    preprocessing_pipeline = create_preprocessing_pipeline(X_train)

    # Fit preprocessing on training data only
    X_train_processed = preprocessing_pipeline.fit_transform(X_train)
    X_val_processed = preprocessing_pipeline.transform(X_val)
    X_test_processed = preprocessing_pipeline.transform(X_test)

    # Scale targets for better convergence
    from sklearn.preprocessing import StandardScaler
    target_scaler = StandardScaler()
    y_train_scaled = target_scaler.fit_transform(y_train.reshape(-1, 1)).ravel()
    y_val_scaled = target_scaler.transform(y_val.reshape(-1, 1)).ravel()
    y_test_scaled = target_scaler.transform(y_test.reshape(-1, 1)).ravel()

    print(f"Processed feature dimension: {X_train_processed.shape[1]}")

    # Create data loaders
    train_dataset = TensorDataset(
        torch.FloatTensor(X_train_processed),
        torch.FloatTensor(y_train_scaled)
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(X_val_processed),
        torch.FloatTensor(y_val_scaled)
    )
    test_dataset = TensorDataset(
        torch.FloatTensor(X_test_processed),
        torch.FloatTensor(y_test_scaled)
    )

    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)

    # Create model
    print(f"\n3. Creating {cfg.MODEL_ARCHITECTURE} model...")
    model = create_model(cfg.MODEL_ARCHITECTURE, X_train_processed.shape[1])
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Setup training components
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.LEARNING_RATE,
        weight_decay=cfg.WEIGHT_DECAY
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=cfg.PATIENCE//2, verbose=True
    )
    early_stopper = EarlyStopping(patience=cfg.PATIENCE, min_delta=cfg.MIN_DELTA)

    # Train model
    print(f"\n4. Training model for up to {cfg.NUM_EPOCHS} epochs...")
    print("-" * 80)

    model, history = train_and_validate(
        model, train_loader, val_loader, optimizer, criterion, scheduler, early_stopper
    )

    # Final evaluation on test set
    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)

    model.eval()
    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_preds.extend(output.cpu().numpy())
            test_targets.extend(target.cpu().numpy())

    # Convert back to original scale
    test_preds_original = target_scaler.inverse_transform(np.array(test_preds).reshape(-1, 1)).ravel()
    test_targets_original = target_scaler.inverse_transform(np.array(test_targets).reshape(-1, 1)).ravel()

    # Calculate metrics on original scale
    mae = mean_absolute_error(test_targets_original, test_preds_original)
    mse = mean_squared_error(test_targets_original, test_preds_original)
    r2 = r2_score(test_targets_original, test_preds_original)
    lll = laplace_log_likelihood(test_targets_original, test_preds_original,
                                confidence=np.std(test_preds_original))

    print(f"\nFinal Test Metrics:")
    print(f"MAE:  {mae:.4f}")
    print(f"MSE:  {mse:.4f}")
    print(f"R²:   {r2:.4f}")
    print(f"LLL:  {lll:.4f}")

    # Create visualizations
    print("\n5. Creating diagnostic plots...")

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Loss curves
    train_losses, val_losses, train_r2s, val_r2s = history
    axes[0, 0].plot(train_losses, label='Train Loss', alpha=0.8)
    axes[0, 0].plot(val_losses, label='Validation Loss', alpha=0.8)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss (MSE)')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # R² curves
    axes[0, 1].plot(train_r2s, label='Train R²', alpha=0.8)
    axes[0, 1].plot(val_r2s, label='Validation R²', alpha=0.8)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('R² Score')
    axes[0, 1].set_title('R² Score Progress')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Predictions vs Actuals
    axes[1, 0].scatter(test_targets_original, test_preds_original, alpha=0.6, s=20)
    min_val = min(test_targets_original.min(), test_preds_original.min())
    max_val = max(test_targets_original.max(), test_preds_original.max())
    axes[1, 0].plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    axes[1, 0].set_xlabel('Actual Values')
    axes[1, 0].set_ylabel('Predicted Values')
    axes[1, 0].set_title(f'Predictions vs Actuals (R² = {r2:.4f})')
    axes[1, 0].grid(True, alpha=0.3)

    # Residual plot
    residuals = test_targets_original - test_preds_original
    axes[1, 1].scatter(test_preds_original, residuals, alpha=0.6, s=20)
    axes[1, 1].axhline(y=0, color='r', linestyle='--', alpha=0.8)
    axes[1, 1].set_xlabel('Predicted Values')
    axes[1, 1].set_ylabel('Residuals')
    axes[1, 1].set_title('Residual Plot')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print("\n" + "="*60)
    print("PIPELINE EXECUTION COMPLETED SUCCESSFULLY!")
    print("="*60)

    return model, preprocessing_pipeline, target_scaler

# Execute the pipeline
if __name__ == "__main__":
    model, preprocessing_pipeline, target_scaler = main()

Inculdes CT Scan

In [None]:
# =============================================================================
# COMPLETE MULTI-MODAL OSIC PULMONARY FIBROSIS MODEL
# Integrating Tabular Data + CT DICOM Images
# Professional Implementation with 10+ Years ML Experience
# =============================================================================

# SECTION 1: IMPORTS AND CONFIGURATION
import os
import gc
import random
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import logging
from tqdm.auto import tqdm

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler

# Medical Imaging
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut
import cv2

# Vision Models
import torchvision.models as models
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights

# ML Tools
from sklearn.model_selection import GroupKFold, train_test_split
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# Augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# =============================================================================
# CONFIGURATION
# =============================================================================

@dataclass
from dataclasses import dataclass, field
from typing import List, Tuple

@dataclass
class Config:
    """Centralized configuration for early detection pipeline"""
    # Paths
    data_dir: Path = Path("/kaggle/input/lung-fibrosis-detection")
    output_dir: Path = Path("/kaggle/working")

    # Data Processing
    img_size: int = 224
    n_slices: int = 5
    window_center: int = -600
    window_width: int = 1500

    # Classification Classes
    classes: List[str] = field(default_factory=lambda: ["Normal", "Early_Fibrosis", "Advanced_Fibrosis"])
    n_classes: int = 3

    # Model Architecture
    model_type: str = "MultiModal"
    backbone: str = "efficientnet"
    tabular_hidden_dims: Tuple[int, ...] = (512, 256, 128, 64)
    fusion_method: str = "attention"
    dropout_rate: float = 0.4
    use_batch_norm: bool = True

    # Training
    n_folds: int = 5
    batch_size: int = 32
    num_epochs: int = 100
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    scheduler: str = 'cosine'
    patience: int = 15
    min_delta: float = 0.001
    gradient_clip: float = 1.0

    # Class Balance
    use_class_weights: bool = True
    focal_loss_alpha: float = 0.25
    focal_loss_gamma: float = 2.0

    # Augmentation
    use_augmentation: bool = True
    aug_prob: float = 0.7
    use_mixup: bool = False

    # Advanced
    use_mixed_precision: bool = True
    num_workers: int = 2
    pin_memory: bool = True
    seed: int = 42

    # Validation
    val_split: float = 0.15
    test_split: float = 0.15


config = Config()

# =============================================================================
# REPRODUCIBILITY
# =============================================================================

def set_seed(seed: int = 42):
    """Ensure reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")

# =============================================================================
# SECTION 2: DATA LOADING AND PREPROCESSING
# =============================================================================

def load_osic_data():
    """Load OSIC dataset with proper handling"""
    try:
        # Load actual OSIC data
        train_df = pd.read_csv(config.data_dir / "train.csv")
        logger.info(f"Loaded OSIC dataset with shape: {train_df.shape}")

        # Basic data validation
        required_cols = ['Patient', 'Weeks', 'FVC', 'Percent', 'Age', 'Sex', 'SmokingStatus']
        assert all(col in train_df.columns for col in required_cols), "Missing required columns"

        return train_df
    except Exception as e:
        logger.warning(f"Could not load OSIC data: {e}")
        logger.info("Creating synthetic data for demonstration...")

        # Create realistic synthetic OSIC-like data
        np.random.seed(config.seed)
        n_patients = 100
        records = []

        for patient_id in range(n_patients):
            n_visits = np.random.randint(3, 10)
            baseline_fvc = np.random.normal(2500, 700)
            age = np.random.randint(50, 85)
            sex = np.random.choice(['Male', 'Female'])
            smoking = np.random.choice(['Never smoked', 'Ex-smoker', 'Currently smokes'])

            for visit in range(n_visits):
                week = visit * np.random.randint(1, 6)
                # Simulate FVC decline
                fvc = baseline_fvc - np.random.normal(5, 2) * week
                fvc = max(fvc, 500)  # Minimum FVC

                records.append({
                    'Patient': f'ID{patient_id:04d}',
                    'Weeks': week,
                    'FVC': fvc,
                    'Percent': np.random.normal(50, 15),
                    'Age': age,
                    'Sex': sex,
                    'SmokingStatus': smoking
                })

        return pd.DataFrame(records)

# =============================================================================
# SECTION 3: MEDICAL IMAGE PROCESSING
# =============================================================================

class DicomProcessor:
    """Professional DICOM processing with proper windowing"""

    def __init__(self, window_center: int = -600, window_width: int = 1500):
        self.window_center = window_center
        self.window_width = window_width

    def load_dicom(self, path: str) -> Optional[np.ndarray]:
        """Load and preprocess DICOM file"""
        try:
            dcm = pydicom.dcmread(path)

            # Apply DICOM transformations
            img = dcm.pixel_array.astype(np.float32)

            # Apply modality LUT
            if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
                img = img * dcm.RescaleSlope + dcm.RescaleIntercept

            # Apply windowing
            img = self.apply_windowing(img)

            # Normalize to [0, 1]
            img = (img - img.min()) / (img.max() - img.min() + 1e-6)

            return img.astype(np.float32)

        except Exception as e:
            logger.debug(f"Failed to load DICOM {path}: {e}")
            return None

    def apply_windowing(self, img: np.ndarray) -> np.ndarray:
        """Apply lung window settings"""
        min_val = self.window_center - self.window_width // 2
        max_val = self.window_center + self.window_width // 2
        return np.clip(img, min_val, max_val)

# =============================================================================
# SECTION 4: FEATURE ENGINEERING
# =============================================================================

class TabularFeatureEngineer:
    """Advanced feature engineering for tabular data"""

    def __init__(self, train_df: pd.DataFrame):
        self.train_df = train_df
        self.scaler = RobustScaler()
        self.patient_features = {}
        self._prepare_features()

    def _prepare_features(self):
        """Create comprehensive feature set"""
        for patient in self.train_df['Patient'].unique():
            patient_data = self.train_df[self.train_df['Patient'] == patient].sort_values('Weeks')

            # Baseline measurements
            baseline = patient_data.iloc[0]

            # Calculate FVC trajectory
            if len(patient_data) > 1:
                weeks = patient_data['Weeks'].values
                fvc = patient_data['FVC'].values
                # Linear regression for slope
                slope = np.polyfit(weeks, fvc, 1)[0] if len(weeks) > 1 else 0
                std_dev = np.std(fvc)
            else:
                slope = 0
                std_dev = 0

            self.patient_features[patient] = {
                'Age': baseline['Age'],
                'Sex': baseline['Sex'],
                'SmokingStatus': baseline['SmokingStatus'],
                'BaselineFVC': baseline['FVC'],
                'BaselinePercent': baseline['Percent'],
                'BaselineWeeks': baseline['Weeks'],
                'FVCSlope': slope,
                'FVCStdDev': std_dev,
                'NumMeasurements': len(patient_data)
            }

        # Fit scaler
        self._fit_scaler()

    def _fit_scaler(self):
        """Fit scaler on all features"""
        features = []
        for stats in self.patient_features.values():
            features.append(self._encode_features(stats, 0))
        self.scaler.fit(features)

    def _encode_features(self, stats: Dict, current_week: float) -> np.ndarray:
        """Encode features for a given patient and week"""
        # One-hot encode categorical
        sex_male = 1 if stats['Sex'] == 'Male' else 0

        smoke_never = 1 if stats['SmokingStatus'] == 'Never smoked' else 0
        smoke_ex = 1 if stats['SmokingStatus'] == 'Ex-smoker' else 0
        smoke_current = 1 if stats['SmokingStatus'] == 'Currently smokes' else 0

        # Time features
        week_delta = current_week - stats['BaselineWeeks']
        week_squared = week_delta ** 2

        # Interaction features
        age_week = stats['Age'] * week_delta / 100

        # Expected FVC
        expected_fvc = stats['BaselineFVC'] + stats['FVCSlope'] * week_delta

        features = [
            stats['Age'] / 100,
            sex_male,
            smoke_never,
            smoke_ex,
            smoke_current,
            stats['BaselineFVC'] / 5000,
            stats['BaselinePercent'] / 100,
            week_delta / 52,
            week_squared / (52 ** 2),
            stats['FVCSlope'] / 100,
            stats['FVCStdDev'] / 1000,
            age_week,
            expected_fvc / 5000,
            stats['NumMeasurements'] / 10
        ]

        return np.array(features, dtype=np.float32)

    def get_features(self, patient_id: str, week: float) -> np.ndarray:
        """Get scaled features"""
        if patient_id not in self.patient_features:
            # Return zero features if patient not found
            return np.zeros(14, dtype=np.float32)

        features = self._encode_features(self.patient_features[patient_id], week)
        return self.scaler.transform([features])[0]

# =============================================================================
# SECTION 5: MULTI-MODAL DATASET
# =============================================================================

class OSICMultiModalDataset(Dataset):
    """Dataset combining tabular and image data"""

    def __init__(
        self,
        patient_data: pd.DataFrame,
        feature_engineer: TabularFeatureEngineer,
        img_dir: Optional[Path] = None,
        transform: Optional[A.Compose] = None,
        n_slices: int = 3,
        is_train: bool = True
    ):
        self.patient_data = patient_data
        self.feature_engineer = feature_engineer
        self.img_dir = img_dir
        self.transform = transform
        self.n_slices = n_slices
        self.is_train = is_train
        self.dicom_processor = DicomProcessor()

        # Prepare samples
        self.samples = []
        self.patient_images = {}

        # Check for DICOM availability
        self.has_images = img_dir is not None and img_dir.exists()

        if self.has_images:
            self._load_patient_images()

        # Create samples
        for _, row in patient_data.iterrows():
            self.samples.append({
                'patient': row['Patient'],
                'week': row['Weeks'],
                'fvc': row['FVC']
            })

        logger.info(f"Dataset created with {len(self.samples)} samples")

    def _load_patient_images(self):
        """Load available DICOM images for each patient"""
        for patient in self.patient_data['Patient'].unique():
            patient_dir = self.img_dir / patient

            if patient_dir.exists():
                dicom_files = sorted(patient_dir.glob("*.dcm"))

                if dicom_files:
                    # Select evenly spaced slices
                    n_files = len(dicom_files)
                    if n_files >= self.n_slices:
                        indices = np.linspace(0, n_files-1, self.n_slices, dtype=int)
                    else:
                        indices = list(range(n_files))

                    self.patient_images[patient] = [dicom_files[i] for i in indices]
                else:
                    self.patient_images[patient] = None
            else:
                self.patient_images[patient] = None

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        patient = sample['patient']
        week = sample['week']
        fvc = sample['fvc']

        # Get tabular features
        tabular = self.feature_engineer.get_features(patient, week)

        # Get image features
        if self.has_images and patient in self.patient_images:
            slices = self.patient_images.get(patient)
            if slices:
                images = []
                for slice_path in slices[:3]:  # Maximum 3 slices
                    img = self.dicom_processor.load_dicom(str(slice_path))
                    if img is not None:
                        # Resize to target size
                        img = cv2.resize(img, (config.img_size, config.img_size))
                        images.append(img)

                # Create 3-channel image
                if len(images) >= 3:
                    img = np.stack(images[:3], axis=-1)
                elif len(images) == 2:
                    img = np.stack([images[0], images[1], images[1]], axis=-1)
                elif len(images) == 1:
                    img = np.stack([images[0], images[0], images[0]], axis=-1)
                else:
                    img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)
            else:
                img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)
        else:
            # No images available - use zeros
            img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
        else:
            img = torch.from_numpy(img.transpose(2, 0, 1)).float()

        return {
            'image': img,
            'tabular': torch.tensor(tabular, dtype=torch.float32),
            'target': torch.tensor(fvc, dtype=torch.float32)
        }

# =============================================================================
# SECTION 6: AUGMENTATION
# =============================================================================

def get_transforms(is_train: bool = True, img_size: int = 224) -> A.Compose:
    """Get augmentation pipeline"""
    if is_train:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1,
                scale_limit=0.1,
                rotate_limit=10,
                p=0.5
            ),
            A.OneOf([
                A.GaussNoise(var_limit=(10, 50)),
                A.GaussianBlur(blur_limit=3),
                A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),
            ], p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

# =============================================================================
# SECTION 7: NEURAL NETWORK ARCHITECTURES
# =============================================================================

# Tabular Network Components
class ResNetBlock(nn.Module):
    """Residual block for tabular data"""
    def __init__(self, in_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(out_dim) if config.use_batch_norm else nn.Identity()
        self.bn2 = nn.BatchNorm1d(out_dim) if config.use_batch_norm else nn.Identity()
        self.dropout = nn.Dropout(dropout_rate)
        self.shortcut = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = F.relu(self.bn1(self.fc1(x)))
        out = self.dropout(out)
        out = self.bn2(self.fc2(out))
        out += shortcut
        return F.relu(out)

# Attention Mechanisms
class CrossModalAttention(nn.Module):
    """Cross-modal attention between image and tabular features"""
    def __init__(self, img_dim, tab_dim, hidden_dim=256):
        super().__init__()
        self.img_proj = nn.Linear(img_dim, hidden_dim)
        self.tab_proj = nn.Linear(tab_dim, hidden_dim)
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)

    def forward(self, img_feat, tab_feat):
        # Project features
        img_proj = self.img_proj(img_feat).unsqueeze(1)  # [B, 1, hidden]
        tab_proj = self.tab_proj(tab_feat).unsqueeze(1)  # [B, 1, hidden]

        # Cross attention
        combined = torch.cat([img_proj, tab_proj], dim=1)  # [B, 2, hidden]
        attended, _ = self.attention(combined, combined, combined)

        # Aggregate
        return attended.mean(dim=1)  # [B, hidden]

# Main Multi-Modal Model
class OSICMultiModalModel(nn.Module):
    """Complete multi-modal architecture"""

    def __init__(
        self,
        tabular_dim: int = 14,
        hidden_dims: Tuple[int] = (512, 256, 128, 64),
        dropout_rate: float = 0.3,
        backbone: str = "efficientnet",
        fusion_method: str = "attention"
    ):
        super().__init__()

        # Image backbone
        if backbone == "efficientnet":
            self.image_backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT)
            img_feat_dim = self.image_backbone.classifier[1].in_features
            self.image_backbone.classifier = nn.Identity()
        elif backbone == "resnet50":
            self.image_backbone = models.resnet50(pretrained=True)
            img_feat_dim = self.image_backbone.fc.in_features
            self.image_backbone.fc = nn.Identity()
        else:
            self.image_backbone = models.densenet121(pretrained=True)
            img_feat_dim = self.image_backbone.classifier.in_features
            self.image_backbone.classifier = nn.Identity()

        # Tabular network
        tab_layers = []
        prev_dim = tabular_dim
        for hidden_dim in hidden_dims:
            tab_layers.append(ResNetBlock(prev_dim, hidden_dim, dropout_rate))
            prev_dim = hidden_dim
        self.tabular_net = nn.Sequential(*tab_layers)
        tab_feat_dim = hidden_dims[-1]

        # Fusion method
        self.fusion_method = fusion_method
        if fusion_method == "attention":
            self.fusion = CrossModalAttention(img_feat_dim, tab_feat_dim)
            fusion_dim = 256
        elif fusion_method == "gated":
            self.gate = nn.Sequential(
                nn.Linear(img_feat_dim + tab_feat_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 2),
                nn.Softmax(dim=1)
            )
            fusion_dim = img_feat_dim + tab_feat_dim
        else:  # concat
            fusion_dim = img_feat_dim + tab_feat_dim

        # Final regression head
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.BatchNorm1d(256) if config.use_batch_norm else nn.Identity(),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128) if config.use_batch_norm else nn.Identity(),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 1)
        )

        # Uncertainty head (for confidence estimation)
        self.uncertainty_head = nn.Sequential(
            nn.Linear(fusion_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights properly"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, image, tabular):
        # Extract features
        img_features = self.image_backbone(image)
        tab_features = self.tabular_net(tabular)

        # Fusion
        if self.fusion_method == "attention":
            combined = self.fusion(img_features, tab_features)
        elif self.fusion_method == "gated":
            concat_features = torch.cat([img_features, tab_features], dim=1)
            gates = self.gate(concat_features)
            combined = gates[:, 0:1] * img_features + gates[:, 1:2] * tab_features
        else:
            combined = torch.cat([img_features, tab_features], dim=1)

        # Predictions
        prediction = self.head(combined).squeeze(-1)
        log_variance = self.uncertainty_head(combined).squeeze(-1)

        return prediction, log_variance

# =============================================================================
# SECTION 8: LOSS FUNCTIONS AND METRICS
# =============================================================================

class RobustLoss(nn.Module):
    """Combined loss with uncertainty"""
    def __init__(self, alpha: float = 0.7, beta: float = 0.3):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()

    def forward(self, pred_mean, pred_log_var, target):
        # Heteroscedastic uncertainty loss
        precision = torch.exp(-pred_log_var)
        mse_loss = precision * (pred_mean - target) ** 2 + pred_log_var
        nll_loss = 0.5 * torch.mean(mse_loss)

        # MAE for robustness
        mae_loss = self.mae(pred_mean, target)

        # Combined
        total_loss = self.alpha * nll_loss + self.beta * mae_loss

        return total_loss

def laplace_log_likelihood(y_true, y_pred, sigma):
    """Calculate Laplace Log Likelihood"""
    sigma = np.maximum(sigma, 70)  # OSIC specific
    delta = np.minimum(np.abs(y_true - y_pred), 1000)
    metric = -np.sqrt(2) * delta / sigma - np.log(np.sqrt(2) * sigma)
    return np.mean(metric)

# =============================================================================
# SECTION 9: TRAINING PIPELINE
# =============================================================================

class EarlyStopping:
    """Early stopping with best model restoration"""
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = -float('inf')
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def __call__(self, score, model):
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
            self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.best_weights:
                    model.load_state_dict(self.best_weights)
                logger.info(f"Early stopping triggered after {self.counter} epochs")

def train_epoch(model, loader, optimizer, criterion, scaler, config):
    """Train for one epoch"""
    model.train()
    losses = []

    for batch in tqdm(loader, desc='Training'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        optimizer.zero_grad()

        # Mixed precision training
        with autocast(enabled=config.use_mixed_precision):
            predictions, log_var = model(images, tabular)
            loss = criterion(predictions, log_var, targets)

        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()

        losses.append(loss.item())

    return np.mean(losses)

@torch.no_grad()
def validate_epoch(model, loader, criterion):
    """Validate for one epoch"""
    model.eval()
    losses = []
    all_preds = []
    all_targets = []
    all_sigmas = []

    for batch in tqdm(loader, desc='Validation'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        predictions, log_var = model(images, tabular)
        loss = criterion(predictions, log_var, targets)

        losses.append(loss.item())
        all_preds.extend(predictions.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())
        all_sigmas.extend(torch.exp(0.5 * log_var).cpu().numpy())

    # Calculate metrics
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_sigmas = np.array(all_sigmas)

    mae = mean_absolute_error(all_targets, all_preds)
    rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
    r2 = r2_score(all_targets, all_preds)
    lll = laplace_log_likelihood(all_targets, all_preds, all_sigmas)

    return np.mean(losses), {'mae': mae, 'rmse': rmse, 'r2': r2, 'lll': lll}

# =============================================================================
# SECTION 10: MAIN EXECUTION
# =============================================================================

def main():
    """Main training pipeline"""
    print("="*80)
    print("OSIC MULTI-MODAL PIPELINE - COMPLETE IMPLEMENTATION")
    print("="*80)

    # Create output directory
    config.output_dir.mkdir(exist_ok=True, parents=True)

    # 1. Load and prepare data
    logger.info("Loading OSIC data...")
    train_df = load_osic_data()

    # 2. Create feature engineer
    logger.info("Creating feature engineer...")
    feature_engineer = TabularFeatureEngineer(train_df)

    # 3. Prepare patient-level splits (avoiding data leakage)
    patients = train_df['Patient'].unique()
    train_patients, temp_patients = train_test_split(
        patients,
        test_size=(config.val_split + config.test_split),
        random_state=config.seed
    )
    val_patients, test_patients = train_test_split(
        temp_patients,
        test_size=config.test_split/(config.val_split + config.test_split),
        random_state=config.seed
    )

    # Create datasets
    train_data = train_df[train_df['Patient'].isin(train_patients)]
    val_data = train_df[train_df['Patient'].isin(val_patients)]
    test_data = train_df[train_df['Patient'].isin(test_patients)]

    logger.info(f"Train samples: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

    # 4. Create transforms
    train_transform = get_transforms(is_train=True, img_size=config.img_size)
    val_transform = get_transforms(is_train=False, img_size=config.img_size)

    # 5. Create datasets
    img_dir = config.data_dir / "train" if (config.data_dir / "train").exists() else None

    train_dataset = OSICMultiModalDataset(
        train_data, feature_engineer, img_dir, train_transform, config.n_slices, is_train=True
    )
    val_dataset = OSICMultiModalDataset(
        val_data, feature_engineer, img_dir, val_transform, config.n_slices, is_train=False
    )
    test_dataset = OSICMultiModalDataset(
        test_data, feature_engineer, img_dir, val_transform, config.n_slices, is_train=False
    )

    # 6. Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )

    # 7. Create model
    logger.info(f"Creating {config.backbone} model...")
    model = OSICMultiModalModel(
        tabular_dim=14,  # Based on feature engineer
        hidden_dims=config.tabular_hidden_dims,
        dropout_rate=config.dropout_rate,
        backbone=config.backbone,
        fusion_method=config.fusion_method
    ).to(DEVICE)

    # Model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")

    # 8. Training setup
    criterion = RobustLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    if config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
    else:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, verbose=True
        )

    early_stopper = EarlyStopping(patience=config.patience, min_delta=config.min_delta)
    scaler = GradScaler() if config.use_mixed_precision else None

    # 9. Training loop
    logger.info("Starting training...")
    best_score = -float('inf')
    train_history = {'train_loss': [], 'val_loss': [], 'val_r2': [], 'val_lll': []}

    for epoch in range(config.num_epochs):
        logger.info(f"Epoch {epoch + 1}/{config.num_epochs}")

        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, config)

        # Validate
        val_loss, val_metrics = validate_epoch(model, val_loader, criterion)

        # Update scheduler
        if config.scheduler == 'cosine':
            scheduler.step()
        else:
            scheduler.step(val_metrics['r2'])

        # Log metrics
        logger.info(
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"R2: {val_metrics['r2']:.4f} | "
            f"MAE: {val_metrics['mae']:.4f} | "
            f"LLL: {val_metrics['lll']:.4f}"
        )

        # Save history
        train_history['train_loss'].append(train_loss)
        train_history['val_loss'].append(val_loss)
        train_history['val_r2'].append(val_metrics['r2'])
        train_history['val_lll'].append(val_metrics['lll'])

        # Early stopping
        current_score = val_metrics['r2']
        early_stopper(current_score, model)

        if current_score > best_score:
            best_score = current_score
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_score': best_score,
                'config': config
            }, config.output_dir / 'best_model.pth')
            logger.info(f"New best model saved with R2: {best_score:.4f}")

        if early_stopper.early_stop:
            break

        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()

    # 10. Final evaluation on test set
    logger.info("Evaluating on test set...")
    test_loss, test_metrics = validate_epoch(model, test_loader, criterion)

    print("\n" + "="*80)
    print("FINAL TEST SET RESULTS")
    print("="*80)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test MAE: {test_metrics['mae']:.4f}")
    print(f"Test RMSE: {test_metrics['rmse']:.4f}")
    print(f"Test R²: {test_metrics['r2']:.4f}")
    print(f"Test LLL: {test_metrics['lll']:.4f}")

    # 11. Create visualizations
    logger.info("Creating visualizations...")
    create_training_plots(train_history)

    # 12. Feature importance analysis (if possible)
    try:
        analyze_model_predictions(model, test_loader, test_data, feature_engineer)
    except Exception as e:
        logger.warning(f"Could not create prediction analysis: {e}")

    print("\n" + "="*80)
    print("PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*80)

    return model, feature_engineer, train_history

def create_training_plots(history):
    """Create training visualization plots"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curves
    axes[0, 0].plot(history['train_loss'], label='Train Loss', alpha=0.7)
    axes[0, 0].plot(history['val_loss'], label='Validation Loss', alpha=0.7)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # R² progression
    axes[0, 1].plot(history['val_r2'], label='Validation R²', color='green', alpha=0.7)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('R² Score')
    axes[0, 1].set_title('R² Score Progression')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # LLL progression
    axes[1, 0].plot(history['val_lll'], label='Validation LLL', color='red', alpha=0.7)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Laplace Log Likelihood')
    axes[1, 0].set_title('Laplace Log Likelihood Progression')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Learning curves summary
    axes[1, 1].text(0.1, 0.8, f"Best Validation R²: {max(history['val_r2']):.4f}",
                   transform=axes[1, 1].transAxes, fontsize=12)
    axes[1, 1].text(0.1, 0.6, f"Best Validation LLL: {max(history['val_lll']):.4f}",
                   transform=axes[1, 1].transAxes, fontsize=12)
    axes[1, 1].text(0.1, 0.4, f"Final Train Loss: {history['train_loss'][-1]:.4f}",
                   transform=axes[1, 1].transAxes, fontsize=12)
    axes[1, 1].text(0.1, 0.2, f"Final Val Loss: {history['val_loss'][-1]:.4f}",
                   transform=axes[1, 1].transAxes, fontsize=12)
    axes[1, 1].set_title('Training Summary')
    axes[1, 1].axis('off')

    plt.tight_layout()
    plt.savefig(config.output_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

@torch.no_grad()
def analyze_model_predictions(model, loader, data, feature_engineer):
    """Analyze model predictions and create diagnostic plots"""
    model.eval()

    predictions = []
    targets = []
    uncertainties = []

    for batch in tqdm(loader, desc='Analyzing predictions'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        batch_targets = batch['target'].cpu().numpy()

        pred_mean, log_var = model(images, tabular)
        pred_uncertainty = torch.exp(0.5 * log_var)

        predictions.extend(pred_mean.cpu().numpy())
        targets.extend(batch_targets)
        uncertainties.extend(pred_uncertainty.cpu().numpy())

    predictions = np.array(predictions)
    targets = np.array(targets)
    uncertainties = np.array(uncertainties)

    # Create diagnostic plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Predictions vs Targets
    axes[0, 0].scatter(targets, predictions, alpha=0.6, s=20)
    min_val, max_val = min(targets.min(), predictions.min()), max(targets.max(), predictions.max())
    axes[0, 0].plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    axes[0, 0].set_xlabel('True FVC')
    axes[0, 0].set_ylabel('Predicted FVC')
    axes[0, 0].set_title('Predictions vs True Values')
    axes[0, 0].grid(True, alpha=0.3)

    # Residual plot
    residuals = targets - predictions
    axes[0, 1].scatter(predictions, residuals, alpha=0.6, s=20)
    axes[0, 1].axhline(y=0, color='r', linestyle='--', alpha=0.8)
    axes[0, 1].set_xlabel('Predicted FVC')
    axes[0, 1].set_ylabel('Residuals')
    axes[0, 1].set_title('Residual Plot')
    axes[0, 1].grid(True, alpha=0.3)

    # Uncertainty vs Error
    abs_errors = np.abs(residuals)
    axes[1, 0].scatter(uncertainties, abs_errors, alpha=0.6, s=20)
    axes[1, 0].set_xlabel('Predicted Uncertainty')
    axes[1, 0].set_ylabel('Absolute Error')
    axes[1, 0].set_title('Uncertainty vs Error')
    axes[1, 0].grid(True, alpha=0.3)

    # Error distribution
    axes[1, 1].hist(residuals, bins=30, alpha=0.7, density=True)
    axes[1, 1].set_xlabel('Residuals')
    axes[1, 1].set_ylabel('Density')
    axes[1, 1].set_title('Residual Distribution')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(config.output_dir / 'prediction_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Print additional statistics
    print(f"\nPrediction Analysis:")
    print(f"Mean Absolute Error: {np.mean(abs_errors):.4f}")
    print(f"Root Mean Square Error: {np.sqrt(np.mean(residuals**2)):.4f}")
    print(f"R² Score: {r2_score(targets, predictions):.4f}")
    print(f"Mean Uncertainty: {np.mean(uncertainties):.4f}")
    print(f"Std Uncertainty: {np.std(uncertainties):.4f}")

# Execute the complete pipeline
if __name__ == "__main__":
    try:
        model, feature_engineer, history = main()
        logger.info("Pipeline completed successfully!")
    except Exception as e:
        logger.error(f"Pipeline failed with error: {e}")
        raise e


# Disease Progression Prediction
the above code

# =============================================================================
# EARLY DETECTION OF PULMONARY FIBROSIS - CLASSIFICATION PIPELINE
# Multi-Modal Deep Learning for Screening & Diagnosis
# Professional Implementation for Medical Imaging
# =============================================================================

In [None]:
# =============================================================================
# EARLY DETECTION OF PULMONARY FIBROSIS - CLASSIFICATION PIPELINE
# Multi-Modal Deep Learning for Screening & Diagnosis
# Professional Implementation for Medical Imaging
# =============================================================================

# SECTION 1: IMPORTS AND CONFIGURATION
import os
import gc
import random
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import logging
from tqdm.auto import tqdm

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler

# Medical Imaging
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut
import cv2

# Vision Models
import torchvision.models as models
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights

# ML Tools
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler, RobustScaler, LabelEncoder
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.utils.class_weight import compute_class_weight

# Augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# =============================================================================
# CONFIGURATION FOR EARLY DETECTION
# =============================================================================

from dataclasses import dataclass, field
from typing import List, Tuple

@dataclass
class Config:
    """Centralized configuration for early detection pipeline"""
    # Paths
    data_dir: Path = Path("/kaggle/input/lung-fibrosis-detection")
    output_dir: Path = Path("/kaggle/working")

    # Data Processing
    img_size: int = 224
    n_slices: int = 5
    window_center: int = -600
    window_width: int = 1500

    # Classification Classes
    classes: List[str] = field(default_factory=lambda: ["Normal", "Early_Fibrosis", "Advanced_Fibrosis"])
    n_classes: int = 3

    # Model Architecture
    model_type: str = "MultiModal"
    backbone: str = "efficientnet"
    tabular_hidden_dims: Tuple[int, ...] = (512, 256, 128, 64)
    fusion_method: str = "attention"
    dropout_rate: float = 0.4
    use_batch_norm: bool = True

    # Training
    n_folds: int = 5
    batch_size: int = 32
    num_epochs: int = 100
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    scheduler: str = 'cosine'
    patience: int = 15
    min_delta: float = 0.001
    gradient_clip: float = 1.0

    # Class Balance
    use_class_weights: bool = True
    focal_loss_alpha: float = 0.25
    focal_loss_gamma: float = 2.0

    # Augmentation
    use_augmentation: bool = True
    aug_prob: float = 0.7
    use_mixup: bool = False

    # Advanced
    use_mixed_precision: bool = True
    num_workers: int = 2
    pin_memory: bool = True
    seed: int = 42

    # Validation
    val_split: float = 0.15
    test_split: float = 0.15


config = Config()

# =============================================================================
# REPRODUCIBILITY
# =============================================================================

def set_seed(seed: int = 42):
    """Ensure reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")

# =============================================================================
# SECTION 2: SYNTHETIC DATA GENERATION FOR DEMONSTRATION
# =============================================================================

def create_synthetic_detection_data():
    """Create realistic synthetic data for lung fibrosis detection"""
    np.random.seed(config.seed)

    # Simulate patient cohort for screening study
    n_normal = 800      # Healthy controls
    n_early = 150       # Early fibrosis (subclinical)
    n_advanced = 100    # Advanced fibrosis

    records = []
    patient_id = 0

    # Normal patients (age 20-70, good lung function)
    for _ in range(n_normal):
        age = np.random.uniform(20, 70)
        sex = np.random.choice(['Male', 'Female'])
        smoking = np.random.choice(['Never', 'Former', 'Current'], p=[0.6, 0.3, 0.1])

        # Normal lung function
        fvc = np.random.normal(3500, 500)
        dlco = np.random.normal(85, 15)  # Diffusion capacity

        # Environmental/occupational exposure (lower risk)
        dust_exposure = np.random.choice(['None', 'Low', 'Moderate'], p=[0.7, 0.2, 0.1])
        family_history = np.random.choice([0, 1], p=[0.9, 0.1])

        records.append({
            'Patient_ID': f'P{patient_id:05d}',
            'Age': age,
            'Sex': sex,
            'Smoking_Status': smoking,
            'FVC': fvc,
            'DLCO': dlco,
            'Dust_Exposure': dust_exposure,
            'Family_History': family_history,
            'Dyspnea_Score': np.random.randint(0, 2),  # 0-1 for normal
            'Cough': np.random.choice([0, 1], p=[0.8, 0.2]),
            'Weight_Loss': 0,  # Rare in normal
            'Diagnosis': 'Normal'
        })
        patient_id += 1

    # Early fibrosis (age 40-75, subtle symptoms)
    for _ in range(n_early):
        age = np.random.uniform(40, 75)
        sex = np.random.choice(['Male', 'Female'], p=[0.6, 0.4])  # Male predominance
        smoking = np.random.choice(['Never', 'Former', 'Current'], p=[0.4, 0.5, 0.1])

        # Mildly reduced lung function
        fvc = np.random.normal(3000, 400)
        dlco = np.random.normal(70, 12)

        # Higher risk factors
        dust_exposure = np.random.choice(['None', 'Low', 'Moderate', 'High'], p=[0.3, 0.3, 0.3, 0.1])
        family_history = np.random.choice([0, 1], p=[0.7, 0.3])

        records.append({
            'Patient_ID': f'P{patient_id:05d}',
            'Age': age,
            'Sex': sex,
            'Smoking_Status': smoking,
            'FVC': fvc,
            'DLCO': dlco,
            'Dust_Exposure': dust_exposure,
            'Family_History': family_history,
            'Dyspnea_Score': np.random.randint(1, 4),  # 1-3 for early
            'Cough': np.random.choice([0, 1], p=[0.4, 0.6]),
            'Weight_Loss': np.random.choice([0, 1], p=[0.8, 0.2]),
            'Diagnosis': 'Early_Fibrosis'
        })
        patient_id += 1

    # Advanced fibrosis (age 50-80, clear symptoms)
    for _ in range(n_advanced):
        age = np.random.uniform(50, 80)
        sex = np.random.choice(['Male', 'Female'], p=[0.7, 0.3])
        smoking = np.random.choice(['Never', 'Former', 'Current'], p=[0.3, 0.6, 0.1])

        # Significantly reduced lung function
        fvc = np.random.normal(2200, 400)
        dlco = np.random.normal(50, 15)

        # High risk factors
        dust_exposure = np.random.choice(['Low', 'Moderate', 'High'], p=[0.2, 0.4, 0.4])
        family_history = np.random.choice([0, 1], p=[0.5, 0.5])

        records.append({
            'Patient_ID': f'P{patient_id:05d}',
            'Age': age,
            'Sex': sex,
            'Smoking_Status': smoking,
            'FVC': fvc,
            'DLCO': dlco,
            'Dust_Exposure': dust_exposure,
            'Family_History': family_history,
            'Dyspnea_Score': np.random.randint(3, 5),  # 3-4 for advanced
            'Cough': np.random.choice([0, 1], p=[0.2, 0.8]),
            'Weight_Loss': np.random.choice([0, 1], p=[0.4, 0.6]),
            'Diagnosis': 'Advanced_Fibrosis'
        })
        patient_id += 1

    df = pd.DataFrame(records)
    logger.info(f"Created synthetic dataset with {len(df)} patients")
    logger.info(f"Class distribution:\n{df['Diagnosis'].value_counts()}")

    return df

# =============================================================================
# SECTION 3: MEDICAL IMAGE PROCESSING (ENHANCED FOR DETECTION)
# =============================================================================

class DicomProcessor:
    """Enhanced DICOM processing for fibrosis detection"""

    def __init__(self, window_center: int = -600, window_width: int = 1500):
        self.window_center = window_center
        self.window_width = window_width

    def load_dicom(self, path: str) -> Optional[np.ndarray]:
        """Load and preprocess DICOM file with enhanced processing"""
        try:
            dcm = pydicom.dcmread(path)
            img = dcm.pixel_array.astype(np.float32)

            # Apply modality LUT
            if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
                img = img * dcm.RescaleSlope + dcm.RescaleIntercept

            # Apply windowing
            img = self.apply_windowing(img)

            # Enhanced preprocessing for fibrosis patterns
            img = self.enhance_fibrosis_patterns(img)

            # Normalize to [0, 1]
            img = (img - img.min()) / (img.max() - img.min() + 1e-6)

            return img.astype(np.float32)

        except Exception as e:
            logger.debug(f"Failed to load DICOM {path}: {e}")
            return None

    def apply_windowing(self, img: np.ndarray) -> np.ndarray:
        """Apply lung window settings"""
        min_val = self.window_center - self.window_width // 2
        max_val = self.window_center + self.window_width // 2
        return np.clip(img, min_val, max_val)

    def enhance_fibrosis_patterns(self, img: np.ndarray) -> np.ndarray:
        """Enhance patterns relevant to fibrosis detection"""
        # CLAHE for contrast enhancement
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        img_uint8 = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
        enhanced = clahe.apply(img_uint8).astype(np.float32)

        return enhanced

# =============================================================================
# SECTION 4: FEATURE ENGINEERING FOR DETECTION
# =============================================================================

class DetectionFeatureEngineer:
    """Feature engineering for early detection"""

    def __init__(self, train_df: pd.DataFrame):
        self.train_df = train_df
        self.scaler = StandardScaler()
        self.label_encoders = {}
        self._prepare_features()

    def _prepare_features(self):
        """Create comprehensive feature set for detection"""
        # Prepare categorical encoders
        categorical_features = ['Sex', 'Smoking_Status', 'Dust_Exposure']
        for feature in categorical_features:
            le = LabelEncoder()
            le.fit(self.train_df[feature])
            self.label_encoders[feature] = le

        # Prepare numerical features
        numerical_features = self._extract_numerical_features(self.train_df)
        self.scaler.fit(numerical_features)

    def _extract_numerical_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract and engineer numerical features"""
        features = []

        for _, row in df.iterrows():
            # Demographics (age-related risk)
            age_norm = row['Age'] / 100.0
            age_risk = 1.0 if row['Age'] > 60 else 0.0  # Higher risk after 60

            # Gender (male predominance in IPF)
            sex_male = 1.0 if row['Sex'] == 'Male' else 0.0

            # Smoking status
            smoke_never = 1.0 if row['Smoking_Status'] == 'Never' else 0.0
            smoke_former = 1.0 if row['Smoking_Status'] == 'Former' else 0.0
            smoke_current = 1.0 if row['Smoking_Status'] == 'Current' else 0.0

            # Lung function (key diagnostic markers)
            fvc_norm = row['FVC'] / 5000.0  # Normalize
            fvc_reduced = 1.0 if row['FVC'] < 2500 else 0.0  # Below normal
            dlco_norm = row['DLCO'] / 100.0
            dlco_reduced = 1.0 if row['DLCO'] < 60 else 0.0  # Significantly reduced

            # Environmental/occupational exposure
            dust_none = 1.0 if row['Dust_Exposure'] == 'None' else 0.0
            dust_low = 1.0 if row['Dust_Exposure'] == 'Low' else 0.0
            dust_moderate = 1.0 if row['Dust_Exposure'] == 'Moderate' else 0.0
            dust_high = 1.0 if row['Dust_Exposure'] == 'High' else 0.0

            # Family history and symptoms
            family_history = float(row['Family_History'])
            dyspnea_score = row['Dyspnea_Score'] / 4.0  # Normalize 0-4 scale
            cough = float(row['Cough'])
            weight_loss = float(row['Weight_Loss'])

            # Composite risk scores
            smoking_risk = smoke_former * 0.5 + smoke_current * 1.0
            exposure_risk = dust_low * 0.25 + dust_moderate * 0.5 + dust_high * 1.0
            symptom_score = dyspnea_score + cough * 0.3 + weight_loss * 0.4

            # Interaction features
            age_smoking = age_norm * smoking_risk
            age_exposure = age_norm * exposure_risk
            male_exposure = sex_male * exposure_risk

            feature_vector = [
                age_norm, age_risk, sex_male,
                smoke_never, smoke_former, smoke_current,
                fvc_norm, fvc_reduced, dlco_norm, dlco_reduced,
                dust_none, dust_low, dust_moderate, dust_high,
                family_history, dyspnea_score, cough, weight_loss,
                smoking_risk, exposure_risk, symptom_score,
                age_smoking, age_exposure, male_exposure
            ]

            features.append(feature_vector)

        return np.array(features, dtype=np.float32)

    def get_features(self, patient_data: Dict) -> np.ndarray:
        """Get scaled features for a single patient"""
        # Convert dict to dataframe format
        temp_df = pd.DataFrame([patient_data])
        features = self._extract_numerical_features(temp_df)
        return self.scaler.transform(features)[0]

# =============================================================================
# SECTION 5: MULTI-MODAL DATASET FOR DETECTION
# =============================================================================

class FibrosisDetectionDataset(Dataset):
    """Dataset for early fibrosis detection"""

    def __init__(
        self,
        patient_data: pd.DataFrame,
        feature_engineer: DetectionFeatureEngineer,
        img_dir: Optional[Path] = None,
        transform: Optional[A.Compose] = None,
        n_slices: int = 5,
        is_train: bool = True
    ):
        self.patient_data = patient_data
        self.feature_engineer = feature_engineer
        self.img_dir = img_dir
        self.transform = transform
        self.n_slices = n_slices
        self.is_train = is_train
        self.dicom_processor = DicomProcessor()

        # Prepare label mapping
        self.class_to_idx = {cls: idx for idx, cls in enumerate(config.classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}

        # Check for DICOM availability
        self.has_images = img_dir is not None and img_dir.exists()
        if self.has_images:
            self._load_patient_images()

        logger.info(f"Detection dataset created with {len(patient_data)} patients")
        logger.info(f"Class distribution: {patient_data['Diagnosis'].value_counts().to_dict()}")

    def _load_patient_images(self):
        """Load available DICOM images for each patient"""
        self.patient_images = {}
        for _, row in self.patient_data.iterrows():
            patient_id = row['Patient_ID']
            patient_dir = self.img_dir / patient_id

            if patient_dir.exists():
                dicom_files = sorted(patient_dir.glob("*.dcm"))
                if dicom_files:
                    # Select representative slices
                    n_files = len(dicom_files)
                    if n_files >= self.n_slices:
                        # Take slices from different regions of lungs
                        indices = np.linspace(0, n_files-1, self.n_slices, dtype=int)
                    else:
                        indices = list(range(n_files))
                    self.patient_images[patient_id] = [dicom_files[i] for i in indices]

    def __len__(self):
        return len(self.patient_data)

    def __getitem__(self, idx):
        row = self.patient_data.iloc[idx]
        patient_id = row['Patient_ID']
        diagnosis = row['Diagnosis']

        # Get tabular features
        tabular = self.feature_engineer.get_features(row.to_dict())

        # Get image features
        if self.has_images and patient_id in self.patient_images:
            slices = self.patient_images[patient_id]
            images = []
            for slice_path in slices[:self.n_slices]:
                img = self.dicom_processor.load_dicom(str(slice_path))
                if img is not None:
                    img = cv2.resize(img, (config.img_size, config.img_size))
                    images.append(img)

            # Create multi-channel image
            if len(images) >= 3:
                img = np.stack(images[:3], axis=-1)
            elif len(images) == 2:
                img = np.stack([images[0], images[1], images[0]], axis=-1)
            elif len(images) == 1:
                img = np.stack([images[0]] * 3, axis=-1)
            else:
                img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)
        else:
            # Generate synthetic texture patterns based on diagnosis for demo
            img = self._generate_synthetic_ct_pattern(diagnosis)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
        else:
            img = torch.from_numpy(img.transpose(2, 0, 1)).float()

        # Get label
        label = self.class_to_idx[diagnosis]

        return {
            'image': img,
            'tabular': torch.tensor(tabular, dtype=torch.float32),
            'target': torch.tensor(label, dtype=torch.long),
            'patient_id': patient_id
        }

    def _generate_synthetic_ct_pattern(self, diagnosis: str) -> np.ndarray:
        """Generate synthetic CT-like patterns for demonstration"""
        img = np.random.rand(config.img_size, config.img_size, 3).astype(np.float32)

        if diagnosis == 'Normal':
            # Smooth, uniform pattern
            img = cv2.GaussianBlur(img, (15, 15), 0) * 0.5 + 0.3
        elif diagnosis == 'Early_Fibrosis':
            # Subtle irregular patterns
            noise = np.random.rand(config.img_size, config.img_size, 3) * 0.3
            img = cv2.GaussianBlur(img, (7, 7), 0) * 0.6 + noise + 0.2
        else:  # Advanced_Fibrosis
            # Prominent irregular, honeycomb-like patterns
            noise = np.random.rand(config.img_size, config.img_size, 3) * 0.5
            honeycomb = np.sin(np.linspace(0, 10*np.pi, config.img_size))
            honeycomb = np.outer(honeycomb, honeycomb)[:, :, np.newaxis] * 0.3
            img = img * 0.4 + noise + np.repeat(honeycomb, 3, axis=2) + 0.1

        return np.clip(img, 0, 1)

# =============================================================================
# SECTION 6: AUGMENTATION FOR MEDICAL DETECTION
# =============================================================================

def get_detection_transforms(is_train: bool = True, img_size: int = 224) -> A.Compose:
    """Get medical-appropriate augmentation pipeline"""
    if is_train:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.05,  # Conservative shifts for medical
                scale_limit=0.1,
                rotate_limit=5,    # Small rotations
                p=0.5
            ),
            A.OneOf([
                A.GaussNoise(var_limit=(5, 15)),
                A.GaussianBlur(blur_limit=3),
                A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),
                A.RandomGamma(gamma_limit=(80, 120)),
            ], p=0.6),
            A.CoarseDropout(
                max_holes=8,
                max_height=32,
                max_width=32,
                p=0.3
            ),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

# =============================================================================
# SECTION 7: NEURAL NETWORK ARCHITECTURES FOR DETECTION
# =============================================================================

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class AttentionModule(nn.Module):
    """Attention mechanism for highlighting important regions"""
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction),
            nn.ReLU(),
            nn.Linear(in_channels // reduction, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        attention = avg_out + max_out
        return x * attention.view(b, c, 1, 1)

class FibrosisDetectionModel(nn.Module):
    """Multi-modal model for fibrosis detection"""

    def __init__(
        self,
        n_classes: int = 3,
        tabular_dim: int = 24,
        hidden_dims: Tuple[int] = (512, 256, 128),
        dropout_rate: float = 0.4,
        backbone: str = "efficientnet"
    ):
        super().__init__()

        # Image backbone with attention
        if backbone == "efficientnet":
            self.image_backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT)
            img_feat_dim = self.image_backbone.classifier[1].in_features
            self.image_backbone.classifier = nn.Identity()
        else:
            self.image_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            img_feat_dim = self.image_backbone.fc.in_features
            self.image_backbone.fc = nn.Identity()

        # Add attention to image features
        self.image_attention = AttentionModule(img_feat_dim)

        # Tabular network with residual connections
        tab_layers = []
        prev_dim = tabular_dim
        for hidden_dim in hidden_dims:
            tab_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = hidden_dim

        self.tabular_net = nn.Sequential(*tab_layers)
        tab_feat_dim = hidden_dims[-1]

        # Cross-modal fusion
        self.fusion = nn.MultiheadAttention(
            embed_dim=256,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )

        self.img_proj = nn.Linear(img_feat_dim, 256)
        self.tab_proj = nn.Linear(tab_feat_dim, 256)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, n_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights properly"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, image, tabular):
        # Extract image features
        img_features = self.image_backbone(image)
        img_features = self.image_attention(img_features.unsqueeze(-1).unsqueeze(-1)).squeeze()
        img_projected = self.img_proj(img_features)

        # Extract tabular features
        tab_features = self.tabular_net(tabular)
        tab_projected = self.tab_proj(tab_features)

        # Cross-modal attention
        img_seq = img_projected.unsqueeze(1)  # [B, 1, 256]
        tab_seq = tab_projected.unsqueeze(1)  # [B, 1, 256]
        combined_seq = torch.cat([img_seq, tab_seq], dim=1)  # [B, 2, 256]

        attended, _ = self.fusion(combined_seq, combined_seq, combined_seq)
        fused_features = attended.mean(dim=1)  # [B, 256]

        # Classification
        logits = self.classifier(fused_features)

        return logits

# =============================================================================
# SECTION 8: TRAINING PIPELINE FOR DETECTION
# =============================================================================

class EarlyStoppingDetection:
    """Early stopping for classification"""
    def __init__(self, patience=15, min_delta=0.001, metric='f1'):
        self.patience = patience
        self.min_delta = min_delta
        self.metric = metric
        self.best_score = 0.0
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def __call__(self, score, model):
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
            self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.best_weights:
                    model.load_state_dict(self.best_weights)
                logger.info(f"Early stopping triggered. Best {self.metric}: {self.best_score:.4f}")

def train_detection_epoch(model, loader, optimizer, criterion, scaler):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(loader, desc='Training'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        optimizer.zero_grad()

        with autocast(enabled=config.use_mixed_precision):
            logits = model(images, tabular)
            loss = criterion(logits, targets)

        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    accuracy = 100. * correct / total
    avg_loss = total_loss / len(loader)

    return avg_loss, accuracy

@torch.no_grad()
def validate_detection_epoch(model, loader, criterion):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    all_probs = []

    for batch in tqdm(loader, desc='Validation'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        logits = model(images, tabular)
        loss = criterion(logits, targets)

        total_loss += loss.item()

        # Get predictions and probabilities
        probs = F.softmax(logits, dim=1)
        _, preds = torch.max(logits, 1)

        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

    # Calculate metrics
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)

    # Multi-class AUC
    try:
        auc = roc_auc_score(all_targets, np.array(all_probs), multi_class='ovr', average='weighted')
    except:
        auc = 0.0

    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    }

    return metrics, all_preds, all_targets

# =============================================================================
# SECTION 9: MAIN EXECUTION FOR EARLY DETECTION
# =============================================================================

def main_detection():
    """Main detection pipeline"""
    print("="*80)
    print("LUNG FIBROSIS EARLY DETECTION PIPELINE")
    print("="*80)

    # Create output directory
    config.output_dir.mkdir(exist_ok=True, parents=True)

    # 1. Create synthetic detection data
    logger.info("Creating synthetic detection dataset...")
    df = create_synthetic_detection_data()

    # 2. Create feature engineer
    logger.info("Creating feature engineer...")
    feature_engineer = DetectionFeatureEngineer(df)

    # 3. Stratified splits to maintain class balance
    patients = df['Patient_ID'].unique()
    labels = df.groupby('Patient_ID')['Diagnosis'].first()

    train_patients, temp_patients = train_test_split(
        patients,
        test_size=(config.val_split + config.test_split),
        stratify=labels,
        random_state=config.seed
    )

    temp_labels = labels[temp_patients]
    val_patients, test_patients = train_test_split(
        temp_patients,
        test_size=config.test_split/(config.val_split + config.test_split),
        stratify=temp_labels,
        random_state=config.seed
    )

    # Create datasets
    train_data = df[df['Patient_ID'].isin(train_patients)]
    val_data = df[df['Patient_ID'].isin(val_patients)]
    test_data = df[df['Patient_ID'].isin(test_patients)]

    logger.info(f"Train: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

    # 4. Create transforms
    train_transform = get_detection_transforms(is_train=True, img_size=config.img_size)
    val_transform = get_detection_transforms(is_train=False, img_size=config.img_size)

    # 5. Create datasets
    train_dataset = FibrosisDetectionDataset(
        train_data, feature_engineer, None, train_transform, config.n_slices, is_train=True
    )
    val_dataset = FibrosisDetectionDataset(
        val_data, feature_engineer, None, val_transform, config.n_slices, is_train=False
    )
    test_dataset = FibrosisDetectionDataset(
        test_data, feature_engineer, None, val_transform, config.n_slices, is_train=False
    )

    # 6. Handle class imbalance with weighted sampling
    if config.use_class_weights:
        class_counts = train_data['Diagnosis'].value_counts()
        class_weights = compute_class_weight(
            'balanced',
            classes=np.unique(train_data['Diagnosis']),
            y=train_data['Diagnosis']
        )
        class_weight_dict = dict(zip(np.unique(train_data['Diagnosis']), class_weights))
        logger.info(f"Class weights: {class_weight_dict}")

        # Create weighted sampler
        sample_weights = [class_weight_dict[label] for label in train_data['Diagnosis']]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
        shuffle = False
    else:
        sampler = None
        shuffle = True

    # 7. Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        sampler=sampler,
        shuffle=shuffle,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )

    # 8. Create model
    logger.info(f"Creating detection model...")
    model = FibrosisDetectionModel(
        n_classes=config.n_classes,
        tabular_dim=24,  # Based on feature engineer
        hidden_dims=config.tabular_hidden_dims,
        dropout_rate=config.dropout_rate,
        backbone=config.backbone
    ).to(DEVICE)

    # Model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")

    # 9. Training setup
    if config.use_class_weights:
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    else:
        criterion = FocalLoss(alpha=config.focal_loss_alpha, gamma=config.focal_loss_gamma)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=20, T_mult=2, eta_min=1e-6
    )

    early_stopper = EarlyStoppingDetection(patience=config.patience, metric='f1')
    scaler = GradScaler() if config.use_mixed_precision else None

    # 10. Training loop
    logger.info("Starting training...")
    best_f1 = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_f1': [], 'val_auc': []
    }

    for epoch in range(config.num_epochs):
        logger.info(f"Epoch {epoch + 1}/{config.num_epochs}")

        # Train
        train_loss, train_acc = train_detection_epoch(model, train_loader, optimizer, criterion, scaler)

        # Validate
        val_metrics, val_preds, val_targets = validate_detection_epoch(model, val_loader, criterion)

        # Update scheduler
        scheduler.step()

        # Log metrics
        logger.info(
            f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.4f} | "
            f"Val F1: {val_metrics['f1']:.4f} | Val AUC: {val_metrics['auc']:.4f}"
        )

        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1'])
        history['val_auc'].append(val_metrics['auc'])

        # Early stopping
        early_stopper(val_metrics['f1'], model)

        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_f1': best_f1,
                'config': config
            }, config.output_dir / 'best_detection_model.pth')
            logger.info(f"New best model saved with F1: {best_f1:.4f}")

        if early_stopper.early_stop:
            break

        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()

    # 11. Final evaluation
    logger.info("Evaluating on test set...")
    test_metrics, test_preds, test_targets = validate_detection_epoch(model, test_loader, criterion)

    print("\n" + "="*80)
    print("FINAL TEST SET RESULTS - EARLY DETECTION")
    print("="*80)
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Test Precision: {test_metrics['precision']:.4f}")
    print(f"Test Recall: {test_metrics['recall']:.4f}")
    print(f"Test F1-Score: {test_metrics['f1']:.4f}")
    print(f"Test AUC: {test_metrics['auc']:.4f}")

    # Classification report
    print("\nDetailed Classification Report:")
    print(classification_report(test_targets, test_preds, target_names=config.classes))

    # Confusion Matrix
    cm = confusion_matrix(test_targets, test_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=config.classes, yticklabels=config.classes)
    plt.title('Confusion Matrix - Fibrosis Detection')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig(config.output_dir / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()

    # 12. Create visualizations
    create_detection_plots(history)

    print("\n" + "="*80)
    print("EARLY DETECTION PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*80)

    return model, feature_engineer, history

def create_detection_plots(history):
    """Create training visualization plots for detection"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curves
    axes[0, 0].plot(history['train_loss'], label='Train Loss', alpha=0.7)
    axes[0, 0].plot(history['val_loss'], label='Validation Loss', alpha=0.7)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy curves
    axes[0, 1].plot(history['train_acc'], label='Train Accuracy', alpha=0.7)
    axes[0, 1].plot([x*100 for x in history['val_acc']], label='Validation Accuracy', alpha=0.7)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('Training and Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # F1 Score progression
    axes[1, 0].plot(history['val_f1'], label='Validation F1', color='green', alpha=0.7)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score')
    axes[1, 0].set_title('F1 Score Progression')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # AUC progression
    axes[1, 1].plot(history['val_auc'], label='Validation AUC', color='red', alpha=0.7)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('AUC Score')
    axes[1, 1].set_title('AUC Score Progression')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(config.output_dir / 'detection_training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

# Execute the detection pipeline
if __name__ == "__main__":
    try:
        model, feature_engineer, history = main_detection()
        logger.info("Early detection pipeline completed successfully!")
    except Exception as e:
        logger.error(f"Pipeline failed with error: {e}")
        raise e


## ChatGPT

In [None]:
# =============================================================================
# COMPLETE MULTI-MODAL OSIC PULMONARY FIBROSIS MODEL (with epoch metric prints)
# =============================================================================

# SECTION 1: IMPORTS AND CONFIGURATION
import os
import gc
import random
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import logging
from tqdm.auto import tqdm
import sys

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler

# Medical Imaging
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut, apply_voi_lut
import cv2

# Vision Models
import torchvision.models as models
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights

# ML Tools
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

@dataclass
class Config:
    data_dir: Path = Path("/kaggle/input/osic-pulmonary-fibrosis-progression")
    output_dir: Path = Path("/kaggle/working")
    img_size: int = 224
    n_slices: int = 3
    window_center: int = -600
    window_width: int = 1500
    model_type: str = "MultiModal"
    backbone: str = "efficientnet"
    tabular_hidden_dims: Tuple[int] = (512, 256, 128, 64)
    fusion_method: str = "attention"
    dropout_rate: float = 0.3
    use_batch_norm: bool = True
    n_folds: int = 5
    batch_size: int = 32
    num_epochs: int = 50
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    scheduler: str = 'cosine'
    patience: int = 10
    min_delta: float = 0.001
    gradient_clip: float = 1.0
    use_augmentation: bool = True
    aug_prob: float = 0.5
    use_mixup: bool = True
    mixup_alpha: float = 0.2
    use_mixed_precision: bool = True
    num_workers: int = 2
    pin_memory: bool = True
    seed: int = 42
    val_split: float = 0.15
    test_split: float = 0.15

config = Config()

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")

# =============================================================================
# SECTION 2: DATA LOADING AND PREPROCESSING
# =============================================================================

def load_osic_data():
    """Load OSIC dataset with proper handling"""
    try:
        train_df = pd.read_csv(config.data_dir / "train.csv")
        logger.info(f"Loaded OSIC dataset with shape: {train_df.shape}")

        # Basic data validation
        required_cols = ['Patient', 'Weeks', 'FVC', 'Percent', 'Age', 'Sex', 'SmokingStatus']
        assert all(col in train_df.columns for col in required_cols), "Missing required columns"

        return train_df
    except Exception as e:
        logger.warning(f"Could not load OSIC data: {e}")
        logger.info("Creating synthetic data for demonstration...")

        # Create realistic synthetic OSIC-like data
        np.random.seed(config.seed)
        n_patients = 100
        records = []

        for patient_id in range(n_patients):
            n_visits = np.random.randint(3, 10)
            baseline_fvc = np.random.normal(2500, 700)
            age = np.random.randint(50, 85)
            sex = np.random.choice(['Male', 'Female'])
            smoking = np.random.choice(['Never smoked', 'Ex-smoker', 'Currently smokes'])

            for visit in range(n_visits):
                week = visit * np.random.randint(1, 6)
                fvc = baseline_fvc - np.random.normal(5, 2) * week
                fvc = max(fvc, 500)

                records.append({
                    'Patient': f'ID{patient_id:04d}',
                    'Weeks': week,
                    'FVC': fvc,
                    'Percent': np.random.normal(50, 15),
                    'Age': age,
                    'Sex': sex,
                    'SmokingStatus': smoking
                })

        return pd.DataFrame(records)

# =============================================================================
# SECTION 3: MEDICAL IMAGE PROCESSING
# =============================================================================

class DicomProcessor:
    """Professional DICOM processing with proper windowing"""

    def __init__(self, window_center: int = -600, window_width: int = 1500):
        self.window_center = window_center
        self.window_width = window_width

    def load_dicom(self, path: str) -> Optional[np.ndarray]:
        """Load and preprocess DICOM file"""
        try:
            dcm = pydicom.dcmread(path)

            # Apply DICOM transformations
            img = dcm.pixel_array.astype(np.float32)

            # Apply modality LUT
            if hasattr(dcm, 'RescaleSlope') and hasattr(dcm, 'RescaleIntercept'):
                img = img * dcm.RescaleSlope + dcm.RescaleIntercept

            # Apply windowing
            img = self.apply_windowing(img)

            # Normalize to [0, 1]
            img = (img - img.min()) / (img.max() - img.min() + 1e-6)

            return img.astype(np.float32)

        except Exception as e:
            logger.debug(f"Failed to load DICOM {path}: {e}")
            return None

    def apply_windowing(self, img: np.ndarray) -> np.ndarray:
        """Apply lung window settings"""
        min_val = self.window_center - self.window_width // 2
        max_val = self.window_center + self.window_width // 2
        return np.clip(img, min_val, max_val)

# =============================================================================
# SECTION 4: FEATURE ENGINEERING
# =============================================================================

class TabularFeatureEngineer:
    """Advanced feature engineering for tabular data"""

    def __init__(self, train_df: pd.DataFrame):
        self.train_df = train_df
        self.scaler = RobustScaler()
        self.patient_features = {}
        self._prepare_features()

    def _prepare_features(self):
        """Create comprehensive feature set"""
        for patient in self.train_df['Patient'].unique():
            patient_data = self.train_df[self.train_df['Patient'] == patient].sort_values('Weeks')

            # Baseline measurements
            baseline = patient_data.iloc[0]

            # Calculate FVC trajectory
            if len(patient_data) > 1:
                weeks = patient_data['Weeks'].values
                fvc = patient_data['FVC'].values
                # Linear regression for slope
                slope = np.polyfit(weeks, fvc, 1)[0] if len(weeks) > 1 else 0
                std_dev = np.std(fvc)
            else:
                slope = 0
                std_dev = 0

            self.patient_features[patient] = {
                'Age': baseline['Age'],
                'Sex': baseline['Sex'],
                'SmokingStatus': baseline['SmokingStatus'],
                'BaselineFVC': baseline['FVC'],
                'BaselinePercent': baseline['Percent'],
                'BaselineWeeks': baseline['Weeks'],
                'FVCSlope': slope,
                'FVCStdDev': std_dev,
                'NumMeasurements': len(patient_data)
            }

        # Fit scaler
        self._fit_scaler()

    def _fit_scaler(self):
        """Fit scaler on all features"""
        features = []
        for stats in self.patient_features.values():
            features.append(self._encode_features(stats, 0))
        self.scaler.fit(features)

    def _encode_features(self, stats: Dict, current_week: float) -> np.ndarray:
        """Encode features for a given patient and week"""
        sex_male = 1 if stats['Sex'] == 'Male' else 0

        smoke_never = 1 if stats['SmokingStatus'] == 'Never smoked' else 0
        smoke_ex = 1 if stats['SmokingStatus'] == 'Ex-smoker' else 0
        smoke_current = 1 if stats['SmokingStatus'] == 'Currently smokes' else 0

        # Time features
        week_delta = current_week - stats['BaselineWeeks']
        week_squared = week_delta ** 2

        # Interaction features
        age_week = stats['Age'] * week_delta / 100

        # Expected FVC
        expected_fvc = stats['BaselineFVC'] + stats['FVCSlope'] * week_delta

        features = [
            stats['Age'] / 100,
            sex_male,
            smoke_never,
            smoke_ex,
            smoke_current,
            stats['BaselineFVC'] / 5000,
            stats['BaselinePercent'] / 100,
            week_delta / 52,
            week_squared / (52 ** 2),
            stats['FVCSlope'] / 100,
            stats['FVCStdDev'] / 1000,
            age_week,
            expected_fvc / 5000,
            stats['NumMeasurements'] / 10
        ]

        return np.array(features, dtype=np.float32)

    def get_features(self, patient_id: str, week: float) -> np.ndarray:
        """Get scaled features"""
        if patient_id not in self.patient_features:
            # Return zero features if patient not found
            return np.zeros(14, dtype=np.float32)

        stats = self.patient_features[patient_id]
        features = self._encode_features(stats, week)
        return self.scaler.transform([features])[0]

# =============================================================================
# SECTION 5: MULTI-MODAL DATASET
# =============================================================================

class OSICMultiModalDataset(Dataset):
    """Dataset combining tabular and image data"""

    def __init__(
        self,
        patient_data: pd.DataFrame,
        feature_engineer: TabularFeatureEngineer,
        img_dir: Optional[Path] = None,
        transform: Optional[A.Compose] = None,
        n_slices: int = 3,
        is_train: bool = True
    ):
        self.patient_data = patient_data
        self.feature_engineer = feature_engineer
        self.img_dir = img_dir
        self.transform = transform
        self.n_slices = n_slices
        self.is_train = is_train
        self.dicom_processor = DicomProcessor()

        # Prepare samples
        self.samples = []
        self.patient_images = {}

        # Check for DICOM availability
        self.has_images = img_dir is not None and img_dir.exists()

        if self.has_images:
            self._load_patient_images()

        # Create samples
        for _, row in patient_data.iterrows():
            self.samples.append({
                'patient': row['Patient'],
                'week': row['Weeks'],
                'fvc': row['FVC']
            })

        logger.info(f"Dataset created with {len(self.samples)} samples")

    def _load_patient_images(self):
        """Load available DICOM images for each patient"""
        for patient in self.patient_data['Patient'].unique():
            patient_dir = self.img_dir / patient

            if patient_dir.exists():
                dicom_files = sorted(patient_dir.glob("*.dcm"))

                if dicom_files:
                    # Select evenly spaced slices
                    n_files = len(dicom_files)
                    if n_files >= self.n_slices:
                        indices = np.linspace(0, n_files-1, self.n_slices, dtype=int)
                    else:
                        indices = list(range(n_files))

                    self.patient_images[patient] = [dicom_files[i] for i in indices]
                else:
                    self.patient_images[patient] = None
            else:
                self.patient_images[patient] = None

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        patient = sample['patient']
        week = sample['week']
        fvc = sample['fvc']

        # Get tabular features
        tabular = self.feature_engineer.get_features(patient, week)

        # Get image features
        if self.has_images and patient in self.patient_images:
            slices = self.patient_images.get(patient)
            if slices:
                images = []
                for slice_path in slices[:3]:  # Maximum 3 slices
                    img = self.dicom_processor.load_dicom(str(slice_path))
                    if img is not None:
                        # Resize to target size
                        img = cv2.resize(img, (config.img_size, config.img_size))
                        images.append(img)

                # Create 3-channel image
                if len(images) >= 3:
                    img = np.stack(images[:3], axis=-1)
                elif len(images) == 2:
                    img = np.stack([images[0], images[1], images[1]], axis=-1)
                elif len(images) == 1:
                    img = np.stack([images[0], images[0], images[0]], axis=-1)
                else:
                    img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)
            else:
                img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)
        else:
            # No images available - use zeros
            img = np.zeros((config.img_size, config.img_size, 3), dtype=np.float32)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
        else:
            img = torch.from_numpy(img.transpose(2, 0, 1)).float()

        return {
            'image': img,
            'tabular': torch.tensor(tabular, dtype=torch.float32),
            'target': torch.tensor(fvc, dtype=torch.float32)
        }

# =============================================================================
# SECTION 6: AUGMENTATION
# =============================================================================

def get_transforms(is_train: bool = True, img_size: int = 224) -> A.Compose:
    """Get augmentation pipeline"""
    if is_train:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1,
                scale_limit=0.1,
                rotate_limit=10,
                p=0.5
            ),
            A.OneOf([
                A.GaussNoise(var_limit=(10, 50)),
                A.GaussianBlur(blur_limit=3),
                A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),
            ], p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

# =============================================================================
# SECTION 7: NEURAL NETWORK ARCHITECTURES
# =============================================================================

# Tabular Network Components
class ResNetBlock(nn.Module):
    """Residual block for tabular data"""
    def __init__(self, in_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(out_dim) if config.use_batch_norm else nn.Identity()
        self.bn2 = nn.BatchNorm1d(out_dim) if config.use_batch_norm else nn.Identity()
        self.dropout = nn.Dropout(dropout_rate)
        self.shortcut = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = F.relu(self.bn1(self.fc1(x)))
        out = self.dropout(out)
        out = self.bn2(self.fc2(out))
        out += shortcut
        return F.relu(out)

# Attention Mechanisms
class CrossModalAttention(nn.Module):
    """Cross-modal attention between image and tabular features"""
    def __init__(self, img_dim, tab_dim, hidden_dim=256):
        super().__init__()
        self.img_proj = nn.Linear(img_dim, hidden_dim)
        self.tab_proj = nn.Linear(tab_dim, hidden_dim)
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)

    def forward(self, img_feat, tab_feat):
        # Project features
        img_proj = self.img_proj(img_feat).unsqueeze(1)  # [B, 1, hidden]
        tab_proj = self.tab_proj(tab_feat).unsqueeze(1)  # [B, 1, hidden]

        # Cross attention
        combined = torch.cat([img_proj, tab_proj], dim=1)  # [B, 2, hidden]
        attended, _ = self.attention(combined, combined, combined)

        # Aggregate
        return attended.mean(dim=1)  # [B, hidden]

# Main Multi-Modal Model
class OSICMultiModalModel(nn.Module):
    """Complete multi-modal architecture"""

    def __init__(
        self,
        tabular_dim: int = 14,
        hidden_dims: Tuple[int] = (512, 256, 128, 64),
        dropout_rate: float = 0.3,
        backbone: str = "efficientnet",
        fusion_method: str = "attention"
    ):
        super().__init__()

        # Image backbone
        if backbone == "efficientnet":
            self.image_backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.DEFAULT)
            img_feat_dim = self.image_backbone.classifier[1].in_features
            self.image_backbone.classifier = nn.Identity()
        elif backbone == "resnet50":
            self.image_backbone = models.resnet50(pretrained=True)
            img_feat_dim = self.image_backbone.fc.in_features
            self.image_backbone.fc = nn.Identity()
        else:
            self.image_backbone = models.densenet121(pretrained=True)
            img_feat_dim = self.image_backbone.classifier.in_features
            self.image_backbone.classifier = nn.Identity()

        # Tabular network
        tab_layers = []
        prev_dim = tabular_dim
        for hidden_dim in hidden_dims:
            tab_layers.append(ResNetBlock(prev_dim, hidden_dim, dropout_rate))
            prev_dim = hidden_dim
        self.tabular_net = nn.Sequential(*tab_layers)
        tab_feat_dim = hidden_dims[-1]

        # Fusion method
        self.fusion_method = fusion_method
        if fusion_method == "attention":
            self.fusion = CrossModalAttention(img_feat_dim, tab_feat_dim)
            fusion_dim = 256
        elif fusion_method == "gated":
            self.gate = nn.Sequential(
                nn.Linear(img_feat_dim + tab_feat_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 2),
                nn.Softmax(dim=1)
            )
            fusion_dim = img_feat_dim + tab_feat_dim
        else:  # concat
            fusion_dim = img_feat_dim + tab_feat_dim

        # Final regression head
        self.head = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.BatchNorm1d(256) if config.use_batch_norm else nn.Identity(),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128) if config.use_batch_norm else nn.Identity(),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 1)
        )

        # Uncertainty head (for confidence estimation)
        self.uncertainty_head = nn.Sequential(
            nn.Linear(fusion_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights properly"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, image, tabular):
        # Extract features
        img_features = self.image_backbone(image)
        tab_features = self.tabular_net(tabular)

        # Fusion
        if self.fusion_method == "attention":
            combined = self.fusion(img_features, tab_features)
        elif self.fusion_method == "gated":
            concat_features = torch.cat([img_features, tab_features], dim=1)
            gates = self.gate(concat_features)
            combined = gates[:, 0:1] * img_features + gates[:, 1:2] * tab_features
        else:
            combined = torch.cat([img_features, tab_features], dim=1)

        # Predictions
        prediction = self.head(combined).squeeze(-1)
        log_variance = self.uncertainty_head(combined).squeeze(-1)

        return prediction, log_variance

# =============================================================================
# SECTION 8: LOSS FUNCTIONS AND METRICS
# =============================================================================

class RobustLoss(nn.Module):
    """Combined loss with uncertainty"""
    def __init__(self, alpha: float = 0.7, beta: float = 0.3):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()

    def forward(self, pred_mean, pred_log_var, target):
        # Heteroscedastic uncertainty loss
        precision = torch.exp(-pred_log_var)
        mse_loss = precision * (pred_mean - target) ** 2 + pred_log_var
        nll_loss = 0.5 * torch.mean(mse_loss)

        # MAE for robustness
        mae_loss = self.mae(pred_mean, target)

        # Combined
        total_loss = self.alpha * nll_loss + self.beta * mae_loss

        return total_loss

def laplace_log_likelihood(y_true, y_pred, sigma):
    """Calculate Laplace Log Likelihood"""
    sigma = np.maximum(sigma, 70)  # OSIC specific
    delta = np.minimum(np.abs(y_true - y_pred), 1000)
    metric = -np.sqrt(2) * delta / sigma - np.log(np.sqrt(2) * sigma)
    return np.mean(metric)

# =============================================================================
# SECTION 9: TRAINING PIPELINE
# =============================================================================

class EarlyStopping:
    """Early stopping with best model restoration"""
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = -float('inf')
        self.counter = 0
        self.early_stop = False
        self.best_weights = None

    def __call__(self, score, model):
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
            self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.best_weights:
                    model.load_state_dict(self.best_weights)
                logger.info(f"Early stopping triggered after {self.counter} epochs")

def train_epoch(model, loader, optimizer, criterion, scaler, config):
    """Train for one epoch"""
    model.train()
    losses = []

    for batch in tqdm(loader, desc='Training'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        optimizer.zero_grad()

        # Mixed precision training
        with autocast(enabled=config.use_mixed_precision):
            predictions, log_var = model(images, tabular)
            loss = criterion(predictions, log_var, targets)

        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()

        losses.append(loss.item())

    return np.mean(losses)

@torch.no_grad()
def validate_epoch(model, loader, criterion):
    """Validate for one epoch"""
    model.eval()
    losses = []
    all_preds = []
    all_targets = []
    all_sigmas = []

    for batch in tqdm(loader, desc='Validation'):
        images = batch['image'].to(DEVICE)
        tabular = batch['tabular'].to(DEVICE)
        targets = batch['target'].to(DEVICE)

        predictions, log_var = model(images, tabular)
        loss = criterion(predictions, log_var, targets)

        losses.append(loss.item())
        all_preds.extend(predictions.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())
        all_sigmas.extend(torch.exp(0.5 * log_var).cpu().numpy())

    # Calculate metrics
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_sigmas = np.array(all_sigmas)

    mae = mean_absolute_error(all_targets, all_preds)
    rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
    r2 = r2_score(all_targets, all_preds)
    lll = laplace_log_likelihood(all_targets, all_preds, all_sigmas)

    return np.mean(losses), {'mae': mae, 'rmse': rmse, 'r2': r2, 'lll': lll}

def main():
    print("="*80)
    print("OSIC MULTI-MODAL PIPELINE - COMPLETE IMPLEMENTATION")
    print("="*80)

    # Create output directory
    config.output_dir.mkdir(exist_ok=True, parents=True)

    # 1. Load and prepare data
    logger.info("Loading OSIC data...")
    train_df = load_osic_data()

    # 2. Create feature engineer
    logger.info("Creating feature engineer...")
    feature_engineer = TabularFeatureEngineer(train_df)

    # 3. Prepare patient-level splits (avoiding data leakage)
    patients = train_df['Patient'].unique()
    train_patients, temp_patients = train_test_split(
        patients,
        test_size=(config.val_split + config.test_split),
        random_state=config.seed
    )
    val_patients, test_patients = train_test_split(
        temp_patients,
        test_size=config.test_split/(config.val_split + config.test_split),
        random_state=config.seed
    )

    # Create datasets
    train_data = train_df[train_df['Patient'].isin(train_patients)]
    val_data = train_df[train_df['Patient'].isin(val_patients)]
    test_data = train_df[train_df['Patient'].isin(test_patients)]

    logger.info(f"Train samples: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

    # 4. Create transforms
    train_transform = get_transforms(is_train=True, img_size=config.img_size)
    val_transform = get_transforms(is_train=False, img_size=config.img_size)

    # 5. Create datasets
    img_dir = config.data_dir / "train" if (config.data_dir / "train").exists() else None

    train_dataset = OSICMultiModalDataset(
        train_data, feature_engineer, img_dir, train_transform, config.n_slices, is_train=True
    )
    val_dataset = OSICMultiModalDataset(
        val_data, feature_engineer, img_dir, val_transform, config.n_slices, is_train=False
    )
    test_dataset = OSICMultiModalDataset(
        test_data, feature_engineer, img_dir, val_transform, config.n_slices, is_train=False
    )

    # 6. Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )

    # 7. Create model
    logger.info(f"Creating {config.backbone} model...")
    model = OSICMultiModalModel(
        tabular_dim=14,  # Based on feature engineer
        hidden_dims=config.tabular_hidden_dims,
        dropout_rate=config.dropout_rate,
        backbone=config.backbone,
        fusion_method=config.fusion_method
    ).to(DEVICE)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")

    # 8. Training setup
    criterion = RobustLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    if config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
    else:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, verbose=True
        )

    early_stopper = EarlyStopping(patience=config.patience, min_delta=config.min_delta)
    scaler = GradScaler() if config.use_mixed_precision else None

    # 9. Training loop
    print("Starting training...")
    best_score = -float('inf')
    train_history = {'train_loss': [], 'val_loss': [], 'val_r2': [], 'val_lll': []}

    for epoch in range(config.num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, config)
        val_loss, val_metrics = validate_epoch(model, val_loader, criterion)

        if config.scheduler == 'cosine':
            scheduler.step()
        else:
            scheduler.step(val_metrics['r2'])

        # Print metrics for this epoch
        print(f"Epoch [{epoch+1}/{config.num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"MAE: {val_metrics['mae']:.4f}, RMSE: {val_metrics['rmse']:.4f}, "
              f"R2: {val_metrics['r2']:.4f}, LLL: {val_metrics['lll']:.4f}")
        sys.stdout.flush()

        train_history['train_loss'].append(train_loss)
        train_history['val_loss'].append(val_loss)
        train_history['val_r2'].append(val_metrics['r2'])
        train_history['val_lll'].append(val_metrics['lll'])

        early_stopper(val_metrics['r2'], model)

        if val_metrics['r2'] > best_score:
            best_score = val_metrics['r2']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_score': best_score,
                'config': config
            }, config.output_dir / 'best_model.pth')

        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

        torch.cuda.empty_cache()
        gc.collect()

    # 10. Final evaluation on test set
    test_loss, test_metrics = validate_epoch(model, test_loader, criterion)

    print("\n" + "="*80)
    print("FINAL TEST SET RESULTS")
    print("="*80)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test MAE: {test_metrics['mae']:.4f}")
    print(f"Test RMSE: {test_metrics['rmse']:.4f}")
    print(f"Test R²: {test_metrics['r2']:.4f}")
    print(f"Test LLL: {test_metrics['lll']:.4f}")

    # 11. Create visualizations
    logger.info("Creating visualizations...")
    create_training_plots(train_history)

    # 12. Feature importance analysis (if possible)
    try:
        analyze_model_predictions(model, test_loader, test_data, feature_engineer)
    except Exception as e:
        logger.warning(f"Could not create prediction analysis: {e}")

if __name__ == "__main__":
    main()


| Goal Component                                                   |    Covered? | Evidence in current pipeline                                                                                        | Required Changes / Clarifications                                                                                                                                                                                                     |                 Priority | Next Action (code-level)                                                                                                                                            |
| ---------------------------------------------------------------- | ----------: | ------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -----------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Target: early prediction of pulmonary fibrosis (progression)** | **Partial** | Pipeline predicts FVC (continuous) and uncertainty — FVC decline is proxy for progression.                          | Confirm *exact* early-prediction definition (e.g., predict future FVC at Δweeks, binary progression within X weeks, or risk score). If you want binary/early-warning, add a classification head or thresholding on predicted decline. |                     High | Decide target form (continuous future FVC vs. binary event). Add target-generation code in `load_osic_data()` to create `target_week` / `delta_week` and label.     |
| **Time horizon / label generation**                              |          No | Current code uses existing FVC at each row as target.                                                               | Create labels for *future* FVC (e.g., FVC at baseline + 12 weeks). Need to align scans/timestamps per patient.                                                                                                                        |                     High | Implement function `build_future_target(df, horizon_weeks)` that merges baseline rows with follow-up FVC at `week + horizon`. Use nearest-match within tolerance.   |
| **DICOM CT inclusion (per-patient imaging)**                     |         Yes | `OSICMultiModalDataset` reads patient DICOM folders, selects slices, applies windowing and builds 3-channel inputs. | Consider selecting slices relative to baseline scan date used for the prediction; support 2D slice selection around lung region or small 3D patches.                                                                                  |                     High | In dataset `__getitem__`, select CTs corresponding to the *baseline* visit (match by scan date or nearest week). Add `scan_date` or mapping file.                   |
| **Tabular features for early prediction**                        |         Yes | `TabularFeatureEngineer` computes baseline stats, slope, expected FVC, time features.                               | Ensure features are computed **using only data available at prediction time** (no leakage of future measurements).                                                                                                                    |                 Critical | In `get_features(patient, week)` enforce using only rows `<= week`. Modify `_prepare_features` to store temporal sequences and compute features per-week on demand. |
| **Model architecture (image + tabular fusion)**                  |         Yes | `OSICMultiModalModel` with attention fusion; uncertainty head included.                                             | Good choice. For early prediction consider adding sequence modeling (if using longitudinal inputs) or include baseline+follow-up branches.                                                                                            |                   Medium | If using multiple prior visits, add LSTM/temporal encoder for tabular sequences and fuse with image embedding.                                                      |
| **Loss / uncertainty modeling**                                  |         Yes | `RobustLoss` combines heteroscedastic NLL and MAE; LLL metric computed.                                             | If target becomes binary progression, add classification loss (CE/BCE) or multi-task loss (regression + classification).                                                                                                              | High (depends on target) | For regression: keep `RobustLoss`. For binary: add `BCEWithLogitsLoss` and weight multi-task losses.                                                                |
| **Evaluation & splits (grouped by patient)**                     |         Yes | Patient-level train/val/test splitting used.                                                                        | For early prediction evaluate at the *prediction horizon*; report time-dependent metrics (MAE, AUC if binary) and calibration of uncertainty (coverage).                                                                              |                 Critical | Implement `evaluate_at_horizon(loader, horizon)` and compute per-horizon metrics; save per-epoch metrics.                                                           |
| **Data leakage checks**                                          |     Partial | Group-split implemented but feature engineer may use future info.                                                   | Ensure feature engineering and image selection never use future target/visits.                                                                                                                                                        |                 Critical | Add unit tests/assertions that for each sample, engineered features and images correspond to time ≤ prediction time.                                                |
| **Augmentation & robustness**                                    |         Yes | Augmentations for slices present.                                                                                   | Consider domain-specific augmentations (intensity scaling, elastic deformations) and test sensitivity to slice selection.                                                                                                             |                   Medium | Add augmentation config toggles and an ablation routine to test slice-counts (`n_slices=1..5`).                                                                     |
| **Interpretability / clinical utility**                          |          No | Diagnostics plots exist (pred vs true, residuals).                                                                  | Add SHAP/feature importance for tabular part and Grad-CAM or saliency maps for image backbone to explain predictions.                                                                                                                 |                   Medium | Export feature_importance via permutation and implement Grad-CAM visualization for sample scans.                                                                    |
| **Deployment / inference pipeline**                              |          No | Training & saving best model exists.                                                                                | Need inference function that accepts baseline tabular + baseline CT and outputs predicted risk/FVC + uncertainty and human-readable report.                                                                                           |               Low→Medium | Implement `predict_for_patient(patient_id, week)` that loads preprocessing pipeline, model, and returns outputs + visualizations.                                   |
| **Next experiments to prioritize**                               |           — | —                                                                                                                   | 1) Define target horizon; 2) Fix temporal leakage; 3) Match CT to baseline visit; 4) Train & evaluate at horizon; 5) Add binary progression head if required.                                                                         |                        — | Implement `build_future_target()`, update dataset mapping to baseline CTs, and run one short experiment (2–5 epochs) printing per-epoch metrics.                    |
