# Part III: Data Driven Approaches

In this tutorial, we want to explore some data driven methods for image reconstruction. We will consider the limited angle CT problem with additional noise on the sinogram. Our test images will be composed of simple shapes which can be loaded from the utils module. Namely, the shapes are created via random weighted norm balls in $\ell^p$ norms, i.e., we consider sets

$$ \left\{x\in\mathbb{R}^2: \left(\sum_{i=1}^2 w_i |x_i - m_i|^p\right)^{1/p} < r\right\} $$

where $r>0, w\in\mathbb{R}^2, m\in\mathbb{R}^2$ are sampled randomly.

Let's look at the shapes this produces.

In [None]:
import matplotlib.pyplot as plt
from utils import random_weighted_norm

IMG_SIZE = 64 # image size is set globally
IMG_KWARGS = {'vmin':0., 'vmax':1., 'cmap':'bone'} # kwargs for plotting
rwn = random_weighted_norm(img_size=IMG_SIZE)

# plot all shapes
P = 5
fig, ax = plt.subplots(1, P, figsize=(20,15))

for i in range(P):
    ax[i].imshow(rwn(p=2**i)[0], **IMG_KWARGS)
    ax[i].set_title('p=' + str(2**i))

## Supervised learning with a post-processing approach

We are interested in the limited angle CT problem. Namely, let $R$ denote the limited-angle Radon operator, then we intend to solve the inverse problem
$$ y = Rx + \epsilon.$$

<div class="alert alert-block alert-info">
<b>Question:</b> How do we incorporate data into our reconstruction method?
</div>

We first focus on so-called supervised learning, we are given a data set 

$$\mathcal{T} = \{(\text{inp}_1,\text{oup}_1), \ldots, (\text{inp}_T, \text{oup}_T)\}$$

of $T$ input-output pairs. Furthermore, we use a **post-processing** approach. This means, we have a mapping $f:X\to X$ and the reconstruction is defined is

$$x = f(R^\dagger y)$$

where $R^\dagger$ denotes the pseudo-inverse or **naive inversion** as explored in Tutorial_01. As we already saw, this inversion can lead to unfavorable results, and therefore we hope that the mapping $f$ removes artifacts and recovers a better solution. Therefore, in this setting the input output pairs consist of 

* $\text{inp}_i$: the input in our cases is the noisy sinogram data of the $i$th ground truth image,
* $\text{oup}_i$: the output in our case is the $i$th ground truth image.


### Some technical details

For the learning setup we now switch the array/tensor backend from ```numpy``` to ```torch```. Furthermore, images will now have additional dimensions. Namely, the images $x$ will be of dimension

$$
B \times C \times N\times N
$$

where

* $B$ denotes the batch size, i.e., how many images we process simultaneously in one tensor,
* $C$ denotes the number of channels, we always use $C=1$, but for RGB it would be $C=3$,
* $N$ the image dimension, in our case given by the global variable ```IMG_SIZE```.

The following commands are relevant, when we switch between ```numpy``` and ```torch```:

* ```torch.tensor(x)```: can be used to convert a ```numpy``` array, to a ```torch``` tensor.
* ```x.numpy()```: this converts a ```torch``` tensor, without a gradient, to a ```numpy``` array. Commands like ```plt.imshow``` do this internally.
* ```x.detach()```: if the tensor ```x``` has a gradient, then we need to detach it first.
* ```x[b,c,...]```: ```plt.imshow``` only works for 2D arrays, so we have to select one batch element $b$ and a channel $c$.

## The forward operator

As already mentioned, our forward operator is the limited angle Radon transform. Here, we utilize the same functionality as before, and integrate it in the torch framework.

Here's a short discussion of this approach:

* Disadvantages:
    * Not a clean native ```torch``` solution, conversions between ```numpy``` and ```torch``` yield additional overhead. However, since we do not intend to work on the GPU this is acceptable.
    * The underlying Radon functions are called from ```skimage```, which do not support batched input. In our case, we simply wrap a for-loop around the batch dimension. Performance-wise this is really bad, since such loops in Python are slow :(

* Advantages:
    * We can simply reutilize the code from before without changing much :) It's a quick solution, yielding a bit slower code.

In [None]:
import torch
import numpy as np
from operators import Radon

NUM_THETAS = 10
ANGLES = (0, 90)
theta = np.linspace(ANGLES[0], ANGLES[1], endpoint = False, num=NUM_THETAS)

class Radon_torch:
    def __init__(self, **kwargs):
        self.R = Radon(**kwargs)
    
    def __call__(self, x, pbar=False):
        x = x.numpy()
        k = torch.zeros((x.shape[0], x.shape[2], self.R.num_theta))
        for i in (trange(x.shape[0]) if pbar else range(x.shape[0])):
            k[i] = torch.tensor(self.R(x[i,0,...]))

        return k

    def inverse(self, k, pbar=False):
        k = k.numpy()
        x = torch.zeros((k.shape[0], 1, k.shape[1], k.shape[1]))
        for i in (trange(k.shape[0]) if pbar else range(x.shape[0])):
            x[i] = torch.tensor(self.R.inverse(k[i]))
        return x

R = Radon_torch(theta=theta)

## Create Dataset

Now we create our own dataset. As mentioned before it consists of input-output pairs:

$$(Rx + \epsilon, x)$$

for randomly sampled images $x$. For utility functions, we save this data as a ```torch.utils.data.TensorDataset``` and create a loader with ```torch.utils.data.DataLoader```. This loader will later provide the functionality to easily iterate over the images.

In [None]:
from tqdm.notebook import trange, tqdm

train_p = 2
num_images = 1000
noise_lvl = 0.02
x = torch.tensor(rwn(p=train_p, B=num_images), dtype=torch.float)[:, None, ...]
x_recon = np.zeros((num_images, IMG_SIZE, IMG_SIZE))
data = R(x, pbar=True) + noise_lvl * torch.normal(0, 1, size=(num_images, IMG_SIZE, NUM_THETAS))

train_dataset = torch.utils.data.TensorDataset(data, x)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=5, shuffle=True)

## Let's look at our Data!

In order to get data from the loader, we first transform it to an iterator with ```iter(train_loader)```. We obtain the first batch with ```next(iter(train_loader))```.

In [None]:
fig, ax = plt.subplots(1,3, figsize = (15,8))
data, x = next(iter(train_loader)) #get_data(5, float('inf'), 0.01)

for i, (z, title, kwargs) in enumerate([(data[0,...], 'Noisy data', {'cmap':'bone'}),
                                        (R.inverse(data)[0,0,...], 'Naive inversion', IMG_KWARGS),
                                        (x.detach()[0,0,...],'Target', IMG_KWARGS),
                                        ]):
                                        
    ax[i].imshow(z, **kwargs)
    ax[i].set_title(title)

# Loading the model

We now define the neural network model, we want to train in the following. Here, we use the celebrated UNet structure from this paper:

<center>
*Ronneberger, O., Fischer, P., & Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18 (pp. 234-241). Springer International Publishing.*
</center>


The model architecture is reimplemented (and slightly compressed) in the ```models``` module. We will now load the model and check how many parameters we will train in the following. 

**Spoiler**: Quite a lot.

In [None]:
from models import UNet
model = UNet()

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
num_params = sum([np.prod(p.size()) for p in model_parameters])

print('Loaded the model with ' + str(num_params) + ' trainable parameters')

## The reconstruction operator

Our reconstruction operator can now be defined as $f_\theta \circ R^\dagger$, where $f_\theta$ denotes the post-processing network with parameters $\theta$.

In [None]:
def recon(k, R):
    Rinv = R.inverse(k)
    return model(Rinv), Rinv

## How is the performance before training?

In [None]:
fig, ax = plt.subplots(1,3, figsize = (15,8))
data, x = next(iter(train_loader))
x_recon, Rinv = recon(data, R)

for i, (z, title) in enumerate([(x.detach()[0,0,...],'Target'), 
                                (Rinv[0,0,...], 'Naive inversion (network input)'), 
                                (x_recon.detach()[0,0,...], 'Model recon, before training')]):
    ax[i].imshow(z, **IMG_KWARGS)
    ax[i].set_title(title)

# Training the model
In order to train the model we consider the following minimization problem:

$$
\min_\theta \mathbb{E}_{(x,y)\sim \mathcal{T}}\left[ \ell(f_\theta(x), y)\right]$$

where

* $\theta$ denote the parameter of the neural network $f_\theta$,
* $x$ is the noisy, badly reconstructed input,
* $y$ is the clean ground truth image,
* $\ell$ is the $L^2$ distance, i.e., $\ell(\hat y, y) = \|\hat y - y\|^2$.

In order to solve can use stochastic gradient descent, which yields the update

$$\theta \gets \theta - \alpha \ \nabla_\theta \left(\sum_{i=1}^B \ell(f_\theta(x_i), y_i)\right).$$

Here, $B$ denotes the batch size and $(x_1,y_1),\ldots, (x_B,y_B)$ are the outputs of ```get_data``` in each step. The parameter $\alpha$ denotes the step size.

Alternatively, we employ the ADAM optimizer, which we define in the following cell. Additionally, we define a learning rate scheduler.

In [None]:
import torch.nn as nn

loss_fct = nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=30)

### &#128221; <span style="color:darkorange"> Task 3.1 </span>
#### Training the network

We now define the train loop, that updates the variables $\theta$. Here, we use the ```autograd``` functionality of ```torch```.

In [None]:
epochs = 3
for e in trange(epochs, desc="Epochs"):
    bbar = tqdm(total=len(train_loader), desc="Batches in epoch: " + str(e))
    for data, x in iter(train_loader):
        opt.zero_grad() # zero out gradients from previous step
        x_recon, Rinv = recon(data, R) # compute the pseudo inverse and the post processed reconstruction
        loss = loss_fct(x_recon, x) # compute the loss
        loss.backward() # compute the gradients
        opt.step() # make a step of the optimizer
        scheduler.step(loss) # make a scheduler step
    
        # additonal computations and plotting
        loss1 = loss.item()
        loss2 = loss_fct(Rinv, x).item()
    
        print(30*'-')
        print('Epoch: ' + str(e))
        print('Current loss:' + str(loss1))
        print('Naive inversion loss:' + str(loss2))
        for param_group in opt.param_groups:
            print('Current lr:' + str(param_group['lr']))
        bbar.update(1)

## How does the reconstruction look now?

In [None]:
fig, ax = plt.subplots(1,3, figsize = (15,8))
data, x = next(iter(train_loader)) #get_data(5, float('inf'), 0.01)
x_recon, Rinv = recon(data, R)

for i, (z, title) in enumerate([(x.detach()[0,0,...],'Target'), 
                                (Rinv[0,0,...], 'Naive inversion (network input)'), 
                                (x_recon.detach()[0,0,...], 'Model recon, after training')]):
    ax[i].imshow(z, **IMG_KWARGS)
    ax[i].set_title(title)

## Saving and loading the models
If you want, you can save or load models with the functions below.

In [None]:
import datetime
name = 'UNet-train-p-' + str(train_p) + str(datetime.datetime.now().strftime('%Y%m%dT%H%M%SZ')) + '.pt'
save_model = True
if save_model:
    torch.save(model.state_dict(), name)

In [None]:
load_model = False
name = 'UNet-train-p-220240916T171815Z.pt'
if load_model:
    model.load_state_dict(torch.load(name))

### &#128221; <span style="color:darkorange"> Task 3.2 </span>
#### Evaluating the results

The cell below allows you to test the performance of your trained model.

Your task is to try out different shapes, noise levels and angle specification and evaluate the performance of your model :)

In [None]:
# %% Test
from utils import get_phantom
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

img_size_test=64
im_kwargs = {'vmin':0., 'vmax':1., 'cmap':'bone'}

def plot_result(logp, noise_lvl, angle):
    theta = np.linspace(angle[0],angle[1], endpoint = False, num=NUM_THETAS)
    R = Radon_torch(theta=theta)
    p = 2**logp if logp < 10 else float('inf')
    x = torch.tensor(rwn(p=p))
    
    data = R(x[:,None,...]) + noise_lvl * torch.normal(0, 1, size=(IMG_SIZE, NUM_THETAS))
    
    x_recon, Rinv = recon(data, R)
    x_recon = x_recon.detach().numpy()[0,0,...]
    
    fig, ax = plt.subplots(1,3, figsize = (20,15))

    for i, (z, title) in enumerate([(x[0],'Ground truth'), (Rinv[0,0,...], 'Naive recon'), (x_recon,'Network recon')]):  
        ax[i].imshow(z, **im_kwargs)
        ax[i].set_title(title + ', error: '+str(round(np.linalg.norm(x - z), 4)))
    ax[0].set_title('p=' +str(p))


p_slider = widgets.FloatSlider(min = 0.0, max = 10., step = 1, value = np.log(train_p)/np.log(2), continuous_update = False)
n_slider = widgets.FloatSlider(min = 0.0, max = .03, step = 0.001, value = 0.01, continuous_update = False)
a_slider = widgets.FloatRangeSlider(value=[0, 90],min=0,max=180,step=10,continuous_update=False)

interactive_plot = interactive(plot_result,logp = p_slider, angle=a_slider, noise_lvl=n_slider)
display(interactive_plot)

# How can we combine model and data driven approaches?

A popular way to combine model and data driven approaches are so-called plug-and-play (PnP) methods. The starting point is a variational minimization problem

$$\min_x \frac{1}{2}\, \|Ax - y\|^2 + \lambda J(x).$$

This problem can be solved via prox-based methods, for example an ADMM update scheme

$$
\begin{align*}
x &\gets \operatorname*{arg min}_{x}\ \frac{1}{2}\, \|Ax - y\|^2 + \frac{\rho}{2} \|{v - u}\|^2,\\
v &\gets \operatorname*{arg min}_{v}\ \lambda J(v) + \frac{\rho}{2} \|{v - (x + u)}\|^2,\\
u &\gets u + x - v.
\end{align*}
$$

Here, the first line can be solved can be solved with a linear solver (e.g. the cg iteration) and the last line ist explicit. The second line is in fact the prox operator of $J$ since 

$$\operatorname{prox}_{\lambda/\rho\ J}(x+u) =  \operatorname*{arg min}_{v}\ \lambda J(v) + \frac{\rho}{2} \|{v - (x + u)}\|^2.$$

Evaluating this prox can be complicated and relies on a possibly hand-crafted functional $J$. The idea of PnP methods consists of replacing a prox step of this kind by an arbitrary map $D$. I.e. the iteration takes the form

$$
\begin{align*}
x &\gets \operatorname*{arg min}_{x}\ \frac{1}{2}\, \|Ax - y\|^2 + \frac{\rho}{2} \|{v - u}\|^2,\\
v &\gets D_\lambda(x + u),\\
u &\gets u + x - v.
\end{align*}
$$

Typically, this function $D$ is a simple denoiser, i.e., a mapping that takes a noisy image and outputs a clear version.

## Training the denoiser

We want to try out the PnP approaches from above. To do so, we first train a denoiser $D$. Here, we use the same setup as before, we just have to change the dataset. Namely, we want the denoiser to minimize the following term,

$$ \mathbb{E}_{x,\varepsilon} \|D(x + \varepsilon) - x\|_2.$$

In [None]:
noise_lvl = 0.2
x    = torch.tensor(rwn(p=2, B=num_images), dtype=torch.float)[:, None, ...]
data = x + noise_lvl *  torch.normal(0., 1., size=x.shape)

train_dataset_denoising = torch.utils.data.TensorDataset(data, x)
train_loader_denoising = torch.utils.data.DataLoader(train_dataset_denoising, batch_size=5, shuffle=True)

## Look at the data again

The data looks slightly different now.

In [None]:
data, x = next(iter(train_loader_denoising))

fig, ax = plt.subplots(1,2, figsize = (10,8))
for i, (z,title) in enumerate([(data, 'Noisy data'), (x, 'Original')]):
    ax[i].imshow(z[0,0,...], **IMG_KWARGS)
    ax[i].set_title(title)

### &#128221; <span style="color:darkorange"> Task 3.3 </span>
#### Train again

With the same set up as before you can now train the denoising model.

In [None]:
denoiser = UNet()
opt = torch.optim.Adam(denoiser.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=30)

In [None]:
epochs = 3

for e in trange(epochs):
    bbar = tqdm(total=len(train_loader_denoising), desc="Batches in epoch")
    for data, x in iter(train_loader_denoising):
        opt.zero_grad() # zero out gradients from previous step
        x_recon = denoiser(data)
        loss = loss_fct(x_recon, x) # compute the loss
        loss.backward() # compute the gradients
        opt.step() # make a step of the optimizer
        scheduler.step(loss) # make a scheduler step
    
        print(30*'-')
        print('Epoch: ' + str(e))
        print('Current Loss:' + str(loss.item()))
        for param_group in opt.param_groups:
            print('Current lr:' + str(param_group['lr']))
        bbar.update(1)

In [None]:
save_model = True
name = 'UNet-denoiser-train-p-' + str(train_p) + str(datetime.datetime.now().strftime('%Y%m%dT%H%M%SZ')) + '.pt'
if save_model:
    torch.save(denoiser.state_dict(), name)

In [None]:
load_model = False
name = 'UNet-denoiser-train-p-220240916T175542Z.pt'
if load_model:
    denoiser.load_state_dict(torch.load(name))

# How well does it work?
The cell below shows the denoising performance of the trained network.

In [None]:
data, x = next(iter(train_loader_denoising))
x_recon = denoiser(data).detach()

fig, ax = plt.subplots(1,3, figsize = (10,8))
for i, (z,title) in enumerate([(x, 'Original'), (data, 'Noisy data'), (x_recon, 'Model recon')]):
    ax[i].imshow(z[0,0,...], **IMG_KWARGS)
    ax[i].set_title(title)

## Defining the PnP prox substitute

Our original goal was to substitute the prox in the admm itertion. In the module ```optimizer``` the optimizer ```admm``` is defined. It has the following signature for initialization:

```admm(R, x0, Rx, rho=0.4, lamda=1., verbosity=0, prox=model_prox, max_it=35, max_inner_it=1)```

where

* ```R``` is the linear operator, i.e. the Radon trafo
* ```x0``` is the inital guess
* ```rho``` is an iteration parameter
* ```lamda``` scales the influence of the prox
* ```prox``` defines the prox mapping
* ```max_it``` determines the umber of steps
* ```max_inner_it``` determines the number of inner iterations

We just have to define the prox maping

In [None]:
def model_quasi_prox(x, lamda):
    lamda = max(min(lamda, 1.),0)
    return (1-lamda) * x + lamda * denoiser(x).detach().numpy()[0,0,...]

### &#128221; <span style="color:darkorange"> Task 3.4 </span>
#### Test the performance on the CT problem

With the following cell you can now test how well the denoiser and the PnP-ADMM iteration performs on the CT reconstruction task. How does the noise level influence the performance?

In [None]:
from optimizers import admm

min_angle = 0
max_angle = 90
theta = np.linspace(min_angle, max_angle, endpoint = False, num=NUM_THETAS)
R_np = Radon(theta=theta)
p = 2

def plot_recon(noise_lvl, logp, lamda):
    p = 2**logp if logp < 10 else float('inf')
    x = rwn(p=p)[0]
    
    data = R_np(x)
    data += noise_lvl * np.random.normal(0,1, size=data.shape)
    x0 = R_np.inverse(data)
    
    
    pnpadmm = admm(R_np, x0, data, rho=0.4, lamda=lamda, verbosity=0, prox=model_quasi_prox, max_it=35, max_inner_it=5)
    pnpadmm.solve()
    
    x_model = denoiser(x0).detach().numpy()[0,0,...]
    fig, ax = plt.subplots(1,4, figsize=(20,15))

    for i, (z, title) in enumerate([(x, 'Original'),
                                    (x0, 'Naive recon'),
                                    (x_model, 'Denoiser output, error: ' + str(np.linalg.norm(x_model - x))),
                                    (pnpadmm.x, 'PnP ADMM Recon, error: ' + str(np.linalg.norm(pnpadmm.x - x))),
                                   ]):
                                    
    
        ax[i].imshow(z,**IMG_KWARGS)
        ax[i].set_title(title)



n_slider = widgets.FloatSlider(min = 0.0, max = .1, step = 0.001, value = 0.02, continuous_update = False)
p_slider = widgets.FloatSlider(min = 0.0, max = 10., step = 1, value = np.log(train_p)/np.log(2), continuous_update = False)
l_slider = widgets.FloatSlider(min = 0.0, max = 5., step = 0.1, value = 5., continuous_update = False)
interactive_plot = interactive(plot_recon, noise_lvl = n_slider, logp = p_slider, lamda = l_slider)
display(interactive_plot)