# Homework 5 - Depth from Focus

The goal of this homework is to explore the focus properties of images captured by a device where the focus-distance can be tuned (e.g. a DSLR camera). You will write a program that takes a sequence of images captured with different focus settings, use these images to find the depth for each pixel in the scene, and then use this depth map to estimate an all-in focus image. 

![](pics/example_scene.png)


Figure 1: An example scene to capture depth from focus

![](pics/focus_left.png)![](pics/focus_right.png)

Figure 2: Two images from a focal stack sequence of the scene from Figure 1. In the left image the foremost object is in focus (the cat). In the right image, the focus is in the background.

In [None]:
import os
import cv2
import numpy as np
import glob
from skimage import color
import skimage 
import operator
import scipy
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.cm as cm
from skimage.color import rgb2gray
import matplotlib.patches as patches
import skimage.morphology
from scipy import ndimage, misc
from skimage import data, transform, exposure
from skimage.util import compare_images

import matplotlib.pyplot as plt

from pathlib import Path

# Information on autoreload
%load_ext autoreload
%autoreload 2

# Import the code that you have to implement

# How to complete this assignment
<span style="color:white">*To complete the assignment, you will solve 5 problems, each of which consists of a coding task and a writing task that is documented in this notebook. To achieve a passing grade, you must succesfully complete all of the coding tasks AND you must write up your results in a well-documented report.* </span>   

## <span style="color:red">Coding Tasks: </span>
You will implement functions in the following python files:
- src/code.py

In [None]:
import src.code as code # these are the functions you have to implement
import src.util as util # these are functions we will provide for you

## <span style="color:blue">Writing Tasks: </span>

All writing tasks necessary for a complete report are summarized at the end of this notebook.

Additional questions (marked with "*Questions*") are posed between the different coding tasks. These questions are mainly a help for you. You should try to answer them for yourself to familiarize yourself with a specific problem. In many cases, these questions also push you in the right direction.

**Important: Questions that are not posed (again) at the end of the notebook do not have to be answered in the report!**



# <span style="color:orange">Problem 1: Load dataset and explore </span>

In [None]:
path = 'Data//scene1//' # There are 2 more scenes available in the data folder

# Data downloaded from https://www.eecs.yorku.ca/~abuolaim/eccv_2018_autofocus/dataset.html

In [None]:
# Load dataset (implemened for you)

imgs,diopters = util.load_focal_stack_data(path,0.5)

In [None]:
N_imgs,dim_x,dimY,_ = imgs.shape

print(imgs.shape)
print(diopters.shape)

### <span style="color:orange">Check - Load Data : Did you read the files in correct order?</span>
Reading filenames can be tricky since different OS might order them differently. It's important that the focal stack is read in the correct order. One way to check this is by plotting the read diopters and see if they give a straight line.

In [None]:
# Make sure that the plot shows line! If not, you messed up the loading part in the images. Order is important in this assignment!

plt.figure(figsize=(15,5))
plt.subplot(121)

plt.plot(diopters,'.')
plt.xlabel("Image Index")
plt.ylabel("Diopter")

plt.subplot(122)

# 1/diopters will give you the direct focus disance

focus_distances = 1/diopters
focus_distances = np.round(focus_distances, 3)
plt.plot(np.arange(0,len(focus_distances)),focus_distances,'.')

plt.title("Focus Distance - m")

plt.tight_layout()


util.save_fig_as_png('plot_focus_distance_and_diopter.png')

### <span style="color:orange">Check - Investigate the scene</span>

After plotting an example image (see below), look at the scene and ask yourself the questions:
 - Which parts are in focus?
 - What part of the image contains high frequency image features? 
 - Try to anticpiate where the All-in-Focus might be problematic

In [None]:
# Let's plot an example of the scene
plt.figure(figsize=(15,15))
plt.imshow(imgs[0,:,:,:])
plt.title("Example Image")

In [None]:
print(imgs.shape)
print(diopters.shape)

### <span style="color:orange">Cut away data too close and too far</span>


In [None]:
# We're cutting a few images at the very end (focusing close to the camera) and very first (focusing really far away)

imgs = imgs[3:45,:,:,:]
diopters = diopters[3:45]


focus_distances = 1/diopters
print(focus_distances.shape)
focus_distances = np.round(focus_distances, 3)

### <span style="color:orange">Explore the focal stack</span>

You'll now have to implement the code.get_patch_ranges() function. You will have to set some reasonable regions. Look at the comments in the function for further instructions.

In [None]:
my_patches = code.get_patch_ranges()

In [None]:
# Investigat my_patches a bit
print(my_patches[2,:])
print(my_patches.shape)

In [None]:
util.display_patches(imgs[0,:,:,:],my_patches)

util.save_fig_as_png('display_chosen_patches.png')

## Now display the different regions

First, we look at the focal stack of the complete image in a large gridded view. The function util.plot_images_grid()  shows the focal stack in an image grid.

In [None]:
# You might need to adapt the size so that it fits well
plt.figure(figsize=(15,9))
# This takes the default parameters. Hence it should display the complete images as well as a 3x3 grid of images
util.plot_images_grid(imgs,focus_distances)

util.save_fig_as_png("focal_stack_full")

### Now visualize it for cropped regions of your patches so that you can exlore the focal stack a bit more in detail

Note: Make sure that your plot_images_grid can handle all parameters. I.e. you should be able to run the cell from above and the cell below without making any changes to the code

In [None]:
m_plots = 5
n_plots = 4

for k in range(my_patches.shape[1]):
    plt.figure(figsize=(15,15))
    xmin = my_patches[k,0]
    ymin = my_patches[k,2]
    xmax = my_patches[k,1]
    ymax = my_patches[k,3]
    util.plot_images_grid(imgs,focus_distances,m_plots,n_plots,xmin,xmax,ymin,ymax)

# <span style="color:orange">Problem 2: Correct for magnification</span>


You may have noticed that there is a small change in magnification that occurs as you change the focus. Once you have captured a focal stack of N images, you will need to compensate for this small change in magnification.

*Questions:* How can you see this in the images? What is an indicator for this?

The following snippet might help you realize this better:

In [None]:
# Visualize why you have to rescale

plt.figure(figsize=(20,20))
plt.subplot(121)

plt.imshow(imgs[0,:,:,:])
plt.title("Focused at " + str(focus_distances[0]) + ' m',fontsize=20)
plt.axis('off')

plt.subplot(122)

plt.imshow(imgs[-1,:,:,:])
plt.title("Focused at " + str(focus_distances[-1]) + ' m',fontsize=20)
plt.axis('off')

plt.tight_layout()

util.save_fig_as_png("two_images_at_different_focus_distances")


Let the focal stack be represented by $I(x,y,k)$, where $k \in [0,N-1]$ is the index into the focal stack and $(x,y)$ are pixel coordinates. Let $v_k$ be the set of focal distances (in meters) used in your experiments in part 1. You can calculate the lens-to-sensor distance during each exposure $u_k$ using the Gaussian Lens Law. The camera used has a focal length of $5.2 mm$.

<h1><center> $ \frac{1}{u_k} = \frac{1}{f} - \frac{1}{v_k}$ <\center><\h1>

## Derive the formula for u_k in the function.

<h1><center> $ u_k = \text{You have to derive this}$ <\center><\h1>

## With u_k you can calculate the magnification factor for every focal distance by:

<h1><center> $ m_k = \frac{u_N}{u_k}$ <\center><\h1>

In [None]:
# Define a millimeter
mm = 1/1000.0
# Focal Length of 5.2 mm
focal_length = 5.2*mm 

magnifications = code.calculate_magnification_factor(diopters,focal_length)

#### Make sure that the plot below looks similair to the image provided in the examples

In [None]:
plt.plot(focus_distances,magnifications)
plt.xlabel("Focus distance [m]")
plt.ylabel("Magnification factor")

util.save_fig_as_png("focus_distance_vs_magnification")

### <span style="color:orange">Task: Implement a crop function</span>

First we will implement a function that crops an image from the center of an image.
This function becomes very handy in the magnification-correction function. First we have to rescale the image, then we have to crop it back to the original resolution.

In [None]:
test =  skimage.data.astronaut()

plt.imshow(test)
print(test.shape)

In [None]:
# The code.crop_image will crop an image at center that has the dimension defined by bounding
bounding = (128,128)

In [None]:
test_cropped = code.crop_image(test, bounding)

# This image should be cropped exactly at the center of the image
print(test_cropped.shape)
plt.imshow(test_cropped)


### <span style="color:orange">Task: Implement scale image and crop in center</span>

In [None]:
# Equipped with the crop_center image method you can now implement the code.scale_image_and_crop method that will allow you to compensate for the slight change in magnification.
# You have to think about how to implement this! We leave this up to you figure out

In [None]:
print(imgs[0,:,:,:].shape)
out = code.scale_image_and_crop(imgs[0,:,:,:],magnifications[0])
print(out.shape)

In [None]:
# Let's test if the magnification correction has worked

plt.figure(figsize=(20,22))

plt.subplot(131)

plt.imshow(imgs[0,:,:,:])
plt.title("Focused at " + str(focus_distances[0]) + 'm - Before correction',fontsize=20)
plt.axis('off')

plt.subplot(132)

plt.imshow(out)
plt.title("Focused at " + str(focus_distances[0]) + 'm - After correction',fontsize=20)
plt.axis('off')

plt.subplot(133)

plt.imshow(imgs[-1,:,:,:])
plt.title("Focused at " + str(focus_distances[-1]) + 'm',fontsize=20)
plt.axis('off')

plt.tight_layout()


util.save_fig_as_png("after_magnification_correction")

*Question for the report:* How do you check if the magnification correction has worked? At which parts in the image can this be seen?

### <span style="color:orange">Task: Implement a function that scales all images and crops them in the center</span>

In [None]:
# Now that we know how to correct one image, we have to implement a function that corrects for all images:
imgs_corrected = code.correct_focal_stack_scaling(imgs,diopters,focal_length)

In [None]:
print(imgs.shape) # Should be same size
print(imgs_corrected.shape) # should be same size

### <span style="color:orange">Task: Visualize the quality of the rescaling operations</span>

Use the following website to get inspiration from:
https://scikit-image.org/docs/dev/auto_examples/applications/plot_image_comparison.html

### Checkerboard

In [None]:
img1 = imgs_corrected[0,:,:,:]
img2 = imgs_corrected[-1,:,:,:] # If you're not familair with -1 in numpy arrays google it!!!

util.plot_comparison(img1, img2, method='checkerboard')
util.save_fig_as_png("checkerboard_comparison")

util.plot_comparison(img1, img2, method='diff')
util.save_fig_as_png("diff_comparison")

util.plot_comparison(img1, img2, method='blend')
util.save_fig_as_png("blend_comparison")


### <span style="color:orange">Task: Visualize the quality of the rescaling operations with methods from before</span>


In [None]:
m_plots = 5
n_plots = 4

for k in range(my_patches.shape[1]):
    plt.figure(figsize=(15,15))
    xmin = my_patches[k,0]
    ymin = my_patches[k,2]
    xmax = my_patches[k,1]
    ymax = my_patches[k,3]
    util.plot_images_grid(imgs_corrected,focus_distances,m_plots,n_plots,xmin,xmax,ymin,ymax)

# <span style="color:orange">Problem 3: Focus Measure to calculate Depth-map</span>

Once you have corrected your focal stack, you will now write a program to compute a depth map of the scene. We will use the squared laplacian as a focus measure. Our focus measure is: 

<h1><center> $ M(x,y,k) = \sum\limits_{i = x - k }^{x + K} \sum\limits_{i = x - k }^{x + K} |  \nabla^2 I'(i,j,k)|^2 $ <\center><\h1>

where K is a variable that is chosen based on the amount of texture in the scene. In other words, this operation performs a convolution with a PSF that is a square. 

In principle we could also define other structural elements to do the convolution, e.g. a disk which has the property to be radial symmetric. Different structural elements are e.g. shown here: https://scikit-image.org/docs/dev/auto_examples/numpy_operations/plot_structuring_elements.html#sphx-glr-auto-examples-numpy-operations-plot-structuring-elements-py



The depth can then be calculated for each pixel by finding the index into the focal stack  where the focus is maximum:

<h1><center> $ D(x,y) = \text{arg max  } M(x,y,k) $ <\center><\h1>

Computing this equation will tell you which image in the stack a given pixel is in focus. You will have to refer to your experiment in the former part to determine what the actual focal distance was for that image which you can read from the diopters. Use above equations to compute the depth of the scence from the focal your processed earlier. Here are a few guidelines.

1. You can implement the Laplacian operator in Eq. 4 as a linear filter using the following kernel:
<br>
<h1><center>$L_{ij} = \frac{1}{6} \begin{bmatrix}
1 & 4 & 1\\ 
4 & -20 & 4 \\ 
1 & 4 & 1
\end{bmatrix} $ <\center><\h1>
<br>
    
You can then implement the laplacian operator as a convolution with the captured image:
    
<h1><center>$
\nabla^2 I(x,y,k) = \sum\limits_{i = -1}^{1} \sum\limits_{j = -1}^{1} L(i,j)\cdot I(x-i,y-j,k)
$<\center><\h1>

**Note:** Do not write your own convolution code, you can simply use any convolution routine available in many Python packages to do it for you.

In [None]:
L = code.get_laplacian()

plt.imshow(L)
plt.colorbar()
plt.title('Laplacian Kernel')

util.save_fig_as_png("laplacian_kernel")


In [None]:
img = imgs_corrected[0,:,:,:]
img.shape

In [None]:
plt.imshow(img)

In [None]:
K_disk = skimage.morphology.disk(25)
K_square = skimage.morphology.square(25)

plt.figure(figsize=(5,10))
plt.subplot(121)
util.plot_with_colorbar(K_disk)
plt.subplot(122)
util.plot_with_colorbar(K_square)
plt.tight_layout()

### <span style="color:orange">Task: Calculate the laplacian-image</span>

In [None]:
img_laplacian = code.filter_laplacian(rgb2gray(img))

In [None]:
plt.figure(figsize=(15,7))
plt.imshow(np.log(img_laplacian+1))
# We normalize the image between 0 and 1 and we're also using the logarithm for an enhancement of the edges
plt.colorbar()
plt.tight_layout()

util.save_fig_as_png("after_laplacian_filter")


### <span style="color:orange">Task: Filter the image with average-mask</span>

In [None]:
diameter = 50
out_filtered_square = code.filter_element(img_laplacian,diameter,element='square')
out_filtered_disk = code.filter_element(img_laplacian,diameter,element='disk')

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(out_filtered_square)

util.save_fig_as_png("after_average_mask_square")


plt.figure(figsize=(15,15))
plt.imshow(out_filtered_disk)

util.save_fig_as_png("after_average_mask_disk")


### <span style="color:orange">Task: Apply a maximum-filter for the range</span>

In [None]:
# Let's stick with the square structure element from before

out_maximum = code.filter_maximum(out_filtered_square,250)

plt.imshow(out_maximum)

util.save_fig_as_png("maximum_filter")


### <span style="color:orange">Task: Implement a function that performs the complete filtering</span>

Implement a function which combines the laplacian filter and the structural element filtering. Further implement the norm_image() function in code.py.

In [None]:
diameter_element = 300
diameter_max = 50

img_idx = 30
output = code.filter_images(imgs_corrected[img_idx,:,:,:],diameter_element,diameter_max,element='disk')

plt.figure(figsize=(15,10))
plt.imshow(code.norm_image(output))

util.save_fig_as_png("complete_filtering")


### <span style="color:orange">Task: Filter the complete focal stack</span>

In [None]:
imgs_corrected.shape

In [None]:
# Choose the values that you found best from above

#diameter_element = 50
#diameter_max = 100

In [None]:
res = np.zeros(imgs_corrected.shape[0:3]) # Because we get rid of the colors since we only work on the grayscales images now

for k in range(imgs_corrected.shape[0]):
    print("Process Image " + str(k))
    res[k,:,:] = code.filter_images(imgs_corrected[k,:,:,:],diameter_element,diameter_max,element='square')

In [None]:
res.shape

### <span style="color:orange">Task: Visualize the filter maps</span>

In [None]:
plt.figure(figsize=(15,10))
print(focus_distances.shape)
util.plot_filter_maps(res,focus_distances)
plt.tight_layout()

util.save_fig_as_png("filter_maps_all")


# <span style="color:orange">Problem 4: Extract the most focused points to create an All-In-Focus Image</span>

Let us recap what we have achieved by now:
1. We have corrected the focal stack for magnfication
2. We have applied a laplacian filter and an average filter to the focal stack  
  *Question:* Why?

The latest result will be our focus metric, but we haven't analyzed yet how well this has actually worked. That is what we are going to investigate with the next methods.

First we will pick a few coordinates on the image where we extract the focus-metric along the focal-stack.

Before you go and implement this, please think what do you expect? Make sure that you've understood what you will have to plot and what you expect, otherwise you might have troubles debugging your code!

In [None]:
# Method that picks several points on the image that are interesting to investigate more:
x,y = util.get_points_for_focus(imgs.shape[1],imgs.shape[2],3,3)

Now you will use a function that displays the points as an overlay over the image so that you can make sure that you actually have points that are interesting.

In [None]:
plt.figure(figsize=(15,10))
util.plot_chosen_points(img,x,y)

util.save_fig_as_png("plot_chosen_points")

Now we want to investigate how well the focus measure performs over the complete focal stack by plotting the focus-measure curves for the points that you've just chosen

In [None]:
plt.figure(figsize=(15,10))
util.plot_focus_metric_distances(res,focus_distances, x,y)
plt.title("Focus Metric without normalization")

util.save_fig_as_png("plot_focus_metric")


*Question*: Analyze the line plots of the depth maps. Looking at the scene, do they make sense to you ? 

# <span style="color:orange">Problem 5: Recover an all-focus image of the scene</span>

![alt text](pics/depth_map.png)

Figure 3: A depth index map computed from the focal stack of Fig. 2. The depth map was calculated using equations used above with a value of K=5;

Once you have computed a depth map of the scene, you can use it to recover an all-focus image of the scene. The all-focus image $A(x,y)$ can be computed simply as


 <h1><center> $ A(x,y) = I_{corr}(x,y,D(x,y))$ <\center><\h1>

![alt text](pics/all_in_focus.png)

Figure 4: An all-focus image computed from the focal stack of Figure 2 and the depth map of Figure 3.

### <span style="color:orange">Task: Recover the depth map</span>
Use the method described above and look at the function for more information

In [None]:
depth_map,depth_map_idx = code.compute_depth_map(res,focus_distances)

In [None]:
# Note: Those images will probably not look exactly like the images we've provided in the examples! 
# Don't try to get exactly the same images. There are so many different parameters in play
# and trying to get exactly the same is just useless time spent on your side
# You should get images that look conceptually the same i.e. have similair information in depth!

plt.figure(figsize=(15,10))
util.plot_with_colorbar(depth_map,vmax=2)
plt.tight_layout()

util.save_fig_as_png("computed_depth_map")


plt.figure(figsize=(15,10))
util.plot_with_colorbar(depth_map_idx)
plt.tight_layout()

util.save_fig_as_png("computed_depth_map_indices")


### <span style="color:orange">Task: Calculate the All-in-Focus Image</span>

In [None]:
all_in_focus = code.combine_to_compute_all_in_focus(imgs_corrected,depth_map_idx)

In [None]:
plt.figure(figsize=(15,10))
plt.imshow(all_in_focus)

util.save_fig_as_png("all_in_focus_image")

# <span style="color:orange">Problem 6: Try a different dataset</span>

In [None]:
# Path
path = 'Data//scene6//' # There are 2 more scenes available in the data folder
imgs,diopters = util.load_focal_stack_data(path)
focus_distances = 1/diopters
print("Images loaded")

In [None]:
imgs_corrected = code.correct_focal_stack_scaling(imgs,diopters,focal_length)

res = np.zeros(imgs_corrected.shape[0:3]) # Because we get rid of the colors since we only work on the grayscales images now
for k in range(imgs_corrected.shape[0]):
    print("Process Image " + str(k))
    res[k,:,:] = code.filter_images(imgs_corrected[k,:,:,:],diameter_element,diameter_max,element='square')

In [None]:
depth_map,depth_map_idx = code.compute_depth_map(res,focus_distances)
all_in_focus = code.combine_to_compute_all_in_focus(imgs_corrected,depth_map_idx)

In [None]:
plt.figure(figsize=(15,10))
plt.imshow(all_in_focus)







# <span style="color:blue">Writing Tasks: </span>

Please refer to the latex template for your writeups. Your latex report should document:

### <span style="color:blue">1. Abstract: </span>
3-4 sentences motivating the Depth-of-Field problem and the solution approach that this assignment introduces.

### <span style="color:blue">2. Introduction: </span>
This section should answer: What is depth of field (DoF)? Why is a limited DoF problematic for some imaging applications? How can the DoF be extended by non-computational means? Explain the tradeoff between DoF and lateral resolution of an image taken with an optical system. Why is a limited DoF desired in portrait photography?

### <span style="color:blue">3. Methods: </span>
In this section you describe the solution approaches that you pursued to solve Problems 1-6.
Your descriptions should include the answers to the following questions:

**Problem 1:** Why does it make sense to visualize the focal stack in a an image grid? Can you think of other useful visualisation methods?

**Problem 2:** Why do you observe the mentioned change in maginification (explain with a formula)? At which parts of the images in the focal stack is this change obvious? Why do you need to correct for that and what happens if you don't? What have you done to correct for the change in magnification in the focal step? Describe the method in your own words. Did it work?
We introduced 3 different methods to visually check the maginification correction. What is in your opinion the best method? Explain why (Hint: look at spatial features)!

**Problem 3:** What happens if you convolve an image with a Laplacian kernel? How can this property be used to determine wether an image region is in focus or not? Can you think of a scenario where this approach fails?

**Problem 4:** What do you see in the focus metric? Do the results of your calculated focus metric make sense to you? Describe if your result is correct by using the example of two concrete points in the scene. Why did you pick exactly these two points? Is there also a point where the focus metric does not tell you much (failure case)?

**Problem 5:** What is a depth map and how does it help you to calculate an all in focus image? What did you do to calculate the all in focus image? Why did you decide for this approach?

**Problem 6:** Why did you pick the dataset that you have picked and not the other one? Did everything work as expected from the first attempt on? If not, why?


### <span style="color:blue">4. Results: </span>
Not every image needs to be shown. Only show your best results or the results that give you most insight in a specific topic that you want to discuss. Choose well which images you want to include to form a coherent story. 

### <span style="color:blue">5. Conclusion: </span> 
A very short summary of what you've learned and what you think about the assignment


