In [None]:
import time
import copy

import torch
from torch.nn import Module, functional
from torch.optim import Optimizer, Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LRScheduler, StepLR
from torch import Tensor

from data_loading.dataset import generate_dataloaders
from metrics.iou import calculate_iou, CHAOSIoUTracker
from unet import UNet

## Hyperparameters

In [None]:
N_CLASSES = 4
BATCH_SIZE = 32
N_EPOCHS = 50
LR = .001
WEIGHT_DECAY = .0001
STEP_SIZE = 30

## Generating Dataloaders for feeding the NN

In [None]:
dataloaders = generate_dataloaders(batch_size=BATCH_SIZE, validation_split=0.20)

## Model Definition

In [None]:
model = UNet(n_classes=N_CLASSES)

## Loss Functions

In [None]:
def calculate_loss(prediction: Tensor, ground_truth: Tensor, losses: dict, bce_weight: float = 0.5) -> Tensor:
    bce = functional.binary_cross_entropy_with_logits(prediction, ground_truth)
    dice = dice_loss(prediction, ground_truth)
    loss = bce * bce_weight + dice * (1 - bce_weight)

    losses["bce"] += bce.data.cpu().numpy() * ground_truth.size(0)
    losses["dice"] += dice.data.cpu().numpy() * ground_truth.size(0)
    losses["loss"] += loss.data.cpu().numpy() * ground_truth.size(0)

    return loss


def dice_loss(prediction: Tensor, ground_truth: Tensor, smooth: float = 1e-6) -> Tensor:
    prediction = prediction.sigmoid()

    prediction = prediction.view(prediction.shape[0], prediction.shape[1], -1)
    ground_truth = ground_truth.view(ground_truth.shape[0], ground_truth.shape[1], -1)

    intersection = torch.sum(prediction * ground_truth, dim=2)
    union = torch.sum(prediction, dim=2) + torch.sum(ground_truth, dim=2)
    dice = (2. * intersection + smooth) / (union + smooth)

    return 1 - dice.mean()


## Log Functions

In [None]:
def print_metric_to_console(name: str, train_metric: dict, valid_metric: dict) -> None:
    """
    Formats and prints training and validation metric values to console.
    Expected metric dict:
        {
        "metric1": metric_name1_value,
        "metric2": metric_value2,
        "metric3": metric_value3,
        ... }
    """
    print(f"{name.upper()}:")
    metric_dict = {"train": train_metric, "valid": valid_metric}
    for phase, metric in metric_dict.items():
        output = " ".join(f"{cls}: {score:.5f}" for cls, score in metric.items())
        print(f"\t({phase}) {output}")


def print_metric_to_tb(writer: SummaryWriter, train_metric: dict, valid_metric: dict, epoch: int) -> None:
    """
    Prints training and validation metric values to Tensorboard.
    Expected metric dict:
        {
        "metric1": metric_name1_value,
        "metric2": metric_value2,
        "metric3": metric_value3,
        ... }
    """
    for cls in train_metric.keys():
        scalar_dict = {
            "train": train_metric[cls],
            "valid": valid_metric[cls]
        }
        writer.add_scalars(main_tag=cls, tag_scalar_dict=scalar_dict, global_step=epoch)


WRITER_LOSS = SummaryWriter("logs/loss")
WRITER_ACCURACY = SummaryWriter("logs/accuracy")

## Training Loop

In [None]:
def generate_run_name(
        model_name: str,
        epochs: int,
        batch_size: int,
        lr: float,
        step_size: int,
        weight_decay: float
) -> str:
    timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
    return f"{model_name}_EP{epochs}_BS{batch_size}_LR{lr}_Step{step_size}_WD{weight_decay}_{timestamp}"


def train_model(
        model: Module,
        dataloaders: dict,
        n_epochs: int,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        device: torch.device = torch.device("cpu")
) -> None:
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch + 1}/{n_epochs}\n{'=' * 30}")
        print(f"LR: {optimizer.param_groups[0]['lr']}")

        start_time = time.time()

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            losses = {"bce": 0.0, "dice": 0.0, "loss": 0.0}
            total_loss = 0.0
            epoch_samples = 0
            accuracy_tracker = CHAOSIoUTracker()

            with torch.set_grad_enabled(phase == "train"):
                for images, masks in dataloaders[phase]:
                    images, masks = images.to(device), masks.to(device)

                    optimizer.zero_grad()
                    prediction = model(images)

                    loss = calculate_loss(prediction=prediction, ground_truth=masks, losses=losses)

                    iou_score = calculate_iou(prediction=prediction, ground_truth=masks)
                    accuracy_tracker.update(batch_iou=iou_score, batch_size=masks.size(0))

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                    total_loss += loss.item()
                    epoch_samples += masks.size(0)

            avg_loss = {name: value / epoch_samples for name, value in losses.items()}
            avg_accuracy = accuracy_tracker.get_results()

            if phase == "train":
                train_loss = avg_loss  # save the training loss
                train_accuracy = avg_accuracy  # save the training accuracy
            else:
                # print loss
                print_metric_to_console(name="LOSS", train_metric=train_loss, valid_metric=avg_loss)
                print_metric_to_tb(writer=WRITER_LOSS, train_metric=train_loss, valid_metric=avg_loss, epoch=epoch)

                # print accuracy
                print_metric_to_console(
                    name="ACCURACY(IoU)",
                    train_metric=train_accuracy,
                    valid_metric=avg_accuracy
                )
                print_metric_to_tb(
                    writer=WRITER_ACCURACY,
                    train_metric=train_accuracy,
                    valid_metric=avg_accuracy,
                    epoch=epoch
                )

                epoch_loss = total_loss / len(dataloaders[phase])
                if epoch_loss < best_loss:
                    print("Saving best model...")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        finish_time = round(time.time() - start_time)
        print(f"TIME: {finish_time // 60}min {finish_time % 60}s")
        scheduler.step()

    print(f"Best validation loss: {best_loss:.5f}")
    model.load_state_dict(best_model_wts)

    run_name = generate_run_name(model_name="UNet",
                                 weight_decay=WEIGHT_DECAY,
                                 lr=LR,
                                 step_size=STEP_SIZE,
                                 batch_size=BATCH_SIZE,
                                 epochs=N_EPOCHS
                                 )

    torch.save(model.state_dict(), f"./models/{run_name}.pth")


In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs/losses

In [None]:
%tensorboard --logdir logs/accuracies

## Define Optimizer, LR Scheduler and Start Training

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
lr_scheduler = StepLR(optimizer=optimizer, step_size=STEP_SIZE, gamma=0.1)

train_model(
    model=model,
    dataloaders=dataloaders,
    n_epochs=N_EPOCHS,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    device=device
)

LOSS:
	(train) bce: 0.07829 dice: 0.83865 loss: 0.45847
	(valid) bce: 0.11657 dice: 0.82053 loss: 0.46855
ACCURACY(IOU):
	(train) liver: 0.43357 r_kidney: 0.32086 l_kidney: 0.34772 spleen: 0.32575
	(valid) liver: 0.39052 r_kidney: 0.16710 l_kidney: 0.20785 spleen: 0.18373
Saving best model...
TIME: 4min 30s

Epoch 5/50
LR: 0.001
LOSS:
	(train) bce: 0.05765 dice: 0.79601 loss: 0.42683
	(valid) bce: 0.06561 dice: 0.78339 loss: 0.42450
ACCURACY(IOU):
	(train) liver: 0.48079 r_kidney: 0.41840 l_kidney: 0.33328 spleen: 0.29140
	(valid) liver: 0.37566 r_kidney: 0.23205 l_kidney: 0.33591 spleen: 0.25403
Saving best model...
TIME: 4min 34s

Epoch 6/50
LR: 0.001


KeyboardInterrupt: 