# Chest X-Ray Pneumonia Classification

This notebook provides an interactive interface for training and evaluating pneumonia classification models.

## Setup

Make sure you have:
1. Downloaded the dataset using `data/download_data.sh`
2. Installed all requirements: `pip install -r requirements.txt`

## 1. Import Libraries

In [None]:
import sys
import os

# Add src to path
sys.path.append('../src')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

from data import get_data_loaders, ChestXRayDataset, get_transforms
from models.resnet_cbam import get_resnet_cbam
from models.multimodal import get_multimodal_model

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Data Exploration

In [None]:
# Configuration
DATA_DIR = '../data/chest_xray'
BATCH_SIZE = 32
IMG_SIZE = 224
NUM_WORKERS = 4

# Load data
print('Loading data...')
train_loader, val_loader, test_loader = get_data_loaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    img_size=IMG_SIZE
)

print(f'Training samples: {len(train_loader.dataset)}')
print(f'Validation samples: {len(val_loader.dataset)}')
print(f'Test samples: {len(test_loader.dataset)}')

In [None]:
# Visualize sample images
def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Denormalize image tensor for visualization."""
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Get a batch of training data
images, labels = next(iter(train_loader))

# Display images
class_names = ['NORMAL', 'PNEUMONIA']
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(8):
    img = denormalize(images[i])
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    axes[i].imshow(img)
    axes[i].set_title(f'Class: {class_names[labels[i]]}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Class distribution
def get_class_distribution(dataset):
    """Get class distribution in dataset."""
    class_counts = {0: 0, 1: 0}
    for _, label in dataset:
        class_counts[label] += 1
    return class_counts

train_dist = get_class_distribution(train_loader.dataset)
val_dist = get_class_distribution(val_loader.dataset)
test_dist = get_class_distribution(test_loader.dataset)

# Plot distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, dist, title in zip(axes, [train_dist, val_dist, test_dist], ['Train', 'Validation', 'Test']):
    ax.bar(class_names, [dist[0], dist[1]])
    ax.set_title(f'{title} Set Distribution')
    ax.set_ylabel('Count')
    ax.set_xlabel('Class')
    
    # Add value labels on bars
    for i, (k, v) in enumerate(dist.items()):
        ax.text(i, v, str(v), ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 3. Model Training

In [None]:
# Training configuration
MODEL_TYPE = 'resnet_cbam'  # or 'multimodal'
NUM_CLASSES = 2
EPOCHS = 20
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4

# Create model
if MODEL_TYPE == 'resnet_cbam':
    model = get_resnet_cbam(num_classes=NUM_CLASSES, pretrained=True)
elif MODEL_TYPE == 'multimodal':
    model = get_multimodal_model(num_classes=NUM_CLASSES, pretrained_vision=True)
else:
    raise ValueError(f'Unknown model type: {MODEL_TYPE}')

model = model.to(device)
print(f'Created {MODEL_TYPE} model')
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [None]:
# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })
    
    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })
    
    return running_loss / len(val_loader), 100. * correct / total

In [None]:
# Training loop
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    print(f'\nEpoch {epoch}/{EPOCHS}')
    print('='*60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Print epoch summary
    print(f'\nTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, f'../checkpoints/best_model_{MODEL_TYPE}.pth')
        print(f'Best model saved with validation accuracy: {val_acc:.2f}%')

print(f'\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%')

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Val Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Model Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc

# Load best model
checkpoint = torch.load(f'../checkpoints/best_model_{MODEL_TYPE}.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Evaluate on test set
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Evaluating'):
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

# Print classification report
print('\nClassification Report:')
print('='*60)
print(classification_report(all_labels, all_preds, target_names=class_names))

In [None]:
# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

In [None]:
# Plot ROC curve
fpr, tpr, _ = roc_curve(all_labels, all_probs[:, 1])
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print(f'\nROC-AUC Score: {roc_auc:.4f}')

## 5. Visualize Predictions

In [None]:
# Visualize some predictions
model.eval()
images, labels = next(iter(test_loader))
images_device = images.to(device)

with torch.no_grad():
    outputs = model(images_device)
    probs = torch.softmax(outputs, dim=1)
    _, predicted = outputs.max(1)

# Display predictions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(8):
    img = denormalize(images[i])
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    
    axes[i].imshow(img)
    
    true_label = class_names[labels[i]]
    pred_label = class_names[predicted[i]]
    confidence = probs[i][predicted[i]].item() * 100
    
    color = 'green' if labels[i] == predicted[i] else 'red'
    axes[i].set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)', color=color)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 6. Save Model for Inference

In [None]:
# Save final model
final_save_path = f'../checkpoints/final_model_{MODEL_TYPE}.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'model_type': MODEL_TYPE,
    'num_classes': NUM_CLASSES,
    'img_size': IMG_SIZE,
    'class_names': class_names,
}, final_save_path)

print(f'Model saved to {final_save_path}')