In [1]:
"""
Pediatric Pneumonia Screening: Multi-Center Cross-Domain Evaluation
Training & Evaluation Pipeline

Description:
    This script implements the training and validation loop for three architectures:
    EfficientNet-B0, ConvNeXt-Tiny, and ViT-Base-16.
    It supports cross-domain evaluation on external datasets.

Note for Reviewers:
    - Paths are configured for the Kaggle environment but can be adapted for local use.
    - Change MODEL_ARCHITECTURE and RANDOM_SEED in the 'Config' class to reproduce specific runs.
"""

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
from tqdm.notebook import tqdm
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# ===================================================================
# 1. Reproducibility & Configuration
# ===================================================================

def seed_everything(seed=42):
    """
    Sets seeds for all random number generators to ensure reproducibility.
    Crucial for the stability analysis reported in the paper.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Deterministic algorithms ensure that results are identical given the same seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    # --- A. Experiment Settings (Modify these to reproduce different runs) ---
    # Options: "EfficientNet_B0", "ConvNeXt_Tiny", "ViT_Base_16"
    MODEL_ARCHITECTURE = "ConvNeXt_Tiny" 
    
    # Seeds used in the study: [42, 378, 1024, 2025, 4096]
    RANDOM_SEED = 42
    
    # --- B. Hyperparameters ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    N_CPU = os.cpu_count()
    EPOCHS = 15
    BATCH_SIZE = 32
    IMG_SIZE = 224
    USE_LABEL_SMOOTHING = True
    LABEL_SMOOTHING_FACTOR = 0.1
    USE_TTA = True  # Test-Time Augmentation

    # --- C. Path Configuration (Auto-detects Environment) ---
    if os.path.exists('/kaggle/input'):
        # Kaggle Environment
        DATA_ROOT = '/kaggle/input'
        OUTPUT_DIR = './'
    else:
        # Local Environment (Reviewer: Change this path to your data directory)
        DATA_ROOT = './data'
        OUTPUT_DIR = './weights'
        os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Dataset Paths (Relative to DATA_ROOT)
    # Note: These paths assume the directory structure described in README.md
    # You may need to adjust the CSV filename based on your data preparation step.
    SOURCE_CSV_PATH = os.path.join(DATA_ROOT, 'chexpert-processed/chexpert_20k_balanced.csv') 
    EXTERNAL_TEST_DIR = os.path.join(DATA_ROOT, 'chest-xray-pneumonia/chest_xray/test')
    
    # --- D. Model Specifics ---
    MODEL_ID = f"{MODEL_ARCHITECTURE}_Seed{RANDOM_SEED}"
    
    # Learning rates tailored for stability
    LEARNING_RATES = {
        "EfficientNet_B0": 1e-4,
        "ConvNeXt_Tiny": 1e-4,
        "ViT_Base_16": 3e-5
    }
    LEARNING_RATE = LEARNING_RATES[MODEL_ARCHITECTURE]

# Initialize Config & Seed
config = Config()
seed_everything(config.RANDOM_SEED)

print(f"--- Experiment Configuration ---")
print(f"Model: {config.MODEL_ARCHITECTURE}")
print(f"Seed:  {config.RANDOM_SEED}")
print(f"Device: {config.DEVICE}")
print(f"Data Root: {config.DATA_ROOT}")

# ===================================================================
# 2. Data Loading
# ===================================================================

class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx):
        # Assumes the CSV has 'Path' and 'label' columns
        # Path adjustment for Kaggle/Local compatibility might be needed here
        img_path = self.df.iloc[idx]['Path']
        label = self.df.iloc[idx]['label']
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            # Robustness: skip corrupted images during training
            return None 

def collate_fn_robust(batch):
    """Filters out None samples (corrupted images) from the batch."""
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

# Transforms (ImageNet normalization)
train_transforms = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.RandAugment(num_ops=2, magnitude=9), # Strong augmentation for generalization
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.85, 1.15)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transforms = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ===================================================================
# 3. Model Initialization
# ===================================================================

def get_model(architecture, num_classes=2):
    """
    Factory function to initialize models with ImageNet weights
    and modify the classification head.
    """
    if architecture == "EfficientNet_B0":
        model = models.efficientnet_b0(weights='IMAGENET1K_V1')
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)
        
    elif architecture == "ConvNeXt_Tiny":
        model = models.convnext_tiny(weights='IMAGENET1K_V1')
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_features, num_classes)
        
    elif architecture == "ViT_Base_16":
        model = models.vit_b_16(weights='IMAGENET1K_V1')
        in_features = model.heads.head.in_features
        model.heads.head = nn.Linear(in_features, num_classes)
        
    else:
        raise ValueError(f"Unknown architecture: {architecture}")
        
    return model

# ===================================================================
# 4. Training & Evaluation Engine
# ===================================================================

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    progress = tqdm(loader, desc="Training", leave=False)
    
    for images, labels in progress:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        progress.set_postfix(loss=loss.item())
        
    return running_loss / len(loader)

def evaluate(model, loader, device, desc="Evaluating", use_tta=False):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc=desc, leave=False):
            images = images.to(device)
            
            if use_tta:
                # Test-Time Augmentation: Average of Original + Horizontal Flip
                out_1 = model(images)
                out_2 = model(torch.flip(images, dims=[3]))
                outputs = (torch.softmax(out_1, dim=1) + torch.softmax(out_2, dim=1)) / 2
            else:
                outputs = model(images)
                
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    f1 = f1_score(all_labels, all_preds, average='weighted')
    report = classification_report(all_labels, all_preds, target_names=['Normal', 'Pneumonia'], digits=4)
    return f1, report

# ===================================================================
# 5. Main Execution Loop
# ===================================================================

if __name__ == "__main__":
    # --- Load Data ---
    if os.path.exists(config.SOURCE_CSV_PATH):
        print("Loading Source Data...")
        full_df = pd.read_csv(config.SOURCE_CSV_PATH)
        # Stratified split based on Seed
        train_df, val_df = train_test_split(
            full_df, test_size=0.2, 
            random_state=config.RANDOM_SEED, 
            stratify=full_df['label']
        )
        
        train_loader = DataLoader(
            CustomDataset(train_df, transform=train_transforms),
            batch_size=config.BATCH_SIZE, shuffle=True, 
            num_workers=config.N_CPU, collate_fn=collate_fn_robust
        )
        val_loader = DataLoader(
            CustomDataset(val_df, transform=eval_transforms),
            batch_size=config.BATCH_SIZE, shuffle=False, 
            num_workers=config.N_CPU, collate_fn=collate_fn_robust
        )
        
        # --- Setup Model ---
        model = get_model(config.MODEL_ARCHITECTURE).to(config.DEVICE)
        
        # Loss & Optimizer
        criterion = nn.CrossEntropyLoss(label_smoothing=config.LABEL_SMOOTHING_FACTOR)
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.EPOCHS, eta_min=1e-7)
        
        # --- Training Loop ---
        best_f1 = 0.0
        print("\nStarting Training...")
        
        for epoch in range(1, config.EPOCHS + 1):
            avg_loss = train_one_epoch(model, train_loader, criterion, optimizer, config.DEVICE)
            val_f1, _ = evaluate(model, val_loader, config.DEVICE, desc="Validating")
            scheduler.step()
            
            print(f"Epoch {epoch}/{config.EPOCHS} | Train Loss: {avg_loss:.4f} | Val F1: {val_f1:.4f}")
            
            if val_f1 > best_f1:
                best_f1 = val_f1
                save_path = os.path.join(config.OUTPUT_DIR, f"{config.MODEL_ID}_best.pth")
                torch.save(model.state_dict(), save_path)
                print(f"--> Best model saved to {save_path}")
        
        print(f"\nTraining Complete. Best Validation F1: {best_f1:.4f}")

        # --- External Evaluation (Example: Kaggle Pediatric) ---
        if os.path.exists(config.EXTERNAL_TEST_DIR):
            print("\nRunning External Evaluation...")
            # Load best weights
            model.load_state_dict(torch.load(save_path))
            
            test_dataset = datasets.ImageFolder(config.EXTERNAL_TEST_DIR, transform=eval_transforms)
            test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, num_workers=config.N_CPU)
            
            ext_f1, ext_report = evaluate(model, test_loader, config.DEVICE, desc="Testing", use_tta=config.USE_TTA)
            print(f"External Test F1 (TTA={config.USE_TTA}): {ext_f1:.4f}")
            print(ext_report)
        else:
            print(f"\nExternal test directory not found at {config.EXTERNAL_TEST_DIR}. Skipping external evaluation.")
            
    else:
        print(f"Source CSV not found at {config.SOURCE_CSV_PATH}. Please check the path.")

--- Experiment Configuration ---
Model: ConvNeXt_Tiny
Seed:  42
Device: cuda
Data Root: /kaggle/input
Source CSV not found at /kaggle/input/chexpert-processed/chexpert_20k_balanced.csv. Please check the path.
