Demo notebook for inverse problems
==================================

**Table of content**
1. [Defining an inverse problem](#defining-an-inverse-problem)
1. [Example on CelebaHQ dataset with DPS algorithm](#example-on-celebahq-dataset-with-dps-algorithm)
1. [Example on FFHQ dataset with $\Pi\text{GDM}$ algorithm](#example-on-ffhq-dataset-with--algorithm)


# Defining an inverse problem

Inverse problem,
\begin{equation*}
    y = A x + \sigma \ \epsilon
\end{equation*}
- $y$ the observation
- $A$ the degradation operator
- $\sigma \ \epsilon$ Gaussian noise $\mathcal{N}(0, \sigma I)$

Let's see how to

1. Load an image
1. Load the degradation operator $A$
1. Plot the observation $y$


## Loading an image

In [None]:
from py_source.utils import load_image, display_image
import matplotlib.pyplot as plt

# load the image
img_path = "./material/celebahq_img/00010.jpg"
x_origin = load_image(img_path)

# plot the image
fig, ax = plt.subplots()

display_image(x_origin, ax)


## Loading the degradation operator

Let's load SR16 degradation operator

It reduces the resolution of the image by $16$

In [None]:
import math
import torch

# load the degradation operator
path_operator = f"./material/degradation_operators/sr16.pt"
degradation_operator = torch.load(path_operator, map_location="cpu")

# apply degradation operator
# NOTE: it operates on bach of images
y = degradation_operator.H(x_origin[None])

# reshape to plot the observation
# NOTE: y is a square image with 3 channels
n_channels = 3
n_pixel_per_channel = y.shape[1] // n_channels
hight = width = int(math.sqrt(n_pixel_per_channel))

y_reshaped = y.reshape(n_channels, hight, width)

# plot the image
fig, ax = plt.subplots()
display_image(y_reshaped)

Let's add Gaussian noise with $\sigma = 0.1$

In [None]:
# add noise
sigma = 0.1
y_reshaped_noised = y_reshaped + sigma * torch.randn_like(y_reshaped)

# plot the three images side by side
fig, axes = plt.subplots(1, 3)

images = (x_origin, y_reshaped, y_reshaped_noised)
titles = ("original", "degraded", "degraded + noise")

for ax, img, title in zip(axes, images,titles):
    display_image(img, ax)
    ax.set_title(title)

fig.tight_layout()

In the end, the inverse problem can be defined using the tuple ``(y, degradation_operator, sigma)``

**Note**:
In practice, we don't have access to $x$ (in the code ``x_origin``) but only to $y$, $A$, and $\sigma$

# Example on CelebaHQ dataset with DPS algorithm

The model details and checkpoint can be found in [Hugging Face](https://huggingface.co/google/ddpm-celebahq-256)

Here, the package ``Diffusers`` is used under the hood to load the model.

Beforehand, let's load the model and perform unconditional sampling

In [None]:
import torch
from py_source.utils import load_epsilon_net
from py_source.sampling.unconditional import unconditional_sampling


# load the noise predictor with 1000 diffusion steps
device = "cuda:0"
n_steps = 1000
torch.set_default_device(device)

eps_net = load_epsilon_net("celebahq", n_steps, device)

# check unconditional generation
# NOTE: use initial noise to specify number of generated samples
initial_noise = torch.randn((1, 3, 256, 256), device=device)
generated_images = unconditional_sampling(eps_net, initial_noise, display_im=False)

# plot image
fig, ax = plt.subplots()
display_image(generated_images[0], ax)

Now we have all the building blocks to solve linear inverse problems.

Let's solve SR16 problem with CelebaHQ model prior using DPS algorithm [1].


.. [1] Chung, Hyungjin, et al. "Diffusion posterior sampling for general noisy inverse problems." arXiv preprint arXiv:2209.14687 (2022).


In [None]:
device = "cuda:0"
torch.set_default_device(device)


# define first the inverse problem

# load the image
img_path = "./material/celebahq_img/00010.jpg"
x_origin = load_image(img_path, device)


# load the degradation operator
path_operator = f"./material/degradation_operators/sr16.pt"
degradation_operator = torch.load(path_operator, map_location=device)

# apply degradation operator
y = degradation_operator.H(x_origin[None])
y = y.squeeze(0)

# add noise
sigma = 0.01
y = y + sigma * torch.randn_like(y)

inverse_problem = (y, degradation_operator, sigma)

In [None]:
from py_source.utils import load_epsilon_net
from py_source.sampling.dps import dps

# load model with 500 diffusion steps
n_steps = 500
eps_net = load_epsilon_net("celebahq", n_steps, device)

# solve problem
initial_noise = torch.randn((1, 3, 256, 256), device=device)
reconstruction = dps(initial_noise, inverse_problem, eps_net)

In [None]:
# plot results

# reshape y
n_channels = 3
n_pixel_per_channel = y.shape[0] // n_channels
hight = width = int(math.sqrt(n_pixel_per_channel))

y_reshaped = y.reshape(n_channels, hight, width)

# init figure
fig, axes = plt.subplots(1, 3)

images = (x_origin, y_reshaped, reconstruction[0])
titles = ("original", "degraded", "reconstruction")

# display figures
for ax, img, title in zip(axes, images,titles):
    display_image(img, ax)
    ax.set_title(title)

fig.tight_layout()

# Example on FFHQ dataset with $\Pi\text{GDM}$ algorithm

The model details can be found in [Diffusion Posterior Sampling for General Noisy Inverse Problems](https://arxiv.org/abs/2209.14687) in the Experiment section.

The model checkpoint, ``ffhq_10m.pt``, can be downloaded [here](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh)

Beforehand, let's load the model and perform unconditional sampling

In [None]:

# load the noise predictor with 500 diffusion steps
device = "cuda:0"
n_steps = 1000
torch.set_default_device(device)

eps_net = load_epsilon_net("ffhq", n_steps, device)

# check unconditional generation
# NOTE: use initial noise to specify number of generated samples
initial_noise = torch.randn((1, 3, 256, 256), device=device)
generated_images = unconditional_sampling(eps_net, initial_noise, display_im=False)

# plot image
fig, ax = plt.subplots()
display_image(generated_images[0], ax)

Now let's solve an inpainting problem with FFHQ prioir using $\Pi\text{GDM}$ algorithm [1]

.. [1] Song, Jiaming, et al. "Pseudoinverse-guided diffusion models for inverse problems." International Conference on Learning Representations. 2023.

In [None]:
device = "cuda:0"
torch.set_default_device(device)


# define first the inverse problem

# load the image
img_path = "./material/ffhq_img/00018.png"
x_origin = load_image(img_path, device)

# load the degradation operator
path_operator = f"./material/degradation_operators/inpainting_middle.pt"
degradation_operator = torch.load(path_operator, map_location=device)

# apply degradation operator
y = degradation_operator.H(x_origin[None])
y = y.squeeze(0)

# add noise
sigma = 0.01
y = y + sigma * torch.randn_like(y)

inverse_problem = (y, degradation_operator, sigma)


This algorithm leverages the SVD decomposition of the degradation operator $A = U^\top \Sigma V$

Therefore, let's use ``EpsilonNetSVD`` to make the right transformations.

In [None]:
from py_source.utils import load_epsilon_net
from py_source.sampling.pgdm import pgdm_svd
from py_source.sampling.epsilon_net import EpsilonNetSVD

# load model with 500 diffusion steps
n_steps = 500
eps_net = load_epsilon_net("ffhq", n_steps, device)

eps_net_svd = EpsilonNetSVD(
        net=eps_net.net,
        alphas_cumprod=eps_net.alphas_cumprod,
        timesteps=eps_net.timesteps,
        H_func=degradation_operator,
        device=device,
    )

# solve problem
initial_noise = torch.randn((1, 3, 256, 256), device=device)
reconstruction = pgdm_svd(initial_noise, inverse_problem, eps_net_svd)

In [None]:
# plot results

# reshape y
y_reshaped =  -torch.ones(3 * 256 * 256, device=device)
y_reshaped[: y.shape[0]] = y
y_reshaped = degradation_operator.V(y_reshaped[None])
y_reshaped = y_reshaped.reshape(3, 256, 256)


# init figure
fig, axes = plt.subplots(1, 3)

images = (x_origin, y_reshaped, reconstruction[0])
titles = ("original", "degraded", "reconstruction")

# display figures
for ax, img, title in zip(axes, images,titles):
    display_image(img, ax)
    ax.set_title(title)

fig.tight_layout()