In [2]:
#!/usr/bin/env python3

rm: ./output/*: No such file or directory


In [63]:
from __future__ import annotations

import glob
import itertools
import os
import sys
from copy import deepcopy
from dataclasses import dataclass, field
from functools import lru_cache, reduce
from itertools import chain, product
from skimage import img_as_float,img_as_ubyte
from os import PathLike
from pathlib import Path
from typing import (Dict, Iterable, List, Mapping, NamedTuple, Optional,
                    Sequence)
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
from numpy import ndarray
from PIL import Image

## Declarations

### Global Variables

### Classes

In [4]:
Channel = ndarray


@dataclass
class Pic:
    red: Channel
    blue: Channel
    green: Channel


@dataclass
class Offset:
    row: int
    col: int

    def __sub__(self, other: Offset) -> Offset:
        return Offset(self.row - other.row, self.col - other.col)


@dataclass
class Pixel:
    row: int
    col: int

    def __add__(self, other: Offset) -> Pixel:
        return Pixel(self.row + other.row, self.col + other.col)

## Alignment Algorithms

In [5]:
def ssd(x: ndarray, y: ndarray) -> float:
    return np.sum((x - y) ** 2)

def gradient(x: ndarray, y: ndarray) -> float:
    return ssd(cv2.Laplacian(x, cv2.CV_64F), cv2.Laplacian(y, cv2.CV_64F))

In [6]:
def initialize_pic(img: ndarray) -> Pic:
    excess = len(img) % 3
    if excess != 0:
        img = img[:-excess]
    b, g, r = np.split(img, indices_or_sections=3)
    return Pic(red=r, blue=b, green=g)


def calc_align_basic(
    reference_img: Channel,
    other: Channel,
    metric: Callable = ssd,
    offset_window: int = 15,
    r_estim: int = 0,
    c_estim: int = 0,
) -> Offset:
    """Give best offset for alignment of 2 channels."""
    scores = {}

    for r_off, c_off in product(range(-offset_window, offset_window + 1), repeat=2):
        shifted = np.roll(other, shift=(r_estim + r_off, c_estim + c_off), axis=(0, 1))
        scores[(r_estim + r_off, c_estim + c_off)] = metric(reference_img, shifted)

    r_off, c_off = min(scores, key=lambda k: scores[k])
    return Offset(r_off, c_off)

def shift_2d_replace(data, offset: Offset, constant=0):
    """Shifts the array in two dimensions while setting rolled values to
    constant.

    :param data: The 2d numpy array to be shifted
    :param dx: The shift in x
    :param dy: The shift in y
    :param constant: The constant to replace rolled values with
    :return: The shifted array with "constant" where roll occurs
    """
    dx, dy = offset.col, offset.row

    shifted_data: ndarray = np.roll(data, dx, axis=1)
    if dx < 0:
        shifted_data[:, dx:] = constant
    elif dx > 0:
        shifted_data[:, 0:dx] = constant

    shifted_data = np.roll(shifted_data, dy, axis=0)
    if dy < 0:
        shifted_data[dy:, :] = constant
    elif dy > 0:
        shifted_data[0:dy, :] = constant
    return shifted_data


def merge(pic: Pic) -> ndarray:
    """Combine all channels into PIL Image."""
    merged = np.stack([pic.red, pic.green, pic.blue], axis=-1)
    return merged

In [7]:
def calc_align_pyramid(to_align: Channel, ref: Channel) -> Channel:
    """Aligns two channels using image pyramids."""
    THRESHOLD = 800
    print(to_align.size, to_align.shape)
    if to_align.size < THRESHOLD ** 2 and ref.size < THRESHOLD ** 2:
        return calc_align_basic(to_align, ref)
    to_align_resized = cv2.resize(
        src=to_align, dsize=(to_align.shape[0] // 2, to_align.shape[1] // 2)
    )
    ref_resized = cv2.resize(src=ref, dsize=(ref.shape[0] // 2, ref.shape[1] // 2))
    estim_Offset = calc_align_pyramid(to_align_resized, ref_resized)
    corr_Offset = Offset(row=estim_Offset.row * 2, col=estim_Offset.col * 2)

    return corr_Offset
#   return calc_align_basic(    ref, to_align, r_estim=corr_Offset.row, c_estim=corr_Offset.col, offset_window=2)


def align_pic_basic(pic: Pic) -> (Pic, Offset, Offset):
    """Aligns aii three channels."""
    return align_pic(pic, calc_align_basic)


def align_pic_with_pyramid(pic: Pic) -> (Pic, Offset, Offset):
    """Aligns aii three channels."""
    return align_pic(pic, calc_align_pyramid)


def merge(pic: Pic) -> ndarray:
    """Combine all channels into PIL Image."""
    merged = np.stack([pic.red, pic.green, pic.blue], axis=-1)
    return merged


def align_pic(
    pic: Pic, calc_align: Callable, align_metric: Callable = gradient
) -> (Pic, Offset, Offset):
    """Aligns all three channels."""
    # align red
    r_offset = calc_align(pic.blue, pic.red)
    r = shift_2d_replace(pic.red, r_offset)
    # align green
    g_offset = calc_align(pic.blue, pic.green)
    g = shift_2d_replace(pic.green, g_offset)
    return Pic(blue=pic.blue, red=r, green=g), r_offset, g_offset

## Adjustments

In [8]:
def fix_exposure(x: ndarray, show=False):
    unit_len = np.max(x) - np.min(x)
    fixed = (x - np.amin(x)) / unit_len
    if show:
        plt.imshow(fixed)
    return fixed

In [9]:
def awb_grey(im: ndarray, show=False):
    # Compute the mean color over the entire image
    avg_color = np.mean(im)

    # Scale the averge color to be grey (0.5)
    scaling = 0.5 / avg_color

    # Apply the scaling to the entire image
    balanced_im = im * scaling
    #balanced_im = balanced_im.astype(np.uint8)
    if show:
        plt.imshow(balanced_im)
    return balanced_im

def awb_white(im: ndarray, show=False):
    # Compute the brightest color over the entire image
    brightest_color = np.amax(im)

    # Scale the brightest color to be white (1.0)
    scaling = 1.0 / brightest_color

    # Apply the scaling to the entire image
    balanced_im = im * scaling
    #balanced_im = balanced_im.astype(np.uint8)
    if show:
        plt.imshow(balanced_im)
    return balanced_im
    

In [10]:
def find_borders(mat) -> (int, int):
    val_r, val_c = [], []
    val_rs, val_cs = [], []
    R, C = mat.shape
    sobel_c = cv2.Sobel(mat,cv2.CV_64F,1,0,ksize=5)
    sobel_r = cv2.Sobel(mat,cv2.CV_64F,0,1,ksize=5)
    
    # find row border
    for i in range(R):
        val_r.append(mat[i] @ mat[i - 1])
        val_rs.append(sobel_r[i] @ sobel_r[i - 1])
    r_up_cutoff = (np.argmin(val_r[R//20:R//4]) + np.argmin(val_rs[R//20:R//4])) // 2
    r_bot_cutoff = (np.argmin(val_r[R//4*3:-R//15]) + np.argmin(val_r[R//4*3:-R//20])) // 2
                              
    # find col border
    for i in range(C):
        val_c.append(sobel_c[:, i] @ sobel_c[:, i - 1])
    c_left_cutoff = np.argmin(val_c[C//20:C//4])
    c_right_cutoff = np.argmin(val_c[C//4*3:-C//20])
    return r_up_cutoff, r_bot_cutoff, c_left_cutoff, c_right_cutoff

def crop_borders(pic:Pic):
    r_up_cut, r_bot_cut, r_left_cut, r_right_cut = find_borders(pic.red)
    g_up_cut, g_bot_cut, g_left_cut, g_right_cut = find_borders(pic.green)
    b_up_cut, b_bot_cut, b_left_cut, b_right_cut = find_borders(pic.blue)
    
    up_cut = int(np.mean([r_up_cut, g_up_cut, b_up_cut]))
    bot_cut = int(np.mean([r_bot_cut, g_bot_cut, b_bot_cut]))
    left_cut = int(np.mean([r_left_cut, g_left_cut, b_left_cut]))
    right_cut = int(np.mean([r_right_cut, g_right_cut, b_right_cut]))
    
    pic.red = pic.red[up_cut:bot_cut, left_cut:right_cut]
    pic.green = pic.green[up_cut:bot_cut, left_cut:right_cut]
    pic.blue = pic.blue[up_cut:bot_cut, left_cut:right_cut]

In [27]:
def edge_detection(mat):
    mat_edges = cv2.Canny(mat,80,180)
    return mat_edges


def align_pic_with_canny(
    pic: Pic, calc_align: Callable=calc_align_pyramid, align_metric: Callable = gradient
) -> (Pic, Offset, Offset):
    """Aligns all three channels."""

    r_edges = edge_detection(pic.red)
    g_edges = edge_detection(pic.green)
    b_edges = edge_detection(pic.blue)
    
    # align red
    r_offset = calc_align(b_edges, r_edges)
    r = shift_2d_replace(pic.red, r_offset)
    
    # align green
    g_offset = calc_align(b_edges, g_edges)
    g = shift_2d_replace(pic.green, g_offset)
    
    return Pic(blue=pic.blue, red=r, green=g), r_offset, g_offset

In [28]:
def adjust(pic:Pic):
    #crop_borders(pic)
    pic.red = awb_grey(pic.red)
    pic.green = awb_grey(pic.green)
    pic.blue = awb_grey(pic.blue)
    pic.red = fix_exposure(pic.red)
    pic.green = fix_exposure(pic.green)
    pic.blue = fix_exposure(pic.blue)

## Compute and Output

In [64]:
def main(p: os.PathLike, f_out) -> None:
    out_dir = Path("output")
    out_dir.mkdir(parents=True, exist_ok=True)
    img_path = Path(p)
    
    # read image
    img = img_as_float(plt.imread(img_path))

    # initialize and adjust pictures
    pic = initialize_pic(img)
    adjust(pic)
    aligned_pic = pic
    red_offset, green_offset = Offset(), Offset()
    
    # comment out all alignments for direct stack
    
    # do alignments
    #aligned_pic, red_offset, green_offset = align_pic_basic(pic)
    #aligned_pic, red_offset, green_offset = align_pic_with_pyramid(pic)
    
""" comment/uncomment below for canny edge detection """
#     pic.red = img_as_ubyte(pic.red)
#     pic.green = img_as_ubyte(pic.green)
#     pic.blue = img_as_ubyte(pic.blue)
#     aligned_pic, red_offset, green_offset = align_pic_with_canny(pic)
""" comment/uncomment above for canny edge detection """
    
    aligned_img = merge(aligned_pic)
    aligned_img = Image.fromarray(img_as_ubyte(aligned_img))
    aligned_img.save(out_dir / img_path.stem, arr=aligned_img, format="jpeg", optimize=True, quality=50)
    print(red_offset, green_offset, file=f_out)
    
    # plt.figure()
    # plt.imshow(aligned_img)
    

In [65]:
!rm ./output/*

In [66]:
if __name__ == "__main__":
    data = Path("data")
    fname = 'offset_low.txt'
    f_out = open(fname, 'w')
    original_stdout = sys.stdout # Save a reference to the original standard output
    sys.stdout = f_out # Change the standard output to the file we created.
    t = time.time()
    
    for p in chain(data.rglob("*.jpg")):
        print(p, file=f_out)
        main(p, f_out)
    t = time.time() - t 
    print(f'"rumtime is {t} seconds"', file=f_out)
    f_out.close()
    sys.stdout = original_stdout # Reset the standard output to its original value

In [67]:
if __name__ == "__main__":
    data = Path("data")
    fname = 'offset_high.txt'
    f_out = open(fname, 'w')
    original_stdout = sys.stdout # Save a reference to the original standard output
    sys.stdout = f_out # Change the standard output to the file we created.
    t = time.time()
        
    for p in chain(data.rglob("*.tif")):
        print(p, file=f_out)
        main(p, f_out)
    t = time.time() - t 
    print(f'"rumtime is {t} seconds"', file=f_out)
    f_out.close()
    sys.stdout = original_stdout # Reset the standard output to its original value