In [None]:
import time
import copy

import torch
from torch.nn import Module, functional
from torch.optim import Optimizer, Adam
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LRScheduler, StepLR
from torch import Tensor

from data_loading.dataset import generate_dataloaders
from unet import UNet

## Hyperparameters

In [None]:
N_CLASSES = 4
BATCH_SIZE = 32
N_EPOCHS = 50
LR = .001
WEIGHT_DECAY = .0001
STEP_SIZE = 30

## Generating Dataloaders for feeding the NN

In [None]:
dataloaders = generate_dataloaders(batch_size=BATCH_SIZE, validation_split=0.20)

## Model Definition

In [None]:
model = UNet(n_classes=N_CLASSES)

## Loss Functions

In [None]:
def calculate_loss(prediction: Tensor, ground_truth: Tensor, losses: dict, bce_weight: float = 0.5) -> Tensor:
    bce = functional.binary_cross_entropy_with_logits(prediction, ground_truth)
    dice = dice_loss(prediction, ground_truth)
    loss = bce * bce_weight + dice * (1 - bce_weight)

    losses["bce"] += bce.data.cpu().numpy() * ground_truth.size(0)
    losses["dice"] += dice.data.cpu().numpy() * ground_truth.size(0)
    losses["loss"] += loss.data.cpu().numpy() * ground_truth.size(0)

    return loss


def dice_loss(prediction: Tensor, ground_truth: Tensor, smooth: float = 1e-6) -> Tensor:
    prediction = prediction.sigmoid()

    prediction = prediction.view(prediction.shape[0], prediction.shape[1], -1)
    ground_truth = ground_truth.view(ground_truth.shape[0], ground_truth.shape[1], -1)

    intersection = torch.sum(prediction * ground_truth, dim=2)
    union = torch.sum(prediction, dim=2) + torch.sum(ground_truth, dim=2)
    dice = (2. * intersection + smooth) / (union + smooth)

    return 1 - dice.mean()


## Log Functions

In [None]:
def print_metric_to_console(metric_name: str, train_metric: dict, valid_metric: dict) -> None:
    """
    Formats and prints training and validation metric values to console.
    Expected metric dict:
        {
        "metric1": metric_name1_value,
        "metric2": metric_value2,
        "metric3": metric_value3,
        ... }
    """
    print(f"{metric_name.upper()}:")
    metric_dict = {"train": train_metric, "valid": valid_metric}
    for phase, metric in metric_dict.items():
        output = " ".join(f"{k}: {v:.5f}" for k, v in metric.items())
        print(f"\t({phase}) {output}")


def print_losses_to_tb(writer: SummaryWriter, train_losses: dict, valid_losses: dict, epoch: int) -> None:
    for key in train_losses.keys():
        scalar_dict = {
            "train": train_losses[key],
            "valid": valid_losses[key]
        }
        writer.add_scalars(main_tag=key, tag_scalar_dict=scalar_dict, global_step=epoch)


WRITER = SummaryWriter("logs/losses")

## Training Loop

In [None]:
def generate_run_name(
        model_name: str,
        epochs: int,
        batch_size: int,
        lr: float,
        step_size: int,
        weight_decay: float
) -> str:
    timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
    return f"{model_name}_EP{epochs}_BS{batch_size}_LR{lr}_Step{step_size}_WD{weight_decay}_{timestamp}"


def train_model(
        model: Module,
        dataloaders: dict,
        n_epochs: int,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        device: torch.device = torch.device("cpu")
) -> None:
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(n_epochs):
        print(f"\nEpoch {epoch + 1}/{n_epochs}\n{'=' * 30}")
        print(f"LR: {optimizer.param_groups[0]['lr']}")

        start_time = time.time()

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            losses = {"bce": 0.0, "dice": 0.0, "loss": 0.0}
            total_loss = 0.0
            epoch_samples = 0

            with torch.set_grad_enabled(phase == "train"):
                for batch_X, batch_y in dataloaders[phase]:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)

                    optimizer.zero_grad()
                    predictions = model(batch_X)
                    loss = calculate_loss(prediction=predictions, ground_truth=batch_y, losses=losses)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                    total_loss += loss.item()
                    epoch_samples += batch_y.size(0)

            avg_losses = {loss_name: value / epoch_samples for loss_name, value in losses.items()}

            if phase == "train":
                train_losses = avg_losses  # save the training loss
            else:
                print_metric_to_console(metric_name="LOSS", train_metric=train_losses, valid_metric=avg_losses)
                print_losses_to_tb(writer=WRITER, train_losses=train_losses, valid_losses=avg_losses, epoch=epoch)

                epoch_loss = total_loss / len(dataloaders[phase])
                if epoch_loss < best_loss:
                    print("Saving best model...")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        finish_time = round(time.time() - start_time)
        print(f"TIME: {finish_time // 60}min {finish_time % 60}s")
        scheduler.step()

    print(f"Best validation loss: {best_loss:.5f}")
    model.load_state_dict(best_model_wts)

    run_name = generate_run_name(model_name="UNet",
                                 weight_decay=WEIGHT_DECAY,
                                 lr=LR,
                                 step_size=STEP_SIZE,
                                 batch_size=BATCH_SIZE,
                                 epochs=N_EPOCHS
                                 )

    torch.save(model.state_dict(), f"./models/{run_name}.pth")


In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs/losses

## Define Optimizer, LR Scheduler and Start Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
lr_scheduler = StepLR(optimizer=optimizer, step_size=STEP_SIZE, gamma=0.1)

train_model(
    model=model,
    dataloaders=dataloaders,
    n_epochs=N_EPOCHS,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    device=device
)