In [None]:
import os
import h5py
import torch
import logging
import numpy as np

import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from tqdm import tqdm
from datetime import datetime
from typing import Tuple, List
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

import torchvision.transforms.functional as F

from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18, ResNet18_Weights

In [None]:
# Trick needed to make logging work in an IPython notebook
# https://stackoverflow.com/a/21475297
from importlib import reload
reload(logging)

In [None]:
%load_ext tensorboard

In [None]:
OUTPUT_PATH = "runs"
# This will only use ~7GBs of VRAM, but maybe we should use a lower number?
# As far as I know a smaller batch size leads to better generalization
BATCH_SIZE = 512

In [None]:
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [None]:
logging.basicConfig(
  format="[%(asctime)s:%(levelname)s]: %(message)s",
  level=logging.INFO
)

In [None]:
# TODO: set all hyperparameters here, or use a config file
# that way we can keep track of all the hyperparameters we used in each run on wandb

config = {'BATCH_SIZE': BATCH_SIZE, 
          'EPOCHS': 100,
          'LEARNING_RATE': 0.001,
          'WEIGHT_DECAY': 0.0001,
          'MOMENTUM': 0.9,
          'NUM_WORKERS': 4,
          'DEVICE': 'cuda',
          }

# WandB: Define metadata of the run
run_name = 'test'
notes = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.'
tags = ['tag1', 'tag2']

In [None]:
# Wandb
import wandb
wandb.login()
%env "WANDB_NOTEBOOK_NAME" "train.ipynb"
%env WANDB_SILENT=False

run = wandb.init(name=run_name, notes = notes, tags = tags, project='PatchCamelyon',  entity='mi_ams',  config = config)

In [None]:
class PatchCamelyonDataset(Dataset):
    def __init__(self, data_path: str, targets_path: str, transform=None) -> None:
        self.data = h5py.File(data_path)#["x"]
        self.targets = h5py.File(targets_path)#["y"]
        self.transform = transform

    def __len__(self) -> int:
        return self.targets.shape[0]

    def __getitem__(self, idx: int) -> Tuple[torch.tensor, torch.tensor]:
        sample = torch.tensor(self.data[idx, :, :, :]).float() / 255.0
        # [channels, x, y] to [x, y, channels]
        sample = torch.permute(sample, (2, 0, 1))

        # We need to squeeze the targets as they are
        # nested within multiple arrays
        target = torch.tensor(self.targets[idx].squeeze())

        if self.transform:
            sample = self.transform(sample)
        
        return sample, target

In [None]:
def show(imgs: List[torch.tensor] | torch.tensor, labels: List[str | int] = None):    
    if type(imgs) != list:
        imgs = [imgs]

    if labels is None:
        labels = [""] * len(imgs)

    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(12, 6))

    for i, (img, label) in enumerate(zip(imgs, labels)):
        img = img.detach()
        img = F.to_pil_image(img)

        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        axs[0, i].set_xlabel(label)

In [None]:
train_dataset = PatchCamelyonDataset(
    data_path="data/camelyonpatch_level_2_split_train_x.h5",
    targets_path="data/camelyonpatch_level_2_split_train_y.h5",
    transform=transforms.Resize(224, antialias=True)
)

val_dataset = PatchCamelyonDataset(
    data_path="data/camelyonpatch_level_2_split_valid_x.h5",
    targets_path="data/camelyonpatch_level_2_split_valid_y.h5",
    transform=transforms.Resize(224, antialias=True)
)

train_dataset = PatchCamelyonDataset(
    data_path="data/camelyonpatch_level_2_split_test_x.h5",
    targets_path="data/camelyonpatch_level_2_split_test_y.h5",
    transform=transforms.Resize(224, antialias=True)
)

In [None]:
class2label = {
    0: "No Tumor",
    1: "Tumor"
}

samples, labels = zip(*[train_dataset[x] for x in range(5)])
samples, labels = list(samples), [class2label[x.item()] for x in list(labels)]

show(samples, labels)

In [None]:
def train_epoch(
    model: nn.Module,
    device: torch.device,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    loss_fn: nn.Module
):
    model.train()
    train_loss, correct_preds, batches_n = 0, 0, 0

    for _, (images, targets) in enumerate(tqdm(train_loader)):
        images, targets = images.to(device), targets.to(device)

        optimizer.zero_grad()

        preds = model(images)

        loss = loss_fn(preds, targets)
        loss.backward()

        # Record the training loss and the number of correct predictions
        batches_n += 1
        train_loss += loss.item()

        preds = preds.argmax(dim=1)
        correct_preds += (preds == targets).sum()

        optimizer.step()

    train_loss /= batches_n
    accuracy = correct_preds / len(train_loader.dataset)

    return train_loss, accuracy

In [None]:
def test(
    model: nn.Module,
    device: torch.device,
    test_loader: DataLoader
):
    model.eval()
    test_loss, correct_preds, batches_n = 0, 0, 0

    with torch.no_grad():
        for _, (images, targets) in enumerate(tqdm(test_loader)):
            images, targets = images.to(device), targets.to(device)

            preds = model(images)
            test_loss += loss_fn(preds, targets)

            batches_n += 1
            preds = preds.argmax(dim=1)
            correct_preds += (preds == targets).sum()

    test_loss /= batches_n
    accuracy = correct_preds / len(test_loader.dataset)

    return test_loss, accuracy

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(in_features=512, out_features=2)

In [None]:
model = model.to(device)

In [None]:
%%wandb

best_val_accuracy = 0
run_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%S")
run_folder = os.path.join(OUTPUT_PATH, run_id)

os.makedirs(run_folder)
writer = SummaryWriter(run_folder)

for epoch in range(10):
    logging.info(f"starting epoch {epoch}")

    train_loss, train_accuracy = train_epoch(
        model=model,
        device=device,
        optimizer=optimizer,
        loss_fn=loss_fn,
        train_loader=train_loader
    )

    logging.info(f"the train accuracy was {train_accuracy} (loss: {train_loss})")

    val_loss, val_accuracy = test(model=model, device=device, test_loader=val_loader)

    logging.info(f"the validation accuracy was {val_accuracy} (loss: {val_loss})")

    # Log metrics to tensorboard
    writer.add_scalar("loss/train", train_loss, epoch)
    writer.add_scalar("loss/val", val_loss, epoch)

    writer.add_scalar("acc/train", train_accuracy, epoch)
    writer.add_scalar("acc/val", val_accuracy, epoch)
    
    # Log epoch metrics
    wandb.log(
        {
        "Epoch": epoch,
        "loss/train": train_loss,
        "loss/val": val_loss,
        "acc/train": train_accuracy,
        "acc/val": val_accuracy,
        }
    )
       
    # Pick the best model according to the validation
    # accuracy score
    if val_accuracy > best_val_accuracy:
        logging.info(f"found new best model at epoch {epoch} with accuracy {val_accuracy} (loss {val_loss})")

        best_val_accuracy = val_accuracy
        wandb.run.summary["best_val_accuracy"] = best_val_accuracy
        wandb.run.summary["best_val_epoch"] = epoch

        # Save the model to disk
        torch.save(
            model.state_dict(),
            os.path.join(run_folder, f"model_{epoch}.pt")
        )

In [None]:
test_loss, test_accuracy = test(model=model, device=device, test_loader=test_loader)

In [None]:
logging.info(f"the test accuracy was {test_accuracy}")
wandb.run.summary["acc/test"] = test_accuracy

In [None]:
%tensorboard --logdir runs

In [None]:
# tell wandb we are done with this notebook
run.finish()