Progresive Training

**mae based**

In [8]:
# Imports and environment setup
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import cv2
import pydicom
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from sklearn.model_selection import train_test_split

# Set seeds for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

# Paths and device setup
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load train.csv
train_df = pd.read_csv(DATA_DIR / "train.csv")
print(f"Loaded train dataset: {train_df.shape}")

# Tabular feature extraction function
def get_tab_features(row):
    age_scaled = (row['Age'] - 30) / 30
    sex_encoded = 0 if row['Sex'] == 'Male' else 1
    smoking = row['SmokingStatus']
    smoke_map = {
        'Never smoked': [0,0],
        'Ex-smoker': [1,1],
        'Currently smokes': [0,1]
    }
    smoke_vec = smoke_map.get(smoking, [1,0])
    return np.array([age_scaled, sex_encoded, smoke_vec[0], smoke_vec[1]])

# Compute linear decay coefficients per patient
A = {}
TAB = {}
P = []
print("Calculating linear decay coefficients per patient...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, _ = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
        except:
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0])
    else:
        A[patient] = 0.0
    TAB[patient] = get_tab_features(sub.iloc[0])
    P.append(patient)
print(f"Processed {len(P)} patients")

# Medical Image Augmentations
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10, 50), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])
    def __call__(self, image):
        return self.transform(image=image)['image']

# Dataset class
class OSICDenseNetDataset(Dataset):
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = MedicalAugmentation(augment=augment)
        self.patient_images = {}
        for p in self.patients:
            p_dir = self.data_dir / p
            if p_dir.exists():
                images = [f for f in p_dir.iterdir() if f.suffix.lower() == '.dcm']
                if images:
                    self.patient_images[p] = images
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")
    def __len__(self):
        if self.split == 'train':
            return len(self.valid_patients) * 6  # augment repeat
        else:
            return len(self.valid_patients)
    def __getitem__(self, idx):
        patient_idx = idx % len(self.valid_patients) if self.split == 'train' else idx
        patient = self.valid_patients[patient_idx]
        images = self.patient_images[patient]
        img_path = np.random.choice(images) if len(images) > 1 else images[0]
        img = self.load_and_preprocess_dicom(img_path)
        img_tensor = self.augmentor(img)
        tab = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        return img_tensor, tab, target, patient
    def load_and_preprocess_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            if len(img.shape) == 3:
                img = img[img.shape[0] // 2]
            img = cv2.resize(img, (512, 512))
            img_min, img_max = img.min(), img.max()
            if img_max > img_min:
                img = (img - img_min) / (img_max - img_min) * 255
            else:
                img = np.zeros_like(img)
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            return img
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            return np.zeros((512, 512, 3), dtype=np.uint8)

# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        attn = torch.cat([avg_out, max_out], dim=1)
        attn = self.conv1(attn)
        attn = self.sigmoid(attn)
        return x * attn   # ✅ multiply with original features, not with reduced map


# Model with spatial attention + cross-modal attention + quantile regression output
class WorkingDenseNetModel(nn.Module):
    def __init__(self, tabular_dim=4, dropout_rate=0.4):
        super().__init__()
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        self.spatial_attention = SpatialAttention()
        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )
        embed_dim = 1024
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(512, embed_dim)
        self.value_proj = nn.Linear(512, embed_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=0.2, batch_first=True)
        self.fusion_layer = nn.Sequential(
            nn.Linear(embed_dim + 512, 768),
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(768, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate / 2)
        )
        # Main prediction head (mean FVC)
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        # Quantile regression heads
        self.lower_quantile_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.upper_quantile_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, images, tabular):
        batch_size = images.size(0)
        img_features = self.features(images)
        img_features = self.spatial_attention(img_features)
        img_features = F.adaptive_avg_pool2d(img_features, (1, 1)).view(batch_size, -1)  # (B, 1024)
        tab_features = self.tabular_processor(tabular)  # (B, 512)
        queries = self.query_proj(img_features).unsqueeze(1)  # (B, 1, 1024)
        keys = self.key_proj(tab_features).unsqueeze(1)       # (B, 1, 1024)
        values = self.value_proj(tab_features).unsqueeze(1)   # (B, 1, 1024)
        attended, _ = self.cross_attention(query=queries, key=keys, value=values)
        attended = attended.squeeze(1)  # (B, 1024)
        combined = torch.cat([attended, tab_features], dim=1)  # (B, 1536)
        fused = self.fusion_layer(combined)                    # (B, 256)
        mean_fvc = self.mean_head(fused).squeeze(1)            # (B,)
        lower_q = self.lower_quantile_head(fused).squeeze(1)   # (B,)
        upper_q = self.upper_quantile_head(fused).squeeze(1)   # (B,)
        return mean_fvc, lower_q, upper_q

# Quantile loss implementation
def quantile_loss(predictions, targets, quantile):
    errors = targets - predictions
    return torch.mean(torch.max(quantile * errors, (quantile - 1) * errors))

# Combined loss for training
def combined_loss(mean_pred, lower_pred, upper_pred, target):
    mse = F.mse_loss(mean_pred, target)
    lower_loss = quantile_loss(lower_pred, target, 0.2)
    upper_loss = quantile_loss(upper_pred, target, 0.8)
    return mse + lower_loss + upper_loss

# R² score implementation in PyTorch
def r2_score_torch(y_pred, y_true):
    ss_res = torch.sum((y_true - y_pred) ** 2)
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
    return 1 - ss_res / (ss_tot + 1e-8)

# Dataset splitting for training/validation
train_patients, val_patients = train_test_split(P, test_size=0.2, random_state=42)

# Instantiate datasets & loaders
train_dataset = OSICDenseNetDataset(train_patients, A, TAB, TRAIN_DIR, split='train', augment=True)
val_dataset = OSICDenseNetDataset(val_patients, A, TAB, TRAIN_DIR, split='val', augment=False)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
print(f"Training with {len(train_loader)} batches, validating with {len(val_loader)} batches")

# Training loop class
class Trainer:
    def __init__(self, model, device, learning_rate=1e-4):
        self.model = model
        self.device = device
        self.lr = learning_rate
        self.best_val_mae = float('inf')
    def train(self, train_loader, val_loader, epochs=30, patience=8):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4, verbose=True)
        patience_counter = 0
        for epoch in range(epochs):
            self.model.train()
            train_losses, train_maes, train_r2s = [], [], []
            for images, tabular, targets, _ in train_loader:
                images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                mean_pred, lower_pred, upper_pred = self.model(images, tabular)
                loss = combined_loss(mean_pred, lower_pred, upper_pred, targets)
                mae = F.l1_loss(mean_pred, targets)
                r2 = r2_score_torch(mean_pred, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())
                train_maes.append(mae.item())
                train_r2s.append(r2.item())
            avg_train_loss = np.mean(train_losses)
            avg_train_mae = np.mean(train_maes)
            avg_train_r2 = np.mean(train_r2s)
            # Validation
            self.model.eval()
            val_losses, val_maes, val_r2s = [], [], []
            with torch.no_grad():
                for images, tabular, targets, _ in val_loader:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    mean_pred, lower_pred, upper_pred = self.model(images, tabular)
                    loss = combined_loss(mean_pred, lower_pred, upper_pred, targets)
                    mae = F.l1_loss(mean_pred, targets)
                    r2 = r2_score_torch(mean_pred, targets)
                    val_losses.append(loss.item())
                    val_maes.append(mae.item())
                    val_r2s.append(r2.item())
            avg_val_loss = np.mean(val_losses)
            avg_val_mae = np.mean(val_maes)
            avg_val_r2 = np.mean(val_r2s)
            print(f"Epoch {epoch+1}/{epochs} - "
                  f"Train loss: {avg_train_loss:.4f}, MAE: {avg_train_mae:.4f}, R²: {avg_train_r2:.4f} | "
                  f"Val loss: {avg_val_loss:.4f}, MAE: {avg_val_mae:.4f}, R²: {avg_val_r2:.4f}")
            scheduler.step(avg_val_mae)
            if avg_val_mae < self.best_val_mae:
                self.best_val_mae = avg_val_mae
                torch.save(self.model.state_dict(), "best_working_model.pth")
                print("Best model saved!")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        return self.best_val_mae

# Instantiate model and trainer
model = WorkingDenseNetModel(tabular_dim=4).to(DEVICE)
trainer = Trainer(model, DEVICE, learning_rate=1e-4)

# Start training
best_mae = trainer.train(train_loader, val_loader, epochs=30, patience=8)
print(f"Training complete with best validation MAE: {best_mae:.4f}")

# Test-time augmentation predictor
class TTAPredictor:
    def __init__(self, model, augmentor, device, num_augmentations=5):
        self.model = model
        self.augmentor = augmentor
        self.device = device
        self.num_augmentations = num_augmentations
        self.model.eval()
    def predict(self, image, tabular_features):
        means, lowers, uppers = [], [], []
        with torch.no_grad():
            img_tensor = self.augmentor(image=image)['image'].to(self.device)
            tab_tensor = torch.tensor(tabular_features).float().unsqueeze(0).to(self.device)
            m, l, u = self.model(img_tensor.unsqueeze(0), tab_tensor)
            means.append(m.item())
            lowers.append(l.item())
            uppers.append(u.item())
            for _ in range(self.num_augmentations):
                aug_img = self.augmentor(image=image)['image'].to(self.device)
                m, l, u = self.model(aug_img.unsqueeze(0), tab_tensor)
                means.append(m.item())
               


Using device: cuda
Loaded train dataset: (1549, 7)
Calculating linear decay coefficients per patient...


100%|██████████| 176/176 [00:00<00:00, 1352.89it/s]

Processed 176 patients





Dataset train: 138 patients with images
Dataset val: 36 patients with images
Training with 103 batches, validating with 5 batches
Epoch 1/30 - Train loss: 51.0877, MAE: 5.0442, R²: -0.9643 | Val loss: 68.2587, MAE: 5.7449, R²: -0.5497
Best model saved!
Epoch 2/30 - Train loss: 38.4311, MAE: 4.2857, R²: -0.3855 | Val loss: 56.0758, MAE: 5.0694, R²: -0.1019
Best model saved!
Epoch 3/30 - Train loss: 33.4124, MAE: 4.0172, R²: -0.1599 | Val loss: 52.8034, MAE: 4.9098, R²: -0.0648
Best model saved!
Epoch 4/30 - Train loss: 32.6211, MAE: 3.9740, R²: -0.2348 | Val loss: 54.1547, MAE: 4.9816, R²: -0.0732
Epoch 5/30 - Train loss: 32.4139, MAE: 4.0162, R²: -0.1434 | Val loss: 53.8372, MAE: 4.9236, R²: -0.0445
Epoch 6/30 - Train loss: 32.5289, MAE: 4.0367, R²: -0.2262 | Val loss: 56.7123, MAE: 5.0993, R²: -0.0964
Epoch 7/30 - Train loss: 32.4337, MAE: 4.0171, R²: -0.1252 | Val loss: 53.8248, MAE: 4.9535, R²: -0.0692
Epoch 8/30 - Train loss: 32.1055, MAE: 3.9782, R²: -0.1686 | Val loss: 55.7477, M

**LLL based**
Just a small typo in the end

In [5]:
# Imports and environment setup
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

import cv2
import pydicom
import albumentations as albu
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# Set seeds for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Paths and device setup
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load train.csv
train_df = pd.read_csv(DATA_DIR / "train.csv")
print(f"Loaded train dataset: {train_df.shape}")

# Tabular feature extraction function
def get_tab_features(row):
    age_scaled = (row['Age'] - 30) / 30
    sex_encoded = 0 if row['Sex'] == 'Male' else 1
    smoking = row['SmokingStatus']
    smoke_map = {
        'Never smoked': [0,0],
        'Ex-smoker': [1,1],
        'Currently smokes': [0,1]
    }
    smoke_vec = smoke_map.get(smoking, [1,0])
    return np.array([age_scaled, sex_encoded, smoke_vec[0], smoke_vec[1]])

# Compute linear decay coefficients per patient
A = {}
TAB = {}
P = []

print("Calculating linear decay coefficients per patient...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values

    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, _ = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
        except:
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0])
    else:
        A[patient] = 0.0
    TAB[patient] = get_tab_features(sub.iloc[0])
    P.append(patient)

print(f"Processed {len(P)} patients")

# Medical Image Augmentations
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10, 50), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])

    def __call__(self, image):
        return self.transform(image=image)['image']

# Dataset class
class OSICDenseNetDataset(Dataset):
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = MedicalAugmentation(augment=augment)
        self.patient_images = {}
        for p in self.patients:
            p_dir = self.data_dir / p
            if p_dir.exists():
                images = [f for f in p_dir.iterdir() if f.suffix.lower() == '.dcm']
                if images:
                    self.patient_images[p] = images
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")

    def __len__(self):
        if self.split == 'train':
            return len(self.valid_patients) * 6  # augment repeat
        else:
            return len(self.valid_patients)

    def __getitem__(self, idx):
        patient_idx = idx % len(self.valid_patients) if self.split == 'train' else idx
        patient = self.valid_patients[patient_idx]
        images = self.patient_images[patient]
        img_path = np.random.choice(images) if len(images) > 1 else images[0]
        img = self.load_and_preprocess_dicom(img_path)
        img_tensor = self.augmentor(img)
        tab = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        return img_tensor, tab, target, patient

    def load_and_preprocess_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            if len(img.shape) == 3:
                img = img[img.shape[0] // 2]
            img = cv2.resize(img, (512, 512))
            img_min, img_max = img.min(), img.max()
            if img_max > img_min:
                img = (img - img_min) / (img_max - img_min) * 255
            else:
                img = np.zeros_like(img)
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            return img
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            return np.zeros((512, 512, 3), dtype=np.uint8)

# Model with spatial attention + cross-modal attention + quantile regression output
class WorkingDenseNetModel(nn.Module):
    def __init__(self, tabular_dim=4, dropout_rate=0.4):
        super().__init__()
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = densenet.features
        self.spatial_attention = SpatialAttention()

        self.tabular_processor = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )

        embed_dim = 1024
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(512, embed_dim)
        self.value_proj = nn.Linear(512, embed_dim)

        self.cross_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=0.2, batch_first=True)

        self.fusion_layer = nn.Sequential(
            nn.Linear(embed_dim + 512, 768),
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(768, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate / 2)
        )

        # Main prediction head (mean FVC)
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        # Quantile regression heads
        self.lower_quantile_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.upper_quantile_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, images, tabular):
        batch_size = images.size(0)
        img_features = self.features(images)
        img_features = self.spatial_attention(img_features)
        img_features = F.adaptive_avg_pool2d(img_features, (1, 1)).view(batch_size, -1)  # (B, 1024)

        tab_features = self.tabular_processor(tabular)  # (B, 512)

        queries = self.query_proj(img_features).unsqueeze(1)  # (B, 1, 1024)
        keys = self.key_proj(tab_features).unsqueeze(1)       # (B, 1, 1024)
        values = self.value_proj(tab_features).unsqueeze(1)   # (B, 1, 1024)

        attended, _ = self.cross_attention(query=queries, key=keys, value=values)
        attended = attended.squeeze(1)  # (B, 1024)

        combined = torch.cat([attended, tab_features], dim=1)  # (B, 1536)
        fused = self.fusion_layer(combined)                    # (B, 256)

        mean_fvc = self.mean_head(fused).squeeze(1)            # (B,)
        lower_q = self.lower_quantile_head(fused).squeeze(1)   # (B,)
        upper_q = self.upper_quantile_head(fused).squeeze(1)   # (B,)

        return mean_fvc, lower_q, upper_q

# Quantile loss implementation
def quantile_loss(predictions, targets, quantile):
    errors = targets - predictions
    return torch.mean(torch.max(quantile * errors, (quantile - 1) * errors))

# Combined loss for training
def combined_loss(mean_pred, lower_pred, upper_pred, target):
    mse = F.mse_loss(mean_pred, target)
    lower_loss = quantile_loss(lower_pred, target, 0.2)
    upper_loss = quantile_loss(upper_pred, target, 0.8)
    return mse + lower_loss + upper_loss

# Dataset splitting for training/validation
train_patients, val_patients = train_test_split(P, test_size=0.2, random_state=42)

# Instantiate datasets & loaders
train_dataset = OSICDenseNetDataset(train_patients, A, TAB, TRAIN_DIR, split='train', augment=True)
val_dataset = OSICDenseNetDataset(val_patients, A, TAB, TRAIN_DIR, split='val', augment=False)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training with {len(train_loader)} batches, validating with {len(val_loader)} batches")

# Training loop class
class Trainer:
    def __init__(self, model, device, learning_rate=1e-4):
        self.model = model
        self.device = device
        self.lr = learning_rate
        self.best_val_lll = -float('inf')  # Track best (max) LLL

    def laplace_log_likelihood_metric(self, pred_mean, pred_lower, pred_upper, targets):
        # Approximate log variance from quantile upper/lower bounds, avoiding direct use of predicted variance
        sigma = (pred_upper - pred_lower) / 2
        sigma = np.clip(sigma, 1e-3, 1e3)
        delta = np.abs(pred_mean - targets)
        score = -np.log(2 * sigma) - delta / sigma
        return np.mean(score)

    def train(self, train_loader, val_loader, epochs=30, patience=8):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True)
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_losses, train_maes = [], []

            for images, tabular, targets, _ in train_loader:
                images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                mean_pred, lower_pred, upper_pred = self.model(images, tabular)
                loss = combined_loss(mean_pred, lower_pred, upper_pred, targets)
                mae = F.l1_loss(mean_pred, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())
                train_maes.append(mae.item())

            avg_train_loss = np.mean(train_losses)
            avg_train_mae = np.mean(train_maes)

            # Validation
            self.model.eval()
            val_losses, val_maes = [], []
            val_pred_means = []
            val_pred_lowers = []
            val_pred_uppers = []
            val_targets_list = []

            with torch.no_grad():
                for images, tabular, targets, _ in val_loader:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    mean_pred, lower_pred, upper_pred = self.model(images, tabular)
                    loss = combined_loss(mean_pred, lower_pred, upper_pred, targets)
                    mae = F.l1_loss(mean_pred, targets)

                    val_losses.append(loss.item())
                    val_maes.append(mae.item())

                    val_pred_means.extend(mean_pred.cpu().numpy())
                    val_pred_lowers.extend(lower_pred.cpu().numpy())
                    val_pred_uppers.extend(upper_pred.cpu().numpy())
                    val_targets_list.extend(targets.cpu().numpy())

            val_pred_means = np.array(val_pred_means)
            val_pred_lowers = np.array(val_pred_lowers)
            val_pred_uppers = np.array(val_pred_uppers)
            val_targets_list = np.array(val_targets_list)

            avg_val_loss = np.mean(val_losses)
            avg_val_mae = np.mean(val_maes)

            # Calculate LLL and R^2
            val_lll = self.laplace_log_likelihood_metric(val_pred_means, val_pred_lowers, val_pred_uppers, val_targets_list)
            ss_res = np.sum((val_targets_list - val_pred_means) ** 2)
            ss_tot = np.sum((val_targets_list - np.mean(val_targets_list)) ** 2)
            val_r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else float('-inf')

            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {avg_train_loss:.6f}, MAE: {avg_train_mae:.6f}")
            print(f"Val Loss: {avg_val_loss:.6f}, MAE: {avg_val_mae:.6f}, R²: {val_r2:.6f}, LLL (main): {val_lll:.6f}")

            # Scheduler step and model save based on max LLL
            scheduler.step(val_lll)
            if val_lll > self.best_val_lll:
                self.best_val_lll = val_lll
                torch.save(self.model.state_dict(), 'best_working_model.pth')
                print("✅ New best model saved (based on LLL)!")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        return self.best_val_lll


# Instantiate model and trainer
model = WorkingDenseNetModel(tabular_dim=4).to(DEVICE)
trainer = Trainer(model, DEVICE, learning_rate=1e-4)

# Start training
best_mae = trainer.train(train_loader, val_loader, epochs=30, patience=8)
print(f"Training complete with best validation MAE: {best_mae:.4f}")

# Test-time augmentation predictor
class TTAPredictor:
    def __init__(self, model, augmentor, device, num_augmentations=5):
        self.model = model
        self.augmentor = augmentor
        self.device = device
        self.num_augmentations = num_augmentations
        self.model.eval()

    def predict(self, image, tabular_features):
        means, lowers, uppers = [], [], []
        with torch.no_grad():
            img_tensor = self.augmentor(image=image)['image'].to(self.device)
            tab_tensor = torch.tensor(tabular_features).float().unsqueeze(0).to(self.device)

            m, l, u = self.model(img_tensor.unsqueeze(0), tab_tensor)
            means.append(m.item())
            lowers.append(l.item())
            uppers.append(u.item())

            for _ in range(self.num_augmentations):
                aug_img = self.augmentor(image=image)['image'].to(self.device)
                m, l, u = self.model(aug_img.unsqueeze(0), tab_tensor)
                means.append(m.item())
                lowers.append(l.item())
                uppers.append(u.item())

        mean_pred = np.median(means)
        lower_pred = np.median(lowers)
        upper_pred = np.median(uppers)
        confidence = upper_pred - lower_pred
        return mean_pred, confidence

test_augmentor = MedicalAugmentation(augment=False)
tta_predictor = TTAPredictor(model, test_augmentor, DEVICE)

# Optional: implement submission function, evaluation, or extend with confidence heads per your requirement



Using device: cuda
Loaded train dataset: (1549, 7)
Calculating linear decay coefficients per patient...


100%|██████████| 176/176 [00:00<00:00, 1305.71it/s]

Processed 176 patients





Dataset train: 138 patients with images
Dataset val: 36 patients with images
Training with 103 batches, validating with 5 batches
Epoch 1/30
Train Loss: 51.087677, MAE: 5.044207
Val Loss: 68.258679, MAE: 5.744914, R²: -0.279940, LLL (main): -9.970015
✅ New best model saved (based on LLL)!
Epoch 2/30
Train Loss: 38.431060, MAE: 4.285720
Val Loss: 56.075777, MAE: 5.069439, R²: -0.077374, LLL (main): -3.848270
✅ New best model saved (based on LLL)!
Epoch 3/30
Train Loss: 33.412413, MAE: 4.017181
Val Loss: 52.803370, MAE: 4.909763, R²: -0.006549, LLL (main): -3.456856
✅ New best model saved (based on LLL)!
Epoch 4/30
Train Loss: 32.621137, MAE: 3.973981
Val Loss: 54.154743, MAE: 4.981632, R²: -0.034543, LLL (main): -3.431411
✅ New best model saved (based on LLL)!
Epoch 5/30
Train Loss: 32.413872, MAE: 4.016197
Val Loss: 53.837245, MAE: 4.923565, R²: -0.033564, LLL (main): -3.517206
Epoch 6/30
Train Loss: 32.528891, MAE: 4.036727
Val Loss: 56.712250, MAE: 5.099263, R²: -0.095244, LLL (main)

**1st lll trying to push r^2 to be between 0.1 to 0.5**

In [9]:
# Imports and environment setup
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import cv2
import pydicom
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# Seed for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Paths and device
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load train.csv
train_df = pd.read_csv(DATA_DIR / "train.csv")
print(f"Loaded train dataset: {train_df.shape}")

# Tabular feature extraction
def get_tab_features(row):
    age_scaled = (row['Age'] - 30) / 30
    sex_encoded = 0 if row['Sex'] == 'Male' else 1
    smoke_map = {
        'Never smoked': [0, 0],
        'Ex-smoker': [1, 1],
        'Currently smokes': [0, 1]
    }
    smoking = row['SmokingStatus']
    smoke_vec = smoke_map.get(smoking, [1, 0])
    return np.array([age_scaled, sex_encoded, smoke_vec[0], smoke_vec[1]])

# Compute linear decay coefficients
A = {}
TAB = {}
P = []

print("Calculating linear decay coefficients per patient...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, _ = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
        except:
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0])
    else:
        A[patient] = 0.0
    TAB[patient] = get_tab_features(sub.iloc[0])
    P.append(patient)
print(f"Processed {len(P)} patients.")

# Medical Image Augmentation
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10, 50), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])

    def __call__(self, image):
        return self.transform(image=image)['image']

# Dataset
class OSICDenseNetDataset(Dataset):
    def __init__(self, patients, A_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = MedicalAugmentation(augment=augment)
        self.patient_images = {}
        for p in self.patients:
            pd = self.data_dir / p
            if pd.exists():
                imgs = [f for f in pd.iterdir() if f.suffix.lower() == '.dcm']
                if imgs:
                    self.patient_images[p] = imgs
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")

    def __len__(self):
        return len(self.valid_patients) * 6 if self.split == 'train' else len(self.valid_patients)

    def __getitem__(self, idx):
        idx_mod = idx % len(self.valid_patients) if self.split == 'train' else idx
        patient = self.valid_patients[idx_mod]
        images = self.patient_images[patient]
        img_path = np.random.choice(images) if len(images) > 1 else images[0]
        img = self.load_and_preprocess_dicom(img_path)
        img_tensor = self.augmentor(img)
        tab_feat = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        tgt = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        return img_tensor, tab_feat, tgt, patient

    def load_and_preprocess_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            if len(img.shape) == 3:
                img = img[img.shape[0] // 2]
            img = cv2.resize(img, (512, 512))
            mn, mx = img.min(), img.max()
            if mx > mn:
                img = (img - mn) / (mx - mn) * 255
            else:
                img = np.zeros_like(img)
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            return img
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            return np.zeros((512, 512, 3), dtype=np.uint8)

# Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv1(x_cat)
        return x * self.sigmoid(x_cat)

# Model with uncertainty added as auxiliary output
class WorkingDenseNetModel(nn.Module):
    def __init__(self, tab_dim=4, dropout=0.4):
        super().__init__()
        dnet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = dnet.features
        self.spatial_attention = SpatialAttention()
        self.tabular_processor = nn.Sequential(
            nn.Linear(tab_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU()
        )
        edim = 1024
        self.query_proj = nn.Linear(edim, edim)
        self.key_proj = nn.Linear(512, edim)
        self.value_proj = nn.Linear(512, edim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=edim, num_heads=8, dropout=0.2, batch_first=True)
        self.fusion_layer = nn.Sequential(
            nn.Linear(edim + 512, 768),
            nn.BatchNorm1d(768),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(768, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout / 2)
        )
        self.mean_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.uncertainty_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Softplus()  # ensures positive uncertainty
        )
    def forward(self, images, tabular):
        b = images.size(0)
        img_feat = self.features(images)
        img_feat = self.spatial_attention(img_feat)
        img_feat = F.adaptive_avg_pool2d(img_feat, (1, 1)).view(b, -1)
        tab_feat = self.tabular_processor(tabular)
        q = self.query_proj(img_feat).unsqueeze(1)
        k = self.key_proj(tab_feat).unsqueeze(1)
        v = self.value_proj(tab_feat).unsqueeze(1)
        attn_out, _ = self.cross_attention(q, k, v)
        attn_out = attn_out.squeeze(1)
        combined = torch.cat([attn_out, tab_feat], dim=1)
        fused_feat = self.fusion_layer(combined)
        mean_pred = self.mean_head(fused_feat).squeeze(1)
        uncertainty = self.uncertainty_head(fused_feat).squeeze(1)
        return mean_pred, uncertainty

# Loss combining MSE as primary and uncertainty weighted auxiliary loss
def combined_loss(mean_pred, uncertainty, targets, unc_weight=0.1):
    mse = F.mse_loss(mean_pred, targets)
    # Encourage uncertainty to reflect prediction error
    unc_loss = torch.mean(((mean_pred - targets).pow(2) / (uncertainty + 1e-6)) + torch.log(uncertainty + 1e-6))
    return mse + unc_weight * unc_loss

# Laplace Log Likelihood metric for reporting
def laplace_log_likelihood_metric(mean_pred, uncertainty, targets):
    sigma = np.sqrt(uncertainty)
    sigma = np.clip(sigma, 1e-3, 1e3)
    delta = np.abs(mean_pred - targets)
    score = -np.log(2 * sigma) - delta / sigma
    return np.mean(score)

# Dataset split
from sklearn.model_selection import train_test_split
train_patients, val_patients = train_test_split(P, test_size=0.2, random_state=42)

train_dataset = OSICDenseNetDataset(train_patients, A, TAB, TRAIN_DIR, split='train', augment=True)
val_dataset = OSICDenseNetDataset(val_patients, A, TAB, TRAIN_DIR, split='val', augment=False)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training with {len(train_loader)} batches, validating with {len(val_loader)} batches.")

# Trainer class with MSE+uncertainty loss and LLL + R² reporting
class Trainer:
    def __init__(self, model, device, lr=1e-4):
        self.model = model
        self.device = device
        self.lr = lr
        self.best_val_lll = -float('inf')

    def train(self, train_loader, val_loader, epochs=30, patience=8):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=4, verbose=True
        )
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_losses, train_maes = [], []

            for images, tabular, targets, _ in train_loader:
                images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                mean_pred, uncertainty = self.model(images, tabular)
                loss = combined_loss(mean_pred, uncertainty, targets)
                mae = F.l1_loss(mean_pred, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())
                train_maes.append(mae.item())

            avg_train_loss = np.mean(train_losses)
            avg_train_mae = np.mean(train_maes)

            self.model.eval()
            val_losses, val_maes = [], []
            val_preds, val_vars, val_targets = [], [], []

            with torch.no_grad():
                for images, tabular, targets, _ in val_loader:
                    images, tabular, targets = images.to(self.device), tabular.to(self.device), targets.to(self.device)
                    mean_pred, uncertainty = self.model(images, tabular)
                    loss = combined_loss(mean_pred, uncertainty, targets)
                    mae = F.l1_loss(mean_pred, targets)

                    val_losses.append(loss.item())
                    val_maes.append(mae.item())
                    val_preds.extend(mean_pred.cpu().numpy())
                    val_vars.extend(uncertainty.cpu().numpy())
                    val_targets.extend(targets.cpu().numpy())

            val_preds = np.array(val_preds)
            val_vars = np.array(val_vars)
            val_targets = np.array(val_targets)

            avg_val_loss = np.mean(val_losses)
            avg_val_mae = np.mean(val_maes)

            # Calculate LLL and R2
            val_lll = laplace_log_likelihood_metric(val_preds, val_vars, val_targets)
            ss_res = np.sum((val_targets - val_preds) ** 2)
            ss_tot = np.sum((val_targets - np.mean(val_targets)) ** 2)
            val_r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else float('-inf')

            print(
                f"Epoch {epoch+1}/{epochs} | "
                f"Train Loss: {avg_train_loss:.4f}, MAE: {avg_train_mae:.4f} | "
                f"Val Loss: {avg_val_loss:.4f}, MAE: {avg_val_mae:.4f}, R²: {val_r2:.4f}, LLL (main): {val_lll:.4f}"
            )

            # Scheduler and early stopping based on LLL (maximize)
            scheduler.step(val_lll)

            if val_lll > self.best_val_lll:
                self.best_val_lll = val_lll
                torch.save(self.model.state_dict(), "best_working_model.pth")
                print("✅ New best model saved based on LLL!")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        return self.best_val_lll

# Initialize and train
model = WorkingDenseNetModel(tab_dim=4).to(DEVICE)
trainer = Trainer(model, DEVICE, lr=1e-4)
best_lll = trainer.train(train_loader, val_loader, epochs=30, patience=8)
print(f"Training complete. Best validation LLL: {best_lll:.4f}")


Using device: cuda
Loaded train dataset: (1549, 7)
Calculating linear decay coefficients per patient...


100%|██████████| 176/176 [00:00<00:00, 1165.67it/s]

Processed 176 patients.





Dataset train: 138 patients with images
Dataset val: 36 patients with images
Training with 103 batches, validating with 5 batches.
Epoch 1/30 | Train Loss: 51.0170, MAE: 5.0713 | Val Loss: 66.9944, MAE: 5.7251, R²: -0.2759, LLL (main): -5.6618
✅ New best model saved based on LLL!
Epoch 2/30 | Train Loss: 36.1358, MAE: 4.2566 | Val Loss: 53.2650, MAE: 5.0861, R²: -0.0407, LLL (main): -4.5835
✅ New best model saved based on LLL!
Epoch 3/30 | Train Loss: 31.5575, MAE: 4.0260 | Val Loss: 49.5153, MAE: 4.8130, R²: 0.0172, LLL (main): -3.9880
✅ New best model saved based on LLL!
Epoch 4/30 | Train Loss: 30.9680, MAE: 4.0346 | Val Loss: 54.1459, MAE: 5.0264, R²: -0.0846, LLL (main): -4.1635
Epoch 5/30 | Train Loss: 30.4053, MAE: 4.0412 | Val Loss: 51.4621, MAE: 4.9320, R²: -0.0332, LLL (main): -3.9351
✅ New best model saved based on LLL!
Epoch 6/30 | Train Loss: 30.9999, MAE: 4.0573 | Val Loss: 51.8511, MAE: 4.9948, R²: -0.0316, LLL (main): -3.8778
✅ New best model saved based on LLL!
Epoch 7

**Here is a set of advanced, practical strategies with implementation guidance that have proven effective in similar medical regression tasks and can be implemented within your Kaggle time budget:**

1. Use a Multi-Task Learning approach:
Add an auxiliary regression target to jointly predict the raw FVC value (or percentile) alongside decay slope 
A
A. This helps the model learn finer representations.

E.g., predict both 
A
A (decay rate) and baseline FVC.

Combine their losses weighted appropriately.

2. Use Longitudinal Information or Patient Embeddings:
Your current model uses single image slices without explicit temporal or patient-specific embeddings beyond tabular data. Enhancing temporal modeling with:

Adding trainable patient embeddings or encoding patient ID with learnable embedding vectors.

Using sequences of images as input, e.g., with temporal CNN, RNN, or transformers, to better capture progression trends.

3. Refine Loss Function by Adding an MAE/MSE term weighted more heavily:
Because you want better predictive fit (and R² reflects that), increase the weight of standard MSE or MAE loss compared to uncertainty loss, or first pretrain model on pure regression loss then finetune uncertainty.

In [10]:
# Imports and environment setup
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import cv2
import pydicom
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from sklearn.model_selection import train_test_split

# Seed for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything()

# Paths and device
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
TEST_DIR = DATA_DIR / "test"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load train.csv
train_df = pd.read_csv(DATA_DIR / "train.csv")
print(f"Loaded train dataset: {train_df.shape}")

# Tabular feature extraction
def get_tab_features(row):
    age_scaled = (row['Age'] - 30) / 30
    sex_encoded = 0 if row['Sex'] == 'Male' else 1
    smoke_map = {
        'Never smoked': [0, 0],
        'Ex-smoker': [1, 1],
        'Currently smokes': [0, 1]
    }
    smoking = row['SmokingStatus']
    smoke_vec = smoke_map.get(smoking, [1, 0])
    return np.array([age_scaled, sex_encoded, smoke_vec[0], smoke_vec[1]])

# Compute linear decay coefficients per patient
A = {}
TAB = {}
BASELINE_FVC = {}
P = []

print("Calculating linear decay coefficients and baseline FVC per patient...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, b = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
            BASELINE_FVC[patient] = b
        except:
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0])
            BASELINE_FVC[patient] = fvc[0]
    else:
        A[patient] = 0.0
        BASELINE_FVC[patient] = fvc[0]
    TAB[patient] = get_tab_features(sub.iloc[0])
    P.append(patient)
print(f"Processed {len(P)} patients.")

# Medical Image Augmentation
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10, 50), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225]),
                ToTensorV2()
            ])
    def __call__(self, image):
        return self.transform(image=image)['image']

# Dataset
class OSICDenseNetMultiTaskDataset(Dataset):
    def __init__(self, patients, A_dict, BASELINE_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.BASELINE_dict = BASELINE_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = MedicalAugmentation(augment=augment)
        self.patient_images = {}
        for p in self.patients:
            pd = self.data_dir / p
            if pd.exists():
                imgs = [f for f in pd.iterdir() if f.suffix.lower() == '.dcm']
                if imgs:
                    self.patient_images[p] = imgs
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")

    def __len__(self):
        return len(self.valid_patients) * 6 if self.split == 'train' else len(self.valid_patients)

    def __getitem__(self, idx):
        idx_mod = idx % len(self.valid_patients) if self.split == 'train' else idx
        patient = self.valid_patients[idx_mod]
        images = self.patient_images[patient]
        img_path = np.random.choice(images) if len(images) > 1 else images[0]
        img = self.load_and_preprocess_dicom(img_path)
        img_tensor = self.augmentor(img)
        tab_feat = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)
        decay_target = torch.tensor(self.A_dict[patient], dtype=torch.float32)
        baseline_target = torch.tensor(self.BASELINE_dict[patient], dtype=torch.float32)
        return img_tensor, tab_feat, decay_target, baseline_target

    def load_and_preprocess_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            if len(img.shape) == 3:
                img = img[img.shape[0] // 2]
            img = cv2.resize(img, (512, 512))
            mn, mx = img.min(), img.max()
            if mx > mn:
                img = (img - mn) / (mx - mn) * 255
            else:
                img = np.zeros_like(img)
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            return img
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            return np.zeros((512, 512, 3), dtype=np.uint8)

# Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv1(x_cat)
        return x * self.sigmoid(x_cat)

# Multi-task model with decay rate and baseline FVC prediction
class MultiTaskDenseNetModel(nn.Module):
    def __init__(self, tab_dim=4, dropout=0.4):
        super().__init__()
        dnet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = dnet.features
        self.spatial_attention = SpatialAttention()
        self.tabular_processor = nn.Sequential(
            nn.Linear(tab_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU()
        )
        edim = 1024
        self.query_proj = nn.Linear(edim, edim)
        self.key_proj = nn.Linear(512, edim)
        self.value_proj = nn.Linear(512, edim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=edim, num_heads=8, dropout=0.2, batch_first=True)
        self.fusion_layer = nn.Sequential(
            nn.Linear(edim + 512, 768), nn.BatchNorm1d(768), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(768, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout / 2)
        )
        # Predict decay rate A
        self.decay_head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
        # Predict baseline FVC
        self.baseline_fvc_head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, images, tabular):
        b = images.size(0)
        img_feat = self.features(images)
        img_feat = self.spatial_attention(img_feat)
        img_feat = F.adaptive_avg_pool2d(img_feat, (1, 1)).view(b, -1)
        tab_feat = self.tabular_processor(tabular)
        q = self.query_proj(img_feat).unsqueeze(1)
        k = self.key_proj(tab_feat).unsqueeze(1)
        v = self.value_proj(tab_feat).unsqueeze(1)
        attn_out, _ = self.cross_attention(q, k, v)
        attn_out = attn_out.squeeze(1)
        combined = torch.cat([attn_out, tab_feat], dim=1)
        fused_feat = self.fusion_layer(combined)
        decay_pred = self.decay_head(fused_feat).squeeze(1)
        baseline_pred = self.baseline_fvc_head(fused_feat).squeeze(1)
        return decay_pred, baseline_pred

# Multi-task combined MSE loss with alpha weighting
def multitask_loss(decay_pred, baseline_pred, decay_target, baseline_target, alpha=0.7):
    decay_loss = F.mse_loss(decay_pred, decay_target)
    baseline_loss = F.mse_loss(baseline_pred, baseline_target)
    return alpha * decay_loss + (1 - alpha) * baseline_loss

# Laplace Log Likelihood metric for reporting using normal approximation of residuals with std dev from data
def laplace_log_likelihood_metric(pred, target, eps=1e-6):
    # residual std dev estimated from residual error per batch (used for reporting only)
    residuals = np.abs(pred - target)
    sigma = np.maximum(np.std(residuals), eps)
    delta = np.abs(pred - target)
    score = -np.log(2 * sigma) - delta / sigma
    return np.mean(score)

# Dataset splitting
train_patients, val_patients = train_test_split(P, test_size=0.2, random_state=42)

train_dataset = OSICDenseNetMultiTaskDataset(train_patients, A, BASELINE_FVC, TAB, TRAIN_DIR, augment=True, split='train')
val_dataset = OSICDenseNetMultiTaskDataset(val_patients, A, BASELINE_FVC, TAB, TRAIN_DIR, augment=False, split='val')

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training with {len(train_loader)} batches, validating with {len(val_loader)} batches.")

# Trainer class monitoring R² as main metric with LLL printed every epoch
class MultiTaskTrainer:
    def __init__(self, model, device, lr=1e-4, alpha=0.7):
        self.model = model
        self.device = device
        self.lr = lr
        self.alpha = alpha
        self.best_val_r2 = -float('inf')

    def train(self, train_loader, val_loader, epochs=25, patience=6):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_losses = []

            for images, tabular, decay_targets, baseline_targets in train_loader:
                images, tabular = images.to(self.device), tabular.to(self.device)
                decay_targets, baseline_targets = decay_targets.to(self.device), baseline_targets.to(self.device)

                optimizer.zero_grad()
                decay_pred, baseline_pred = self.model(images, tabular)
                loss = multitask_loss(decay_pred, baseline_pred, decay_targets, baseline_targets, alpha=self.alpha)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())

            avg_train_loss = np.mean(train_losses)

            self.model.eval()
            val_decay_preds, val_baseline_preds = [], []
            val_decay_targets, val_baseline_targets = [], []

            with torch.no_grad():
                for images, tabular, decay_targets, baseline_targets in val_loader:
                    images, tabular = images.to(self.device), tabular.to(self.device)
                    decay_targets, baseline_targets = decay_targets.to(self.device), baseline_targets.to(self.device)

                    decay_pred, baseline_pred = self.model(images, tabular)
                    val_decay_preds.extend(decay_pred.cpu().numpy())
                    val_baseline_preds.extend(baseline_pred.cpu().numpy())
                    val_decay_targets.extend(decay_targets.cpu().numpy())
                    val_baseline_targets.extend(baseline_targets.cpu().numpy())

            val_decay_preds = np.array(val_decay_preds)
            val_decay_targets = np.array(val_decay_targets)

            ss_res = np.sum((val_decay_targets - val_decay_preds) ** 2)
            ss_tot = np.sum((val_decay_targets - np.mean(val_decay_targets)) ** 2)
            val_r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else float('-inf')

            # Calculate LLL (reported every epoch, not used for early stop here)
            val_lll = laplace_log_likelihood_metric(val_decay_preds, val_decay_targets)

            print(
                f"Epoch {epoch+1}/{epochs} | "
                f"Train Loss: {avg_train_loss:.4f} | "
                f"Val R² (decay): {val_r2:.4f} | "
                f"LLL (reported): {val_lll:.4f}"
            )

            scheduler.step(val_r2)

            if val_r2 > self.best_val_r2:
                self.best_val_r2 = val_r2
                torch.save(self.model.state_dict(), "best_multitask_model.pth")
                print("✅ New best model saved based on R²!")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        return self.best_val_r2

# Initialize model and trainer
model = MultiTaskDenseNetModel(tab_dim=4).to(DEVICE)
trainer = MultiTaskTrainer(model, DEVICE, lr=1e-4, alpha=0.7)

# Train
best_r2 = trainer.train(train_loader, val_loader, epochs=25, patience=6)
print(f"Training complete. Best validation R²: {best_r2:.4f}")


Using device: cuda
Loaded train dataset: (1549, 7)
Calculating linear decay coefficients and baseline FVC per patient...


100%|██████████| 176/176 [00:00<00:00, 1333.15it/s]

Processed 176 patients.





Dataset train: 138 patients with images
Dataset val: 36 patients with images
Training with 103 batches, validating with 5 batches.
Epoch 1/25 | Train Loss: 2661169.8471 | Val R² (decay): -0.2947 | LLL (reported): -3.4533
✅ New best model saved based on R²!
Epoch 2/25 | Train Loss: 2657601.7621 | Val R² (decay): -0.1562 | LLL (reported): -3.3961
✅ New best model saved based on R²!
Epoch 3/25 | Train Loss: 2633535.4430 | Val R² (decay): -0.0591 | LLL (reported): -3.3695
✅ New best model saved based on R²!
Epoch 4/25 | Train Loss: 2602165.7464 | Val R² (decay): -0.1813 | LLL (reported): -3.3938
Epoch 5/25 | Train Loss: 2551404.4211 | Val R² (decay): -0.1923 | LLL (reported): -3.4261
Epoch 6/25 | Train Loss: 2476379.3847 | Val R² (decay): -0.1963 | LLL (reported): -3.4291
Epoch 7/25 | Train Loss: 2387398.4345 | Val R² (decay): -0.1774 | LLL (reported): -3.4425
Epoch 8/25 | Train Loss: 2304322.7439 | Val R² (decay): -0.1407 | LLL (reported): -3.4127
Epoch 9/25 | Train Loss: 2236971.8507 | V

**Doc previous work and strategies that was done long long back from google doc**

In [12]:
# Imports and environment setup
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import pydicom
import albumentations as albu
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from sklearn.model_selection import train_test_split

# Seed for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Paths and device setup
DATA_DIR = Path("../input/osic-pulmonary-fibrosis-progression")
TRAIN_DIR = DATA_DIR / "train"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load training data
train_df = pd.read_csv(DATA_DIR / "train.csv")
print(f"Train dataset loaded: {train_df.shape}")

# Tabular features extraction
def get_tab_features(row):
    age_scaled = (row['Age'] - 30) / 30
    sex_encoded = 0 if row['Sex'] == 'Male' else 1
    smoke_map = {
        'Never smoked': [0, 0],
        'Ex-smoker': [1, 1],
        'Currently smokes': [0, 1]
    }
    smoking = row['SmokingStatus']
    smoke_vec = smoke_map.get(smoking, [1, 0])
    return np.array([age_scaled, sex_encoded, smoke_vec[0], smoke_vec[1]])

# Calculate linear decay coefficients and baseline FVC per patient
A = {}
BASELINE_FVC = {}
TAB = {}
patients_list = []

print("Calculating decay coefficients and baseline FVC...")
for patient in tqdm(train_df['Patient'].unique()):
    sub = train_df[train_df['Patient'] == patient].sort_values('Weeks')
    fvc = sub['FVC'].values
    weeks = sub['Weeks'].values
    if len(weeks) > 1:
        c = np.vstack([weeks, np.ones(len(weeks))]).T
        try:
            a, b = np.linalg.lstsq(c, fvc, rcond=None)[0]
            A[patient] = a
            BASELINE_FVC[patient] = b
        except:
            A[patient] = (fvc[-1] - fvc[0]) / (weeks[-1] - weeks[0])
            BASELINE_FVC[patient] = fvc[0]
    else:
        A[patient] = 0.0
        BASELINE_FVC[patient] = fvc[0]
    TAB[patient] = get_tab_features(sub.iloc[0])
    patients_list.append(patient)
print(f"Processed {len(patients_list)} patients")

# Compute normalization stats for targets
decay_values = np.array(list(A.values()), dtype=np.float32)
baseline_values = np.array(list(BASELINE_FVC.values()), dtype=np.float32)
decay_mean, decay_std = decay_values.mean(), decay_values.std()
baseline_mean, baseline_std = baseline_values.mean(), baseline_values.std()
print(f"Decay mean/std: {decay_mean:.3f}/{decay_std:.3f}")
print(f"Baseline mean/std: {baseline_mean:.3f}/{baseline_std:.3f}")

# Medical image augmentation class
class MedicalAugmentation:
    def __init__(self, augment=True):
        if augment:
            self.transform = albu.Compose([
                albu.Rotate(limit=15, p=0.7),
                albu.HorizontalFlip(p=0.5),
                albu.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.7),
                albu.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                albu.GaussNoise(var_limit=(10, 50), p=0.5),
                albu.RandomGamma(gamma_limit=(80, 120), p=0.5),
                albu.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
                albu.OpticalDistortion(distort_limit=0.3, shift_limit=0.3, p=0.3),
                albu.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        else:
            self.transform = albu.Compose([
                albu.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
    def __call__(self, image):
        return self.transform(image=image)['image']

# Dataset class for multi-task regression: decay slope and baseline FVC
class OSICDenseNetMultiTaskDataset(Dataset):
    def __init__(self, patients, A_dict, BASELINE_dict, TAB_dict, data_dir, split='train', augment=True):
        self.patients = [p for p in patients if p not in ['ID00011637202177653955184', 'ID00052637202186188008618']]
        self.A_dict = A_dict
        self.BASELINE_dict = BASELINE_dict
        self.TAB_dict = TAB_dict
        self.data_dir = Path(data_dir)
        self.split = split
        self.augmentor = MedicalAugmentation(augment=augment)
        self.patient_images = {}
        for p in self.patients:
            p_dir = self.data_dir / p
            if p_dir.exists():
                images = [f for f in p_dir.iterdir() if f.suffix.lower() == '.dcm']
                if images:
                    self.patient_images[p] = images
        self.valid_patients = [p for p in self.patients if p in self.patient_images]
        print(f"Dataset {split}: {len(self.valid_patients)} patients with images")

    def __len__(self):
        return len(self.valid_patients) * 6 if self.split == 'train' else len(self.valid_patients)

    def __getitem__(self, idx):
        idx_mod = idx % len(self.valid_patients) if self.split == 'train' else idx
        patient = self.valid_patients[idx_mod]
        images = self.patient_images[patient]
        img_path = np.random.choice(images) if len(images) > 1 else images[0]
        img = self.load_and_preprocess_dicom(img_path)
        img_tensor = self.augmentor(img)
        tab_feat = torch.tensor(self.TAB_dict[patient], dtype=torch.float32)

        # ✅ Normalized targets
        decay_target = (torch.tensor(self.A_dict[patient], dtype=torch.float32) - decay_mean) / decay_std
        baseline_target = (torch.tensor(self.BASELINE_dict[patient], dtype=torch.float32) - baseline_mean) / baseline_std

        return img_tensor, tab_feat, decay_target, baseline_target

    def load_and_preprocess_dicom(self, path):
        try:
            dcm = pydicom.dcmread(str(path))
            img = dcm.pixel_array.astype(np.float32)
            if len(img.shape) == 3:
                img = img[img.shape[0] // 2]
            img = cv2.resize(img, (512, 512))
            mn, mx = img.min(), img.max()
            if mx > mn:
                img = (img - mn) / (mx - mn) * 255
            else:
                img = np.zeros_like(img)
            img = np.stack([img, img, img], axis=2).astype(np.uint8)
            return img
        except Exception as e:
            print(f"Error loading DICOM {path}: {e}")
            return np.zeros((512, 512, 3), dtype=np.uint8)

# Spatial attention module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_cat = self.conv(x_cat)
        return x * self.sigmoid(x_cat)

# Multi-task model predicting decay slope and baseline FVC
class MultiTaskDenseNetModel(nn.Module):
    def __init__(self, tab_dim=4, dropout=0.4):
        super().__init__()
        dnet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.features = dnet.features
        self.spatial_attention = SpatialAttention()
        self.tabular_processor = nn.Sequential(
            nn.Linear(tab_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU()
        )
        edim = 1024
        self.query_proj = nn.Linear(edim, edim)
        self.key_proj = nn.Linear(512, edim)
        self.value_proj = nn.Linear(512, edim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=edim, num_heads=8, dropout=0.2, batch_first=True)
        self.fusion_layer = nn.Sequential(
            nn.Linear(edim + 512, 768), nn.BatchNorm1d(768), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(768, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout / 2)
        )
        self.decay_head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.baseline_fvc_head = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, images, tabular):
        b = images.size(0)
        img_feat = self.features(images)
        img_feat = self.spatial_attention(img_feat)
        img_feat = F.adaptive_avg_pool2d(img_feat, (1, 1)).view(b, -1)
        tab_feat = self.tabular_processor(tabular)
        q = self.query_proj(img_feat).unsqueeze(1)
        k = self.key_proj(tab_feat).unsqueeze(1)
        v = self.value_proj(tab_feat).unsqueeze(1)
        attn_out, _ = self.cross_attention(q, k, v)
        attn_out = attn_out.squeeze(1)
        combined = torch.cat([attn_out, tab_feat], dim=1)
        fused_feat = self.fusion_layer(combined)
        decay_pred = self.decay_head(fused_feat).squeeze(1)
        baseline_pred = self.baseline_fvc_head(fused_feat).squeeze(1)
        return decay_pred, baseline_pred

# Loss function: weighted MSE
def multitask_loss(decay_pred, baseline_pred, decay_target, baseline_target, alpha=0.7):
    decay_loss = F.mse_loss(decay_pred, decay_target)
    baseline_loss = F.mse_loss(baseline_pred, baseline_target)
    return alpha * decay_loss + (1 - alpha) * baseline_loss

# Laplace Log Likelihood metric calculated every epoch, regardless of usage in training
def laplace_log_likelihood_metric(pred, target, eps=1e-6):
    residuals = np.abs(pred - target)
    sigma = np.maximum(np.std(residuals), eps)
    delta = np.abs(pred - target)
    score = -np.log(2 * sigma) - delta / sigma
    return np.mean(score)

# Dataset splitting
train_patients, val_patients = train_test_split(patients_list, test_size=0.2, random_state=42)
train_dataset = OSICDenseNetMultiTaskDataset(train_patients, A, BASELINE_FVC, TAB, TRAIN_DIR, augment=True, split='train')
val_dataset = OSICDenseNetMultiTaskDataset(val_patients, A, BASELINE_FVC, TAB, TRAIN_DIR, augment=False, split='val')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

print(f"Training with {len(train_loader)} batches, validating with {len(val_loader)} batches.")

# Trainer class printing LLL every epoch plus tracking R² for early stopping
class MultiTaskTrainer:
    def __init__(self, model, device, lr=1e-4, alpha=0.7):
        self.model = model
        self.device = device
        self.lr = lr
        self.alpha = alpha
        self.best_val_r2 = -float('inf')

    def train(self, train_loader, val_loader, epochs=25, patience=6):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_losses = []

            for images, tabular, decay_targets, baseline_targets in train_loader:
                images, tabular = images.to(self.device), tabular.to(self.device)
                decay_targets, baseline_targets = decay_targets.to(self.device), baseline_targets.to(self.device)

                optimizer.zero_grad()
                decay_pred, baseline_pred = self.model(images, tabular)
                loss = multitask_loss(decay_pred, baseline_pred, decay_targets, baseline_targets, alpha=self.alpha)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())

            avg_train_loss = np.mean(train_losses)

            self.model.eval()
            val_decay_preds, val_baseline_preds = [], []
            val_decay_targets, val_baseline_targets = [], []

            with torch.no_grad():
                for images, tabular, decay_targets, baseline_targets in val_loader:
                    images, tabular = images.to(self.device), tabular.to(self.device)
                    decay_targets, baseline_targets = decay_targets.to(self.device), baseline_targets.to(self.device)

                    decay_pred, baseline_pred = self.model(images, tabular)
                    val_decay_preds.extend(decay_pred.cpu().numpy())
                    val_baseline_preds.extend(baseline_pred.cpu().numpy())
                    val_decay_targets.extend(decay_targets.cpu().numpy())
                    val_baseline_targets.extend(baseline_targets.cpu().numpy())

            val_decay_preds = np.array(val_decay_preds)
            val_decay_targets = np.array(val_decay_targets)

            ss_res = np.sum((val_decay_targets - val_decay_preds) ** 2)
            ss_tot = np.sum((val_decay_targets - np.mean(val_decay_targets)) ** 2)
            val_r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else float('-inf')

            val_lll = laplace_log_likelihood_metric(val_decay_preds, val_decay_targets)

            print(
                f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | "
                f"Val R² (decay): {val_r2:.4f} | LLL (reported): {val_lll:.4f}"
            )

            scheduler.step(val_r2)

            if val_r2 > self.best_val_r2:
                self.best_val_r2 = val_r2
                torch.save(self.model.state_dict(), "best_multitask_model.pth")
                print("✅ New best model saved based on R²!")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        return self.best_val_r2

# Instantiate and train model
model = MultiTaskDenseNetModel(tab_dim=4).to(DEVICE)
trainer = MultiTaskTrainer(model, DEVICE, lr=1e-4, alpha=0.7)
best_r2 = trainer.train(train_loader, val_loader, epochs=25, patience=6)
print(f"Training complete. Best validation R²: {best_r2:.4f}")


Using device: cuda
Train dataset loaded: (1549, 7)
Calculating decay coefficients and baseline FVC...


100%|██████████| 176/176 [00:00<00:00, 1357.83it/s]

Processed 176 patients
Decay mean/std: -4.524/6.118
Baseline mean/std: 2817.542/835.401





Dataset train: 138 patients with images
Dataset val: 36 patients with images
Training with 103 batches, validating with 5 batches.
Epoch 1/25 | Train Loss: 0.8692 | Val R² (decay): -0.0046 | LLL (reported): -1.5115
✅ New best model saved based on R²!
Epoch 2/25 | Train Loss: 0.8185 | Val R² (decay): -0.0125 | LLL (reported): -1.5359
Epoch 3/25 | Train Loss: 0.7874 | Val R² (decay): -0.0294 | LLL (reported): -1.5628
Epoch 4/25 | Train Loss: 0.7837 | Val R² (decay): -0.0335 | LLL (reported): -1.5590
Epoch 5/25 | Train Loss: 0.7918 | Val R² (decay): -0.0179 | LLL (reported): -1.5404
Epoch 6/25 | Train Loss: 0.7842 | Val R² (decay): -0.0373 | LLL (reported): -1.5557
Epoch 7/25 | Train Loss: 0.7965 | Val R² (decay): -0.0432 | LLL (reported): -1.5647
Early stopping at epoch 7
Training complete. Best validation R²: -0.0046
