In [1]:
# Import stuff
from torchvision import transforms, datasets
import os
from PIL import Image
import numpy as np
import platform
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.utils.data import DataLoader, random_split
from pathlib import Path

#### Run this if in Google Colab

In [2]:
def in_colab() -> bool:
    try:
        import google.colab
        return True
    except Exception:
        return False

REPO_URL = "https://github.com/Tiromachelan/pneumonia_classification.git"

if in_colab():
    if not Path("pneumonia_classification").exists():
        !rm -r 128x128_data
        !rm -r pneumonia_classification
        !git clone {REPO_URL}
        !mv pneumonia_classification/128x128_data .

rm: cannot remove '128x128_data': No such file or directory
rm: cannot remove 'pneumonia_classification': No such file or directory
Cloning into 'pneumonia_classification'...
remote: Enumerating objects: 5855, done.[K
remote: Counting objects: 100% (5855/5855), done.[K
remote: Compressing objects: 100% (5846/5846), done.[K
remote: Total 5855 (delta 8), reused 5851 (delta 6), pack-reused 0 (from 0)[K
Receiving objects: 100% (5855/5855), 13.85 MiB | 24.07 MiB/s, done.
Resolving deltas: 100% (8/8), done.


In [None]:
# Delete images in 128x128_data if needed
normal_dir = "128x128_data/NORMAL"
pneumonia_dir = "128x128_data/PNEUMONIA"

for file in os.listdir(normal_dir) + os.listdir(pneumonia_dir):
    img_path = os.path.join(normal_dir if file in os.listdir(normal_dir) else pneumonia_dir, file)
    os.remove(img_path)

In [3]:
# Select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device.type}\n")

Using device: cuda



# Data Preparation

#### Resize the images to 128 x 128

In [None]:
# Set up paths
normal_dir = "data/NORMAL"
pneumonia_dir = "data/PNEUMONIA"

# Find which images have the smallest size
min_size = float('inf')
min_dimensions = (float('inf'), float('inf'))
smallest_img_path = ""

for image in os.listdir(normal_dir) + os.listdir(pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(normal_dir if image in os.listdir(normal_dir) else pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            width, height = img.size
            if width * height < min_size:
                min_size = width * height
                min_dimensions = (width, height)
                smallest_img_path = img_path

print(f"Smallest image dimensions: {min_dimensions}")
print(f"Smallest image path: {smallest_img_path}")

# Paths for resized images
resized_normal_dir = "128x128_data/NORMAL"
resized_pneumonia_dir = "128x128_data/PNEUMONIA"

# Resize all images to 128x128 pixels and save them
for image in os.listdir(normal_dir) + os.listdir(pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(normal_dir if image in os.listdir(normal_dir) else pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            img = img.convert("L")  # Convert to grayscale
            width, height = img.size
            if width > height:
                cropped_width = 128
                cropped_height = int(height * 128 / width)
            else:
                cropped_height = 128
                cropped_width = int(width * 128 / height)
            img = img.resize((cropped_width, cropped_height)) # Resize preserving ratio
            left = 0
            upper = (cropped_height - 128) // 2
            right = 128
            lower = upper + 128
            img = img.crop((left, upper, right, lower)) # Center crop

            if "NORMAL" in img_path:
                save_path = os.path.join(resized_normal_dir, image)
            else:
                save_path = os.path.join(resized_pneumonia_dir, image)
            img.save(save_path)

# ~30 seconds

#### Ensure that all of the images are the same shape

In [None]:
# Convert each image to a tensor and ensure they are all 1 x 128 x 128
counter = 0
for image in os.listdir(resized_normal_dir) + os.listdir(resized_pneumonia_dir):
    if image.endswith((".jpeg", ".jpg", ".png")):
        img_path = os.path.join(resized_normal_dir if image in os.listdir(resized_normal_dir) else resized_pneumonia_dir, image)
        with Image.open(img_path, "r") as img:
            img_tensor = transforms.ToTensor()(img)
            if img_tensor.shape != (1, 128, 128):
                print(f"{image} has shape {img_tensor.shape}")
                counter += 1
print(f"{counter} images with incorrect shape")

#### Load the data

In [5]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = datasets.ImageFolder(
    root="128x128_data",
    transform=transform
)

print(f"Classes: {dataset.classes}")

# Split data
train_size = int(0.8 * len(dataset))
test_size = int(0.1 * len(dataset))
val_size = len(dataset) - train_size - test_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

# DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Classes: ['NORMAL', 'PNEUMONIA']
Train size: 4684
Val size: 587
Test size: 585


# Fully Connected Neural Network

#### Model

In [6]:
# Define the network
class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 2)
        )
    def forward(self, x):
        return self.net(x)


#### Parameters

In [7]:
# Define the training parameters
model = MLP(128*128).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)
epochs = 100
history = {
    "train_loss": [], "train_acc": [],
    "val_loss":   [], "val_acc":   []
}

# Calculate accuracy from logits
def accuracy_from_logits(logits, y):
    preds = logits.argmax(1)  # choose class with highest predicted score
    return (preds == y).float().mean().item()  # fraction of correct predictions


#### Training

In [8]:
train_start_time = time.time()
for epoch in range(epochs):
    epoch_start_time = time.time()

    # Training
    model.train()
    running_loss, running_correct, total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        running_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    train_loss = running_loss / total
    train_acc = running_correct / total

    # Validation
    model.eval()
    val_running_loss, val_running_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            val_running_loss += loss.item() * x.size(0)
            val_running_correct += (logits.argmax(1) == y).sum().item()
            val_total += x.size(0)

    val_loss = val_running_loss / val_total
    val_acc  = val_running_correct / val_total

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    epoch_time = time.time() - epoch_start_time
    print(f"Epoch {epoch:02d} | "
          f"train: loss={train_loss:.4f}, acc={train_acc:.4f} | "
          f"val: loss={val_loss:.4f}, acc={val_acc:.4f} | "
          f"time: {epoch_time:.2f}s")

# Total training time
total_time = time.time() - train_start_time
print(f"\nTotal training time: {total_time:.2f}s")

# Save the trained model
torch.save({
    "model_state": model.state_dict(),
    "history": history
}, "mlp_relu_128x128.pth")

# 5:41 on MPS
# 7:50 on CPU
# 6:20 on L4

Epoch 00 | train: loss=14374.0622, acc=0.7205 | val: loss=168.0005, acc=0.7496 | time: 4.72s
Epoch 01 | train: loss=874.1040, acc=0.7319 | val: loss=2.5472, acc=0.7428 | time: 3.86s
Epoch 02 | train: loss=3.4634, acc=0.7316 | val: loss=0.5656, acc=0.7445 | time: 3.78s
Epoch 03 | train: loss=0.5681, acc=0.7316 | val: loss=0.5349, acc=0.7445 | time: 3.82s
Epoch 04 | train: loss=0.7637, acc=0.7316 | val: loss=0.5702, acc=0.7445 | time: 3.81s
Epoch 05 | train: loss=0.5831, acc=0.7316 | val: loss=0.5683, acc=0.7445 | time: 3.84s
Epoch 06 | train: loss=0.5862, acc=0.7316 | val: loss=0.5724, acc=0.7445 | time: 3.81s
Epoch 07 | train: loss=0.5837, acc=0.7316 | val: loss=0.5685, acc=0.7445 | time: 3.80s
Epoch 08 | train: loss=0.5847, acc=0.7316 | val: loss=0.5738, acc=0.7445 | time: 3.81s
Epoch 09 | train: loss=0.5821, acc=0.7316 | val: loss=0.5763, acc=0.7445 | time: 3.88s
Epoch 10 | train: loss=0.5837, acc=0.7316 | val: loss=0.5690, acc=0.7445 | time: 3.84s
Epoch 11 | train: loss=0.5874, acc=