 # Neural Denoisers



 This notebook is a reproduction of some of the results presented in the paper [Deep Image Priors, Ulyanov 2018](https://arxiv.org/abs/1711.10925).
 

 What we will do is take an image $x_0$ from the internet and use a Neural Network (UNet) to solve

 $$

     \theta \in \argmin \|D_{\theta}(x_{\sigma}) - x_\sigma \|^2 \;,

 $$

 where $D_{\theta}$ is a convolutional neural net and $x_\sigma = x_0 + \sigma z$ where $z$ is a realization of a standard Gaussian noise.

 More concretely, we will follow the quantity $ L_0(\theta_t) = \|D_{\theta_t}(x_{in}) - x_0 \|^2$ and $L_{\sigma}(\theta_t) = \|D_{\theta_t}(x_{in}) - x_\sigma \|^2$ through the optimization trajectory $\theta_{1:n}$ of the optimization problem. For $x_{in}$, we can chose both $x_\sigma$ or a random (fixed during training) input!

 ### Quick PyTorch Tutorial

 This is a **really fast** pytorch tutorial that introduces the basic pytorch functions that you will need.

 1. **The basic neural network class**: A neural Network is an instance a class that inherits from the basic `torch.nn.Module` class. You can see how it was made in the class `UnetWithoutSkip` below. The important thing for us is that it contains a bunch of **parameters** which are the quantities that we are going to optimize. It also implements a `forward` function, which is responsible for taking the input and producing the output.

 2. **Calculating gradients of the network weights**: The first thing we do is ask torch to "track" the propagation (the composition) of the parameters of the network for a certain operation. This is done by setting `mymodel.requires_grad_(True)`. Then, we can calculate any quantity that uses `mymodel` and then calculate the gradient of the parameters with respect to this quantity. Below is a dumb example

    ```

    x_out = mymodel(x_in)

    loss = (x_out**2).sum()

    loss.backward()

    ```

 The code above calculates the gradients of the parameters of `mymodel`with respect to the loss defined (the squared sum of the outputs).

 3. **Optimizing**: While it is possible to manually write a gradient descent, by iterating through each element of `mymodel.parameters()` and updating it using the attribute `.grad` of each model, the classes from `torch.optim` handle all this for us. There are several available options, but we will use the Adam optimizer.

 To do so, the first thing one needs to do is to instanciate it with the parameters we need to optimize:

     ```optimizer = torch.optim.Adam(mymodel.parameters(), lr=1e-2)```

 Then, two methods are really useful:

 ```optimizer.zero_grad()``` -> Cleans the gradients, so we can start a new gradient computation

 ```optimizer.step()``` -> Updates the value of the variables according to a gradient that has been calculated (by YOU!)

In [None]:
from torchvision.transforms import (
    ToPILImage,
    ToTensor,
    Normalize,
    Compose,
)
import numpy as np
from tqdm import tqdm
import PIL
import requests
import matplotlib.pyplot as plt
import torch
import platform
from matplotlib import animation
from matplotlib import rc
from typing import Callable
from functools import partial
import os
import pickle

rc("animation", html="jshtml")


def get_device():
    # Check if CUDA is available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print(f"Using {device} for computation.")
        return device
    elif (
        "Apple" in platform.system()
    ):  # Check for Apple Silicon devices (MacOS on M1, etc.)
        device = torch.device("mps")  # Use Apple Metal Performance Shaders
        print(f"Using {device} for computation.")
        return device
    else:
        device = torch.device("cpu")
        print(f"Using CPU for computation.")
        return device


# Get the device
device = get_device()


def image_grid(imgs, rows, cols):

    w, h = imgs[0].size
    grid = PIL.Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, (img, _) in enumerate(zip(imgs, range(cols * rows))):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid



In [None]:

torch.manual_seed(42)
url = "https://cdn.pixabay.com/photo/2024/08/15/19/19/highland-cow-8972000_960_720.jpg"
img = PIL.Image.open(requests.get(url, stream=True).raw).crop((194, 88, 706, 600))

print(img.size)
img

In [None]:
# This cell describes transforming image to [-1, 1] data and back
img2data = Compose([ToTensor(), Normalize([0.5], [0.5])])
data2img = Compose([Normalize([-1.0], [2.0]), ToPILImage()])

img_data = img2data(img)  # Transform img to [-1, 1]

In [None]:
# This cell creates a noisy version
sigma = 0.5
noise = torch.randn_like(img_data)
noised_img_data = img_data + sigma * noise
noised_img = data2img(noised_img_data)
noised_img



In [None]:
# Network related stuff, taken and adapted from https://github.com/NVlabs/edm2/tree/main
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
    value = np.asarray(value)
    if shape is not None:
        shape = tuple(shape)
    if dtype is None:
        dtype = torch.get_default_dtype()
    if device is None:
        device = torch.device("cpu")
    if memory_format is None:
        memory_format = torch.contiguous_format

    tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
    if shape is not None:
        tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
    tensor = tensor.contiguous(memory_format=memory_format)

    return tensor


def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
    if dtype is None:
        dtype = ref.dtype
    if device is None:
        device = ref.device
    return constant(
        value, shape=shape, dtype=dtype, device=device, memory_format=memory_format
    )


def normalize(x, dim=None, tol=1e-4):
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
    norm = norm + tol * (norm.numel() / x.numel()) ** 0.5
    return x  # / norm Deactivate the normalization


def resample(x, f=[1, 1], mode="keep"):
    if mode == "keep":
        return x
    f = np.float32(f)
    assert f.ndim == 1 and len(f) % 2 == 0
    pad = (len(f) - 1) // 2
    f = f / f.sum()
    f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
    f = const_like(x, f)
    c = x.shape[1]
    if mode == "down":
        return torch.nn.functional.conv2d(
            x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)
        )
    assert mode == "up"
    return torch.nn.functional.conv_transpose2d(
        x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,)
    )


def mp_silu(x):
    return torch.nn.functional.silu(x) / 0.596


def mp_sum(a, b, t=0.5):
    return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t**2)


class MPConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super().__init__()
        self.out_channels = out_channels
        self.weight = torch.nn.Parameter(
            torch.randn(out_channels, in_channels, *kernel)
        )

    def forward(self, x, gain=1):
        w = self.weight.to(torch.float32)
        if self.training:
            with torch.no_grad():
                self.weight.copy_(normalize(w))  # forced weight normalization
        w = normalize(w)  # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel()))  # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1] // 2,))


class Block(torch.nn.Module):
    def __init__(
        self,
        in_channels,  # Number of input channels.
        out_channels,  # Number of output channels.
        flavor="enc",  # Flavor: 'enc' or 'dec'.
        resample_mode="keep",  # Resampling: 'keep', 'up', or 'down'.
        resample_filter=[1, 1],  # Resampling filter.
        dropout=0,  # Dropout probability.
        res_balance=0.3,  # Balance between main branch (0) and residual branch (1).
        clip_act=256,  # Clip output activations. None = do not clip.
    ):
        super().__init__()
        self.out_channels = out_channels
        self.flavor = flavor
        self.resample_filter = resample_filter
        self.resample_mode = resample_mode
        self.dropout = dropout
        self.res_balance = res_balance
        self.clip_act = clip_act
        self.emb_gain = torch.nn.Parameter(torch.zeros([]))
        self.conv_res0 = MPConv(
            out_channels if flavor == "enc" else in_channels,
            out_channels,
            kernel=[3, 3],
        )
        self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3, 3])
        self.conv_skip = (
            MPConv(in_channels, out_channels, kernel=[1, 1])
            if in_channels != out_channels
            else None
        )

    def forward(self, x):
        # Main branch.
        x = resample(x, f=self.resample_filter, mode=self.resample_mode)
        if self.flavor == "enc":
            if self.conv_skip is not None:
                x = self.conv_skip(x)
            x = normalize(x, dim=1)  # pixel norm

        # Residual branch.
        y = self.conv_res0(mp_silu(x))
        if self.training and self.dropout != 0:
            y = torch.nn.functional.dropout(y, p=self.dropout)
        y = self.conv_res1(y)

        # Connect the branches.
        if self.flavor == "dec" and self.conv_skip is not None:
            x = self.conv_skip(x)
        x = mp_sum(x, y, t=self.res_balance)

        # Clip activations.
        if self.clip_act is not None:
            x = x.clip_(-self.clip_act, self.clip_act)
        return x


class UNetWithoutSkip(torch.nn.Module):
    def __init__(
        self,
        img_resolution,  # Image resolution.
        img_channels,  # Image channels.
        model_channels=192,  # Base multiplier for the number of channels.
        channel_mult=[
            1,
            2,
            3,
            4,
        ],  # Per-resolution multipliers for the number of channels.
        num_blocks=3,  # Number of residual blocks per resolution.
        **block_kwargs,  # Arguments for Block.
    ):
        super().__init__()
        cblock = [model_channels * x for x in channel_mult]
        self.out_gain = torch.nn.Parameter(torch.zeros([]))

        # Encoder.
        self.enc = torch.nn.ModuleDict()
        cout = img_channels + 1
        for level, channels in enumerate(cblock):
            res = img_resolution >> level
            if level == 0:
                cin = cout
                cout = channels
                self.enc[f"{res}x{res}_conv"] = MPConv(cin, cout, kernel=[3, 3])
            else:
                self.enc[f"{res}x{res}_down"] = Block(
                    cout, cout, flavor="enc", resample_mode="down", **block_kwargs
                )
            for idx in range(num_blocks):
                cin = cout
                cout = channels
                self.enc[f"{res}x{res}_block{idx}"] = Block(
                    cin, cout, flavor="enc", **block_kwargs
                )

        # Decoder.
        self.dec = torch.nn.ModuleDict()
        for level, channels in reversed(list(enumerate(cblock))):
            res = img_resolution >> level
            if level == len(cblock) - 1:
                self.dec[f"{res}x{res}_in0"] = Block(
                    cout, cout, flavor="dec", **block_kwargs
                )
                self.dec[f"{res}x{res}_in1"] = Block(
                    cout, cout, flavor="dec", **block_kwargs
                )
            else:
                self.dec[f"{res}x{res}_up"] = Block(
                    cout, cout, flavor="dec", resample_mode="up", **block_kwargs
                )
            for idx in range(num_blocks + 1):
                cin = cout
                cout = channels
                self.dec[f"{res}x{res}_block{idx}"] = Block(
                    cin, cout, flavor="dec", **block_kwargs
                )
        self.out_conv = MPConv(cout, img_channels, kernel=[3, 3])

    def forward(self, x):
        # Encoder.
        x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
        for name, block in self.enc.items():
            x = block(x)

        # Decoder.
        for name, block in self.dec.items():
            x = block(x)
        x = self.out_conv(x, gain=self.out_gain)
        return x



In [None]:
unet = torch.compile(
    UNetWithoutSkip(
        img_resolution=512,
        img_channels=3,
        model_channels=16,
        channel_mult=[1, 1, 2],
        num_blocks=2,
    )
).to(device)
unet = unet.requires_grad_(True)


In [None]:
N_steps = 4000
lr = 1e-2
stats = {"loss2real": [], "loss2noised": [], "recImage": []}
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)


In [None]:
in_data = noised_img_data[None].to(device)
noised_img_data = noised_img_data.to(device)
img_data = img_data.to(device)
pbar = tqdm(range(N_steps))
for step in pbar:
    # Implement here your solution
    optimizer.zero_grad()
    pred = unet(in_data)[0]
    loss2noised = ((noised_img_data - pred) ** 2).mean()
    loss2noised.backward()
    optimizer.step()
    with torch.no_grad():
        loss2real = ((img_data - pred) ** 2).mean()
    if step % 10 == 0:
        stats["loss2real"].append(loss2real.item())
        stats["loss2noised"].append(loss2noised.item())
        stats["recImage"].append(data2img(pred))
        pbar.set_postfix(
            {"loss noisy": loss2noised.item(), "loss clean": loss2real.item()}
        )


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

ax.plot(stats["loss2real"], label=r"$L_{0}(\theta_t)$")
ax.plot(stats["loss2noised"], label=r"$L_{\sigma}(\theta_t)$")
ax.set_xlabel("t")
ax.legend()
fig.show()

fig, axes = plt.subplots(1, 2, figsize=(10, 10))
axes[0].imshow(image_grid(stats["recImage"][-64:], rows=8, cols=8))
axes[1].imshow(
    image_grid(
        [
            img,
            noised_img,
            stats["recImage"][np.argmin(stats["loss2real"])],
            stats["recImage"][np.argmin(stats["loss2noised"])],
        ],
        rows=2,
        cols=2,
    )
)
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
fig.show()



# Langevin sampling with denoisers

## Implementation a Langevin sampler for a given noise level.

In [None]:
def denoiser_fn(x, sigma, class_labels):
    with torch.no_grad():
        pred = (
              net(
                x.to(device),
                torch.ones(x.shape[0], device=device)*sigma,
                class_labels=torch.nn.functional.one_hot(class_labels.long().to(device), 1000),
            ).cpu()
        )
    return pred.clamp(-1, 1)


def load_karras_net():
    curr_dir = os.getcwd()
    os.chdir("edm2")
    import dnnlib
    path_to_model = 'https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions/edm2-img64-s-1073741-0.075.pkl'
    with dnnlib.util.open_url(path_to_model, verbose=True) as f:
        data = pickle.load(f)
    net = data['ema'].to(device)
    os.chdir(curr_dir)
    return torch.compile(net)
net = load_karras_net()
sigma_data = 0.5

def run_langevin(
    initial_samples: torch.Tensor,
    denoiser_fn: Callable[[torch.Tensor, float], torch.Tensor],
    sigma: float,
    n_steps: int,
    lr: float,
) -> torch.Tensor:
    """
    initial_samples: Initialization for the Langevin algorithm
    denoiser_fn: Function D(x, \sigma), which gives the denoising for a given sigma.
    n_steps: The number of steps to run Langevin.
    sigma: Amount of noise
    lr: The Langevin learning rate.
    """
    # implement your code here
    samples = initial_samples.clone()
    pbar = tqdm(range(n_steps), desc=f"Langevin {sigma:.2f}")
    for i in pbar:
        score = (denoiser_fn(samples, sigma) - samples) / (sigma**2)
        drift = score
        samples = samples + lr * drift + ((2 * lr) ** 0.5) * torch.randn_like(samples)
        if i % 10:
            pbar.set_postfix(
                {
                    "score norm": torch.linalg.vector_norm(score[0]).item(),
                    "dinit": torch.linalg.vector_norm(
                        samples[0] - initial_samples[0]
                    ).item(),
                }
            )
    return samples



In [None]:
sigma = 2
batch_size = 2
n_langevin = 100
lr = 1e-1

In [None]:
initial_samples = torch.randn((batch_size, 3, 64, 64)) * (
    (sigma**2 + sigma_data**2) ** 0.5
)
class_labels = (
    torch.ones(
        batch_size,
    )
    * 240
)
samples = initial_samples.clone()

denoiser_fixed_class_fn = partial(denoiser_fn, class_labels=class_labels)

In [None]:
samples = run_langevin(
    initial_samples=samples,
    denoiser_fn=denoiser_fixed_class_fn,
    sigma=sigma,
    lr=lr,
    n_steps=n_langevin,
)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.imshow(
    image_grid(
        [
            data2img(i)
            for i in denoiser_fn(
                torch.cat((initial_samples, samples)), sigma, class_labels.repeat(2)
            )
        ],
        rows=2,
        cols=2,
    )
)
fig.show()