# HOMEWORK 2
For this homework you will have to complete and implement the colour balancing for:
* Gray world algorithm
* Scale-by-max algorithm

You are free to use your own images. Experiment with more images and think about the effect each of the algorithms has on the resulting (balanced) image.

### Colour Balancing
In this notebook we will show different type of colour balancing making use of von Kries' hypothesis.

In [2]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [15, 5]

In [None]:
img = cv2.imread('./data/sea.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img);

### White patch
In white patch algorithm we choose a group of pixels we know they should be white. We then scale the resulting image colour channels by this white patch.

In [None]:
# Define white patch and the coefficients
row, col = 485, 864 
white = img[row, col, :]
coeffs = 255.0/white

# Apply white balancing and generate balanced image
balanced = np.zeros_like(img, dtype=np.float32)
for channel in range(3):
    balanced[..., channel] = img[..., channel] * coeffs[channel]

# White patching does not guarantee that the dynamic range is preserved, images must be clipped.
balanced = balanced/255
balanced[balanced > 1] = 1

plt.subplot(121), plt.imshow(img)
plt.subplot(122), plt.imshow(balanced);

### Gray world
This algorithm assumes that a scene, on average, is gray.

In [None]:
# Load your image
img = cv2.imread('./data/unbalanced_light.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Compute the mean values for all three colour channels (red, green, blue)
mean_r = np.mean(img[:, :, 0])
mean_g = np.mean(img[:, :, 1])
mean_b = np.mean(img[:, :, 2])


def calibrate(image, pivot_channel = None):
    """Compute the coefficients kr, kg, kb
    Note: there are 3 coefficients to compute but we only have 2 equations.
    Therefore, you have to make an assumption, fix the value of one of the
    coefficients and compute the remining two
    pivot_channel selects which channel is set to 1."""
    means = [np.mean(img[:, :, 0]), np.mean(img[:, :, 1]), np.mean(img[:, :, 2])]
    if pivot_channel is None:
        pivot_channel = means.index(max(means))

    ks = [means[pivot_channel] / m for m in means]
    return np.clip((img @ np.diag(ks)).astype(int), 0, 255)


# Apply color balancing and generate the balanced image. Show the original and the balanced image side by side
plt.rcParams['figure.figsize'] = [15, 12]
plt.subplot(221), plt.imshow(img)
plt.subplot(222), plt.imshow(calibrate(img))
plt.subplot(223), plt.imshow(calibrate(img, 1))
plt.subplot(224), plt.imshow(calibrate(img, 2));

### Scale-by-max
This is a straightforward algorithm that scales each colour channel by its maximum value. Note that it is sensitive to noise and saturations.

In [None]:
# Load your image
img = cv2.imread('./data/sea.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Compute the maximum values for all three colour channels (red, green, blue)
max_r = np.max(img[:, :, 0])
max_g = np.max(img[:, :, 1])
max_b = np.max(img[:, :, 2])

print(max_r, max_g, max_b)

# Apply scale-by-max balancing and generate the balanced image
balanced = img @ np.diag([1 / max_r, 1 / max_g, 1 / max_b])

plt.subplot(121), plt.imshow(img)
plt.subplot(122), plt.imshow(balanced);