# 02 - Train Base Classifier

Goals:
1. Load data with configurable subset size
2. Create PyTorch DataLoaders with proper transforms
3. Train ResNet-50 for multi-label classification
4. Evaluate and validate F1 ≥ 0.40
5. Generate and save predictions for conformal methods

This notebook uses MPS acceleration for M3 Pro.

In [None]:
# Imports

import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import json
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image

# Add src to path
PROJECT_ROOT = Path("..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

from config import CONFIG, LABELS, DATA_DIR, MODELS_DIR, RESULTS_DIR, print_config

# Print configuration
print_config()

In [None]:
# Set seeds for reproducibility

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # MPS doesn't have manual_seed but setting torch seed covers it

set_seed(CONFIG["seed"])
print(f"Random seed set to: {CONFIG['seed']}")

In [None]:
# Setting up device

device = CONFIG["device"]

# Verify MPS is available
if device == "mps":
    if torch.backends.mps.is_available():
        print("MPS (Metal Performance Shaders) is available")
        device = torch.device("mps")
    else:
        print("MPS not available, falling back to CPU")
        device = torch.device("cpu")
elif device == "cuda":
    if torch.cuda.is_available():
        print(f"CUDA is available: {torch.cuda.get_device_name(0)}")
        device = torch.device("cuda")
    else:
        print("CUDA not available, falling back to CPU")
        device = torch.device("cpu")
else:
    device = torch.device("cpu")
    print("Using CPU")

print(f"Device: {device}")

In [None]:
# Load metadata

print("Loading metadata...")
metadata_path = DATA_DIR / "metadata_with_splits.csv"
df = pd.read_csv(metadata_path)
print(f"Total images in metadata: {len(df):,}")

# Load Kaggle path for images
kaggle_path_file = DATA_DIR / "kaggle_path.txt"
with open(kaggle_path_file, 'r') as f:
    KAGGLE_PATH = Path(f.read().strip())
print(f"Images location: {KAGGLE_PATH}")

In [None]:
# Apply subset if configured

subset_size = CONFIG["subset_size"]

if subset_size is not None and subset_size < len(df):
    print(f"\nApplying subset: {subset_size:,} images")

    # Sample proportionally from each split to maintain ratios
    df_subset = []
    for split in ['train', 'val', 'cal', 'test']:
        split_df = df[df['split'] == split]
        ratio = {'train': 0.70, 'val': 0.10, 'cal': 0.10, 'test': 0.10}[split]
        n_samples = int(subset_size * ratio)

        # Sample by patient to avoid data leakage
        patients = split_df['Patient ID'].unique()
        np.random.shuffle(patients)

        # Take patients until we have enough images
        selected_patients = []
        count = 0
        for p in patients:
            patient_images = len(split_df[split_df['Patient ID'] == p])
            if count + patient_images <= n_samples * 1.2:  # Allow 20% margin
                selected_patients.append(p)
                count += patient_images
            if count >= n_samples:
                break

        split_subset = split_df[split_df['Patient ID'].isin(selected_patients)]
        df_subset.append(split_subset)

    df = pd.concat(df_subset, ignore_index=True)
    print(f"Subset created: {len(df):,} images")

print(f"\nSplit distribution:")
print(df['split'].value_counts())

In [None]:
# Filter to only images with at least one disease
print("\nFiltering to images with at least one disease label...")

def has_disease(finding_labels_str):
    """Check if image has at least one of our 14 diseases"""
    findings = finding_labels_str.split("|")
    for finding in findings:
        if finding in LABELS:
            return True
    return False

# Count before filtering
total_before = len(df)
no_finding_count = (~df['Finding Labels'].apply(has_disease)).sum()

# Filter
df = df[df['Finding Labels'].apply(has_disease)].reset_index(drop=True)

print(f"  Before filtering: {total_before:,} images")
print(f"  'No Finding' removed: {no_finding_count:,} images")
print(f"  After filtering: {len(df):,} images with diseases")

print(f"\nNew split distribution:")
print(df['split'].value_counts())

# Verify positive label rate improved
sample_labels = np.array([
    [1 if l in row.split("|") else 0 for l in LABELS]
    for row in df['Finding Labels'].head(1000)
])
new_positive_rate = sample_labels.mean() * 100
print(f"\nPositive label rate (sample): {new_positive_rate:.1f}%")

In [None]:
# Define dataset class

class ChestXrayDataset(Dataset):
    """
    ChestX-ray14 Dataset for multi-label classification
    """

    def __init__(self, dataframe, image_root, labels, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.image_root = Path(image_root)
        self.labels = labels
        self.transform = transform

        # Find image directories (images are in subfolders)
        self.image_dirs = list(self.image_root.glob("images_*/images"))
        if not self.image_dirs:
            # Try direct path
            self.image_dirs = [self.image_root]

        print(f"Image directories found: {len(self.image_dirs)}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        image_name = row['Image Index']
        image_path = self._find_image(image_name)

        if image_path is None:
            raise FileNotFoundError(f"Image not found: {image_name}")

        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Parse labels
        labels = self._parse_labels(row['Finding Labels'])
        labels = torch.from_numpy(labels)

        return image, labels

    def _find_image(self, image_name):
        """Search for image in all image directories"""
        for img_dir in self.image_dirs:
            img_path = img_dir / image_name
            if img_path.exists():
                return img_path
        return None

    def _parse_labels(self, finding_labels_str):
        """Convert label string to binary vector"""
        findings = finding_labels_str.split("|")
        binary = np.zeros(len(self.labels), dtype=np.float32)
        for finding in findings:
            if finding in self.labels:
                idx = self.labels.index(finding)
                binary[idx] = 1.0
        return binary

In [None]:
# Define transforms

image_size = CONFIG["image_size"]

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])  # ImageNet stats
])

eval_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print(f"Image size: {image_size}x{image_size}")
print(f"Train augmentations: flip, rotation, affine")

In [None]:
# Create datasets and dataloaders

print("\nCreating datasets...")

train_df = df[df['split'] == 'train']
val_df = df[df['split'] == 'val']
cal_df = df[df['split'] == 'cal']
test_df = df[df['split'] == 'test']

train_dataset = ChestXrayDataset(train_df, KAGGLE_PATH, LABELS, transform=train_transform)
val_dataset = ChestXrayDataset(val_df, KAGGLE_PATH, LABELS, transform=eval_transform)
cal_dataset = ChestXrayDataset(cal_df, KAGGLE_PATH, LABELS, transform=eval_transform)
test_dataset = ChestXrayDataset(test_df, KAGGLE_PATH, LABELS, transform=eval_transform)

print(f"Train: {len(train_dataset):,} images")
print(f"Val:   {len(val_dataset):,} images")
print(f"Cal:   {len(cal_dataset):,} images")
print(f"Test:  {len(test_dataset):,} images")

# Create dataloaders
batch_size = CONFIG["batch_size"]
num_workers = 0
pin_memory = False

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=pin_memory)
cal_loader = DataLoader(cal_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=pin_memory)

print(f"\nBatch size: {batch_size}")
print(f"Train batches: {len(train_loader)}")

In [None]:
# Test data loading

print("\nTesting data loading...")
images, labels = next(iter(train_loader))
print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")
print(f"Label example: {labels[0].numpy()}")
print(f"Active labels: {[LABELS[i] for i in torch.where(labels[0] == 1)[0].tolist()]}")

In [None]:
# Define model

class ChestXrayClassifier(nn.Module):
    """
    ResNet-50 based multi-label classifier
    """

    def __init__(self, num_classes=14, dropout=0.5, pretrained=True):
        super().__init__()

        # Load pretrained ResNet-50
        weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        self.backbone = models.resnet50(weights=weights)

        # Get feature dimension
        num_features = self.backbone.fc.in_features

        # Replace final layer with custom head
        self.backbone.fc = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, num_classes),
            nn.Sigmoid()  # Multi-label: independent probabilities
        )

        # Track which layers are frozen
        self.backbone_frozen = False

    def freeze_backbone(self):
        """Freeze all backbone parameters"""
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone_frozen = True
        print("Backbone frozen")

    def unfreeze_backbone(self, layers=['layer3', 'layer4']):
        """Unfreeze specific backbone layers"""
        for name, module in self.backbone.named_children():
            if name in layers:
                for param in module.parameters():
                    param.requires_grad = True
                print(f"Unfroze: {name}")
        self.backbone_frozen = False

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# Create model
model = ChestXrayClassifier(
    num_classes=len(LABELS),
    dropout=CONFIG["dropout"],
    pretrained=CONFIG["pretrained"]
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel: ResNet-50")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Define loss function with class weights

def compute_class_weights(dataloader, num_classes):
    """Compute class weights for imbalanced dataset"""
    label_counts = torch.zeros(num_classes)
    total_samples = 0

    for _, labels in tqdm(dataloader, desc="Computing class weights"):
        label_counts += labels.sum(dim=0)
        total_samples += labels.size(0)

    # Weight = total / (2 * count), capped for rare classes
    weights = total_samples / (2 * (label_counts + 1))
    weights = torch.clamp(weights, min=0.5, max=10.0)  # Prevent extreme weights

    return weights

print("Computing class weights...")
class_weights = compute_class_weights(train_loader, len(LABELS))
class_weights = class_weights.to(device)

print("\nClass weights:")
for i, (label, weight) in enumerate(zip(LABELS, class_weights)):
    print(f"  {label}: {weight:.2f}")

In [None]:
# Define weighted BCE loss

class WeightedBCELoss(nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = weights

    def forward(self, predictions, targets):
        # Binary cross-entropy with class weights
        bce = -(
            targets * torch.log(predictions + 1e-7) * self.weights +
            (1 - targets) * torch.log(1 - predictions + 1e-7)
        )
        return bce.mean()

criterion = WeightedBCELoss(class_weights)

In [None]:
# Define training functions

def train_epoch(model, loader, criterion, optimizer, device, grad_clip=1.0):
    model.train()
    total_loss = 0.0

    pbar = tqdm(loader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()
        total_loss += loss.item()

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / len(loader)


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    for images, labels in tqdm(loader, desc="Evaluating"):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        all_preds.append(outputs.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    return total_loss / len(loader), all_preds, all_labels


def compute_metrics(predictions, labels, threshold=0.5):
    """Compute multi-label classification metrics"""
    preds_binary = (predictions > threshold).float()

    # Per-label metrics
    f1_scores = []
    precisions = []
    recalls = []

    for k in range(labels.shape[1]):
        tp = ((preds_binary[:, k] == 1) & (labels[:, k] == 1)).sum().item()
        fp = ((preds_binary[:, k] == 1) & (labels[:, k] == 0)).sum().item()
        fn = ((preds_binary[:, k] == 0) & (labels[:, k] == 1)).sum().item()

        precision = tp / (tp + fp + 1e-7)
        recall = tp / (tp + fn + 1e-7)
        f1 = 2 * precision * recall / (precision + recall + 1e-7)

        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)

    return {
        'macro_f1': np.mean(f1_scores),
        'macro_precision': np.mean(precisions),
        'macro_recall': np.mean(recalls),
        'per_label_f1': f1_scores
    }

In [None]:
# Training loop

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

# Phase 1: Train head only
print("\n--- Phase 1: Training classification head (backbone frozen) ---")
model.freeze_backbone()

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG["phase1_lr"],
    weight_decay=CONFIG["weight_decay"]
)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)

best_f1 = 0.0
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_f1': []}

for epoch in range(CONFIG["phase1_epochs"]):
    print(f"\nEpoch {epoch+1}/{CONFIG['phase1_epochs']}")

    train_loss = train_epoch(model, train_loader, criterion, optimizer, device,
                             CONFIG["gradient_clip_norm"])
    val_loss, val_preds, val_labels = evaluate(model, val_loader, criterion, device)
    metrics = compute_metrics(val_preds, val_labels)

    scheduler.step()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(metrics['macro_f1'])

    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  Val F1:     {metrics['macro_f1']:.4f}")

    if metrics['macro_f1'] > best_f1:
        best_f1 = metrics['macro_f1']
        torch.save(model.state_dict(), MODELS_DIR / "best_model.pth")
        patience_counter = 0
        print(f"  ✓ New best model saved!")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG["early_stopping_patience"]:
            print(f"  Early stopping triggered!")
            break

In [None]:
# Fine-tune backbone

print("\n--- Phase 2: Fine-tuning backbone layers 3-4 ---")
model.load_state_dict(torch.load(MODELS_DIR / "best_model.pth"))
model.unfreeze_backbone(['layer3', 'layer4'])

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG["phase2_lr"],
    weight_decay=CONFIG["weight_decay"]
)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)

patience_counter = 0

for epoch in range(CONFIG["phase2_epochs"]):
    print(f"\nEpoch {CONFIG['phase1_epochs'] + epoch + 1}/{CONFIG['phase1_epochs'] + CONFIG['phase2_epochs']}")

    train_loss = train_epoch(model, train_loader, criterion, optimizer, device,
                             CONFIG["gradient_clip_norm"])
    val_loss, val_preds, val_labels = evaluate(model, val_loader, criterion, device)
    metrics = compute_metrics(val_preds, val_labels)

    scheduler.step()

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(metrics['macro_f1'])

    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  Val F1:     {metrics['macro_f1']:.4f}")

    if metrics['macro_f1'] > best_f1:
        best_f1 = metrics['macro_f1']
        torch.save(model.state_dict(), MODELS_DIR / "best_model.pth")
        patience_counter = 0
        print(f"  ✓ New best model saved!")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG["early_stopping_patience"]:
            print(f"  Early stopping triggered!")
            break

print(f"\n✓ Training complete! Best F1: {best_f1:.4f}")

In [None]:
# Plot training history

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
ax1 = axes[0]
ax1.plot(history['train_loss'], label='Train', marker='o')
ax1.plot(history['val_loss'], label='Val', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# F1 plot
ax2 = axes[1]
ax2.plot(history['val_f1'], label='Val F1', marker='o', color='green')
ax2.axhline(y=CONFIG["acceptable_f1"], color='red', linestyle='--',
            label=f'Target F1 ({CONFIG["acceptable_f1"]})')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Macro F1')
ax2.set_title('Validation F1 Score')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(RESULTS_DIR / "figures" / "training_history.png", dpi=150)
plt.show()

In [None]:
# Final evaluation on test set

print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

# Load best model
model.load_state_dict(torch.load(MODELS_DIR / "best_model.pth"))
model.eval()

# Evaluate on test set
test_loss, test_preds, test_labels = evaluate(model, test_loader, criterion, device)
test_metrics = compute_metrics(test_preds, test_labels)

print(f"\nTest Results:")
print(f"  Loss:      {test_loss:.4f}")
print(f"  Macro F1:  {test_metrics['macro_f1']:.4f}")
print(f"  Precision: {test_metrics['macro_precision']:.4f}")
print(f"  Recall:    {test_metrics['macro_recall']:.4f}")

print(f"\nPer-Label F1 Scores:")
for label, f1 in zip(LABELS, test_metrics['per_label_f1']):
    print(f"  {label:<20}: {f1:.4f}")

# Check if acceptable
if test_metrics['macro_f1'] >= CONFIG["acceptable_f1"]:
    print(f"\n✓ F1 score ({test_metrics['macro_f1']:.4f}) meets target ({CONFIG['acceptable_f1']})")
else:
    print(f"\n✗ F1 score ({test_metrics['macro_f1']:.4f}) below target ({CONFIG['acceptable_f1']})")
    print("  Consider: more epochs, larger subset, or hyperparameter tuning")

In [None]:
# Generate predictions for conformal methods

print("\n" + "="*60)
print("GENERATING PREDICTIONS FOR CONFORMAL METHODS")
print("="*60)

@torch.no_grad()
def get_predictions(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    for images, labels in tqdm(loader, desc="Generating predictions"):
        images = images.to(device)
        outputs = model(images)
        all_preds.append(outputs.cpu().numpy())
        all_labels.append(labels.numpy())

    return np.vstack(all_preds), np.vstack(all_labels)

# Generate predictions for all splits
print("\nTrain set...")
train_preds, train_labels = get_predictions(model, train_loader, device)

print("Calibration set...")
cal_preds, cal_labels = get_predictions(model, cal_loader, device)

print("Test set...")
test_preds, test_labels_np = get_predictions(model, test_loader, device)

print(f"\nPrediction shapes:")
print(f"  Train: {train_preds.shape}")
print(f"  Cal:   {cal_preds.shape}")
print(f"  Test:  {test_preds.shape}")

In [None]:
# Save predictions

print("\nSaving predictions...")

predictions_dir = RESULTS_DIR / "predictions"
predictions_dir.mkdir(exist_ok=True)

np.save(predictions_dir / "train_preds.npy", train_preds)
np.save(predictions_dir / "train_labels.npy", train_labels)
np.save(predictions_dir / "cal_preds.npy", cal_preds)
np.save(predictions_dir / "cal_labels.npy", cal_labels)
np.save(predictions_dir / "test_preds.npy", test_preds)
np.save(predictions_dir / "test_labels.npy", test_labels_np)

print(f"✓ Predictions saved to {predictions_dir}")

# Save metrics
metrics_path = RESULTS_DIR / "metrics" / "classifier_metrics.json"
with open(metrics_path, 'w') as f:
    json.dump({
        'test_loss': float(test_loss),
        'macro_f1': float(test_metrics['macro_f1']),
        'macro_precision': float(test_metrics['macro_precision']),
        'macro_recall': float(test_metrics['macro_recall']),
        'per_label_f1': {label: float(f1) for label, f1 in zip(LABELS, test_metrics['per_label_f1'])},
        'config': {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()},
        'timestamp': datetime.now().isoformat()
    }, f, indent=2)
print(f"✓ Metrics saved to {metrics_path}")

In [None]:
# Summary

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"""
Model: ResNet-50 (pretrained on ImageNet)
Dataset: ChestX-ray14 subset ({CONFIG['subset_size']} images)

Final Test Metrics:
  Macro F1:  {test_metrics['macro_f1']:.4f}
  Precision: {test_metrics['macro_precision']:.4f}
  Recall:    {test_metrics['macro_recall']:.4f}

Files saved:
  - {MODELS_DIR / 'best_model.pth'}
  - {predictions_dir / 'train_preds.npy'}
  - {predictions_dir / 'cal_preds.npy'}
  - {predictions_dir / 'test_preds.npy'}
  - {metrics_path}

Next step: Run notebook 03_standard_conformal.ipynb
""")