In [1]:
import sys
sys.path.append('../../')

In [None]:
pip install ../../artefacts/bilevel_optimisation-0.1.0-py3-none-any.whl

In [2]:
import os
import torch
from torch.utils.data import DataLoader

In [20]:
from bilevel_optimisation.bilevel.Bilevel import Bilevel
from bilevel_optimisation.data.SolverSpec import SolverSpec
from bilevel_optimisation.dataset.ImageDataset import TestImageDataset, TrainingImageDataset
from bilevel_optimisation.evaluation.Evaluation import evaluate_on_test_data
from bilevel_optimisation.factories.BuildFactory import build_solver_factory
from bilevel_optimisation.utils.DatasetUtils import collate_function
from bilevel_optimisation.utils.SetupUtils import set_up_regulariser, load_optimiser_class, set_up_measurement_model
from bilevel_optimisation.utils.SetupUtils import set_up_inner_energy, set_up_outer_loss, set_up_bilevel_problem
from bilevel_optimisation.utils.ConfigUtils import load_configs
from bilevel_optimisation.utils.FileSystemUtils import save_foe_model
from bilevel_optimisation.utils.SeedingUtils import seed_random_number_generators
from bilevel_optimisation import solver
from bilevel_optimisation.visualisation.Visualisation import visualise_training_stats, visualise_filter_stats

In [4]:
seed_random_number_generators(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

### Setup training dataset

In [5]:
train_data_root_dir = '/home/florianthaler/Documents/data/image_data/BSDS300/images/train'
train_image_dataset = TrainingImageDataset(root_path=train_data_root_dir, dtype=dtype)
batch_size = 32
crop_size = 64
train_loader = DataLoader(train_image_dataset, batch_size=batch_size, collate_fn=lambda x: collate_function(x, crop_size=crop_size))

### Setup test dataset

In [6]:
test_data_root_dir = '../../data/test_images'
test_image_dataset = TestImageDataset(root_path=test_data_root_dir, dtype=dtype)
test_loader = DataLoader(test_image_dataset, batch_size=len(test_image_dataset), shuffle=False,
                         collate_fn=lambda x: collate_function(x, crop_size=-1))

### Setup regulariser

In [7]:
default_config_dir_path = '../../data/configs/default'
custom_config_dir_path = '../../data/configs/custom/example_training_I'

config = load_configs('[DENOISING] train', default_config_dir_path=default_config_dir_path,
                      custom_config_dir_path=custom_config_dir_path, configuring_module='train')

In [8]:
regulariser = set_up_regulariser(config)

### Setup bilevel instance

In [21]:
bilevel = set_up_bilevel_problem(regulariser.parameters(), config)

### Initial test loss and psnr

In [10]:
psnr, test_loss = evaluate_on_test_data(test_loader, regulariser, config, device, dtype, -1, path_to_data_dir=None)

test_loss_list = [test_loss]
psnr_list = [psnr]

### Training loop

In [23]:
train_loss_list = []
filters_list = []
filter_weights_list = []

evaluation_freq = 2
max_num_iterations = 20

path_to_eval_dir = '/home/florianthaler/Documents/data/evaluation/foe_bilevel_denoising/eval'

In [24]:
for k, batch in enumerate(train_loader):

    batch_ = batch.to(device=device, dtype=dtype)
    with torch.no_grad():
        measurement_model = set_up_measurement_model(batch_, config)
        inner_energy = set_up_inner_energy(measurement_model, regulariser, config)
        inner_energy = inner_energy.to(device=device, dtype=dtype)

        outer_loss = set_up_outer_loss(batch_, config)
        train_loss = bilevel.forward(outer_loss, inner_energy)

        train_loss_list.append(train_loss.detach().cpu().item())
        filters_list.append(regulariser.get_filters())
        filter_weights_list.append(regulariser.get_filter_weights())
        print('[TRAIN] iteration [{:d} / {:d}]: '
              'loss = {:.5f}'.format(k + 1, max_num_iterations, train_loss.detach().cpu().item()))

        if (k + 1) % evaluation_freq == 0:
            print('[TRAIN] evaluate on test dataset')

            psnr, test_loss = evaluate_on_test_data(test_loader, regulariser, config, device,
                                                    dtype, k, path_to_eval_dir)
            print('[TRAIN] denoised test images')
            print('[TRAIN]   > average psnr: {:.5f}'.format(psnr))
            print('[TRAIN]   > test loss: {:.5f}'.format(test_loss))

            psnr_list.append(psnr)
            test_loss_list.append(test_loss)

        if (k + 1) == max_num_iterations:
            print('[TRAIN] reached maximal number of iterations')
            break
        else:
            k += 1

[TRAIN] iteration [1 / 20]: loss = 103.95058
[TRAIN] iteration [2 / 20]: loss = 235.45084
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 26.38792
[TRAIN]   > test loss: 355.11237
[TRAIN] iteration [3 / 20]: loss = 167.14388
[TRAIN] iteration [4 / 20]: loss = 149.15738
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 27.67525
[TRAIN]   > test loss: 263.80737
[TRAIN] iteration [5 / 20]: loss = 136.16898
[TRAIN] iteration [6 / 20]: loss = 134.09842
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 28.21934
[TRAIN]   > test loss: 232.66138
[TRAIN] iteration [7 / 20]: loss = 131.24289
[TRAIN] iteration [8 / 20]: loss = 106.87519
[TRAIN] evaluate on test dataset
[TRAIN] denoised test images
[TRAIN]   > average psnr: 28.38552
[TRAIN]   > test loss: 223.92314
[TRAIN] iteration [9 / 20]: loss = 93.58693
[TRAIN] iteration [10 / 20]: loss = 149.70557
[TRAIN] evaluate on test datas

### Store trained model and visualisations of training results

In [25]:
save_foe_model(regulariser, os.path.join(path_to_eval_dir, 'models'), model_dir_name='final')

visualise_training_stats(train_loss_list, test_loss_list, psnr_list, evaluation_freq, path_to_eval_dir)
visualise_filter_stats(filters_list, filter_weights_list, path_to_eval_dir)