## Imports

In [None]:
# this will cause jupyter to reload the *.py files we inport when we make changes to
# them.  otherwiser you'd need to restar the server everytime
%load_ext autoreload
%autoreload 2

In [None]:
# before importing anything we can setup the environment.  the imports will see these
# options so it's a good place to set debug options and other environment variables
# that we might want to control from the notebook.
import os

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
#!pip install torch torchvision opencv-python efficientnet_pytorch

In [None]:
import numpy as np
import pandas as pd
import torch
import torchvision
import matplotlib.pyplot as plt
import seaborn as sns
import gc

In [None]:
from tqdm.notebook import tqdm

In [None]:
# change working directory for notebook
!pwd


In [None]:
# make . the first search path for modules
import sys

sys.path.insert(0, ".")

In [None]:
import glioma
import config
import transform
import model

## Config

In [None]:
# i've moved the config to config.py
# you can make changes there if needed.

In [None]:
# setup torch devices.  this means the code should run with or without a GPU
compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
compute_device

In [None]:
# might need to do more with this !
try:
    torch.multiprocessing.set_start_method("spawn")
    print("spawned")
except RuntimeError:
    pass

In [None]:
!mkdir ../output ../output/images ../output/models ../output/submissions

## Data Access

In [None]:
DATA_ROOT = "../input/glioma-mcd-2025/Data_122824"
TEST_ROOT = f"{DATA_ROOT}/../../Oneshot_testingV2/Test-2.1"
TRAINING_ROOT = f"{DATA_ROOT}/Glioma_MDC_2025_training"

In [None]:
glioma_training_data = glioma.CellTrainingDataset(TRAINING_ROOT)
glioma_test_data = glioma.CellTestDataset(TEST_ROOT)

## Augmentation

### Example

In [None]:
# an example cell to show the transforms
example_cell = glioma_training_data[0]
example_cell.reset()

In [None]:
# an example of a transform sequence with all the probabilities set to 1.0 to
# demonstrate them.
always_transforms = transform.TransformSequence(
    [
        transform.ResetCell(),
        transform.RandomRotate(1.0),
        transform.RandomHFlip(1.0),
        transform.RandomVFlip(1.0),
        transform.RandomXScale(1.0, config.MAX_CROP, config.MAX_PAD),
        transform.Turn(),
        transform.RandomXScale(1.0, config.MAX_CROP, config.MAX_PAD),
        transform.Turn(),
        transform.RandomXShift(1.0, config.MAX_LEFT_SHIFT, config.MIN_RIGHT_SHIFT),
        transform.Turn(),
        transform.RandomXShift(1.0, config.MAX_LEFT_SHIFT, config.MIN_RIGHT_SHIFT),
        transform.Turn(),
        transform.CenterBoxCrop(1.0),
        transform.RGBLevels(1.0),
        transform.Saturation(1.0),
        transform.Brightness(1.0),
        transform.GammaContrast(1.0),
        transform.CLAHE(1.0),
        transform.Equalize(1.0),
    ]
)

In [None]:
blur_transforms = transform.TransformSequence(
    [
        transform.ResetCell(),
        transform.Blur(1.0, 32, 0.5),
        transform.CenterBoxCrop(1.0),
    ]
)
# blur_example_cell = blur_transforms(example_cell, display=True)

In [None]:
hsd_transforms = transform.TransformSequence(
    [
        transform.ResetCell(),
        transform.CenterBoxCrop(1.0),
        transform.HSD(1.0, 1.0),
    ]
)

In [None]:
# demonstrate the transforms
# transformed_example_cell = always_transforms(example_cell, display=True)

In [None]:
# demonstrate the transforms
# hsd_example_cell = hsd_transforms(example_cell, display=True)

### Data Pipelines

In [None]:
training_blur = transform.Blur(1.0, 0, 1.0)

In [None]:
training_transforms = transform.TransformSequence(
    [
        transform.ResetCell(),
        training_blur,
        transform.RandomRotate(0.5),
        transform.RandomHFlip(0.2),
        transform.RandomVFlip(0.2),
        transform.RandomXScale(0.2, config.MAX_CROP, config.MAX_PAD),
        transform.Turn(),
        transform.RandomXScale(0.2, config.MAX_CROP, config.MAX_PAD),
        transform.Turn(),
        transform.RandomXShift(0.2, config.MAX_LEFT_SHIFT, config.MIN_RIGHT_SHIFT),
        transform.Turn(),
        transform.RandomXShift(0.2, config.MAX_LEFT_SHIFT, config.MIN_RIGHT_SHIFT),
        transform.Turn(),
        transform.CenterBoxCrop(1.0),
        transform.HSD(0.5, 0.2),
        transform.RGBLevels(0.2),
        transform.Saturation(0.2),
        transform.Brightness(0.2),
        transform.GammaContrast(0.2),
        transform.CLAHE(0.1),
        transform.Equalize(0.01),
    ]
)

In [None]:
validation_blur = transform.Blur(1.0, 0, 1.0)

In [None]:
# inference_blur = transform.Blur(1.0, 0, 1.0)

In [None]:
validation_transforms = transform.TransformSequence(
    [
        transform.ResetCell(),
        validation_blur,
        transform.CenterBoxCrop(1),
    ]
)

In [None]:
# inference_transforms = transform.TransformSequence(
#     [
#         transform.ResetCell(),
#         inference_blur,
#         transform.CenterBoxCrop(1),
#     ]
# )

In [None]:
# the training, validation and test datasets
training_dataset = glioma_training_data.training_split
validation_dataset = glioma_training_data.validation_split
test_dataset = glioma_test_data

In [None]:
# the collate function takes a list of cells (from the batch) and converts it to
# a list of cells, and a tensor of Xs and ys
def collate_fn(
    batch,
):
    # batch is a list of cells
    # extracts the Xs and ys
    X = [cell.X for cell in batch]
    y = [cell.y for cell in batch]
    # convert to numpy arrays
    X = np.stack(X)
    y = np.asarray(y)
    # format and shuffle the data dimensions
    X = torch.from_numpy(X).float().to(compute_device)
    y = torch.from_numpy(y).float().to(compute_device)
    X = X.permute(0, 3, 1, 2)
    return batch, X, y

In [None]:
# a function to do some final fixup on data fed to the model
def data_fixup(
    batch,
):
    cell, X, y = batch
    # data should be in the range [0, 1], we want [-1, 1]
    X = (X - 0.5) * 2.0
    # targets are -1 and 1, we want 0 and 1
    y = (y + 1.0) / 2.0
    return cell, X, y

In [None]:
# training pipeline
training_pipeline = training_dataset
training_pipeline = training_pipeline.shuffle()
training_pipeline = training_pipeline.map(training_transforms)
training_pipeline = training_pipeline.batch(config.BATCH_SIZE)
training_pipeline = training_pipeline.map(collate_fn)
training_pipeline = training_pipeline.map(data_fixup)

In [None]:
# training validation pipeline
training_validation_pipeline = training_dataset
training_validation_pipeline = training_validation_pipeline.map(validation_transforms)
training_validation_pipeline = training_validation_pipeline.batch(config.BATCH_SIZE)
training_validation_pipeline = training_validation_pipeline.map(collate_fn)
training_validation_pipeline = training_validation_pipeline.map(data_fixup)

In [None]:
# validation pipeline
validation_pipeline = validation_dataset
validation_pipeline = validation_pipeline.map(validation_transforms)
validation_pipeline = validation_pipeline.batch(config.BATCH_SIZE)
validation_pipeline = validation_pipeline.map(collate_fn)
validation_pipeline = validation_pipeline.map(data_fixup)

In [None]:
# test pipeline
test_pipeline = test_dataset
test_pipeline = test_pipeline.map(validation_transforms)
test_pipeline = test_pipeline.batch(config.BATCH_SIZE)
test_pipeline = test_pipeline.map(collate_fn)
test_pipeline = test_pipeline.map(data_fixup)

## Loaders

In [None]:
def loader_collate_fn(
    batch,
):
    return batch[0]

In [None]:
# training_loader = torch.utils.data.DataLoader(
#     training_pipeline,
#     batch_size=1,
#     num_workers=config.NUM_WORKERS,
#     # pin_memory=True,
#     collate_fn=loader_collate_fn,
#     persistent_workers=True,
# )

In [None]:
# training_validation_loader = torch.utils.data.DataLoader(
#     training_validation_pipeline,
#     batch_size=1,
#     num_workers=config.NUM_WORKERS,
#     # pin_memory=True,
#     collate_fn=loader_collate_fn,
#     persistent_workers=True,
# )

In [None]:
# validation_loader = torch.utils.data.DataLoader(
#     validation_pipeline,
#     batch_size=1,
#     num_workers=config.NUM_WORKERS,
#     # pin_memory=True,
#     collate_fn=loader_collate_fn,
#     persistent_workers=True,
# )

In [None]:
# test_loader = torch.utils.data.DataLoader(
#     test_pipeline,
#     batch_size=1,
#     num_workers=config.NUM_WORKERS,
#     # pin_memory=True,
#     collate_fn=loader_collate_fn,
#     persistent_workers=True,
# )

## Model

### HPV Classifier

In [None]:
# create the model
glioma_model = model.HPV_Classifier()

In [None]:
glioma_model.load_feature_extractor("model-000.state")

In [None]:
# load the model with default weights
glioma_model.load("model-001.state")

In [None]:
glioma_model = glioma_model.to(compute_device)

In [None]:
# we should freeze these by default !
glioma_model.freeze_features()
# glioma_model.unfreeze_features()

In [None]:
gc.collect()

In [None]:
# the model is giving up two outputs, these functions device how we interpret them

In [None]:
# ignore one of the outputs, sigmoid on the other one
def probabilities_from_model_1(
    model_output,
):
    # print(type(model_output))
    # first_output = model_output[:, 0]
    probabilities = torch.sigmoid(model_output)
    probabilities = probabilities.squeeze()

    return probabilities

In [None]:
# difference of the outputs is the class logits
# def probabilities_from_model_2(
#     model_output,
# ):
#     difference = model_output[:, 0] - model_output[:, 1]
#     probabilities = torch.sigmoid(difference)

#     return probabilities

In [None]:
# something completely different
# def probabilities_from_model_3(
#     model_output,
# ):
#     value = model_output[:, 0] * model_output[:, 1]
#     probabilities = torch.sigmoid(value)

#     return probabilities

### Optimiser

In [None]:
def make_optimiser_1(
    model,
    learning_rate=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY,
    amsgrad=config.AMS_GRAD,
):
    return torch.optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        amsgrad=amsgrad,
    )


# optimiser_1 = make_optimiser_1(glioma_model)

In [None]:
# def make_optimiser_2(
#     model,
#     learning_rate=config.LEARNING_RATE,
#     weight_decay=config.WEIGHT_DECAY,
#     momentum_decay=config.MOMENTUM_DECAY,
# ):
#     return torch.optim.NAdam(
#         model.parameters(),
#         lr=learning_rate,
#         betas=(0.9, 0.999),
#         eps=1e-08,
#         weight_decay=weight_decay,
#         momentum_decay=momentum_decay,
#     )


# optimiser_2 = make_optimiser_2(glioma_model)

In [None]:
# def make_optimiser_3(
#     model,
#     learning_rate=config.LEARNING_RATE,
#     momentum=config.SGD_MOMENTUM,
#     weight_decay=config.WEIGHT_DECAY,
# ):
#     return torch.optim.SGD(
#         model.parameters(),
#         lr=learning_rate,
#         momentum=momentum,
#         weight_decay=weight_decay,
#     )


# optimiser_3 = make_optimiser_3(glioma_model)

### Loss

In [None]:
# NOTE: for single classes,
# the actual_class is the same as the actual_class_probabilities

In [None]:
# binary cross entropy loss
def loss_landscale_1(
    predicted_class_probabilities: torch.Tensor,
    actual_class_probabilities: torch.Tensor,
) -> torch.Tensor:
    loss = torch.nn.functional.binary_cross_entropy(
        predicted_class_probabilities,
        actual_class_probabilities,
    )
    return loss

In [None]:
# standard L2 loss
def loss_landscale_2(
    predicted_class_probabilities: torch.Tensor,
    actual_class_probabilities: torch.Tensor,
) -> torch.Tensor:
    loss = torch.nn.functional.mse_loss(
        predicted_class_probabilities,
        actual_class_probabilities,
    )
    return loss

In [None]:
# standard L1 loss
def loss_landscale_3(
    predicted_class_probabilities: torch.Tensor,
    actual_class_probabilities: torch.Tensor,
) -> torch.Tensor:
    loss = torch.nn.functional.l1_loss(
        predicted_class_probabilities,
        actual_class_probabilities,
    )
    return loss

### Metrics

In [None]:
def calculate_metrics(
    predicted_class_probability,
    actual_class,
):
    predicted_class = (predicted_class_probability >= 0.5).astype(int)
    actual_class = actual_class.astype(int)

    # print("Predicted Class Probability")
    # print(predicted_class_probability)
    # print("Predicted Class")
    # print(predicted_class)
    # print("Actual Class")
    # print(actual_class)

    f1 = sklearn.metrics.f1_score(actual_class, predicted_class)
    accuracy = sklearn.metrics.accuracy_score(actual_class, predicted_class)
    precision = sklearn.metrics.precision_score(actual_class, predicted_class)
    recall = sklearn.metrics.recall_score(actual_class, predicted_class)
    roc_auc = sklearn.metrics.roc_auc_score(actual_class, predicted_class_probability)

    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "roc_auc": roc_auc,
    }

### Training, Validation and Inference

In [None]:
def do_training(
    epoch,
    model,
    optimiser,
    dataset,
    loss_fn,
    output_fn,
):
    # ensure model is in train mode
    model.train()

    batch_losses = []
    predicted_class_probabilities = []
    actual_classes = []

    batch_losses_1 = []
    batch_losses_2 = []
    batch_losses_3 = []

    tqdm_iterator = tqdm(
        enumerate(dataset),
        total=len(dataset),
        desc=f"{epoch:04d} Training _.________ :",
    )
    for i, (cell, X, actual_class) in tqdm_iterator:
        # forward pass
        optimiser.zero_grad()
        model_output = model(X)
        predicted_class_probability = output_fn(model_output)
        batch_loss = loss_fn(predicted_class_probability, actual_class)

        # calculate all three losses for comparison
        loss_1 = loss_landscale_1(predicted_class_probability, actual_class)
        loss_2 = loss_landscale_2(predicted_class_probability, actual_class)
        loss_3 = loss_landscale_3(predicted_class_probability, actual_class)

        # backward pass
        batch_loss.backward()

        # update weights
        optimiser.step()

        # log losses and predictions
        batch_losses.append(batch_loss)
        predicted_class_probabilities.append(predicted_class_probability)
        actual_classes.append(actual_class)
        batch_losses_1.append(loss_1)
        batch_losses_2.append(loss_2)
        batch_losses_3.append(loss_3)

        # save an example data batch
        if i == 0:
            torchvision.utils.save_image(
                X.detach(),
                f"../output/images/sample_training_data_epoch_{epoch:04d}.png",
                normalize=True,
            )

        # update the progress bar
        tqdm_iterator.set_description_str(
            f"{epoch:04d} Training {batch_loss.item():.8f} >"
        )

    # stick all the predictons and ground truths together
    predicted_class_probabilities = torch.cat(predicted_class_probabilities)
    actual_classes = torch.cat(actual_classes)

    # move them to the cpu
    predicted_class_probabilities = predicted_class_probabilities.numpy(force=True)
    actual_classes = actual_classes.numpy(force=True)

    # print("TRAINING")
    # print(f"{predicted_class_probability.shape=}")
    # print(f"{actual_class.shape=}")
    # print(f"{predicted_class_probabilities.shape=}")
    # print(f"{actual_classes.shape=}")

    # use them to calculate the training metrics
    metrics = calculate_metrics(predicted_class_probabilities, actual_classes)
    metrics["loss"] = torch.stack(batch_losses).mean().item()
    metrics["loss_1"] = torch.stack(batch_losses_1).mean().item()
    metrics["loss_2"] = torch.stack(batch_losses_2).mean().item()
    metrics["loss_3"] = torch.stack(batch_losses_3).mean().item()

    tqdm_iterator.set_description_str(f"{epoch:04d} Training {metrics["loss"]:.8f} =")
    tqdm_iterator.refresh()
    tqdm_iterator.close()

    return metrics

In [None]:
def do_validation(
    epoch,
    model,
    dataset,
    loss_fn,
    output_fn,
):
    # ensure model is in eval mode
    model.eval()

    batch_losses = []
    predicted_class_probabilities = []
    actual_classes = []

    batch_losses_1 = []
    batch_losses_2 = []
    batch_losses_3 = []

    with torch.no_grad():

        tqdm_iterator = tqdm(
            enumerate(dataset),
            total=len(dataset),
            desc=f"{epoch:04d} Validating _.________ :",
        )
        for i, (_, X, actual_class) in tqdm_iterator:
            # forward pass
            model_output = model(X)
            predicted_class_probability = output_fn(model_output)
            batch_loss = loss_fn(predicted_class_probability, actual_class)

            # calculate all three losses for comparison
            loss_1 = loss_landscale_1(predicted_class_probability, actual_class)
            loss_2 = loss_landscale_2(predicted_class_probability, actual_class)
            loss_3 = loss_landscale_3(predicted_class_probability, actual_class)

            # log losses and predictions
            batch_losses.append(batch_loss)
            predicted_class_probabilities.append(predicted_class_probability)
            actual_classes.append(actual_class)
            batch_losses_1.append(loss_1)
            batch_losses_2.append(loss_2)
            batch_losses_3.append(loss_3)

            # save an example data batch
            if epoch == 0 and i == 0:
                torchvision.utils.save_image(
                    X.detach(),
                    f"../output/images/sample_validation_data_epoch_{epoch:04d}.png",
                    normalize=True,
                )

            # update the progress bar
            tqdm_iterator.set_description_str(
                f"{epoch:04d} Validating {batch_loss.item():.8f} >"
            )

        # stick all the predictons and ground truths together
        predicted_class_probabilities = torch.cat(predicted_class_probabilities)
        actual_classes = torch.cat(actual_classes)

        # move them to the cpu
        predicted_class_probabilities = predicted_class_probabilities.numpy(force=True)
        actual_classes = actual_classes.numpy(force=True)

        # print("VALIDATION")
        # print(f"{predicted_class_probability.shape=}")
        # print(f"{actual_class.shape=}")
        # print(f"{predicted_class_probabilities.shape=}")
        # print(f"{actual_classes.shape=}")

        # use them to calculate the training metrics
        metrics = calculate_metrics(predicted_class_probabilities, actual_classes)
        metrics["loss"] = torch.stack(batch_losses).mean().item()
        metrics["loss_1"] = torch.stack(batch_losses_1).mean().item()
        metrics["loss_2"] = torch.stack(batch_losses_2).mean().item()
        metrics["loss_3"] = torch.stack(batch_losses_3).mean().item()

        tqdm_iterator.set_description_str(
            f"{epoch:04d} Validating {metrics["loss"]:.8f} ="
        )
        tqdm_iterator.refresh()
        tqdm_iterator.close()

    return metrics

In [None]:
def do_inference(
    model,
    dataset,
    output_fn,
    epoch=None,
):
    # ensure model is in eval mode
    model.eval()

    # get the last submission file generated if there is one

    last_predictions = []
    predictions = []

    with torch.no_grad():

        prefix = ""
        if epoch is not None:
            prefix = f"{epoch:04d} "

        for i, (cell, X, _) in tqdm(
            enumerate(dataset),
            total=len(dataset),
            desc=f"{prefix}Infering :",
        ):
            # forward pass
            model_output = model(X)
            predicted_class_probability = output_fn(model_output)

            predicted_class = predicted_class_probability.numpy(force=True) >= 0.5
            predicted_class = predicted_class.astype(int)

            for cell, prediction in zip(cell, predicted_class):
                cell.mitosis = prediction
                predictions.append(f"{cell.image.label},{cell.label},{prediction}")

            # save an example data batch
            if epoch == 0 and i == 0:
                torchvision.utils.save_image(
                    X.detach(),
                    f"../output/images/sample_inference_data_epoch_{epoch:04d}.png",
                    normalize=True,
                )

    predictions = sorted(predictions)

    # compare to the last predictions

    if epoch is not None and epoch > 0:

        # load the last predictions
        last_filename = f"../output/submissions/epoch_{epoch-1:04d}.csv"
        with open(last_filename, "r") as csv_file:
            last_predictions = csv_file.readlines()
            # drop the header
            last_predictions = last_predictions[1:]
            last_predictions = [line.strip() for line in last_predictions]

        # compare the predictions
        mitotic = 0
        same = 0
        different = 0
        for i, (last, current) in enumerate(zip(last_predictions, predictions)):
            if last == f"{i+1},{current}":
                same += 1
            else:
                different += 1
            if current.split(",")[2] == "1":
                mitotic += 1
        print(
            f"{epoch-1:04d} Inference has {same} same and {different} different predictions - {mitotic} mitotic."
        )

    # write out the results

    filename = f"../output/submission.csv"
    if epoch is not None:
        filename = f"../output/submissions/epoch_{epoch:04d}.csv"

    with open(filename, "w") as csv_file:
        # write the header
        csv_file.write("Row ID,Image ID,Label ID,Prediction\n")
        # write the predictions
        for i, prediction in enumerate(predictions):
            csv_file.write(f"{i+1},{prediction}\n")

    return

### Training Logs

In [None]:
# LOGS
training_log = []

In [None]:
# reload the training logs if they have been saved so we can continue training
training_log = None
try:
    training_log = pd.read_csv(config.TRAINING_LOG)
except FileNotFoundError:
    pass

if training_log is None:
    training_log = []
else:
    # training_log = training_log.drop(columns="Unnamed: 0")
    # convert back to a list
    training_log = training_log.to_dict("records")
    # convert the values to a list [epoch, subset, metric, value]]
    training_log = [list(row.values()) for row in training_log]

In [None]:
training_log

In [None]:
# tells us the last completed epoch by checking the training log
def last_completed_epoch():
    completed_epochs = [row[0] for row in training_log]
    if len(completed_epochs) == 0:
        return -1
    return max(completed_epochs)

In [None]:
# check if we can reload the model from saved state from training
def reload_latest_model():
    last_epoch = last_completed_epoch()
    if last_epoch == -1:
        return
    glioma_model.load(f"../output/models/{last_epoch:04d}.state")


reload_latest_model()

In [None]:
def plot_simple_metric(
    metric: str,
    metrics: pd.DataFrame,
):
    metric_df = metrics[metrics["metric"] == metric]
    ax = sns.lineplot(
        data=metric_df,
        x="epoch",
        y="value",
        hue="subset",
    )
    ax.legend(loc="upper left", fontsize=8)
    figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/{metric}.png", dpi=300)
    plt.close(figure)

In [None]:
def plot_loss_metric(
    metric: str,
    metrics: pd.DataFrame,
):
    metric_df = metrics[metrics["metric"] == metric]
    ax = sns.lineplot(
        data=metric_df,
        x="epoch",
        y="value",
        hue="subset",
    )
    ax.legend(loc="upper left", fontsize=8)
    figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/{metric}.png", dpi=300)
    # create a log version
    ax.set_yscale("log")
    # figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/{metric}_log.png", dpi=300)
    plt.close(figure)

In [None]:
def plot_all_loss_metric(
    metrics: pd.DataFrame,
):
    metric_df = metrics[metrics["metric"].str.startswith("loss")]
    ax = sns.lineplot(
        data=metric_df,
        x="epoch",
        y="value",
        hue="subset",
        style="metric",
    )
    ax.legend(loc="upper left", fontsize=8)
    figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/all_loss.png", dpi=300)
    # create a log version
    ax.set_yscale("log")
    # figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/all_loss_log.png", dpi=300)
    # start from epoch 400
    plt.xlim(left=400)
    figure.savefig(f"../output/all_loss_log_400.png", dpi=300)
    plt.close(figure)

In [None]:
def plot_near_one_metric(
    metric: str,
    metrics: pd.DataFrame,
):
    metric_df = metrics[metrics["metric"] == metric]
    ax = sns.lineplot(
        data=metric_df,
        x="epoch",
        y="value",
        hue="subset",
    )
    ax.legend(loc="upper left", fontsize=8)
    figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/{metric}.png", dpi=300)
    # create a zoomed in version near 1
    ax.set_ylim(0.9, 1.0)
    # figure = ax.figure
    # figure.legend()
    figure.savefig(f"../output/{metric}_zoom.png", dpi=300)
    metric_df = metrics[metrics["metric"] == metric]
    current_epoch = last_completed_epoch()
    start_epoch = current_epoch - 10
    if start_epoch < 0:
        start_epoch = 0
    min_value = metric_df[metric_df["epoch"] >= start_epoch]["value"].min()
    max_value = metric_df[metric_df["epoch"] >= start_epoch]["value"].max()
    ax.set_ylim(min_value, max_value)
    figure.savefig(f"../output/{metric}_zoom_auto.png", dpi=300)
    plt.close(figure)

In [None]:
def update_plots(
    df: pd.DataFrame,
):

    plot_all_loss_metric(df)

    plot_loss_metric("loss", df)
    plot_loss_metric("loss_1", df)
    plot_loss_metric("loss_2", df)
    plot_loss_metric("loss_3", df)

    plot_near_one_metric("f1", df)
    plot_near_one_metric("accuracy", df)
    plot_near_one_metric("precision", df)
    plot_near_one_metric("recall", df)
    plot_near_one_metric("roc_auc", df)

## Go !

### The training ...

In [None]:
def train(
    model,
    optimiser,
    loss_fn,
    output_fn,
    epochs=config.EPOCHS,
):

    first_epoch = last_completed_epoch() + 1
    last_epoch = first_epoch + epochs

    for epoch in range(first_epoch, last_epoch):

        # train the model
        training_metrics = do_training(
            epoch,
            model,
            optimiser,
            training_pipeline,  # training_validation_pipeline,
            loss_fn,
            output_fn,
        )

        # train the model
        training_metrics = do_training(
            epoch,
            model,
            optimiser,
            training_validation_pipeline,
            loss_fn,
            output_fn,
        )

        # really we should be doing this to get consistent metrics for
        # the training dataset
        training_dataset_metrics = do_validation(
            epoch,
            model,
            training_validation_pipeline,
            loss_fn,
            output_fn,
        )

        validation_dataset_metrics = do_validation(
            epoch,
            model,
            validation_pipeline,
            loss_fn,
            output_fn,
        )

        do_inference(
            model,
            test_pipeline,
            output_fn,
            epoch=epoch,
        )

        # record all the metrics in the training log

        for metric_name, metric_value in training_metrics.items():
            training_log.append([epoch, "training", metric_name, metric_value])

        for metric_name, metric_value in training_dataset_metrics.items():
            training_log.append([epoch, "training_dataset", metric_name, metric_value])

        for metric_name, metric_value in validation_dataset_metrics.items():
            training_log.append(
                [epoch, "validation_dataset", metric_name, metric_value]
            )

        # save the training log and model state

        df = pd.DataFrame(training_log, columns=["epoch", "subset", "metric", "value"])
        df.to_csv(config.TRAINING_LOG, index=False)

        model.save(f"../output/models/{epoch:04d}.state")

        # update the graphs from the training log
        update_plots(df)

        # gc.collect()

### Steps ...

In [None]:
do_inference(
    glioma_model,
    test_pipeline,
    probabilities_from_model_1,
    epoch=-1,
)

In [None]:
do_validation(
    0,
    glioma_model,
    validation_pipeline,
    loss_landscale_1,
    probabilities_from_model_1,
)

In [None]:
glioma_model.freeze_features()

In [None]:
# # run with optimiser_1, loss_landscale_1, very small learning rate
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.001)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=20,
)

In [None]:
# # run with optimiser_1, loss_landscale_1
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0005)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=20,
)

In [None]:
# # run with optimiser_1, loss_landscale_1
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0001)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=20,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.00005)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=20,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0001)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=20,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0001)
for i in range(1, 11):
    training_blur.fade = i / 10
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=10,
    )

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.00005)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=10,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0001)
training_blur.fade = 1.0
for i in reversed([0, 1, 2, 4, 8, 16]):
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=10,
    )

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.0001)
train(
    glioma_model,
    optimiser_1,
    loss_landscale_1,
    probabilities_from_model_1,
    epochs=10,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.00005)
train(
    glioma_model,
    optimiser_1,
    loss_landscalpe_1,
    probabilities_from_model_1,
    epochs=10,
)

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.00001)

In [None]:
for i in [1, 2, 4, 8, 16, 32]:
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=10,
    )

In [None]:
for i in reversed([1, 2, 4, 6, 8, 10, 12, 16]):
    validation_blur.kernel_size = i
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=20,
    )

In [None]:
for i in reversed([1, 2, 4, 6, 8, 10, 12, 14, 16]):
    validation_blur.kernel_size = i
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=10,
    )

In [None]:
optimiser_1 = make_optimiser_1(glioma_model, learning_rate=0.00001)

In [None]:
glioma_model.unfreeze_features()

In [None]:
for i in reversed([1, 2, 4, 6, 8, 10, 12, 14, 16]):
    validation_blur.kernel_size = i
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=10,
    )

In [None]:
for i in reversed([1]):
    validation_blur.kernel_size = i
    training_blur.kernel_size = i
    train(
        glioma_model,
        optimiser_1,
        loss_landscale_1,
        probabilities_from_model_1,
        epochs=50,
    )