Demo notebook for algorithm evaluation
======================================

The algorithm will be evaluated on three aspects
- Correctness
- Perception
- Efficiency

The evaluation is performed on three tasks 
- Inpainting middle
- Super Resolution $\times 16$
- out painting half
for three images.

# Correctness

Here we check how well the algorithm approximates the posterior distribution $p(x | y)$

For that we work on in a step where we can evaluate explicitly the posterior: **Case of Gaussian Mixture**.

In this setup, we have an analytic expression of the score of the diffusion model, the transition kernels, and more precisely the posterior $p(x | y)$.
We use the Sliced Wasserstein (SW) distance to compare the true posterior with the approximate posterior.

In [None]:
import torch
from torch.distributions import (
    MixtureSameFamily,
    Categorical,
    MultivariateNormal,
)

from evaluation.gmm import generate_inverse_problem, load_gmm_epsilon_net


dim = 2
n_samples = 300
n_steps = 300
sigma = 0.1

device = "cpu"
torch.set_default_device(device)


# define the prior distribution: #  Gaussian Gaussian
means = torch.tensor(
    [[8 * i, 8 * j] * (dim // 2) for i in range(-2, 3) for j in range(-2, 3)], dtype=torch.float32
)
n_mixtures = means.shape[0]
covs = torch.eye(dim)[None, :].repeat(n_mixtures, 1, 1)
weights = torch.rand(n_mixtures)
weights = weights / weights.sum()

prior = MixtureSameFamily(Categorical(weights), MultivariateNormal(means, covs))

# deduce the posterior
obs, degradation_operator, posterior = generate_inverse_problem(
    prior, dim, sigma, A=torch.tensor([[1, 0]], dtype=torch.float32)
)

# define inverse problem
inverse_problem = (obs, degradation_operator, sigma)

**Note**:
Notice the the ``degradation_operator`` was defined to mask the y-coordinate.
see ``A=torch.tensor([[1, 0]], dtype=torch.float32)``

Next, let's instantiate the diffusion model and use it to solve the inverse problem using DPS algorithm

In [None]:
from sampling.dps import dps


# load diffusion model trained on prior
eps_net = load_gmm_epsilon_net(prior=prior, dim=dim, n_steps=n_steps)

# solve problem
initial_noise = torch.randn((n_samples, dim), device=device)
reconstruction = dps(initial_noise, inverse_problem, eps_net)


Let's plot the prior, the posterior and the DPS reconstruction to see how they look like

In [None]:
import matplotlib.pyplot as plt

# sample the prior and posterior
samples_prior = prior.sample((n_samples,))
samples_posterior = posterior.sample((n_samples,))

# init figure
fig, ax = plt.subplots()

arr_samples = (samples_prior, samples_posterior, reconstruction)
arr_labels = ("prior", "posterior", "DPS")

# plot
for samples, label in zip(arr_samples, arr_labels):
    ax.scatter(
        samples[:, 0], samples[:, 1], alpha=0.5, label=label
    )

ax.set_xlabel("x")
ax.set_ylabel("y")
fig.legend(loc="upper center", ncols=len(arr_labels))


Finally let's compute the (SW) distance

In [None]:
from evaluation.gmm import sliced_wasserstein

distance = sliced_wasserstein(samples_posterior, reconstruction)

print(f"The sliced Wasserstein distance {distance:.4f}")

# Perception

Here we assess how well the algorithm reconstruction is *perceptually* close to the ground truth.

For that we use the LPIPS metric introduced in [1] which has been shown to match human judgment.
The smaller this metric is, the better.


.. [1] Zhang, Richard, et al. "The unreasonable effectiveness of deep features as a perceptual metric." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.

Let's solve SR16 problem using DPS and compute the LPIPS metric on it

In [None]:
import torch

from utils import load_epsilon_net, load_image
from utils import load_epsilon_net
from sampling.dps import dps

device = "cuda:0"
n_steps = 300
torch.set_default_device(device)


# load the image
img_path = "./material/celebahq_img/00010.jpg"
x_origin = load_image(img_path, device)


# load the degradation operator
path_operator = f"./material/degradation_operators/sr16.pt"
degradation_operator = torch.load(path_operator, map_location=device)

# apply degradation operator
y = degradation_operator.H(x_origin[None])
y = y.squeeze(0)

# add noise
sigma = 0.01
y = y + sigma * torch.randn_like(y)

# define inverse problem
inverse_problem = (y, degradation_operator, sigma)

# load model
eps_net = load_epsilon_net("celebahq", n_steps, device)

# solve problem
initial_noise = torch.randn((1, 3, 256, 256), device=device)
reconstruction = dps(initial_noise, inverse_problem, eps_net)
reconstruction.clamp(-1, 1)

Let's plot the results

In [None]:
# plot results
import math
from utils import display_image


# reshape y
n_channels = 3
n_pixel_per_channel = y.shape[0] // n_channels
hight = width = int(math.sqrt(n_pixel_per_channel))

y_reshaped = y.reshape(n_channels, hight, width)

# init figure
fig, axes = plt.subplots(1, 3)

images = (x_origin, y_reshaped, reconstruction[0])
titles = ("original", "degraded", "reconstruction")

# display figures
for ax, img, title in zip(axes, images,titles):
    display_image(img, ax)
    ax.set_title(title)

fig.tight_layout()

In [None]:
from evaluation.perception import LPIPS

lpips =  LPIPS()
print(f"lpips: {lpips.score(reconstruction.clamp(-1, 1), x_origin)}")

# Efficiency

Here we measure the run time of the algorithm and the memory consumption.

We count time that is need for the algorithm to solve a problem.
Let's look on the previous example.

In [None]:
import time, torch

from utils import load_epsilon_net, load_image
from utils import load_epsilon_net
from sampling.dps import dps

device = "cuda:0"
n_steps = 300
torch.set_default_device(device)


# load the image
img_path = "./material/celebahq_img/00010.jpg"
x_origin = load_image(img_path, device)


# load the degradation operator
path_operator = f"./material/degradation_operators/sr16.pt"
degradation_operator = torch.load(path_operator, map_location=device)

# apply degradation operator
y = degradation_operator.H(x_origin[None])
y = y.squeeze(0)

# add noise
sigma = 0.01
y = y + sigma * torch.randn_like(y)

# define inverse problem
inverse_problem = (y, degradation_operator, sigma)

# load model
eps_net = load_epsilon_net("celebahq", n_steps, device)

# solve problem
initial_noise = torch.randn((1, 3, 256, 256), device=device)

start_time = time.perf_counter()
reconstruction = dps(initial_noise, inverse_problem, eps_net)
finish_time = time.perf_counter()

print(f"Elapsed time: {finish_time -start_time:.5f}")

As it is difficult to get the memory consumption of the algorithm, we will deduce it by monitoring the output of ``nvidia-smi -l`` command.
