<a href="https://colab.research.google.com/github/MachaBar/PROJET-LU2IN013/blob/main/tp2425/tp5_exo2_inverse_problem_with_generative_prior.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exercise 2: Generative prior for imaging inverse problems

<br/><br/>

<a target="_blank" href="https://colab.research.google.com/github/generativemodelingmva/generativemodelingmva.github.io/blob/main/tp2425/tp5_exo2_inverse_problem_with_generative_prior.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

<br/><br/>

**Authors:**

Bruno Galerne: https://www.idpoisson.fr/galerne/

Arthur Leclaire: https://perso.telecom-paristech.fr/aleclaire/

<br/><br/>
You should complete the code regions marked with `### ... ###`


<br/><br/>

In [None]:
# from google.colab import drive
# drive.mount('/content/drive/')

# # default directory:
# %cd /content/drive/MyDrive/Colab\ Notebooks
# # we advise to create a specific directory on your Google drive:
# %cd /content/drive/MyDrive/genmod2425

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torch.autograd as autograd
import os

from PIL import Image
from IPython.display import display

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

def stretch(x):
# stretch values such that [min. max]->[0,1]
    m = torch.min(x)
    M = torch.max(x)
    if M>m:
        return((x-m)/(M-m))
    else:
        return(torch.zeros(x.shape))

def imshow(img, unnormalize=False, zoom_factor=3, stretch_opt=False):
    img = img.clone().detach().to('cpu')
    if unnormalize:
        img = img*0.5 + 0.5     # unnormalize
    if stretch_opt:
        img = stretch(img)
    if zoom_factor!=1:
        img = torch.kron(img, torch.ones(1,zoom_factor,zoom_factor))
    pil_img = torchvision.transforms.functional.to_pil_image(1-img)
    display(pil_img)
    return(pil_img)


## Introduction

In this notebook we will use a generative model as an image prior to solve an imaging inverse problem.
This amounts to limit the space of images to the subset
$$
\{x = G(z),~z\in\mathbb{R}^k\} \subset \mathbb{R}^d
$$
to solve a least squares inverse problem
$$
\min_{x} \|Ax - y \|^2.
$$

## Load pretrained generative network:

In [None]:
# load a pre-trained generative network
!wget -c 'https://www.idpoisson.fr/galerne/mva/GAN_G_net_ep100.pth'

# Generator network:
class G_Net(nn.Module):
  def __init__(self, k):
    super(G_Net, self).__init__()
    self.fc1 = nn.Linear(k, 256)
    self.fc2 = nn.Linear(256, 512)
    self.fc3 = nn.Linear(512, 784)

  def forward(self,x):
    x = self.fc1(x)
    x = F.leaky_relu(x, negative_slope=0.2)
    x = self.fc2(x)
    x = F.leaky_relu(x, negative_slope=0.2)
    x = self.fc3(x)
    x = torch.tanh(x)
    x = x.view(-1,1,28,28) # batch_size x channels x H x W
    return(x)

def show(G,z=None):
  # provide random latent code as option to see evolution
  with torch.no_grad():
    if z==None:
      z = torch.randn(100,k).to(device)
    genimages = G(z)
    pil_img = imshow(torchvision.utils.make_grid(genimages.to('cpu'),nrow=10))
    return(pil_img)
    #print(disnet(genimages))

# initialize generator (with random weights)
k = 32
G_net = G_Net(k).to(device)
z = torch.randn(100,k).to(device)

print("Generator with random init:")
show(G_net);
G_net.load_state_dict(torch.load('GAN_G_net_ep100.pth', map_location=device))
G_net.eval()
G_net.requires_grad_(False)
print("Pretrained generator:")
show(G_net,z=z);


## Operator of inverse problem

Implement an operator $A$ that does crude subsampling with stride s.
Implement also a version of $A$ that applies on a batch.

In [None]:
# the input x is a gray-level image of shape 1xMxN
def A(x,s=2):
    return ### ... ###

# the input x is a batch of gray-level images of shape bx1xMxN
def batchA(x,s=2):
    return ### ... ###

## Input data

We will consider images from the **MNIST test set** and a generative model trained as a GAN using a the disjoint **MNIST training set**.

In [None]:
# input
# transformtest = transforms.ToTensor()
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
datatest = datasets.MNIST('.', train=False, download=True, transform=transform)
input_idx = 0 # a seven
#input_idx = 99 # a nine
#input_idx = 3 # a zero
#input_idx = 5 # a one
#input_idx = 21 # a six
#input_idx = 1 # a two
#input_idx = 4 # a four
#input_idx = 1984 # an unusual 2, hard


x0=  datatest[input_idx][0].to(device)
_, M, N = x0.shape
d = M*N
imshow(x0);

# compute a sample of the direct model y = Ax0 + w
y = ### ... ###

imshow(y)

## Pseudo-inverse of operator

Compute the $A^+ y$ by applying gradient descent to the convex function
$$
f(x) = \|Ax - y \|^2
$$

In [None]:
x = torch.zeros(x0.shape).to(device)
x.requires_grad = True

# compute pseudo-inverse of y
optimizer = optim.SGD([x], lr = 0.01, momentum = 0.9)
niter = 1000
for it in range(niter):
    optimizer.zero_grad()
    fx = ### ... ###
    fx.backward()
    optimizer.step()
    if fx.item()<1e-10:
        print("Convergence reached:")
        print("iteration ", it, "fx = ", fx.item())
        imshow(x);
        break
    if it%(niter//10) == niter//10-1:
        print("iteration ", it, "fx = ", fx.item())
        imshow(x);

## GAN prior for solving the inverse problem
We solve for
$$
\hat x = G(\hat z)
\quad \text{with $\hat z$ solution of}\quad
\min_{z\in\mathbb{R}^{k}} \|A(G(z)) - y \|^2.
$$
Since the result highly depends on the initialization, we will optimize for a batch of $b=10$ $z$ values.

**Exercise:**
Implement a script that:
1. Initialize an array $z = (z^0, \dots, z^{b-1})$ of $b$ random latent code.
2. Optimize $z$ to minimize the sum
$$
\sum_{j=0}^{b-1} \|A(G(z^j)) - y \|^2
$$
using ```optim.Adam([z], lr = 0.01)``` as optimizer for ```niter = 10**4```.
3. Display the $b=10$ corresponding images $G(z^j)$ at initialization and at 10 intermediary steps as well as the iteration number and the value of the function to optimize.

In [None]:

# initialize z
### ... ###

optimizer = optim.Adam([z], lr = .01)

print('Solution:')
imshow(x0, False);
niter = 10**4
losslist = []

for it in range(niter):

    optimizer.zero_grad()
    fx = ### ... ###
    fx.backward()
    optimizer.step()
    losslist.append(fx.item())

    if it==0 or it%(niter//10) == niter//10-1:
        print("iteration ", it, "fx = ", fx.item())
        show(G_net,z=z)

plt.plot(losslist)

# Repeat the experiment with a DCGAN pre-learned with WGAN-GP

In [None]:
nz = 100
ngf = 64

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels = nz, out_channels = ngf * 8, kernel_size = 4, stride = 1, padding = 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(in_channels = ngf * 8, out_channels = ngf * 4, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(in_channels = ngf * 4, out_channels = ngf * 2, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(in_channels = ngf * 2, out_channels = ngf, kernel_size = 4, stride = 2, padding = 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(in_channels = ngf, out_channels = 1, kernel_size=1, stride=1, padding=2, bias=False),
            nn.Tanh()
            # output size. 1 x 28 x 28
        )

    def forward(self, input):
        return self.main(input)

# download a prelearned DCGAN (see Practical Session 3)
G = Generator().to(device)
G.load_state_dict(torch.hub.load_state_dict_from_url('https://perso.telecom-paristech.fr/aleclaire/mva/tp/wgan_epoch100.pt', progress=False))
G.eval();  # Turn generator in evaluation mode to fix BatchNorm layers

G_net.eval()
G.requires_grad_(False)
zt = torch.randn(b,nz,1,1).to(device)
show(G,zt);


In [None]:
### ... ###

## Adjust the exercise for simple denoising ($A= \mathsf{Id}$)

Remark that it is equivalent to a simple GAN inversion.

On a batch of sampled images $x$, compute the average value of $\|G(z_*) - x\|$.

In [None]:
### ... ###

## Repeat the exercise with the operator A defined in the next cell

We will now define a new operator $A$ that sums the values of the gray-level image along each vertical, horizontal and diagonal directions.

In [None]:
# We suppose that:
#  - the input x is a square gray-level image of size 1xMxN with M=N
#  - the output y is 1D tensor of size number of measurements m

def sum_all_diagonal_matrix(mat: torch.tensor):
    # from: https://stackoverflow.com/questions/57347896/sum-all-diagonals-in-feature-maps-in-parallel-in-pytorch
    n,_ = mat.shape
    zero_mat = torch.zeros((n, n),device=mat.device) # Zero matrix used for padding
    mat_padded =  torch.cat((zero_mat, mat, zero_mat), 1) # pads the matrix on left and right
    mat_strided = mat_padded.as_strided((n, 2*n), (3*n + 1, 1)) # Change the strides
    sum_diags = torch.sum(mat_strided, 0) # Sums the resulting matrix's columns
    return(sum_diags[1:])

def axial_and_diagonal_sum(x):
    # sum over diagoanal:
    _,M,N = x.shape
    xmat = x.reshape(M,N)
    yhori = torch.sum(x, axis=2).flatten()
    yvert = torch.sum(x, axis=1).flatten()
    ydiag = sum_all_diagonal_matrix(xmat).flatten()
    y_anti_diag = sum_all_diagonal_matrix(xmat.flip(1)).flatten()
    y = torch.cat((yhori, yvert, ydiag, y_anti_diag.flip(0)))
    return(y)

# test of axial_and_diagonal_sum(x)
t = torch.diag(1+torch.arange(4)).unsqueeze(0)
print("Test of axial_and_diagonal_sum(x):")
print("Input:", t)
print("Output:", axial_and_diagonal_sum(t))

opA = axial_and_diagonal_sum

    def batchopA(x):
    # apply opA to each image of a batch and return a tensor:
    listAx = []
    for bidx in range(x.shape[0]):
        listAx.append(opA(x[bidx,:,:]))
    Ax = torch.stack(listAx)
    return(Ax)


### ... Try y = Ax ... ##

# print('Plot of y: (yhori, yvert, ydiag, y_anti_diag)')
# plt.figure(figsize=(20,4))
# plt.bar(range(y.numel()), y.to('cpu').numpy());
