# Neighborhood-Based Filtering (using pytorch)

## Lab session #1 2023

### with julien.rabin (at) ensicaen.fr


![Logo](Ensicaen-logo.png) 

________________________________
### LastName / Nom : 
### Surname / Prénom : 
### Group :
### Date : 
________________________________

In this notebook, the goal is to implement various filters using the pytorch library :
- [At the beginning](#0---import--load-image) some recalls about loading, reading, manipulating an RGB 8-bit image as a pytorch tensor
- [Local Filters](#a-local-filters) :
    - Gaussian Kernel (with nn.conv) :  [section 1](##-1---Gaussian-filtering-)
    - Approximation with Box Filter (with nn.AvgPool)
    - Comparison with Bilateral Filtering (using nn.UnFold)
    - Application to guided image filtering using Cross-Bilateral Filtering
- [Patch-based Non-Local filters](#b-non-local-patch-based-filters) :
    - Non-Local Means
    - Non-Local PCA
- [Patch-based Auto-Encoder](#c---auto-encoder-patch-processing)


________________________________
<a id='cell_0'></a>
## 0 - Import & Load image


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# for Jupyter notebook
%matplotlib inline 


In [None]:
file_name = 'fake.png' # 
img0 = plt.imread(file_name)
print(img0.max())
if img0.max() > 1. : # jpg -> 255., png -> 1
    img0 = img0 / 255.
plt.figure()
plt.imshow(img0)
plt.title("float RGB format")
plt.axis("off");

#### convert to tensor and resize image (bilinear interpolation)

In [None]:
img0_tsr = torch.tensor(img0).unsqueeze(0).permute(0,3,1,2) # 1 x 3 x height x width 
img_size = 256 # 128
img0_tsr = torch.nn.functional.interpolate(img0_tsr, size=(img_size,img_size), mode='bilinear')
print(img0_tsr.shape)

plt.imshow(img0_tsr[0].permute(1,2,0))
plt.title("show tensor")
plt.axis("off");


### add gaussian noise

In [None]:
img_tsr = img0_tsr.clone() + 30./255. * torch.randn_like(img0_tsr)
img_tsr = torch.clamp(img_tsr,0.,1.)

H = img_tsr.size(2)
W = img_tsr.size(3)

plt.figure(figsize = (10,10))
plt.imshow(torch.cat((img0_tsr,img_tsr),3).squeeze(0).permute(1,2,0))
plt.title("Ground Truth -- Noisy data")

### selection of an interest point for comparison

In [None]:
point = np.array([710,510]) # bout du nez
#point = np.array([502,845]) # bord oreille
                 
interest_point = (point/1024.* img_size).astype(int)
index_interest_point = interest_point[0]*W+interest_point[1]

patch_ori = img_tsr[:,:,interest_point[0]-16:interest_point[0]+17, interest_point[1]-16:interest_point[1]+17].squeeze(0).clone() # do not forget clone() !
patch_ori[:,16,16] = torch.tensor([0,1,0])
patch_ori = patch_ori.permute(1,2,0)
plt.imshow(patch_ori)

# A. Local Filters

<a id='cell_A'></a>
<a id='section_A'></a>

## 1 - Gaussian filtering 
### with nn.functional.conv1d

$$
    \hat u (x) = \frac{1}{C(x)}\sum_{y \in \Omega} u(y) e^{- \tfrac{1}{2\sigma^2} \|x- y\|^2 }
$$

In [None]:
sig_pix = 1.0
# definition du noyau gaussien 1D
k = int(3 * sig_pix)
x = torch.arange(-k,k+1,1) # size : 2k+1
g = ... # definir le noyau gaussien et le normaliser
g /= g.sum()
plt.plot(g)

# convolution with kernel g using torch.nn.functional.conv2d, or twice torch.nn.functional.conv1d ... but not with torch.nn.Conv1d() !
# tip : you can process color channels as a batch
img_tsr_gauss_xy = ...

print(img_tsr_gauss_xy.shape)

res = torch.cat((img0_tsr,img_tsr,img_tsr_gauss_xy),axis=3)
plt.figure(figsize = (10,10))
plt.figure(figsize = (10,10))
plt.imshow(res.squeeze(0).permute(1,2,0))

### display ROI

In [None]:
patch_gauss = img_tsr_gauss_xy[:,:,interest_point[0]-16:interest_point[0]+17, interest_point[1]-16:interest_point[1]+17].squeeze(0).clone()
patch_gauss[:,16,16] = torch.tensor([0,1,0]) # pixel vert
patch_gauss = patch_gauss.permute(1,2,0) 
plt.imshow(torch.cat((patch_ori,patch_gauss), 1))

- the two parameters of this filter are the size of the kernel and the gaussian standard deviation `sig_pix`
- vary these two parameters and display the results in a figure
- compute the difference between the original and the filtered images for different values of `sig_pix`
- what do you conclude from these two experiments ?

## 2 - Comparison with box filters

- change the previous code to implement a box-filter, that is a kernel which is constant over a rectangular domain
- compare with gaussian filtering for large kernel : can you see / show the effect of using an anisotropic kernel ?
- compare your result with torch.nn.AvgPool2d 

$$
\hat u (x) = \frac{1}{(2r+1)^2}\sum_{\|y-x\|_\infty \le r} u(y) 
$$

## 3 - Bilateral Filtering

$$
    \hat u (x) = \frac{1}{C(x)}\sum_{y \in \Omega} u(y) e^{- \tfrac{1}{2\sigma^2} \|x- y\|^2 } e^{- \tfrac{1}{2h^2} \|u(x)- u(y)\|^2 }
$$

### complete the following code to implement the bilateral filter using fold / unfold PyTorch functions

In [None]:
# parameters
half_win_size = 5
win_size = 2*half_win_size+1
sig_pix = 2.
sig_col = 20./255.

d = win_size**2 # patch dim
N = img_tsr.size(2) * img_tsr.size(3) # number of patch (= Number of pixels with appropriate padding)

# precomputation of the gaussian spatial weights
dx = torch.arange(-half_win_size,half_win_size+1,1) # win_size
weight_pix = ... # compute the 2D gaussian weights

# overlapping patch decomposition (sliding window)
patch = torch.nn.Unfold(kernel_size=win_size, dilation=1, padding=half_win_size, stride=1)(img_tsr.view(3,1,img_tsr.size(2),img_tsr.size(3))) # 3 x d x N (color channels as batch)

pix_center = win_size**2//2
patch_diff = ... # using broadcasting, compute the difference between the window color values and the central pixel at index pix_center
weight_color = ... # compute the color weights
weight = weight_pix * weight_color


synth = torch.sum(patch * weight, 1, keepdim=True) / torch.sum(weight, 1, keepdim=True)

# add green pixel to interest point
synth[:,:,index_interest_point] = torch.tensor([0.,1.,0.]).view(3,1)

synth = synth.view(1,3,img_tsr.size(2),img_tsr.size(3))

fig, ax = plt.subplots(1,5, figsize=(20, 20))
ax[0].imshow(color.squeeze(0).permute(1,2,0).squeeze(2)); ax[0].set_title("color map")
ax[1].imshow(weight_pix.view(win_size,win_size)); ax[1].set_title("spatial weight")
ax[2].imshow(weight_color[:,:,index_interest_point].view(win_size,win_size)); ax[2].set_title("color weight")
ax[3].imshow(weight[:,:,index_interest_point].view(win_size,win_size)); ax[3].set_title("weight map")
ax[4].imshow(post_color.squeeze(0).permute(1,2,0)); ax[4].set_title("color after processing")

plt.figure()
plt.imshow(synth.squeeze(0).permute(1,2,0))


In [None]:
patch_bilateral = synth[:,:,interest_point[0]-16:interest_point[0]+17, interest_point[1]-16:interest_point[1]+17].squeeze(0)
patch_bilateral[:,16,16] = torch.tensor([0,1,0])
patch_bilateral = patch_bilateral.permute(1,2,0) 
plt.figure(figsize=(15,5))
plt.imshow(torch.cat((patch_ori,patch_gauss,patch_bilateral), 1))

In [None]:
diff = (synth - img_tsr) # in [-1,1]
diff = torch.clamp(10*diff,-1,1) + 0.5 # in [0,1]
res_bilateral = torch.cat((img0_tsr,img_tsr,synth,diff),axis=3)

plt.figure(figsize=(20,20))
plt.imshow(res_bilateral.squeeze(0).permute(1,2,0))
plt.title(" bilateral filtering")
plt.axis("off")

- the two main parameters of this filter are the gaussian standard deviation `sig_pix` and `sig_col`
- vary these two parameters and display the results in a figure
- compute the difference between the original and the filtered images for different values of `sig_pix`
- what do you conclude from these two experiments ?
- what happen if you change the exponential kernel ? for instance with an indicator function like the box filter
- what happen if you change the color space ? for instance, try Luv or Lab representation

## 3 - Joint / Cross - Bilateral Filtering

see Eisemann and Durand [2004] and Petschnigg et al. [2004] 

$$
    \hat u (x) = \frac{1}{C(x)}\sum_{y \in \Omega} u(y) g_\sigma(y-x) g_h(v(y)-v(x)) 
$$

- adapt the previous code for *joint-bilateral* filtering of an image $u$ using a guide $v$
- Test your guided filter using the image pair 'cakeNo-flash.jpg' and 'cakeFlash.jpg' (or 'cave-flash.bmp' and 'cave-noflash.bmp')
- Compare with simple bilateral filtering (where query, keys and values are the same : $u=v$), can you see some improvement ?
- Do you notice some transfert artefacts ?

________

# B. Non-Local patch-based filters
<a id='part_B'></a>
<!-- Link with #[part_B](#part_B) -->

## 4 - Non Local Filtering

with neighborhood search restriction (win_size x win_size search)

**WARNING** : this algorithm is very slow in python when using nested loops : use a small image for debugging ! (parameter `img_size` at the beginning)

$$
    \hat u (x) = \frac{1}{C(x)}\sum_{y \in W} u(y) e^{- \tfrac{1}{2\varepsilon^2} \|p(x) - p(y)\|^2 }
$$

In [None]:
# parameters
half_patch_size = 1
patch_size = 2*half_patch_size+1
half_win_size = 5
win_size = 2*half_win_size+1

sigma = 10./255. * patch_size

d = 3*patch_size**2 # patch dim
N = img_tsr.size(2) * img_tsr.size(3) # number of patch (= Number of pixels with appropriate padding)

# overlapping patch decomposition (sliding window)
patch = torch.nn.Unfold(kernel_size=patch_size, dilation=1, padding=half_patch_size, stride=1)(img_tsr) # 1 x d x N (no batch)
print(patch.shape)

synth = torch.zeros_like(patch)
dx = torch.arange(-half_win_size,half_win_size+1,1)
dx, dy = torch.meshgrid(dx, dx)
for i in range(N) :
    # fetch query patch at index i
    query = patch[:,:,i].unsqueeze(2).clone() # [1 x d x 1]

    # (x,y) : 2D patch coordinates such that i = x * W + y
    x = ...
    y = ...

    # extraction of patchs in neighborhood [win_size x win_size]
    xx = np.clip(x+dx,0,H-1) 
    yy = np.clip(y+dy,0,W-1)
    I = xx * W + yy # 1D coordinates
    p = patch[0:1,:,I].clone() # [1 x d x win_size x win_size]
    p = p.view(1,d,win_size**2) # 1 x d x N'

    # distance computation between 'query' and 'p'
    patch_dist = ... # with format [1 1 N']

    # weighted mean
    weight_patch = torch.exp( - patch_dist /2. /float(sigma)**2 )  # 1 x 1 x N'
    synth[:,:,i:i+1] = ... # compute the weighted mean of 'p' with 'weight_patch'

    if (x==interest_point[0] and y==interest_point[1]) :
        query = query.view(3,patch_size,patch_size)
        plt.imshow(query.permute(1,2,0).numpy())
        plt.title(f"query patch ({x},{y})")

        weight_patch = weight_patch.view(win_size,win_size)
        #weight_patch[half_win_size, half_win_size] = 0 # discard self comparison
        
        fig, ax = plt.subplots(1,3, figsize=(20, 20))
        ax[0].imshow(color.squeeze(0).permute(1,2,0).squeeze(2)); ax[0].set_title("color map")
        ax[1].imshow(weight_patch.view(win_size,win_size)); ax[1].set_title("weight map")
        ax[2].imshow(post_color.squeeze(0).permute(1,2,0)); ax[2].set_title("color after processing")

# extraction of the central pixel
n = patch_size**2
color_center = [n*c + n//2  for c in range(3)]
synth = synth[:,color_center,:]
synth = synth.view(1,3,H,W)

# show green pixel of interest
x=interest_point[0]; y=interest_point[1]
synth[:,:,x,y] = torch.tensor([0.,1.,0.]).view(1,3)

#plt.figure(); plt.imshow(synth.squeeze(0).permute(1,2,0))


In [None]:
diff = (synth - img_tsr) # in [-1,1]
diff = torch.clamp(10*diff,-1,1) + 0.5 # in [0,1]
res_NLmeans = torch.cat((img0_tsr,img_tsr,synth,diff),axis=3)

plt.figure(figsize=(20,20))
plt.imshow(res_NLmeans.squeeze(0).permute(1,2,0))
plt.title(" NL-means filtering")
plt.axis("off")

In [None]:
patch_NLM = synth[:,:,interest_point[0]-16:interest_point[0]+17, interest_point[1]-16:interest_point[1]+17].squeeze(0)
patch_NLM[:,16,16] = torch.tensor([0,1,0])
patch_NLM = patch_NLM.permute(1,2,0) 
plt.figure(figsize=(20,5))
plt.imshow(torch.cat((patch_ori,patch_gauss,patch_bilateral,patch_NLM), 1))

#### Experiments

- compute the MSE between original image (ground thruth 'img0_tsr') and filtered image 'synth', and the PSNR
- compare with bilateral filtering : what is the interest of NL-means vs bilateral filtering ?
- what are the effect of the two parameters `win_size` and `sigma` ?

## 4.2 fast Non-Local PCA

implement the fast NL-PCA filter

$$
    \hat p(x) = \bar p + \sum_{d=1}^D \rho \left( \langle p(x) - \bar p, v_d \rangle \right) v_d
$$
$$
\text{with } \rho (x) = \begin{cases} x & \text{if } |x| > \tau \\ 0 & \text{otherwise } \end{cases}
$$
where $v_d$ are eigen-vectors

In [None]:
# parameters
half_patch_size = 1
patch_size = 2*half_patch_size+1
half_win_size = 5
win_size = 2*half_win_size+1

sigma = torch.tensor(5./255. * patch_size)

d = 3*patch_size**2 # patch dim
N = img_tsr.size(2) * img_tsr.size(3) # number of patch (= Number of pixels with appropriate padding)

# overlapping patch decomposition (sliding window)
patch = torch.nn.Unfold(kernel_size=patch_size, dilation=1, padding=half_patch_size, stride=1)(img_tsr) # 1 x d x N (no batch)

# computes mean, covariance
patch_mean = ...
patch_centered = (patch - patch_mean) # broadcasting
patch_cov = ... # compute the covariance matrix

# compute PCA using torch.symeig
patch_eig, patch_vect = ... 

In [None]:
# different filtering function
identity =  lambda val, thres : val
hard_thres = lambda val, thres : val * (abs(val) > thres) 
soft_thres = lambda val, thres : torch.sign(val) * torch.max(abs(val) - thres, torch.tensor(0.)) 

In [None]:
x = torch.linspace(-3*sigma,sigma*3,1000)
t = torch.tensor(1.)
plt.figure()
plt.plot(x, identity(x,sigma), '--')
plt.plot(x, hard_thres(x,sigma), '-')
plt.plot(x, soft_thres(x,sigma), '--')
#plt.plot(patch_eig, -1 + torch.zeros_like(patch_eig), 'x')
plt.legend(['identity', 'hard_thresh', 'soft_thres', 'eigenvalues'])

plt.figure()
plt.plot(patch_eig)
plt.title("eigen values")

val = patch_vect.transpose(1,0) @ patch_centered.squeeze(0)
val = val.view(-1)
val, _ = val.sort()
plt.figure()
plt.plot(val)
print(val.shape)
plt.title("projection on eigen vectors")

In [None]:
thres_fun = hard_thres

# perform filtering and reconstruction of patch using : patch_mean, patch_centered, patch_vect and thres_fun()
patch_NLPCA = ...

In [None]:
# aggregation of patches (actual averaging of patch) 
# sum overlapping patch
synth_NLPCA = torch.nn.Fold((img_tsr.size(2),img_tsr.size(3)), patch_size, dilation=1, padding=half_patch_size, stride=1)(patch_NLPCA)
# normalisation by number of overlapping patch
synth_NLPCA /= torch.nn.Fold((img_tsr.size(2),img_tsr.size(3)), patch_size, dilation=1, padding=half_patch_size, stride=1)(torch.ones_like(patch))

In [None]:
diff = (synth_NLPCA - img_tsr) # in [-1,1]
diff = torch.clamp(10*diff,-1,1) + 0.5 # in [0,1]
res_NLPCA = torch.cat((img0_tsr,img_tsr,synth_NLPCA,diff),axis=3)

plt.figure(figsize=(20,20))
plt.imshow(res_NLPCA.squeeze(0).permute(1,2,0))
plt.title(" NL-PCA filtering")
plt.axis("off")

In [None]:
patch_NLPCA = synth_NLPCA[:,:,interest_point[0]-16:interest_point[0]+17, interest_point[1]-16:interest_point[1]+17].squeeze(0)
patch_NLPCA[:,16,16] = torch.tensor([0,1,0])
patch_NLPCA = patch_NLPCA.permute(1,2,0) 
plt.figure(figsize=(20,5))
plt.imshow(torch.cat((patch_ori,patch_bilateral,patch_NLM,patch_NLPCA), 1))
plt.title("noisy / bilateral / NL means / NL PCA")

- Compute the PSNR and compare with the original NL-means filter : what do you notice ?
- test various filtering functions, such as soft_thres (you can also use different activation functions from torch.nn : warning, the function has to be odd : f(-x) = -f(x))
- What happen if you use a different image to compute the PCA ?

______

# C - Auto-Encoder patch processing

the goal of this part is to replace the eigenvectors computed from PCA by two generic Linear layers.
One obtain a shallow neural network that can be trained on the image.


- define a simple MLP, processing directly a batch of patches, and outputing tensors patch with the same dimensions
- train this shallow neural network on the image with a MSE loss function using gradient descent
- aggregate the patch to create a new image 
- test different neural network architectures and other losses