# AI-Driven Early Prediction of Pulmonary Fibrosis Using Deep Learning


# Retrain Quantile MLP: 
Use the new, richer Master CSV (total $\approx 15 \text{ features}$) as input to the existing Quantile MLP structure. The increased dimensionality will allow the network to find correlations missed by the current shallow feature set.
## Ensemble Inference: 
Run the final prediction by averaging the 5 fold models (Out-Of-Fold) to generate the final, smoothest LLL/RMSE score.

In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore')

# ==========================================
# 1. CONFIGURATION
# ==========================================
CONFIG = {
    "data_path": "/kaggle/input/feature-extraction-u-net-segmentation/master_dataset.csv",
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "n_folds": 5,
    "epochs": 100,
    "batch_size": 64,
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "patience": 15,  # Early stopping
    "quantiles": [0.1, 0.5, 0.9],  # For LLL: q10, median, q90
    "seed": 42
}

# Set random seeds
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

# ==========================================
# 2. QUANTILE REGRESSION MLP
# ==========================================
class QuantileMLP(nn.Module):
    """
    Deep MLP for Quantile Regression
    Outputs 3 quantiles: [q10, q50 (median), q90]
    """
    def __init__(self, input_dim, hidden_dims=[256, 128, 64], n_quantiles=3, dropout=0.3):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        # Output layer: 3 quantile predictions
        layers.append(nn.Linear(prev_dim, n_quantiles))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

# ==========================================
# 3. PINBALL LOSS (Quantile Loss)
# ==========================================
def quantile_loss(preds, target, quantiles):
    """
    Pinball Loss for Quantile Regression
    preds: (batch, 3) - predictions for [q10, q50, q90]
    target: (batch,) - true FVC values
    quantiles: [0.1, 0.5, 0.9]
    """
    target = target.unsqueeze(1)  # (batch, 1)
    errors = target - preds  # (batch, 3)
    
    quantiles_tensor = torch.tensor(quantiles, device=preds.device).unsqueeze(0)
    
    loss = torch.max(quantiles_tensor * errors, (quantiles_tensor - 1) * errors)
    return loss.mean()

# ==========================================
# 4. METRIC CALCULATIONS
# ==========================================
def calculate_metrics(y_true, y_pred_median, y_pred_q10, y_pred_q90):
    """
    Calculate comprehensive metrics:
    - R¬≤: Coefficient of determination
    - MSE: Mean Squared Error
    - RMSE: Root Mean Squared Error
    - MAE: Mean Absolute Error
    - RMAE: Relative Mean Absolute Error
    - LLL: Laplace Log Likelihood (Competition Metric)
    """
    y_true = np.array(y_true)
    y_pred_median = np.array(y_pred_median)
    y_pred_q10 = np.array(y_pred_q10)
    y_pred_q90 = np.array(y_pred_q90)
    
    # R¬≤
    ss_res = np.sum((y_true - y_pred_median) ** 2)
    ss_tot = np.sum((y_true - y_true.mean()) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
    
    # MSE, RMSE, MAE
    mse = np.mean((y_true - y_pred_median) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y_true - y_pred_median))
    
    # RMAE (Relative MAE)
    mean_true = np.mean(np.abs(y_true))
    rmae = mae / mean_true if mean_true > 0 else 0
    
    # LLL (Laplace Log Likelihood)
    # sigma = (q90 - q10) / 2.56  # Approximation for Laplace scale
    # For competition: sigma_clipped = max(70, sigma)
    sigma = np.maximum(70, (y_pred_q90 - y_pred_q10) / 2.56)
    delta = np.abs(y_true - y_pred_median)
    lll = -np.sqrt(2) * delta / sigma - np.log(np.sqrt(2) * sigma)
    lll_mean = np.mean(lll)
    
    return {
        'R2': r2,
        'MSE': mse,
        'RMSE': rmse,
        'MAE': mae,
        'RMAE': rmae,
        'LLL': lll_mean
    }

# ==========================================
# 5. DATASET CLASS
# ==========================================
class FibrosisFVCDataset(Dataset):
    def __init__(self, features, targets, weeks):
        self.features = torch.FloatTensor(features)
        self.targets = torch.FloatTensor(targets)
        self.weeks = torch.FloatTensor(weeks)
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx], self.weeks[idx]

# ==========================================
# 6. DATA PREPARATION
# ==========================================
def prepare_data(df):
    """
    Prepare features and targets from master dataset
    """
    print("üìä Preparing Data...")
    
    # Clinical/Demographic features
    clinical_features = ['Age', 'Sex', 'SmokingStatus', 'Weeks', 'Percent']
    
    # Image-derived biomarkers (from U-Net extraction)
    image_features = [
        'lung_vol_ml', 'hu_mean', 'hu_std', 'hu_skew', 'hu_kurt',
        'glcm_contrast', 'glcm_homogeneity', 'glcm_energy', 'glcm_correlation'
    ]
    
    # Encode categorical
    df['Sex'] = df['Sex'].map({'Male': 1, 'Female': 0})
    df['SmokingStatus'] = df['SmokingStatus'].map({
        'Never smoked': 0, 
        'Ex-smoker': 1, 
        'Currently smokes': 2
    })
    
    # Target: FVC value prediction (not slope!)
    all_features = clinical_features + image_features
    
    # Remove rows with missing features
    df_clean = df[all_features + ['FVC']].dropna()
    
    X = df_clean[all_features].values
    y = df_clean['FVC'].values
    weeks = df_clean['Weeks'].values
    
    print(f"‚úÖ Dataset Shape: {X.shape}")
    print(f"‚úÖ Features ({len(all_features)}): {all_features}")
    print(f"‚úÖ Target Range: FVC [{y.min():.0f}, {y.max():.0f}]")
    print(f"‚úÖ Samples: {len(X)}\n")
    
    return X, y, weeks, all_features

# ==========================================
# 7. TRAINING FUNCTION
# ==========================================
def train_epoch(model, loader, optimizer, quantiles, device):
    model.train()
    total_loss = 0
    
    for features, targets, _ in loader:
        features, targets = features.to(device), targets.to(device)
        
        optimizer.zero_grad()
        preds = model(features)
        loss = quantile_loss(preds, targets, quantiles)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate_epoch(model, loader, quantiles, device):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for features, targets, _ in loader:
            features, targets = features.to(device), targets.to(device)
            preds = model(features)
            
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    all_preds = np.vstack(all_preds)  # (N, 3) -> [q10, q50, q90]
    all_targets = np.concatenate(all_targets)
    
    metrics = calculate_metrics(
        all_targets,
        all_preds[:, 1],  # median (q50)
        all_preds[:, 0],  # q10
        all_preds[:, 2]   # q90
    )
    
    return metrics, all_preds

# ==========================================
# 8. MAIN TRAINING LOOP (5-FOLD CV)
# ==========================================
def train_quantile_mlp():
    print("="*70)
    print("üöÄ PHASE 2: QUANTILE MLP TRAINING (5-FOLD ENSEMBLE)")
    print("="*70)
    print(f"Device: {CONFIG['device']}\n")
    
    # Load data
    df = pd.read_csv(CONFIG['data_path'])
    X, y, weeks, feature_names = prepare_data(df)
    
    # Initialize KFold
    kfold = KFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])
    
    # Storage for fold results
    fold_results = []
    fold_models = []
    oof_predictions = np.zeros((len(X), 3))  # Out-of-fold predictions
    
    # Training loop
    for fold, (train_idx, val_idx) in enumerate(kfold.split(X), 1):
        print(f"\n{'='*70}")
        print(f"üìÇ FOLD {fold}/{CONFIG['n_folds']}")
        print(f"{'='*70}")
        
        # Split data
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        weeks_train, weeks_val = weeks[train_idx], weeks[val_idx]
        
        # Scale features
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)
        
        # Create datasets
        train_dataset = FibrosisFVCDataset(X_train_scaled, y_train, weeks_train)
        val_dataset = FibrosisFVCDataset(X_val_scaled, y_val, weeks_val)
        
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        # Initialize model
        model = QuantileMLP(
            input_dim=X.shape[1],
            hidden_dims=[256, 128, 64],
            n_quantiles=3,
            dropout=0.3
        ).to(CONFIG['device'])
        
        optimizer = optim.AdamW(
            model.parameters(),
            lr=CONFIG['learning_rate'],
            weight_decay=CONFIG['weight_decay']
        )
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, verbose=False
        )
        
        # Training
        best_lll = -np.inf
        patience_counter = 0
        
        print(f"\n{'Epoch':<6} {'Train Loss':<12} {'R¬≤':<8} {'MSE':<10} {'RMSE':<8} {'MAE':<8} {'RMAE':<8} {'LLL':<8}")
        print("-" * 70)
        
        for epoch in range(1, CONFIG['epochs'] + 1):
            train_loss = train_epoch(model, train_loader, optimizer, CONFIG['quantiles'], CONFIG['device'])
            val_metrics, _ = evaluate_epoch(model, val_loader, CONFIG['quantiles'], CONFIG['device'])
            
            # Print metrics
            print(f"{epoch:<6} {train_loss:<12.4f} {val_metrics['R2']:<8.4f} "
                  f"{val_metrics['MSE']:<10.2f} {val_metrics['RMSE']:<8.2f} "
                  f"{val_metrics['MAE']:<8.2f} {val_metrics['RMAE']:<8.4f} "
                  f"{val_metrics['LLL']:<8.4f}")
            
            # Scheduler step
            scheduler.step(val_metrics['LLL'])
            
            # Early stopping based on LLL
            if val_metrics['LLL'] > best_lll:
                best_lll = val_metrics['LLL']
                patience_counter = 0
                # Save best model for this fold
                torch.save(model.state_dict(), f'best_model_fold{fold}.pth')
            else:
                patience_counter += 1
            
            if patience_counter >= CONFIG['patience']:
                print(f"‚ö†Ô∏è  Early stopping at epoch {epoch}")
                break
        
        # Load best model and evaluate
        model.load_state_dict(torch.load(f'best_model_fold{fold}.pth'))
        final_metrics, val_preds = evaluate_epoch(model, val_loader, CONFIG['quantiles'], CONFIG['device'])
        
        # Store OOF predictions
        oof_predictions[val_idx] = val_preds
        
        # Store results
        fold_results.append(final_metrics)
        fold_models.append((model, scaler))
        
        print(f"\n‚úÖ FOLD {fold} FINAL METRICS:")
        print(f"   R¬≤ = {final_metrics['R2']:.4f}")
        print(f"   RMSE = {final_metrics['RMSE']:.2f} mL")
        print(f"   LLL = {final_metrics['LLL']:.4f}")
    
    # ==========================================
    # 9. ENSEMBLE RESULTS
    # ==========================================
    print(f"\n{'='*70}")
    print("üèÜ FINAL ENSEMBLE RESULTS (5-FOLD AVERAGE)")
    print(f"{'='*70}")
    
    # Calculate OOF ensemble metrics
    oof_metrics = calculate_metrics(
        y,
        oof_predictions[:, 1],  # median
        oof_predictions[:, 0],  # q10
        oof_predictions[:, 2]   # q90
    )
    
    print(f"\nüìä Out-of-Fold (OOF) Ensemble Performance:")
    print(f"   R¬≤    = {oof_metrics['R2']:.4f}")
    print(f"   MSE   = {oof_metrics['MSE']:.2f}")
    print(f"   RMSE  = {oof_metrics['RMSE']:.2f} mL")
    print(f"   MAE   = {oof_metrics['MAE']:.2f} mL")
    print(f"   RMAE  = {oof_metrics['RMAE']:.4f}")
    print(f"   LLL   = {oof_metrics['LLL']:.4f}")
    
    # Benchmark comparison
    print(f"\nüéØ Benchmark Comparison:")
    print(f"   Target RMSE: < 170 mL  ‚Üí  {'‚úÖ PASSED' if oof_metrics['RMSE'] < 170 else '‚ùå NEEDS IMPROVEMENT'}")
    print(f"   Target LLL:  > -6.64   ‚Üí  {'‚úÖ PASSED' if oof_metrics['LLL'] > -6.64 else '‚ùå NEEDS IMPROVEMENT'}")
    
    # Average fold metrics
    avg_metrics = {k: np.mean([f[k] for f in fold_results]) for k in fold_results[0].keys()}
    print(f"\nüìà Average Across Folds:")
    for metric, value in avg_metrics.items():
        print(f"   {metric:<6} = {value:.4f}")
    
    return fold_models, oof_predictions, oof_metrics

# ==========================================
# 10. RUN TRAINING
# ==========================================
if __name__ == "__main__":
    models, oof_preds, final_metrics = train_quantile_mlp()
    print("\n‚úÖ Training Complete! Models saved as 'best_model_fold{1-5}.pth'")

üöÄ PHASE 2: QUANTILE MLP TRAINING (5-FOLD ENSEMBLE)
Device: cuda

üìä Preparing Data...
‚úÖ Dataset Shape: (1549, 14)
‚úÖ Features (14): ['Age', 'Sex', 'SmokingStatus', 'Weeks', 'Percent', 'lung_vol_ml', 'hu_mean', 'hu_std', 'hu_skew', 'hu_kurt', 'glcm_contrast', 'glcm_homogeneity', 'glcm_energy', 'glcm_correlation']
‚úÖ Target Range: FVC [827, 6399]
‚úÖ Samples: 1549


üìÇ FOLD 1/5

Epoch  Train Loss   R¬≤       MSE        RMSE     MAE      RMAE     LLL     
----------------------------------------------------------------------
1      1349.4300    -11.9587 7793084.50 2791.61  2681.74  0.9998   -58.7744
2      1343.6302    -11.9528 7789507.00 2790.97  2681.07  0.9996   -58.7609
3      1343.5650    -11.9473 7786223.50 2790.38  2680.46  0.9993   -58.7485
4      1347.6377    -11.9404 7782072.00 2789.64  2679.68  0.9990   -58.7328
5      1346.6469    -11.9324 7777280.00 2788.78  2678.79  0.9987   -58.7148
6      1343.3291    -11.9222 7771157.50 2787.68  2677.62  0.9983   -58.6912
7    