In [1]:
!pip uninstall pylopt -y

Found existing installation: PyLOpt 1.0.0
Uninstalling PyLOpt-1.0.0:
  Successfully uninstalled PyLOpt-1.0.0


In [3]:
!pip install ../../artefacts/pylopt-1.0.0-py3-none-any.whl

Processing /home/florianthaler/Documents/research/pylopt/artefacts/pylopt-1.0.0-py3-none-any.whl
Installing collected packages: pylopt
Successfully installed pylopt-1.0.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
import os
import torch
from torch.utils.data import DataLoader
from pathlib import Path

In [5]:
torch.manual_seed(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

### Setup training dataset

In [6]:
from pylopt.dataset.ImageDataset import TestImageDataset, TrainingImageDataset

train_data_root_dir = '/home/florianthaler/Documents/data/image_data/BSDS300/images/train'
train_image_dataset = TrainingImageDataset(root_path=train_data_root_dir, dtype=dtype)
batch_size = 128
crop_size = 64

### Setup test dataset

In [7]:
from pylopt.dataset.dataset_utils import collate_function

test_data_root_dir = '/home/florianthaler/Documents/data/image_data/some_images'
test_image_dataset = TestImageDataset(root_path=test_data_root_dir, dtype=dtype)
test_loader = DataLoader(test_image_dataset, batch_size=len(test_image_dataset), shuffle=False,
                         collate_fn=lambda x: collate_function(x, crop_size=-1))

### Setup regulariser (filters: non trainable, potentials: trainable)

In [8]:
PRETRAINED_FILTER_MODELS = {'chen-ranftl-pock_2014_scaled_7x7': 'filters_7x7_chen-ranftl-pock_2014_scaled.pt', 
                            'pylopt_2025_7x7_II': 'filters_7x7_pylopt_2025_II.pt'}
PRETRAINED_POTENTIAL_MODELS = {'pylopt_2025_student_t_I)': 'student_t_potential_2025_pylopt_I.pt'}

In [26]:
from pylopt.utils.file_system_utils import create_experiment_dir, get_repo_root_path
from pylopt.regularisers.fields_of_experts.ImageFilter import ImageFilter
from pylopt.regularisers.fields_of_experts.potential import StudentT
from pylopt.regularisers.fields_of_experts.FieldsOfExperts import FieldsOfExperts

repo_root_path = get_repo_root_path(Path().resolve())
image_filter = ImageFilter.from_file(os.path.join(repo_root_path, 
                                                  'data', 'model_data',
                                                  PRETRAINED_FILTER_MODELS['chen-ranftl-pock_2014_scaled_7x7']))

image_filter.freeze()           # freeze image_filter, i.e. parameters are not trainable
potential = StudentT(num_marginals=48, init_options={'mode': 'uniform', 'multiplier': 0.0001}, trainable=True)
regulariser = FieldsOfExperts(potential, image_filter)

### Setup lower-level solution method

In [10]:
method_lower = 'nag'
options_lower = {'max_num_iterations': 300, 'rel_tol': 1e-5, 'lip_const': 1e5, 'batch_optimisation': False}

### Setup BilevelOptimisation class

In [11]:
from pylopt.bilevel_problem import BilevelOptimisation

path_to_experiment_dir = os.path.join(os.getcwd(), 'experiment')
if not os.path.exists(path_to_experiment_dir):
    os.makedirs(path_to_experiment_dir)

bilevel_optimisation = BilevelOptimisation(method_lower, 
                                           options_lower, 
                                           operator=torch.nn.Identity(),
                                           noise_level=0.1,
                                           differentiation_method='hessian_free',
                                           path_to_experiments_dir=path_to_experiment_dir)

In [12]:
lam = 10

### Initial test loss and psnr

In [13]:
from pylopt.energy import Energy, MeasurementModel
from pylopt.lower_problem import solve_lower
from pylopt.utils.evaluation_utils import compute_psnr

test_batch_clean = list(test_loader)[0]
test_batch_clean = test_batch_clean.to(device=device, dtype=dtype)

measurement_model = MeasurementModel(test_batch_clean, torch.nn.Identity(), noise_level=0.1)
energy = Energy(measurement_model, regulariser, lam)
energy.to(device=device, dtype=dtype)
test_batch_noisy = measurement_model.get_noisy_observation()

lower_prob_result = solve_lower(energy=energy, method=method_lower, options=options_lower)
psnr = torch.mean(compute_psnr(energy.measurement_model.get_clean_data(), lower_prob_result.solution))
psnr = psnr.detach().cpu().item()

print('psnr [dB]: {:.5f}'.format(psnr))

psnr [dB]: 24.16104


### Callbacks and schedulers

In [14]:
def l2_loss_func(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return 0.5 * torch.sum((x - y) ** 2)

In [16]:
from torch.utils.tensorboard import SummaryWriter
from pylopt.bilevel_problem.callbacks import SaveModel, PlotFiltersAndPotentials, TrainingMonitor
from pylopt.bilevel_problem.scheduler import CosineAnnealingLRScheduler

tb_writer = SummaryWriter(log_dir=os.path.join(path_to_experiment_dir, 'tensorboard'))
callbacks = [PlotFiltersAndPotentials(test_image_dataset, 
                                      path_to_data_dir=path_to_experiment_dir,
                                      plotting_freq=2, 
                                      tb_writer=tb_writer),
             SaveModel(path_to_data_dir=path_to_experiment_dir, 
                       save_freq=2),
             TrainingMonitor(test_image_dataset, 
                             method_lower, 
                             options_lower, 
                             l2_loss_func,
                             path_to_experiment_dir, 
                             operator=torch.nn.Identity(),
                             noise_level=0.1,
                             lam=lam,
                             evaluation_freq=1, 
                             tb_writer=tb_writer)
            ]

schedulers = [CosineAnnealingLRScheduler(step_begin=50, 
                                         restart_cycle=None,
                                         step_end=80,
                                         lr_min=1e-5)]

### Setup logger

In [24]:
import logging
import sys

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)


logger = logging.getLogger("pylopt")
logger.setLevel(logging.INFO)
logger.propagate = True

### Start training

In [None]:
max_num_iterations = 100
method_upper = 'adam'
options_upper = {'max_num_iterations': max_num_iterations, 'lr': [1e-1], 'alternating': True}

bilevel_optimisation.learn(regulariser, 
                           lam, 
                           l2_loss_func, 
                           train_image_dataset,
                           optimisation_method_upper=method_upper, 
                           optimisation_options_upper=options_upper,
                           dtype=dtype, device=device, 
                           callbacks=callbacks, 
                           schedulers=schedulers)

2025-10-02 13:45:30,690 [INFO] root: [TrainingMonitor] compute initial test loss and initial psnr
2025-10-02 13:45:30,895 [INFO] root: [TrainingMonitor]   > average psnr: 20.51734
2025-10-02 13:45:30,895 [INFO] root: [TrainingMonitor]   > test loss: 1370.61719
2025-10-02 13:45:30,895 [INFO] root: [TrainingMonitor]   > evaluation took [ms]: 183.80284
2025-10-02 13:45:32,458 [INFO] root: [TrainingMonitor] log statistics and hyperparameters
2025-10-02 13:45:32,458 [INFO] root: [TrainingMonitor]   > lr for group potentials: 0.100
2025-10-02 13:45:32,642 [INFO] root: [TrainingMonitor]   > average psnr: 20.60477
2025-10-02 13:45:32,642 [INFO] root: [TrainingMonitor]   > test loss: 1343.30200
2025-10-02 13:45:32,642 [INFO] root: [TrainingMonitor]   > evaluation took [ms]: 177.28717
2025-10-02 13:45:33,022 [INFO] root: [TRAIN] iteration [1 / 100]: loss = 582.50275
2025-10-02 13:45:36,582 [INFO] root: [TrainingMonitor] log statistics and hyperparameters
2025-10-02 13:45:36,582 [INFO] root: [Tra

### Store trained model and visualisations of training results

In [13]:
save_foe_model(regulariser, os.path.join(path_to_eval_dir, 'models'), model_dir_name='final')

visualise_training_stats(train_loss_list, test_loss_list, psnr_list, evaluation_freq, path_to_eval_dir)
visualise_filter_stats(filters_list, path_to_eval_dir)