In [12]:
# Josh Burgess
# 300652214

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
from transformers import ViTForImageClassification
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [13]:
# Load pre-trained models for binary pneumonia classification

# Load ResNet-50 (CNN baseline)
print("Loading ResNet-50:")
resnet = models.resnet50(pretrained=True)
# Modify final layer for binary classification 
resnet.fc = nn.Linear(resnet.fc.in_features, 2)
resnet = resnet.to(device)

# Load ViT-Base (ViT baseline)  
print("Loading ViT-Base:")
vit = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=2,
    ignore_mismatched_sizes=True
)
vit = vit.to(device)

print(f"ResNet-50 final layer: {resnet.fc}")
print(f"ViT classifier: {vit.classifier}")

Loading ResNet-50:
Loading ViT-Base:


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ResNet-50 final layer: Linear(in_features=2048, out_features=2, bias=True)
ViT classifier: Linear(in_features=768, out_features=2, bias=True)


In [None]:
# Train a model and return training history

def train_model(model, train_loader, val_loader, num_epochs=5, learning_rate=1e-4):

    # Set up loss function and optimiser
    criterion = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Store training metrics over time
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Process each batch of training data
        for batch_idx, (images, labels) in enumerate(train_loader):
            # Move data to GPU/CPU device
            images, labels = images.to(device), labels.to(device)
            
            optimiser.zero_grad()

            if hasattr(model, 'classifier') and 'ViT' in str(type(model)):
                # ViT returns ImageClassifierOutput, need to extract logits
                outputs = model(images).logits
                
            else:
                # ResNet returns tensor directly
                outputs = model(images)

            loss = criterion(outputs, labels) # Calc loss between predictions and true labels
            loss.backward() # Backward pass (calc gradients)
            optimiser.step() # Update model weights based on gradients
            
            train_loss += loss.item()

            # Calc training accuracy
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            if batch_idx % 5 == 0:  # Print every 5 batches
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

        
        # Calculate training accuracy for this epoch
        train_acc = 100 * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        # Evaluate without updating weights
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                if hasattr(model, 'classifier') and 'ViT' in str(type(model)):
                    # ViT returns ImageClassifierOutput, need to extract logits
                    outputs = model(images).logits

                    
                else:
                    # ResNet returns tensor directly
                    outputs = model(images)

                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                # Calculate accuracy
                max_values, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        # Calculate average metrics for this epoch
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_acc = 100 * correct / total
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    return history

In [21]:
# Import
from data_utils import create_splits, get_train_val_splits, get_transforms
from dataset import create_data_loaders, create_subset_loader

# Load and split your data
print("Loading chest X-ray data...")
normal_files, pneumonia_files = create_splits()
normal_train, normal_val, pneumonia_train, pneumonia_val = get_train_val_splits(normal_files, pneumonia_files)

# Get transforms
train_transform, val_transform = get_transforms()

# Create data loaders
train_loader, val_loader = create_data_loaders(
    normal_train, normal_val, pneumonia_train, pneumonia_val,
    train_transform, val_transform, batch_size=32
)

# Create 500-image subset for testing
subset_loader = create_subset_loader(normal_train, pneumonia_train, train_transform)

print(f"Training samples: {len(normal_train) + len(pneumonia_train)}")
print(f"Validation samples: {len(normal_val) + len(pneumonia_val)}")
print(f"Subset samples: {len(subset_loader.dataset)}")

Loading chest X-ray data...
Training samples: 4185
Validation samples: 1047
Subset samples: 500


In [22]:
# Test training on small subset
print("Training ResNet-50 on 500-image subset")
resnet_history = train_model(resnet, subset_loader, subset_loader, num_epochs=2, learning_rate=1e-4)

print("\nTraining ViT on 500-image subset") 
vit_history = train_model(vit, subset_loader, subset_loader, num_epochs=2, learning_rate=1e-4)

print(f"ResNet final validation accuracy: {resnet_history['val_acc'][-1]:.2f}%")
print(f"ViT final validation accuracy: {vit_history['val_acc'][-1]:.2f}%")

Training ResNet-50 on 500-image subset
Epoch 1/2, Batch 0/16, Loss: 0.0500
Epoch 1/2, Batch 5/16, Loss: 0.0123
Epoch 1/2, Batch 10/16, Loss: 0.0089
Epoch 1/2, Batch 15/16, Loss: 0.0019


KeyError: 'train_acc'