In [9]:
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from typing import Callable, NamedTuple, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import skimage as sk
import skimage.io as skio
from numpy import ndarray
from PIL import Image

# Input images

In [10]:
data_dir, extra_dir, out_dir = Path("data"), Path("extra"), Path("output")
adjust_dir = Path("output/adjust")
out_dir.mkdir(parents=True, exist_ok=True)

low_res_imgs = list(data_dir.glob("*.jpg"))
high_res_imgs = list(data_dir.glob("*.tif"))
extra_imgs = list(extra_dir.glob("*"))

print(f"number of images = {len(low_res_imgs)}")
print(f"number of images = {len(high_res_imgs)}")
print(f"number of images = {len(extra_imgs)}")

number of images = 3
number of images = 11
number of images = 4


# Aligning Channels

## Helper Functions

In [26]:
class Displacement(Pixel):
    """Start of each color as a row."""

    r: Pixel()
    g: Pixel()
    b: Pixel()


class Pixel(NamedTuple):
    row: int = 0
    col: int = 0


class Offset(NamedTuple):
    row: int = 0
    col: int = 0


class ChannelSize(NamedTuple):
    h: int = 0
    w: int = 0


class Pic:
    def __init__(
        self,
        im: ndarray,
        d: Displacement = Displacement(),
        ch_size: ChannelSize = ChannelSize(),
    ) -> None:
        self.im = im
        self.d = d
        self.ch_size = ch_size

In [27]:
def est_channel_height(img: ndarray):
    return img.shape[0] // 3

def est_channel_width(img: ndarray):
    return img.shape[1]

def max_channel_height(img: Pic):
    return min(img.r.shape[0], img.b.shape[0], img.g.shape[0])

def max_channel_width(img: Pic):
    return min(img.r.shape[1], img.b.shape[1], img.g.shape[1])

def channels(img, disp: Displacement, h:int, w:int) -> Tuple[ndarray, ndarray, ndarray]:
    # create channel matrices
    # print(B_start, G_start, R_start, h, w)
    r, g, b = disp.r, disp.g, disp.b
    B_mat = img[b : b + h, :w]
    G_mat = img[g : g + h, :w]
    R_mat = img[r : r + h, :w]
    return R_mat, G_mat, B_mat

In [28]:
# Returns a sub_matrix extracted from img
def sub_image(
    img: ndarray,
    start: Pixel = Pixel(),
    ch_size: ChannelSize = None,
    offset: Offset = Offset(),
    pad_val: float = 0,
) -> ndarray:

    res_h = h - 2 * offset.row
    res_w = w - 2 * offset.col
    R, C = img.shape
    
    # if out of bounds, return None
    if (
        start.row not in range(R)
        or start.col not in range(C)
        or start.row + res_h not in range(R)
        or start.col + res_w not in range(C)
    ):
        return None

    # if ch_size is not specified, initialize it
    if ch_size == None:
        ch_size = ChannelSize(h=est_channel_height(img), w=est_channel_width(img))

    # copy wanted entries of img to result
    result = img[start.row : start.row + res_h, start.col : start.col + res_w]
    return result

## Alignment Algorithms

### Basic

In [None]:
# Returns alignment index by simply dividing the image in 3
def align_basic(img: ndarray) -> Displacement:
    G_start = channel_height(img)  # floor division to get integer indices
    R_start = channel_height(img) * 2
    return Displacement(g=G_start, r=R_start)

### SSD

In [21]:
# Returns the ssd between matrix a and matrix b
def ssd(a: ndarray, b: ndarray) -> float:
    return np.sum((a - b) ** 2)

### NCC

In [22]:
# Returns the ncc between matrix a and matrix b
def ncc(a: ndarray, b: ndarray) -> float:
    assert(a.shape == b.shape)
    s = np.array([a[r] @ b[r] for r in range(a.shape[0])])
    return np.sum(s)

## Alignment Computations

In [24]:
def align(
    pic: Pic,
    img: ndarray,
    dis_est: Displacement = None,
    window: int = 20,
    ch_size: ChannelSize = None,
    offset=Offset(row=20, col=20),
    metric: Callable = ssd,
    use_min=True,
) -> Displacement:

    # if ch_size is not specified, initialize it
    if ch_size == None:
        ch_size = ChannelSize(h=est_channel_height(img), w=est_channel_width(img))

    # if displacement_estimate is not specified, initialize it
    if displacement_estimate == None:
        dis_est = Displacement()
        dis_est.g.row = ch_size.h
        dis_est.r.row = ch_size.h * 2

    b_est = sub_image(img, dis_est.b, ch_size, offset)

    # find displacement for G channel
    score = {}
    for d in range(-window, window):
        g = sub_image(img, dis_est.g + d, ch_size, offset)
        score[d] = ssd(b, g)

    # displacement that gives best result is the 'key' in dictionary that gives best score
    if use_min:
        G_start = min(score)
    else:
        G_start = max(score)

    # find displacement for R channel
    score.clear()
    for d in range(-window, window):
        r = sub_image(img, dis_est.r + d, ch_size, offset)
        score[d] = metric(b, r)

    # displacement that gives best result is the 'key' in dictionary that gives best score
    if use_min:
        R_start = min(score)
    else:
        R_start = max(score)

    return Displacement(g=G_start, r=R_start)

### Image Pyramid

In [None]:
def pyramid(
    img: ndarray, *args
) -> Displacement:
    if img.size < 1500 * 500:
        return align(img, *args)
    im_resize = cv2.resize(img, (img.shape[0] // 2, img.shape[1] // 2))
    new_pyr = pyramid(im_resize, align_metric, *args)
    g_est, r_est = new_pyr.g, new_pyr.r
    G_start = int(np.round(g_est / im_resize.shape[0] * img.shape[0]))
    R_start = int(np.round(r_est / im_resize.shape[0] * img.shape[0]))
    return Displacement(g=G_start, r=R_start)

# Test and Display Results

In [29]:
def compute(img, out_dir:Path, algorithm, *args, show=False, adjust=False):
    # read input file
    print(img.name)
    im = cv2.imread(str(img), cv2.IMREAD_GRAYSCALE)
    if show:
        plt.figure()
        plt.imshow(im, cmap=plt.get_cmap("gray"))

    # initiialize variables
    dis_est = 
    result = Pic(im)
        dis_est = Displacement()
        dis_est.g.row = ch_size.h
        dis_est.r.row = ch_size.h * 2
    # do adjustments
    if adjust:
        # awb_grey(im, show=True)
        # awb_white(im, show=True)
        # fix_exposure(im, show=True)
        # crop_borders(im, show=True)
        
    # compute displacements
    #d = pyramid(im, algorithm)
    d = align(im, ssd, )
    print(f'"({d.b}, 0), ({d.g}, 0), ({d.r}, 0)"')
    
    # combine channels and display result
    R_mat, G_mat, B_mat = channels(im, d, h, w)
    result = np.dstack([R_mat, G_mat, B_mat])
    if show:
        plt.figure()
        plt.title(f'"{R_mat.shape}, {G_mat.shape}, {B_mat.shape}"')
        plt.imshow(result)

    # save the images
    fname = out_dir / img.stem
    Image.fromarray(result).save(fname, "PNG")
    return result

## Testing low resolution images

In [None]:
if __name__ == "__main__":
    for im in low_res_imgs:
        save_n_display(im, out_dir, align_basic)
        save_n_display(im, out_dir, align_ssd)

## Testing high resolution images

In [None]:
if __name__ == "__main__":
    for im in high_res_imgs:
        save_n_display(im, out_dir, align_ssd)

In [None]:
if __name__ == "__main__":
    for im in high_res_imgs:
        save_n_display(im, out_dir, align_ssd)

## Testing extra images

In [None]:
if __name__ == "__main__":
    for im in extra_imgs:
        save_n_display(im, out_dir, align_ssd)

# Adjustments

## Normalize Exposures

In [None]:
# takes in a matrix with values within [1,0],
# and transforms it so that the minimum value becomes 0, maximum value becomes 1
def fix_exposure(mat: ndarray, show=False) -> None:
    unit_len = np.max(mat) - np.min(mat)
    mat = (mat - np.amin(mat)) / unit_len
    if show:
        plt.imshow(mat)

## Crop Borders

In [None]:
def find_border(mat, axis):
    # zero pad mat at rightmost and bottom
    mat_padded = np.pad(mat, ((0, 1), (0, 1), "constant", 0))
    
    val_r, val_c = [], []
    # find row border
    for i in range[mat.shape[0]]:
        val_r[i, i + 1] = mat_padded[i] @ mat_padded[i + 1]
    r_cutoff = np.unravel_index(np.argmin(val))[1]
    # find col border
    for i in range[mat.shape[1]]:
        val_c[i, i + 1] = mat_padded[:, [i]] @ mat_padded[:, [i + 1]]
    c_cutoff = np.unravel_index(np.argmin(val))[1]
    return [r_cutoff, c_cutoff]

def crop_borders(mat, show=False):
    find_border(mat)
    if show:
        plt.imshow(balanced_im)

## Auto White Balance (AWB)

In [None]:
# Automatic (AWB)
# • Grey World: force average color of scene to grey
# • White World: force brightest object to white

def awb_grey(im, show=False):
    # Compute the mean color over the entire image
    avg_color = np.mean(im)

    # Scale the averge color to be grey (0.5)
    scaling = 0.5 / avg_color

    # Apply the scaling to the entire image
    balanced_im = im * scaling
    if show:
        plt.imshow(balanced_im)
    im = balanced_im

def awb_white(im, show=False):
    # Compute the brightest color over the entire image
    brightest_color = np.amax(im)

    # Scale the brightest color to be white (1.0)
    scaling = 1.0 / brightest_color

    # Apply the scaling to the entire image
    balanced_im = im * scaling
    if show:
        plt.imshow(balanced_im)
    im = balanced_im

## Apply Adjustments

In [None]:
# im = low_res_imgs[:1]
# plt.show(im)

# awb_grey(im, show=True)
# #awb_white(im, show=True)
# fix_exposure(im, show=True)
# #crop_borders(im, show=True)

# im_aligned = save_n_display(input, out_dir, align_ssd)

In [None]:
im_aligned = save_n_display(low_res_imgs[:1], out_dir, align_basic)
#awb_grey(im_aligned, show=True)
awb_white(im_aligned, show=True)
fix_exposure(im_aligned, show=True)
#crop_borders(im_aligned, show=True)