### Usage within repository

In [4]:
import sys
sys.path.append('../../../')

### Mount Google drive when using Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Usage with Python wheel (e.g. when using Google Colab)

In [2]:
# pip install ../../../artefacts/bilevel_optimisation-1.0.0-py3-none-any.whl

Processing /home/florianthaler/Documents/research/bilevel_optimisation/artefacts/bilevel_optimisation-1.0.0-py3-none-any.whl
Installing collected packages: bilevel-optimisation
Successfully installed bilevel-optimisation-1.0.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import torch
from torch.utils.data import DataLoader

In [5]:
from bilevel_optimisation.bilevel.Bilevel import Bilevel
from bilevel_optimisation.data.SolverSpec import SolverSpec
from bilevel_optimisation.dataset.ImageDataset import TestImageDataset, TrainingImageDataset
from bilevel_optimisation.evaluation.Evaluation import evaluate_on_test_data
from bilevel_optimisation.factories.BuildFactory import build_solver_factory
from bilevel_optimisation.utils.DatasetUtils import collate_function
from bilevel_optimisation.utils.SetupUtils import set_up_regulariser, load_optimiser_class, set_up_measurement_model
from bilevel_optimisation.utils.SetupUtils import set_up_inner_energy, set_up_outer_loss, set_up_bilevel_problem
from bilevel_optimisation.utils.ConfigUtils import load_app_config
from bilevel_optimisation.utils.FileSystemUtils import save_foe_model
from bilevel_optimisation.utils.SeedingUtils import seed_random_number_generators
from bilevel_optimisation import solver
from bilevel_optimisation.visualisation.Visualisation import visualise_training_stats, visualise_filter_stats

In [6]:
seed_random_number_generators(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

### Setup training dataset

In [7]:
train_data_root_dir = '/home/florianthaler/Documents/data/image_data/BSDS300/images/train'
train_image_dataset = TrainingImageDataset(root_path=train_data_root_dir, dtype=dtype)
batch_size = 32
crop_size = 64
train_loader = DataLoader(train_image_dataset, batch_size=batch_size, collate_fn=lambda x: collate_function(x, crop_size=crop_size))

### Setup test dataset

In [8]:
test_data_root_dir = '/home/florianthaler/Documents/data/image_data/some_images'
test_image_dataset = TestImageDataset(root_path=test_data_root_dir, dtype=dtype)
test_loader = DataLoader(test_image_dataset, batch_size=len(test_image_dataset), shuffle=False,
                         collate_fn=lambda x: collate_function(x, crop_size=-1))

### Setup regulariser

In [9]:
custom_config_dir_path = 'example_training_I'
    # custom_config_dir_path = /home/florianthaler/Documents/research/bilevel_optimisation/bilevel_optimisation/config_data/custom/training_config

app_name = 'bilevel_optimisation'
configuring_module = '[DENOISING] train'
config = load_app_config(app_name, custom_config_dir_path, configuring_module)

In [8]:
regulariser = set_up_regulariser(config)

### Setup bilevel instance

In [8]:
bilevel = set_up_bilevel_problem(regulariser.parameters(), config)

### Initial test loss and psnr

In [9]:
psnr, test_loss = evaluate_on_test_data(test_loader, regulariser, config, device, dtype, -1, path_to_data_dir=None)

test_loss_list = [test_loss]
psnr_list = [psnr]

### Training loop

In [10]:
train_loss_list = []
filters_list = []
filter_weights_list = []

evaluation_freq = 2
max_num_iterations = 4

path_to_eval_dir = '/home/florianthaler/Documents/data/evaluation/foe_bilevel_denoising/eval'

In [11]:
for k, batch in enumerate(train_loader):

    batch_ = batch.to(device=device, dtype=dtype)
    with torch.no_grad():
        measurement_model = set_up_measurement_model(batch_, config)
        inner_energy = set_up_inner_energy(measurement_model, regulariser, config)
        inner_energy = inner_energy.to(device=device, dtype=dtype)

        outer_loss = set_up_outer_loss(batch_, config)
        train_loss = bilevel.forward(outer_loss, inner_energy)

        train_loss_list.append(train_loss.detach().cpu().item())
        filters_list.append(regulariser.get_filters())
        filter_weights_list.append(regulariser.get_filter_weights())
        print('[TRAIN] iteration [{:d} / {:d}]: '
              'loss = {:.5f}'.format(k + 1, max_num_iterations, train_loss.detach().cpu().item()))

        if (k + 1) % evaluation_freq == 0:
            print('[TRAIN] evaluate on test dataset')

            psnr, test_loss = evaluate_on_test_data(test_loader, regulariser, config, device,
                                                    dtype, k, path_to_eval_dir)
            print('[TRAIN] denoised test images')
            print('[TRAIN]   > average psnr: {:.5f}'.format(psnr))
            print('[TRAIN]   > test loss: {:.5f}'.format(test_loss))

            psnr_list.append(psnr)
            test_loss_list.append(test_loss)

        if (k + 1) == max_num_iterations:
            print('[TRAIN] reached maximal number of iterations')
            break
        else:
            k += 1

[TRAIN] iteration [1 / 4]: loss = 239.41830
[TRAIN] iteration [2 / 4]: loss = 107.70458
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 29.27701
[TRAIN]   > test loss: 182.36945
[TRAIN] iteration [3 / 4]: loss = 108.14030
[TRAIN] iteration [4 / 4]: loss = 97.03108
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 29.31618
[TRAIN]   > test loss: 180.73187
[TRAIN] reached maximal number of iterations


### Store trained model and visualisations of training results

In [12]:
save_foe_model(regulariser, os.path.join(path_to_eval_dir, 'models'), model_dir_name='final')

visualise_training_stats(train_loss_list, test_loss_list, psnr_list, evaluation_freq, path_to_eval_dir)
visualise_filter_stats(filters_list, filter_weights_list, path_to_eval_dir)