<img src=https://brand.uark.edu/_resources/images/UA_Logo_Horizontal.jpg width="400" height="96">

###_Artificial intelligence for image processing and analysis._

# Notebook 2.2 Convolution and image filtering
---
##### The purpose of this notebook is to introduce the mathematical convolution operation, as well as how they apply to image filtering and pattern recognition.



### Required packages
---
##### **_Run this code chunk first. If you encounter an error when trying to run code chunks in this notebook, then first try re-running this chunk._**

In [None]:
# Import all of the necessary packages
import numpy as np
import matplotlib.pyplot as plt
import imageio as io
import torch

# Convolution
---
##### One mathematical opteration that is very important for image processing is the [_convolution_](https://en.wikipedia.org/wiki/Convolution). The formal definition of a _convolution_ is:
> "A mathematical operation on two functions that produces a third function that expresses how the shape of one is modified by the other."
##### In the terms of image processing, a convolution is done to modify an _input image_ in some way, based on some _kernel_. During the convolution operation, the _kernel_ is slid, or _convolved_, across the entire _input image_, ultimately resulting in a modified image.
##### To perform the convolution operation in python, the package `torch` (or pytorch) can be used, which will also be useful for neural networks and will be covered more in-depth later in this module.
##### Below is an example of the convolution operation, where the bottom blue image is the input image, the small moving gray image overlapping the input image is the kernel, and the top green image is the output from the convolution operation.
<img src="https://github.com/vdumoulin/conv_arithmetic/blob/master/gif/same_padding_no_strides.gif?raw=true" alt="same_padding_no_strides.gif">


# Image Filtering
---
##### One common thing to do with images from an image processing standpoint is to use some kind of _kernel_ to filter an image in some way. A few common examples of filtering are to blur an image, sharpen an image, or to decrease the amount of noise in an image. These kernels can range in size and magnitude of the filtering done, and below are examples of these kernels.

### Image blurring
---

In [None]:
# Load an example image
test_image = io.imread("imageio:chelsea.png") / 255

# Generate a 2D Gaussian filter (for image blurring)
kernel_size = 15
x = np.arange(0,kernel_size)
x = x - (kernel_size // 2)
xv, yv = np.meshgrid(x,x)
std = kernel_size/3
gaussian_kernel = np.exp(-(xv*xv + yv*yv) / (2*std*std))
gaussian_kernel = gaussian_kernel / np.sum(gaussian_kernel)

gaussian_filter = np.zeros((kernel_size,kernel_size,3))
gaussian_filter[:,:,0] = gaussian_kernel
gaussian_filter[:,:,1] = gaussian_kernel
gaussian_filter[:,:,2] = gaussian_kernel

# Convolve example image using pytorch (you do not have to concern yourself with how this code works)
torch_test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0).float()
torch_filter = torch.tensor(gaussian_filter).permute(2,0,1).unsqueeze(1).float()

blur_image = torch.nn.functional.conv2d(torch_test_image,torch_filter,groups=3).squeeze().permute(1,2,0)

# Create a figure to display the...
fig = plt.figure(figsize=(20,20))

# Initial Image
ax = fig.add_subplot(1, 2, 1)
ax.set_title('Initial Image')
plot1 = plt.imshow(test_image)

# Blurred Image
ax = fig.add_subplot(1, 2, 2)
ax.set_title('Blurred Image')
plot2 = plt.imshow(blur_image)

### Image sharpening
---

In [None]:
# Load an example image
test_image = io.imread("imageio:chelsea.png") / 255

# Generate sharpening filter (identity - avergaing filter)
kernel_size = 3

# Identity filter
identity_kernel = np.zeros((kernel_size,kernel_size))
identity_kernel[int((kernel_size-1)/2),int((kernel_size-1)/2)] = 2

# Averaging filter
average_kernel = np.ones((kernel_size,kernel_size))
average_kernel /= (kernel_size**2)
average_kernel

# Sharpen filter
sharpen_kernel = identity_kernel - average_kernel
sharpen_kernel

sharpen_filter = np.zeros((kernel_size,kernel_size,3))
sharpen_filter[:,:,0] = sharpen_kernel
sharpen_filter[:,:,1] = sharpen_kernel
sharpen_filter[:,:,2] = sharpen_kernel

# Convolve example image using pytorch (you do not have to concern yourself with how this code works)
torch_test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0).float()
torch_filter = torch.tensor(sharpen_filter).permute(2,0,1).unsqueeze(1).float()

sharp_image = torch.nn.functional.conv2d(torch_test_image,torch_filter,groups=3).squeeze().permute(1,2,0)

# Adjust the image just to prevent a warning message
sharp_image[sharp_image > 1] = 1
sharp_image[sharp_image < 0] = 0

# Create a figure to display the...
fig = plt.figure(figsize=(20,20))

# Initial Image
ax = fig.add_subplot(1, 2, 1)
ax.set_title('Initial Image')
plot1 = plt.imshow(test_image)

# Sharpened Image
ax = fig.add_subplot(1, 2, 2)
ax.set_title('Sharpened Image')
plot2 = plt.imshow(sharp_image)

### Image de-noising
---

In [None]:
# Load an example image
test_image = io.imread("imageio:coins.png") / 255
initial_image = io.imread("imageio:coins.png") / 255
noisy_image = test_image

# Randomly generate noise in the image
number_of_noise_pixels = 1000
x_random = np.random.randint(0,high=test_image.shape[1],size=number_of_noise_pixels)
y_random = np.random.randint(0,high=test_image.shape[0],size=number_of_noise_pixels)

for i in range(number_of_noise_pixels):
  if np.random.random() > 0.5:
    noisy_image[y_random[i],x_random[i]] = 1
  else:
    noisy_image[y_random[i],x_random[i]] = 0

# Manually do convolution (since the median filter is a little more intricate)
kernel_size = 3
kernel_length = kernel_size // 2

denoised_image = np.zeros((test_image.shape[0]-kernel_length*2,test_image.shape[1]-kernel_length*2))
denoised_image.shape

for i in range(kernel_length,noisy_image.shape[0]-kernel_length):
  for j in range(kernel_length,noisy_image.shape[1]-kernel_length):
    denoised_image[i-kernel_length,j-kernel_length] = np.median(noisy_image[i-kernel_length:i+kernel_length,j-kernel_length:j+kernel_length])

# Create a figure to display the...
fig = plt.figure(figsize=(30,30))

# Initial Image
ax = fig.add_subplot(1, 3, 1)
ax.set_title('Initial Image')
plot1 = plt.imshow(initial_image,cmap='gray')

# Noisy Image
ax = fig.add_subplot(1, 3, 2)
ax.set_title('Noisy Image')
plot2 = plt.imshow(noisy_image,cmap='gray')

# De-noised Image
ax = fig.add_subplot(1, 3, 3)
ax.set_title('De-noised Image')
plot3 = plt.imshow(denoised_image,cmap='gray')

### Custom filters
---
##### In the following code chunk, you are able to create a custom filter and convolve it with any of the included images. You are encouraged to try anything out, but here are some examples of filters to get you started (decimals are allowed).
1. Averaging filter (this will blur the image):
$\begin{bmatrix}
  \frac{1}{n} & \frac{1}{n} & \frac{1}{n} \\[0.3em]
  \frac{1}{n} & \frac{1}{n} & \frac{1}{n} \\[0.3em]
  \frac{1}{n} & \frac{1}{n} & \frac{1}{n}
\end{bmatrix} \text{where } n \text{ is the total number of elements in the matrix (9 in this case).} $

2. Inverted colors:
$\begin{bmatrix}
  0 & 0 & 0 \\[0.3em]
  0 & -1 & 0\\[0.3em]
  0 & 0 & 0
\end{bmatrix}$
3. Rough highlight of edges (this will be covered in the next notebook):
$\begin{bmatrix}
  3 & 0 & -3 \\[0.3em]
  0 & 0 & 0\\[0.3em]
  -3 & 0 & 3
\end{bmatrix}$

In [None]:
# Import one-time package
import ipywidgets as widgets
from IPython.display import clear_output

# Include form to picking an image to load, and change the kernel size
# --- Form starts here --- #
#@title Play around with these settings to filter an image with a custom kernel. Change one of these settings to clear the output. { run: "auto" }
image_name = "astronaut.png" #@param ["astronaut.png","chelsea.png","coins.png","wikkie.png"]
kernel_size =  3#@param {type:"integer"}
# --- Form ends here --- #

# Load the sample image within imageio
sample_image = io.imread("imageio:" + image_name) / 255

# Custom kernel
items = [widgets.FloatText(value=0,layout=widgets.Layout(width='50%')) for i in range(kernel_size**2)]
items[int((kernel_size**2//2))].value = 1
items.append(widgets.Button(description="Convolve image!"))
grid = widgets.GridBox(items, layout=widgets.Layout(grid_template_columns=("repeat("+str(kernel_size) + ", 100px)")))
display(grid)
#output = widgets.Output()
#display(grid,output)

# Display the...
fig = plt.figure(figsize=(20,20))

# Initial image
fig.add_subplot(1,2,1)
ax1 = plt.gca()
ax1.set_title('Initial Image')
if len(sample_image.shape) == 3:
  plot1 = plt.imshow(sample_image)
else:
  plot1 = plt.imshow(sample_image,cmap='gray')

# Convole image with custom kernel when 'convolve image' button is clicked
def on_button_clicked(b):
  clear_output(wait=True)
  display(grid)
  
  # Convert input image for convolution
  if len(sample_image.shape) == 2:
    torch_image = torch.tensor(np.expand_dims(sample_image,2)).permute(2,0,1).unsqueeze(0).float()
  else:
    torch_image = torch.tensor(sample_image).permute(2,0,1).unsqueeze(0).float()

  # Extract all values from grid of custom kernel
  custom_kernel = np.zeros((kernel_size**2))
  for i in range(int(kernel_size)**2):
    custom_kernel[i] = items[i].value

  # Prepare custom filter
  custom_kernel = np.reshape(custom_kernel,(kernel_size,kernel_size))
  
  if torch_image.shape[1] == 3:
    custom_filter = np.zeros((kernel_size,kernel_size,3))
    for i in range(3):
      custom_filter[:,:,i] = custom_kernel
    torch_filter = torch.tensor(custom_filter).permute(2,0,1).unsqueeze(1).float()
  else:
    custom_filter = custom_kernel
    torch_filter = torch.tensor(custom_filter).unsqueeze(0).unsqueeze(1).float()

  filtered_image = torch.nn.functional.conv2d(torch_image,torch_filter,groups=torch_image.shape[1]).squeeze()

  if filtered_image.shape[0] == 3:
    filtered_image = filtered_image.permute(1,2,0)

  # Deal with clipping issues
  if np.mean(custom_kernel) < 0:
    filtered_image -= (torch.mean(filtered_image) - torch.mean(torch_image))

  filtered_image[filtered_image>1] = 1
  filtered_image[filtered_image<0] = 0
  
  # Display the...
  fig = plt.figure(figsize=(20,20))

  # Initial image
  fig.add_subplot(1,2,1)
  ax1 = plt.gca()
  ax1.set_title('Initial Image')
  if len(sample_image.shape) == 3:
    plot1 = plt.imshow(sample_image)
  else:
    plot1 = plt.imshow(sample_image,cmap='gray')

  # Filtered image
  fig.add_subplot(1,2,2)
  ax2 = plt.gca()
  ax2.set_title('Filtered Image')
  if len(filtered_image.shape) == 3:
    plot2 = plt.imshow(filtered_image)
  else:
    plot2 = plt.imshow(filtered_image,cmap='gray')
      
items[-1].on_click(on_button_clicked)
