<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Setting-up-imports" data-toc-modified-id="Setting-up-imports-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Setting up imports</a></span></li><li><span><a href="#Setting-up-Constant-Hyperparameters" data-toc-modified-id="Setting-up-Constant-Hyperparameters-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Setting up Constant Hyperparameters</a></span></li><li><span><a href="#Setting-up-Parameters-and-Functions-for-Training" data-toc-modified-id="Setting-up-Parameters-and-Functions-for-Training-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Setting up Parameters and Functions for Training</a></span><ul class="toc-item"><li><span><a href="#Hyperparameters-Search-Space" data-toc-modified-id="Hyperparameters-Search-Space-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Hyperparameters Search Space</a></span></li><li><span><a href="#Creating-the-training-function" data-toc-modified-id="Creating-the-training-function-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Creating the training function</a></span></li><li><span><a href="#Creating-the-evaluation-function" data-toc-modified-id="Creating-the-evaluation-function-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Creating the evaluation function</a></span></li></ul></li><li><span><a href="#Running-the-training" data-toc-modified-id="Running-the-training-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Running the training</a></span><ul class="toc-item"><li><span><a href="#Loading-data-for-training" data-toc-modified-id="Loading-data-for-training-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Loading data for training</a></span></li><li><span><a href="#Configuring-the-Tuner-with-a-Scheduler-and-a-Search-Algorithm" data-toc-modified-id="Configuring-the-Tuner-with-a-Scheduler-and-a-Search-Algorithm-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Configuring the Tuner with a Scheduler and a Search Algorithm</a></span></li><li><span><a href="#Running-the-Tuner" data-toc-modified-id="Running-the-Tuner-4.3"><span class="toc-item-num">4.3&nbsp;&nbsp;</span>Running the Tuner</a></span></li></ul></li><li><span><a href="#Evaluating-the-best-Results" data-toc-modified-id="Evaluating-the-best-Results-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Evaluating the best Results</a></span></li></ul></div>

# Setting up imports

In [None]:
import os
from itertools import product

import torch
from torch.nn import CrossEntropyLoss, Sequential
from torch.nn.functional import normalize
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop, Resize, RandomCrop, GaussianBlur
from torchvision.utils import save_image

import ray
from ray import tune
from ray.air import session, RunConfig, CheckpointConfig
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search import ConcurrencyLimiter


from dataset import POCDataReader, POCvsCS9DataReader, data_augment_, POCDataset
from metrics import Metrics, EvaluationMetrics
from models import UNet, DeepCrack, SubUNet, DenSubUNet
from loss import *
from pipelines import InputPipeline
from pipelines.filters import CrackBinaryFilter, BGBinaryFilter, SequenceFilters, SumFilters
# from pipelines.filters.small_kernel import FrangiFilter, LaplacianFilter, SatoFilter, SobelFilter
from pipelines.filters.medium_kernel import FrangiFilter, LaplacianFilter, SatoFilter, SobelFilter
# from pipelines.filters.large_kernel import FrangiFilter, LaplacianFilter, SatoFilter, SobelFilter
from train import training_loop, validation_loop
from train_tqdm import evaluation_loop


# Setting up Constant Hyperparameters

In [None]:
EPOCHS = 40
NUM_SAMPLES = 1

NUM_AUGMENT = 1

LOAD_DATA_ON_GPU = True
GPUS_PER_TRIAL = 1
CPUS_PER_TRIAL = 20

##### Selecting Cuda device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Setting up Parameters and Functions for Training

## Hyperparameters Search Space

##### Preload Losses Functions
Get a list of all loss function per type (Pixel/Volume) for grid or random search

In [None]:
pixel_loss_list = [
    CrossEntropyLoss(weight=torch.tensor([.65, .35])),
    FocalLoss(weight=torch.tensor([.65, .35]), gamma=1.4),
]

volume_loss_list = [
    JaccardLoss(),
    TverskyLoss(alpha=0.3, beta=0.7),
    FocalTverskyLoss(alpha=0.3, beta=0.7, gamma=1.4),
]

##### Preload Pipeline
Get a list of all possible filter to apply in the Pipeline

In [None]:
filter_list = [normalize]

layer_list = [
    None,
    SobelFilter(),
    LaplacianFilter(threshold=0.75),
    FrangiFilter(),
    SatoFilter(),
    SumFilters(FrangiFilter(), SatoFilter()),
]

##### Search Space
A dict containing all hyperparameters that we want to analyse and try (also put constant ones in)

In [None]:
search_space = {
    "Network": tune.grid_search([UNet, DeepCrack, SubUNet]),
    "Optimizer": Adam,

    "Learning Rate": 1e-4, #tune.loguniform(1e-6, 1e-3),
    "Batch Size": 16,

    "Loss Combiner": tune.grid_search([MeanLoss, PixelLoss, VolumeLoss]),
    "Loss Combiner_ratio": tune.grid_search([0, .25, .5, .75, 1]),
    "Loss Volume": tune.grid_search(volume_loss_list),
    "Loss Pixel": tune.grid_search(pixel_loss_list),

    "Pipe Filter": normalize,
    "Pipe Layer": tune.grid_search(layer_list),
}

## Creating the training function

In [None]:
def train(config, train_data, val_data):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    inpip = InputPipeline(
        filter=config["Pipe Filter"],
        additional_channel=config["Pipe Layer"])
    if LOAD_DATA_ON_GPU:
        inpip = inpip.to(device)

    train_dataset = POCDataset(
        data=train_data,
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)
    train_dataset.precompute_transform()

    if LOAD_DATA_ON_GPU:
        training_dataloader = DataLoader(
            train_dataset,
            batch_size=int(config["Batch Size"]),
            sampler=train_dataset.sampler,
            shuffle= True if train_dataset.sampler is None else None,
        )
    else:
        training_dataloader = DataLoader(
            train_dataset,
            batch_size=int(config["Batch Size"]),
            sampler=train_dataset.sampler,
            shuffle= True if train_dataset.sampler is None else None,
            num_workers=CPUS_PER_TRIAL//2,
            pin_memory=True,
            pin_memory_device=device)

    val_dataset = POCDataset(
        data=val_data, 
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)
    val_dataset.precompute_transform()

    if LOAD_DATA_ON_GPU:
        validation_dataloader = DataLoader(
            val_dataset,
            batch_size=int(config["Batch Size"]),
            shuffle=True)
    else:
        validation_dataloader = DataLoader(
            val_dataset,
            batch_size=int(config["Batch Size"]),
            shuffle=True,
            num_workers=CPUS_PER_TRIAL//2,
            pin_memory=True,
            pin_memory_device=device)

    model = config["Network"](n_channels=inpip.nb_channel, n_classes=2)
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device)

    loss_fn = MultiscaleLoss(config["Loss Combiner"](
        config["Loss Pixel"],
        config["Loss Volume"],
        ratio=config["Loss Combiner_ratio"])).to(device)

    optimizer = config["Optimizer"](model.parameters(), lr=config["Learning Rate"], betas=(0.9, 0.99))
    lr_scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS//3)

    loaded_checkpoint = session.get_checkpoint()
    if loaded_checkpoint:
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            model_state, optimizer_state, scheduler_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
        lr_scheduler.load_state_dict(scheduler_state)

    train_metrics = Metrics(
        buffer_size=len(training_dataloader),
        mode="Training",
        hyperparam=config,
        device=device)

    val_metrics = Metrics(
        buffer_size=len(validation_dataloader),
        mode="Validation",
        hyperparam=config,
        device=device)


    for epoch in range(1, EPOCHS+1):  # loop over the dataset multiple times
        training_loop(epoch, training_dataloader, model, loss_fn, optimizer, lr_scheduler, train_metrics, device)
        validation_loop(epoch, validation_dataloader, model, loss_fn, val_metrics, device)

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and can be accessed through `session.get_checkpoint()`
        # API in future iterations.
        os.makedirs("model", exist_ok=True)
        torch.save((model.state_dict(), optimizer.state_dict(), lr_scheduler.state_dict()), "model/checkpoint.pt")
        checkpoint = Checkpoint.from_directory("model")
        session.report(metrics=val_metrics.get_metrics(epoch), checkpoint=checkpoint)

    train_metrics.close_tensorboard()
    val_metrics.close_tensorboard()


## Creating the evaluation function

In [None]:
def evaluate(test_data, result):

    if not result.best_checkpoints:
        return None

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    inpip = InputPipeline(
        filter=result.config["Pipe Filter"],
        additional_channel=result.config["Pipe Layer"])
    if LOAD_DATA_ON_GPU:
        inpip = inpip.to(device)

    test_dataset = POCDataset(
        test_data,
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)
    
    if LOAD_DATA_ON_GPU:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    else:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=2*CPUS_PER_TRIAL, pin_memory=True, pin_memory_device=device)

    best_trained_model = result.config["Network"](n_channels=inpip.nb_channel, n_classes=2).to(device)

    checkpoint_path = os.path.join(result.best_checkpoints[0][0].to_directory(), "checkpoint.pt")
    model_state, _, _ = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    test_metrics = EvaluationMetrics(
        buffer_size=len(evaluation_dataloader),
        hyperparam=result.config,
        epochs=result.best_checkpoints[0][1]["Epoch"], # True epoch of the best run
        device=device)

    evaluation_loop(dataloader=evaluation_dataloader, model=best_trained_model, metric=test_metrics, device=device)


# Running the training

## Loading data for training

##### Loading POC2 dataset for analysis

In [None]:
data_reader = POCDataReader(root_dir="../data/POC", load_on_gpu=False, verbose=True)
train_data, val_data, test_data = data_reader.split()

train_data = data_augment_(train_data, n=NUM_AUGMENT, load_on_gpu=False, verbose=True, seed=1234)

##### OR Loading POC2 or CS9 dataset for training comparison

In [None]:
# data_reader = POCvsCS9DataReader(root_dir="../data/POCvsCS9", dataset="cs9", load_on_gpu=False, verbose=True)
# train_data, val_data, test_data = data_reader.split()

## Configuring the Tuner with a Scheduler and a Search Algorithm (Using the ray tune library)

##### Create a new Tune experience

In [None]:
# scheduler = ASHAScheduler(max_t=EPOCHS, grace_period=2, reduction_factor=2)
# search_algo = HyperOptSearch()
# search_algo = OptunaSearch()

tune_config = tune.TuneConfig(
    metric="CrackIoU",
    mode="max",
    num_samples=NUM_SAMPLES,
#     scheduler=scheduler,
#     search_alg=search_algo,
    max_concurrent_trials=4,
)

tune_trainable = tune.with_resources(
    trainable=tune.with_parameters(train, train_data=train_data, val_data=val_data),
    resources={"cpu": CPUS_PER_TRIAL, "gpu": GPUS_PER_TRIAL}
)

tuner = tune.Tuner(
    trainable=tune_trainable,
    tune_config=tune_config,
    param_space=search_space,
    run_config=RunConfig(
        local_dir="~/POC-Project/ray_results/",
        checkpoint_config=CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute="CrackIoU",
            checkpoint_score_order="max",
            checkpoint_at_end=False)))

##### Or load a previous one from disk (with its results)

In [None]:
# tuner = tune.Tuner.restore(
#     path="/home/pirl/POC-Project/ray_results/train_2023-06-30_18-36-49_small_kernel/",
#     trainable=None, #tune_trainable,
#     resume_unfinished=False,
#     resume_errored=False,
#     restart_errored=False,
# )

# results = tuner.get_results()

## Running the Tuner

In [None]:
results = tuner.fit()

# Evaluating the best Results

In [None]:
best_result = results.get_best_result(metric="CrackIoU", mode="max", scope="all")  # Get best result object
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(best_result.metrics["Loss"]))
print("Best trial final validation CrackIoU: {}".format(best_result.metrics["CrackIoU"]))

for result in results:
    evaluate(test_data=test_data, result=result)

## Gathering activation maps & predictions from the Network

In [None]:
from my_utils import show_img
    
def print_activation_map(result, test_data):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    inpip = InputPipeline(
        filter=result.config["Pipe Filter"],
        additional_channel=result.config["Pipe Layer"])
    if LOAD_DATA_ON_GPU:
        inpip = inpip.to(device)

    test_dataset = POCDataset(
        test_data,
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)

    if LOAD_DATA_ON_GPU:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    else:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2*CPUS_PER_TRIAL, pin_memory=True, pin_memory_device=device)

    images = next(iter(evaluation_dataloader))[0]

    best_trained_model = result.config["Network"](n_channels=inpip.nb_channel, n_classes=2).to(device)

    checkpoint_path = os.path.join(result.best_checkpoints[0][0].to_directory(), "checkpoint.pt")
    model_state, _, _ = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)
    
    best_trained_model.eval()
    first_block = best_trained_model.encoder.block1
    
    with torch.inference_mode():
        activation_map = first_block(images)[0].sum(dim=1, keepdim=True)

    show_img(images)
    show_img(activation_map)

In [None]:
from my_utils import show_img
    
def print_prediction_proba(result, test_data):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    inpip = InputPipeline(
        filter=result.config["Pipe Filter"],
        additional_channel=result.config["Pipe Layer"])
    if LOAD_DATA_ON_GPU:
        inpip = inpip.to(device)

    test_dataset = POCDataset(
        test_data,
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)

    if LOAD_DATA_ON_GPU:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    else:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2*CPUS_PER_TRIAL, pin_memory=True, pin_memory_device=device)

    images = next(iter(evaluation_dataloader))[0]

    best_trained_model = result.config["Network"](n_channels=inpip.nb_channel, n_classes=2).to(device)

    checkpoint_path = os.path.join(result.best_checkpoints[0][0].to_directory(), "checkpoint.pt")
    model_state, _, _ = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)
    
    best_trained_model.eval()
    with torch.inference_mode():
        prediction_proba = best_trained_model(images)

        heatmap = images.clone().detach()
        heatmap /= heatmap.max()
        heatmap[:,1] += .1 * prediction_proba[:,1]

    show_img(heatmap)
    show_img(prediction_proba[:,1:])
    show_img(prediction_proba.argmax(dim=1, keepdim=True))

In [None]:

dataset = 'cs9'

def get_visualisation_files(result, test_data):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    inpip = InputPipeline(
        filter=result.config["Pipe Filter"],
        additional_channel=result.config["Pipe Layer"])
    if LOAD_DATA_ON_GPU:
        inpip = inpip.to(device)

    test_dataset = POCDataset(
        test_data,
        transform=inpip,
        target_transform=None,
        negative_mining=False,
        load_on_gpu=LOAD_DATA_ON_GPU)

    if LOAD_DATA_ON_GPU:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    else:
        evaluation_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2*CPUS_PER_TRIAL, pin_memory=True, pin_memory_device=device)

    best_trained_model = result.config["Network"](n_channels=inpip.nb_channel, n_classes=2).to(device)

    checkpoint_path = os.path.join(result.best_checkpoints[0][0].to_directory(), "checkpoint.pt")
    model_state, _, _ = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)
    
    best_trained_model.eval()
    first_block = best_trained_model.encoder.block1

    # data_iterator = iter(evaluation_dataloader)
    # list_first_data = [next(data_iterator) for _ in range(10)]
    for item in evaluation_dataloader:
        image = item[0]; label = item[1]; fname = item[2][0]
        fpath = f"../imgs_net/{fname}"

        if not os.path.exists(fpath):
            os.makedirs(fpath+"/activation_maps")
            os.makedirs(fpath+"/heatmap")
            os.makedirs(fpath+"/mask_bin")

        if not os.path.exists(fpath + "/img.png"):
            img = image[0,:3].clone()
            img -= img.min(); img /= img.max()
            # print(img.size(), img.unique())
            save_image(img, f"{fpath}/img.png")
        if not os.path.exists(fpath + "/label.png"):
            label = label[0,1:].expand(3,-1,-1)
            # print(label.size(), label.unique())
            save_image(label, f"{fpath}/label.png")

        best_trained_model.eval()
        with torch.inference_mode():
            
            # if str(result.config['Pipe Layer']) != "None":
            #     save_image(image[0,3:].expand(3, -1, -1), f"{fpath}/filter/{dataset}.png")

            activation_map = first_block(image)[0].sum(dim=1, keepdim=True)[0].expand(3, -1, -1).clone()
            activation_map -= activation_map.min()
            activation_map /= activation_map.max()
            # print(activation_map.size(), activation_map.unique())
            save_image(activation_map, f"{fpath}/activation_maps/{dataset}.png")

            prediction_proba = best_trained_model(image)[0]
            
            heatmap = image.clone().detach()[0,:3]
            heatmap /= heatmap.max()
            heatmap[1] += .5 * prediction_proba[1]
            # print(heatmap.size(), heatmap.unique())
            save_image(heatmap, f"{fpath}/heatmap/{dataset}.png")
            
            mask_bin = prediction_proba.argmax(dim=0, keepdim=True).expand(3, -1, -1).float()
            # print(mask_bin.size(), mask_bin.unique())
            save_image(mask_bin, f"{fpath}/mask_bin/{dataset}.png")

In [None]:
for result in results:
    # print(f"{result.config['Loss Combiner']} - {result.config['Pipe Layer']}")
    # print_activation_map(result, test_data)
    # print_prediction_proba(result, test_data)
    get_visualisation_files(result, test_data)