In [None]:
def runTraining(writer: SummaryWriter, loader, val_loader, modelName="Test_Model", labeled_dataset_names=None):
    print("-" * 40)
    print(f"~~~~~~~~  Starting the training for {modelName}... ~~~~~~")
    print("-" * 40)

    ## CREATION OF YOUR MODEL
    print(f" Model Name: {modelName}")
    net = UNet(num_classes).to(device)

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    train_losses = []
    train_dc_losses = []
    val_losses = []
    val_dc_losses = []

    best_loss_val = 1000

    directory = f"Results/Statistics/{modelName}"
    if not os.path.exists(directory):
        os.makedirs(directory)

    # START THE TRAINING
    for epoch in range(total_epochs):
        net.train()
        num_batches = len(loader)
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        for idx, data in enumerate(loader):
            net.zero_grad()
            optimizer.zero_grad()

            images, labels, img_names = data
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            net_predictions = net(images)

            segmentation_classes = utils.getTargetSegmentation(labels)
            loss = CE_loss(net_predictions, segmentation_classes)
            running_train_loss += loss.item()
            dice_loss = utils.compute_dsc(net_predictions, labels)
            running_dice_loss += dice_loss

            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                writer.add_scalar(
                    "Loss/train", running_train_loss / (idx + 1), epoch * len(loader) + idx
                )
                writer.add_scalar(
                    "Dice/train", running_dice_loss / (idx + 1), epoch * len(loader) + idx
                )

            if idx % 100 == 0:
                probs = torch.softmax(net_predictions, dim=1)
                y_pred = torch.argmax(probs, dim=1)

                annotated_img_names = [
                    f"{img} (Unlabeled)" if labeled_dataset_names and img not in labeled_dataset_names else img
                    for img in img_names
                ]
                writer.add_figure(
                    "predictions vs. actuals",
                    utils.plot_net_predictions(images, labels, y_pred, batch_size, annotated_img_names),
                    global_step=epoch * len(loader) + idx,
                )

            printProgressBar(
                idx + 1,
                num_batches,
                prefix=f"[Training] Epoch: {epoch} ",
                length=15,
                suffix=f" Loss: {running_train_loss / (idx + 1):.4f}, ",
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)
        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        # Validation
        net.eval()
        val_running_loss = 0
        val_running_dc = 0

        with torch.no_grad():
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data
                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                net_predictions = net(images)
                segmentation_classes = utils.getTargetSegmentation(labels)
                loss = CE_loss(net_predictions, segmentation_classes)
                val_running_loss += loss.item()

                dice_loss = utils.compute_dsc(net_predictions, labels)
                val_running_dc += dice_loss

                if idx % 10 == 0:
                    writer.add_scalar(
                        "Loss/val", val_running_loss / (idx + 1), epoch * len(val_loader) + idx
                    )
                    writer.add_scalar(
                        "Dice/val", val_running_dc / (idx + 1), epoch * len(val_loader) + idx
                    )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix=f"[Validation] Epoch: {epoch} ",
                    length=15,
                    suffix=f" Loss: {val_running_loss / (idx + 1):.4f}, ",
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        dc_loss = val_running_dc / len(val_loader)
        val_dc_losses.append(dc_loss)

        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists(f"./models/{modelName}"):
                os.makedirs(f"./models/{modelName}")
            torch.save(
                net.state_dict(), f"./models/{modelName}/{epoch}_Epoch"
            )

        printProgressBar(
            num_batches,
            num_batches,
            done=f"[Epoch: {epoch}, TrainLoss: {train_loss:.4f}, TrainDice: {train_dc_loss:.4f}, ValLoss: {val_loss:.4f}]",
        )

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    writer.flush()
