# Tutorial \#7 (23 Feb. 2024)

Faisal Jayousi 

jayousi@unice.fr

## What is DeepInverse all about?
DeepInverse is a Pytorch based Python module for solving imaging inverse problems with Deep Learning. If provides a range of linear and non-linear forward operators, such as blurring, denoising and downsampling.

## Importing DeepInverse
* You need to install the DeepInverse module for this tutorial: ```pip install deepinv```
* To import it: use ```import deepinv```
* Documentation: https://deepinv.github.io/deepinv/index.html

In [1]:
import deepinv as dinv
import matplotlib.pyplot as plt
import torch

import numpy as np
import random
from pathlib import Path
import torch
from torchvision import datasets
from torchvision import transforms

import deepinv as dinv
from torch.utils.data import DataLoader
from deepinv.optim.data_fidelity import L1
from deepinv.optim.prior import PnP
from deepinv.unfolded import unfolded_builder
from deepinv.training_utils import train, test

import matplotlib.pyplot as plt





## LISTA

The subject of this tutorial is LISTA (Learned Iterative Soft-Thresholding Algorithm). The main difference with the classical ISTA algorithm is that the thresholding parameters of the proximal operator are learned. For a use-case in compressed sensing, see https://deepinv.github.io/deepinv/auto_examples/unfolded/demo_LISTA.html.

**NB**: Using a GPU for this tutorial is recommended. However, the parameters have been modestly adjusted to ensure efficient execution on a CPU in minimal time.

The steps are as follows:
1. Data generation: as a first approach, we will use the well-known MNIST dataset. It can be downloaded using ```datasets.MNIST()```. See the documentation for more details on the arguments.
2. Generate blurred images: you can use ```dinv.datasets.generate_dataset()```. An appropriate blurring model needs to be defined first. For simplicity, choose a gaussian kernel with $\sigma=3$.
3. Load the data using PyTorch's dataloaders. Don't forget to shuffle the training data. You may use DeepInverse's ```datasets.HDF5Dataset()```.
4. Train using the training data. See DeepInverse's ```training_utils```.
5. Test your model on the test dataset. Plot the learned stepsize and regularisastion parameter.
6. Test your model on a generated SMLM image. How well does it work?
7. Redo everything but with generated SMLM images (use ~250 images for training).


### Data generation

In [2]:
# Function to generate an SMLM image
def generate_ground_truth(n: int, n_mol: int, margin: int) -> np.ndarray:
    """
    n: shape of output
    n_mol: number of molecules
    margin: to avoid placing molecules near the edges
    """
    gt = np.zeros((n, n), dtype=np.double)
    for _ in np.arange(n_mol):
        i = random.randint(margin, n-margin)
        j = random.randint(margin, n-margin)
        gt[i, j] = 255.
    return gt

In [3]:
# Reproducibility
torch.manual_seed(0)
random.seed(234)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [4]:
BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR
DATA_DIR = BASE_DIR
RESULTS_DIR = BASE_DIR
CKPT_DIR = BASE_DIR

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

In [5]:
def blur_operator(x):
    s = 3
    t = np.concatenate((np.arange(0, x.shape[0]/2+1), np.arange(-x.shape[0]/2,-1)))
    Y, X = np.meshgrid(t,t)
    h = h=np.exp(-(X**2+Y**2)/(2.0*float(s)**2))
    h = h/np.sum(h)  #Is the operator, instead of A, consider h
    return np.real(np.fft.ifft2(np.fft.fft2(h) * np.fft.fft2(x)))


In [6]:
img_size = 28
n_channels = 1
operation = "compressed-sensing"
train_dataset_name = "MNIST_train"

# Generate training and evaluation datasets in HDF5 folders and load them.
train_test_transform = transforms.Compose([transforms.ToTensor()])
train_base_dataset = datasets.MNIST(
    root=ORIGINAL_DATA_DIR, train=True, transform=train_test_transform, download=True
)
test_base_dataset = datasets.MNIST(
    root=ORIGINAL_DATA_DIR, train=False, transform=train_test_transform, download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:03<00:00, 2855722.98it/s]


Extracting .\MNIST\raw\train-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 2360215.37it/s]


Extracting .\MNIST\raw\train-labels-idx1-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 653627.75it/s]


Extracting .\MNIST\raw\t10k-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2278498.84it/s]

Extracting .\MNIST\raw\t10k-labels-idx1-ubyte.gz to .\MNIST\raw






In [7]:
# Use parallel dataloader if using a GPU to fasten training, otherwise, as all computes are on CPU, use synchronous
# data loading.
num_workers = 4 if torch.cuda.is_available() else 0


physics = dinv.physics.Blur(dinv.physics.blur.gaussian_blur(sigma=(1, 1), angle=0), padding='circular', device='cpu')
my_dataset_name = "demo_LISTA"
n_images_max = (
    1000 if torch.cuda.is_available() else 200
)  # maximal number of images used for training
measurement_dir = DATA_DIR / train_dataset_name / operation
generated_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_base_dataset,
    test_dataset=test_base_dataset,
    physics=physics,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    test_datapoints=8,
    num_workers=num_workers,
    dataset_filename=str(my_dataset_name),
)

train_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=generated_datasets_path, train=False)

Computing train measurement vectors from base dataset...


100%|██████████| 1/1 [00:00<00:00,  3.87it/s]


Computing test measurement vectors from base dataset...


100%|██████████| 2/2 [00:00<00:00, 100.04it/s]

Dataset has been saved in MNIST_train\compressed-sensing





In [8]:
# Select the data fidelity term
data_fidelity = L1()

# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
# If the prior is initialized with a list of length max_iter,
# then a distinct weight is trained for each PGD iteration.
# For fixed trained model prior across iterations, initialize with a single model.
max_iter = 10 if torch.cuda.is_available() else 10  # Number of unrolled iterations
level = 2
prior = [
    PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
    for i in range(max_iter)
]

# Unrolled optimization algorithm parameters

lamb = [1.0] * max_iter  # initialization of the regularization parameter.
# A distinct lamb is trained for each iteration.

stepsize = [1.0] * max_iter  # initialization of the stepsizes.
# A distinct stepsize is trained for each iteration.

sigma_denoiser_init = 0.05
sigma_denoiser = [sigma_denoiser_init * torch.ones(level, 3)] * max_iter
# A distinct sigma_denoiser is trained for each iteration.

params_algo = {  # wrap all the restoration parameters in a 'params_algo' dictionary
    "stepsize": stepsize,
    "g_param": sigma_denoiser,
    "lambda": lamb,
}

trainable_params = [
    "g_param",
    "stepsize",
    "lambda",
]  # define which parameters from 'params_algo' are trainable

# Define the unfolded trainable model.
model = unfolded_builder(
    iteration="PGD",
    params_algo=params_algo.copy(),
    trainable_params=trainable_params,
    data_fidelity=data_fidelity,
    max_iter=max_iter,
    prior=prior,
).to(device)

  [nn.Parameter(torch.tensor(el).to(device)) for el in param_value]


### Define algorithm

Same as in the previous tutorials, we will solve the following problem $$\min_x \frac{1}{2}\|Ax-y\|_2^2 + \lambda\|x\|_1$$

In this tutorial, we will focus exclusively on $\ell^1$ regularisation. However, the module provides a wide range of intriguing priors, both classical and learnable. For further details, refer to https://deepinv.github.io/deepinv/deepinv.models.html. Additionally, we will optimise the selection of $\lambda$.

In [9]:
# Training parameters
epochs = 4
learning_rate = 1e-3


# Choose optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Choose supervised training loss
losses = [dinv.loss.SupLoss(metric=dinv.metric.mse())]

# Logging parameters
verbose = True
wandb_vis = False  # plot curves and images in Weight&Bias

# Batch sizes and data loaders
train_batch_size = 1
test_batch_size = 8

train_dataloader = DataLoader(
    train_dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=test_batch_size, num_workers=num_workers, shuffle=False
)

In [10]:
train(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=epochs,
    losses=losses,
    physics=physics,
    optimizer=optimizer,
    device=device
)

The model has 80 trainable parameters


OSError: [WinError 123] The filename, directory name, or volume label syntax is incorrect: './24-02-23-11:58:18'

In [None]:
# Test model and plot results

### Test on SMLM image

In [None]:
gt = generate_ground_truth(28, 4, 4)
gt = torch.from_numpy(gt).to(torch.float)


blurred_image = ...
out = ...

plt.subplot(131)
plt.imshow(gt, cmap='gray')
plt.subplot(132)
plt.imshow(blurred_image[0, 0], cmap="gray")
plt.subplot(133)
plt.imshow(out.detach().numpy()[0, 0], cmap='gray')
plt.show()

### Retraining on SMLM images