In [1]:
import time
import copy

import torch
from torch.nn import Module, functional
from torch.optim import Optimizer, Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LRScheduler, StepLR
from torch import Tensor

from data_loading.dataset import generate_dataloaders
from metrics.iou import calculate_iou, CHAOSIoUTracker
from unet import UNet

## Hyperparameters

In [2]:
N_CLASSES = 4
BATCH_SIZE = 32
N_EPOCHS = 50
LR = .001
WEIGHT_DECAY = .0001
STEP_SIZE = 30

## Generating Dataloaders for feeding the NN

In [3]:
dataloaders = generate_dataloaders(batch_size=BATCH_SIZE, validation_split=0.20)

## Model Definition

In [4]:
model = UNet(n_classes=N_CLASSES)

## Log Functions

In [6]:
def print_metric_to_console(name: str, train_metric: dict, valid_metric: dict) -> None:
    """
    Formats and prints training and validation metric values to console.
    Expected metric dict:
        {
        "metric1": metric_value1,
        "metric2": metric_value2,
        "metric3": metric_value3,
        ... }
    """
    print(f"{name.upper()}:")
    metric_dict = {"train": train_metric, "valid": valid_metric}
    for phase, metric in metric_dict.items():
        output = " ".join(f"{cls}: {score:.5f}" for cls, score in metric.items())
        print(f"\t({phase}) {output}")


def print_metric_to_tb(writer: SummaryWriter, train_metric: dict, valid_metric: dict, epoch: int) -> None:
    """
    Prints training and validation metric values to Tensorboard.
    Expected metric dict:
        {
        "metric1": metric_value1,
        "metric2": metric_value2,
        "metric3": metric_value3,
        ... }
    """
    for cls in train_metric.keys():
        scalar_dict = {
            "train": train_metric[cls],
            "valid": valid_metric[cls]
        }
        writer.add_scalars(main_tag=cls, tag_scalar_dict=scalar_dict, global_step=epoch)


WRITER_LOSS = SummaryWriter("logs/loss")
WRITER_ACCURACY = SummaryWriter("logs/accuracy")

## Training Loop

In [7]:
def generate_run_name(
        model_name: str,
        epochs: int,
        batch_size: int,
        lr: float,
        step_size: int,
        weight_decay: float
) -> str:
    timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
    return f"{model_name}_EP{epochs}_BS{batch_size}_LR{lr}_Step{step_size}_WD{weight_decay}_{timestamp}"


def train_model(
        model: Module,
        dataloaders: dict,
        n_epochs: int,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        device: torch.device = torch.device("cpu")
) -> None:
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch + 1}/{n_epochs}\n{'=' * 30}")
        print(f"LR: {optimizer.param_groups[0]['lr']}")

        start_time = time.time()

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            losses = {"bce": 0.0, "dice": 0.0, "loss": 0.0}
            total_loss = 0.0
            epoch_samples = 0
            accuracy_tracker = CHAOSIoUTracker()

            with torch.set_grad_enabled(phase == "train"):
                for images, masks in dataloaders[phase]:
                    images, masks = images.to(device), masks.to(device)

                    optimizer.zero_grad()
                    prediction = model(images)

                    loss = calculate_loss(prediction=prediction, ground_truth=masks, losses=losses)

                    iou_score = calculate_iou(prediction=prediction, ground_truth=masks)
                    accuracy_tracker.update(batch_iou=iou_score, batch_size=masks.size(0))

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                    total_loss += loss.item()
                    epoch_samples += masks.size(0)

            avg_loss = {name: value / epoch_samples for name, value in losses.items()}
            avg_accuracy = accuracy_tracker.get_results()

            if phase == "train":
                train_loss = avg_loss  # save the training loss
                train_accuracy = avg_accuracy  # save the training accuracy
            else:
                # print loss
                print_metric_to_console(name="LOSS", train_metric=train_loss, valid_metric=avg_loss)
                print_metric_to_tb(writer=WRITER_LOSS, train_metric=train_loss, valid_metric=avg_loss, epoch=epoch)

                # print accuracy
                print_metric_to_console(
                    name="ACCURACY(IoU)",
                    train_metric=train_accuracy,
                    valid_metric=avg_accuracy
                )
                print_metric_to_tb(
                    writer=WRITER_ACCURACY,
                    train_metric=train_accuracy,
                    valid_metric=avg_accuracy,
                    epoch=epoch
                )

                epoch_loss = total_loss / len(dataloaders[phase])
                if epoch_loss < best_loss:
                    print("Saving best model...")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        finish_time = round(time.time() - start_time)
        print(f"TIME: {finish_time // 60}min {finish_time % 60}s")
        scheduler.step()

    print(f"Best validation loss: {best_loss:.5f}")
    model.load_state_dict(best_model_wts)

    run_name = generate_run_name(model_name="UNet",
                                 weight_decay=WEIGHT_DECAY,
                                 lr=LR,
                                 step_size=STEP_SIZE,
                                 batch_size=BATCH_SIZE,
                                 epochs=N_EPOCHS
                                 )

    torch.save(model.state_dict(), f"./models/{run_name}.pth")


In [8]:
%reload_ext tensorboard
%tensorboard --logdir logs/losses

Reusing TensorBoard on port 6006 (pid 14628), started 34 days, 0:36:34 ago. (Use '!kill 14628' to kill it.)

In [9]:
%tensorboard --logdir logs/accuracies

Reusing TensorBoard on port 6006 (pid 12688), started 12:38:15 ago. (Use '!kill 12688' to kill it.)

## Define Optimizer, LR Scheduler and Start Training

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
lr_scheduler = StepLR(optimizer=optimizer, step_size=STEP_SIZE, gamma=0.1)

train_model(
    model=model,
    dataloaders=dataloaders,
    n_epochs=N_EPOCHS,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    device=device
)


Epoch 1/50
LR: 0.001
LOSS:
	(train) bce: 0.43447 dice: 0.95125 loss: 0.69286
	(valid) bce: 0.46896 dice: 0.94256 loss: 0.70576
ACCURACY(IOU):
	(train) liver: 0.15494 r_kidney: 0.44651 l_kidney: 0.17873 spleen: 0.40679
	(valid) liver: 0.12595 r_kidney: 0.55906 l_kidney: 0.50787 spleen: 0.49606
Saving best model...
TIME: 4min 29s

Epoch 2/50
LR: 0.001
LOSS:
	(train) bce: 0.22822 dice: 0.90627 loss: 0.56725
	(valid) bce: 0.27485 dice: 0.94278 loss: 0.60882
ACCURACY(IOU):
	(train) liver: 0.34500 r_kidney: 0.60236 l_kidney: 0.56004 spleen: 0.61122
	(valid) liver: 0.13511 r_kidney: 0.55906 l_kidney: 0.54724 spleen: 0.61811
TIME: 4min 31s

Epoch 3/50
LR: 0.001
LOSS:
	(train) bce: 0.13644 dice: 0.85791 loss: 0.49717
	(valid) bce: 0.12346 dice: 0.83316 loss: 0.47831
ACCURACY(IOU):
	(train) liver: 0.41428 r_kidney: 0.60236 l_kidney: 0.55216 spleen: 0.44430
	(valid) liver: 0.51328 r_kidney: 0.55906 l_kidney: 0.18539 spleen: 0.20596
TIME: 4min 21s

Epoch 4/50
LR: 0.001
LOSS:
	(train) bce: 0.08852


KeyboardInterrupt

