In [None]:
# Install dependencies (run this cell once if you haven't already)
%pip install torch torchvision pandas matplotlib seaborn scikit-learn numpy

# X-Ray Classification with PyTorch
Simplest implementation using `torchvision.datasets.ImageFolder`.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set device
# Check for CUDA (NVIDIA), then MPS (Mac), then CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
# 1. Transforms
# Using ImageNet standard mean and std for better Transfer Learning performance
def get_transforms(split="train"):
    if split == "train":
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.RandomHorizontalFlip(p=0.5),
            torchvision.transforms.RandomRotation(15),
            torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

In [None]:
# 2. Datasets & Loaders
# ImageFolder expects structure: root/class_x/xxx.png, root/class_y/yyy.png
train_dir = 'data/chestx-ray/train'
test_dir = 'data/chestx-ray/test'

print("Loading datasets...")
trainset = ImageFolder(root=train_dir, transform=get_transforms("train"))
testset = ImageFolder(root=test_dir, transform=get_transforms("test"))

class_names = trainset.classes
print(f"Classes: {class_names}")
print(f"Train size: {len(trainset)}, Test size: {len(testset)}")

BATCH_SIZE = 32
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# 3. Model Architecture
def model_arch(num_classes):
    weights = ResNet18_Weights.DEFAULT
    model = resnet18(weights=weights)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = model_arch(num_classes=len(class_names)).to(device)

In [None]:
# 4. Training Loop
def train_model(model, trainloader, testloader, criterion, optimizer, device, num_epochs=5):
    train_losses = []
    test_accuracies = []
    
    print("Starting Training...")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for images, labels in trainloader:
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(trainloader)
        accuracy, _, _ = evaluate(model, testloader, device)
        
        train_losses.append(epoch_loss)
        test_accuracies.append(accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Test Acc: {accuracy*100:.2f}%")
        
    return train_losses, test_accuracies

def evaluate(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    accuracy = (all_preds == all_labels).mean()
    
    return accuracy, all_labels, all_preds

In [None]:
# 5. Plotting
def plot_results(train_losses, test_accuracies, all_labels, all_preds, class_names):
    plt.style.use('ggplot')
    fig, ax = plt.subplots(1, 2, figsize=(15, 6))
    
    # Loss & Accuracy
    ax[0].plot(train_losses, label='Training Loss', color='tab:orange', linewidth=2)
    ax[0].set_ylabel('Loss', color='tab:orange')
    ax2 = ax[0].twinx()
    ax2.plot(test_accuracies, label='Test Accuracy', color='tab:blue', linewidth=2)
    ax2.set_ylabel('Accuracy', color='tab:blue')
    ax[0].set_title("Training Dynamics")
    ax[0].set_xlabel("Epochs")
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax[1],
                xticklabels=class_names, yticklabels=class_names)
    ax[1].set_title("Confusion Matrix")
    ax[1].set_ylabel("Actual")
    ax[1].set_xlabel("Predicted")
    
    plt.tight_layout()
    plt.show()

In [None]:
# 6. Run Training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Train on Train set
losses, accuracies = train_model(model, trainloader, testloader, criterion, optimizer, device, num_epochs=5)

# Final Evaluation on Test Set
final_acc, labels, preds = evaluate(model, testloader, device)
print(f"Final Test Accuracy: {final_acc*100:.2f}%")

plot_results(losses, accuracies, labels, preds, class_names)