## Imports and Defines

In [1]:
import io
import os
import csv
import sys
import tqdm
import torch
import zipfile
import requests
import numpy as np
from torch import nn
from itertools import product
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

## Hyperparameters

In [2]:
IMAGE_SIZE = [512]
EPOCHS = [2, 5]
LR = [0.005, 0.01]
MOMENTUM = [0.2, 0.9]
WEIGHT_DECAY = [0.0005, 0.001]
BATCH_SIZE = [25]

keys = ["image_size", "epochs", "lr", "momentum", "weight_decay", "batch_size"]
combos = product(IMAGE_SIZE, EPOCHS, LR, MOMENTUM, WEIGHT_DECAY, BATCH_SIZE)

combos = [dict(zip(keys, combo)) for combo in combos]


## Database Setups and Images Transforms

In [None]:
def db_setup(img_size):
    if "google.colab" in sys.modules:
        from google.colab import drive
        drive.mount('/content/drive')
        BASE_PATH = "/content/drive/MyDrive/chest_xray/processed/"
    else:
        BASE_PATH = os.path.join(os.getcwd(), "chest_xray/processed/")


    db_train = datasets.ImageFolder(root=BASE_PATH+'train', transform=None)
    db_val = datasets.ImageFolder(root=BASE_PATH+'val', transform=None)
    db_test = datasets.ImageFolder(root=BASE_PATH+'test', transform=None)


    db_train.transform = transforms.Compose([
        transforms.Grayscale(1),
        transforms.RandomResizedCrop(size=[img_size, img_size], scale=(0.5,1.)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize(             # Normalize using ImageNet's mean and standard deviation
            mean=0.485,
            std=0.225
        )
    ])


    db_val.transform = transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize([img_size, img_size]),
        # Resize the short side of the image to 256
        transforms.CenterCrop([img_size, img_size]),       # Crop a center patch of the image of size 224x224
        transforms.ToTensor(),            # Convert the image to tensor format
        transforms.Normalize(             # Normalize using ImageNet's mean and standard deviation
            mean= 0.406,
            std=0.225
        )
    ])

    db_test.transform = transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize([img_size, img_size]),           # Resize the short side of the image to 256
        transforms.CenterCrop([img_size, img_size]),       # Crop a center patch of the image of size 224x224
        transforms.ToTensor(),            # Convert the image to tensor format
        transforms.Normalize(             # Normalize using ImageNet's mean and standard deviation
            mean=0.485,
            std=0.229
        )
    ])


    return db_train, db_val, db_test

## Training and Testing Funcitons

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def fit_one_epoch(model, opt, loader):
    model.train(True)
    loss = nn.CrossEntropyLoss()
    losses, accuracies = [], []
    for images, labels in tqdm.tqdm(loader):
        images = images.to(device)
        labels = labels.to(device)

        pred = model(images)
        l = loss(pred, labels)
        acc = (pred.argmax(1) == labels).float().mean()

        l.backward()
        opt.step()
        opt.zero_grad()

        losses.append(l.detach().item())
        accuracies.append(acc.detach().item())
    return np.mean(losses), np.mean(accuracies)


@torch.no_grad()
def eval(model, loader):
    model.train(False)
    loss = nn.CrossEntropyLoss()
    accuracies, losses = [], []
    all_preds, all_labels = [], []
    for images, labels in tqdm.tqdm(loader):
        images = images.to(device)
        labels = labels.to(device)

        pred = model(images)

        all_preds.extend(pred.argmax(1).cpu().tolist())
        all_labels.extend(labels.tolist())

        acc = (pred.argmax(1) == labels).float().mean()
        l = loss(pred, labels)

        accuracies.append(acc.detach().item())
        losses.append(l.detach().item())
    return np.mean(losses), np.mean(accuracies), all_preds, all_labels

@torch.no_grad()
def plot_cm(all_labels, all_preds):
    cm = confusion_matrix(all_labels, all_preds)

    # Plot
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='gray')
    plt.title("Confusion Matrix")
    plt.show()

    return cm

def fit(model, loader_train, loader_val, epochs=50, opt=None):
    assert opt is not None
    hist_tr_loss, hist_val_loss, hist_tr_acc, hist_val_acc = [], [], [], []
    for epoch in range(epochs + 1):
        tr_l, tr_acc = fit_one_epoch(model, opt, loader_train)
        val_l, val_acc, p, l = eval(model, loader_val)

        print(f"Finished epoch {epoch + 1} of {epochs}: Train Loss = {tr_l:.3f}  Val Loss = {val_l:.3f}   Train Acc = {tr_acc:.3f}   Val Acc = {val_acc:.3f}", flush=True)
        hist_tr_loss.append(tr_l)
        hist_val_loss.append(val_l)
        hist_tr_acc.append(tr_acc)
        hist_val_acc.append(val_acc)
    return hist_tr_loss, hist_val_loss, hist_tr_acc, hist_val_acc


def plot_training_history(combo_stats):
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    for combo, stats in combo_stats.items():
        c = dict(combo)
        label = f"lr={c['LR']}, mom={c['MOMENTUM']}, bs={c['BATCH_SIZE']}"
        plt.plot(stats["hist_tr_acc"], label=f"{label} train")
        plt.plot(stats["hist_val_acc"], label=f"{label} val", linestyle="--")
    plt.ylim([0.4, 1.05])
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training vs Validation Accuracy")
    plt.legend(fontsize=7)

    plt.subplot(1, 2, 2)
    for combo, stats in combo_stats.items():
        c = dict(combo)
        label = f"lr={c['LR']}, mom={c['MOMENTUM']}, bs={c['BATCH_SIZE']}"
        plt.plot(stats["hist_tr_loss"], label=f"{label} train")
        plt.plot(stats["hist_val_loss"], label=f"{label} val", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend(fontsize=7)

    plt.tight_layout()
    plt.show()


## Fitting Different Hyperparameters

In [None]:
combo_stats = {}

for combo in combos:

    combo_key = tuple(combo.items())


    db_train, db_val, db_test = db_setup(combo["image_size"])

    model = models.resnet18()
    model.conv1 = nn.Conv2d(
        in_channels=1,
        out_channels=64,
        kernel_size=(7, 7),
        stride=(2, 2),
        padding=(3, 3),
        bias=False
    )

    loader_train = DataLoader(db_train, batch_size=combo["batch_size"], shuffle=True, drop_last=True)
    loader_val = DataLoader(db_val, batch_size=combo["batch_size"], shuffle=False)
    loader_test = DataLoader(db_test, batch_size=combo["batch_size"], shuffle=False)

    model = model.to(device)

    opt = torch.optim.SGD(model.parameters(), lr=combo["lr"], momentum=combo["momentum"], weight_decay=combo["weight_decay"])

    hist_tr_loss, hist_val_loss, hist_tr_acc, hist_val_acc = fit(model, loader_train, loader_val, epochs=combo["epochs"], opt=opt)

    combo_stats[combo_key] = {
        "hist_tr_loss"   : hist_tr_loss,
        "hist_val_loss" : hist_val_loss,
        "hist_tr_acc"    : hist_tr_acc,
        "hist_val_acc"   : hist_val_acc
    }

Mounted at /content/drive


100%|██████████| 163/163 [1:41:20<00:00, 37.30s/it]
100%|██████████| 36/36 [12:23<00:00, 20.65s/it]

Finished epoch 0 of 2: Train Loss = 0.541  Val Loss = 0.353   Train Acc = 0.830   Val Acc = 0.859



100%|██████████| 163/163 [1:33:10<00:00, 34.30s/it]
100%|██████████| 36/36 [06:27<00:00, 10.76s/it]

Finished epoch 1 of 2: Train Loss = 0.310  Val Loss = 0.356   Train Acc = 0.862   Val Acc = 0.844





Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


100%|██████████| 163/163 [1:36:14<00:00, 35.43s/it]
100%|██████████| 36/36 [06:26<00:00, 10.75s/it]

Finished epoch 0 of 2: Train Loss = 0.542  Val Loss = 0.670   Train Acc = 0.831   Val Acc = 0.723



100%|██████████| 163/163 [1:36:27<00:00, 35.51s/it]
100%|██████████| 36/36 [06:26<00:00, 10.74s/it]

Finished epoch 1 of 2: Train Loss = 0.306  Val Loss = 0.453   Train Acc = 0.871   Val Acc = 0.803





Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


100%|██████████| 163/163 [1:34:35<00:00, 34.82s/it]
100%|██████████| 36/36 [06:05<00:00, 10.15s/it]

Finished epoch 0 of 2: Train Loss = 0.485  Val Loss = 0.603   Train Acc = 0.841   Val Acc = 0.700



100%|██████████| 163/163 [1:33:52<00:00, 34.56s/it]
100%|██████████| 36/36 [06:05<00:00, 10.16s/it]

Finished epoch 1 of 2: Train Loss = 0.342  Val Loss = 0.262   Train Acc = 0.868   Val Acc = 0.899





Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


100%|██████████| 163/163 [1:33:01<00:00, 34.24s/it]
 19%|█▉        | 7/36 [01:06<04:36,  9.52s/it]

## Choosing Best Hyperparameter Combination

In [None]:
plot_training_history(combo_stats)

best_combo = None
best_acc = -1
best_loss = None

for combo, stats in combo_stats.items():
    final_acc = stats["hist_test_acc"][-1]
    final_loss = stats["hist_test_loss"][-1]

    if final_acc > best_acc:
        best_acc = final_acc
        best_loss = final_loss
        best_combo = combo

print(f'Best Test accuracy: {best_acc*100:.2f}%')
print(f'Best Test loss: {best_loss:.3f}')

## Refit on Best

In [None]:
db_train, db_val, db_test = db_setup(best_combo["image_size"])

model = models.resnet50()

loader_train = DataLoader(db_train, batch_size=best_combo["batch_size"], shuffle=True, drop_last=True)
loader_val = DataLoader(db_val, batch_size=best_combo["batch_size"], shuffle=False)
loader_test = DataLoader(db_test, batch_size=best_combo["batch_size"], shuffle=False)

model = model.to(device)

opt = torch.optim.SGD(model.parameters(), lr=best_combo["lr"], momentum=best_combo["momentum"], weight_decay=best_combo["weight_decay"])

fit(model, loader_train, loader_val, epochs=best_combo["epochs"], opt=opt)


## Results

In [None]:

l, acc, all_labels, all_preds = eval(model, loader_test)
cm = plot_cm(all_labels, all_preds)
print(cm)