In [34]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

def translate_no_pad(im, translation):
    """
    translate an image `tx` to the right, `ty` down
    :param image (h, w, c) array
    :param tx (float) pixels to translate right
    :param ty (float) pixels to translate down
    """
    ty, tx = translation
    h, w = im.shape[:2]
    mat = translation_matrix = np.array([
        [1, 0, tx],
        [0, 1, ty]
    ], dtype=np.float32)
    return cv2.warpAffine(im, mat, (w, h))

In [91]:
img_dir = 'data/'
out_dir = 'out_path/'

def read_img(img_name, img_dir = 'data/'):
    img = plt.imread(img_dir + img_name)
    img = img / 255
    if '.tif' in img_name:
        img /= 255
    return img

# translates the image with a circular shift, where translation = (dh, dw)
def translate_circ(img, translation):
    return np.roll(np.roll(img, translation[0], axis=0), translation[1], axis=1)

def split_into_channels(img):
    # compute the height of each part (just 1/3 of total)
    height = np.floor(img.shape[0] / 3.0).astype(np.int_)

    # separate color channels
    b, g, r = img[:height], img[height: 2*height], img[2*height: 3*height]
    return r, g, b

# computes ssd score between two images
def ssd(img1, img2):
    ssd_score = np.sum((img1 - img2)**2)
    return ssd_score

# computes the normalized cross correlation, by first demeaning and normalizing
def ncc(img_1, img_2):
    n1 = img_1-np.mean(img_1)
    n2 = img_2-np.mean(img_2)
    n1 = n1 / np.sqrt(np.sum(n1**2))
    n2 = n2 / np.sqrt(np.sum(n2**2))
    return -np.sum(img_1*img_2)
    
# Brute force algorithm which searches in the [min_tr, max_tr] space for the best translation,
# calculates the similarity score based on a given metric
def naive_align_two_channels(c1, c2, metric=ssd, min_tr_h=-15, max_tr_h=15, min_tr_w=-15, max_tr_w=15, translate=translate_circ):
    min_score = float('inf')
    translation = 0, 0
    for u in range(min_tr_h, max_tr_h+1):
        for v in range(min_tr_w, max_tr_w+1):
            test_trans = (u,v)
            c1_trans = translate(c1, test_trans)
            s = metric(c1_trans, c2)
            if s < min_score:
                translation = test_trans
                min_score = s
    return translation, min_score

# Stacks the channels to get an RGB image
def stack_channels(r, g, b):
    img_out = np.dstack([r, g, b])
    return img_out

# Aligns images by a hard-coded translation
def process_hard_coded(r, g, b, trans_r, trans_g):
    r = translate(r, trans_r)
    g = translate(g, trans_g)
    img_aligned = stack_channels(r, g, b)
    return img_aligned

# Naively crops the image by a fixed percentage from all sides
def crop_border(img, ratio = 0.2):
    h, w = img.shape[:2]
    return img[int(h*(ratio/2)):int(h*(1-ratio/2)), int(w*(ratio/2)):int(w*(1-ratio/2))]

# Given an unsplit image, returns the aligned image and translations needed using the single-scale search.
def process_single_scale(img, metric = None, precrop = False, translate = translate_circ):
    r, g, b = split_into_channels(img)
    trans_r, trans_g = (0,0), (0,0)
    if metric:
        r_cr, g_cr, b_cr = r, g, b
        if precrop:
            r_cr, g_cr, b_cr = map(crop_border, (r, g, b))
        trans_r, _ = naive_align_two_channels(r_cr, b_cr, metric=metric, translate=translate)
        trans_g, _ = naive_align_two_channels(g_cr, b_cr, metric=metric, translate=translate)

        r = translate(r, trans_r)
        g = translate(g, trans_g)
    img_aligned = stack_channels(r, g, b)
    return img_aligned, trans_r, trans_g

def save_img(fname, img_out, out_dir = 'out_path/'):
    img_out = (img_out*255).astype(np.uint8)
    plt.imsave(out_dir + fname, img_out)

In [97]:
### Multiscale Pyramid Algorithm
import skimage
import math

#Computes the alignment needed to align the channels, searching in the space
# [-refine, refine] of displacements from the previous level's optimal displacements after upscaling
def compute_pyramid_alignment(img1, img2, metric, refine=2, translate=translate_circ):
    min_size, max_size = 100, 4000
    resize_ratio = 2
    h, w = img1.shape[:2]

    num_times_downscale = max(math.floor(math.log2(h/min_size)), math.floor(math.log2(w/min_size)))
    scale_factor = pow(1/resize_ratio, num_times_downscale)
    
    num_pyramid_search = 0
    start_search_space = (-20, 20)
    min_score = float('inf')
    best_trans = (0,0)
    
    min_tr_h, max_tr_h = (0,0)
    min_tr_w, max_tr_w = (0,0)

    while scale_factor * h < max_size and scale_factor * w < max_size and num_times_downscale - num_pyramid_search >= 0:
        num_pyramid_search += 1
        img1_rs = skimage.transform.rescale(img1, scale_factor)
        img2_rs = skimage.transform.rescale(img2, scale_factor)

        min_tr_h, max_tr_h = start_search_space if num_pyramid_search == 1 else (best_trans[0]-refine, best_trans[0]+refine)
        min_tr_w, max_tr_w = start_search_space if num_pyramid_search == 1 else (best_trans[1]-refine, best_trans[1]+refine)
                
        best_trans, min_score = naive_align_two_channels(img1_rs, img2_rs, metric, min_tr_h, max_tr_h, min_tr_w, max_tr_w, translate)

        scale_factor *= resize_ratio
        # Doubles the previous levels translations since it is downscaled by a factor of 2
        if num_times_downscale - num_pyramid_search >= 0:
            best_trans = best_trans[0] * 2, best_trans[1] * 2
    
    while num_times_downscale - num_pyramid_search > 0:
        num_pyramid_search += 1
        best_trans = best_trans[0] * 2, best_trans[1] * 2
        
    return best_trans

contrast_funcs = {
    'equalize': lambda x: skimage.exposure.equalize_hist(x),
    'rescale': lambda x: skimage.exposure.rescale_intensity(x, in_range='image', out_range=(0,1)),
    'adapt_eq': lambda x: skimage.exposure.equalize_adapthist(x, clip_limit=0.008)
}

# Computes the translations needed to align an unsplit image and returns the processed image
def process_pyramid_scale(img, conv_filter = True, refine = 2, metric = ssd, precrop = True, postcrop = True, auto_contrast='adapt_eq', translate = translate_circ):
    r, g, b = split_into_channels(img)
    trans_r, trans_g = (0,0), (0,0)
    
    if metric:
        r_pr, g_pr, b_pr = r, g, b
        # Crops a fixed portion of the image to remove border artifacts to aid in alignment.
        if precrop:
            r_pr, g_pr, b_pr = map(crop_border, (r, g, b))
        # The Sobel filter is used to create a feature map of edges. The alignment 
        # is computed on these edge feature maps.
        if conv_filter:
           r_pr, g_pr, b_pr = map(skimage.filters.sobel, (r_pr, g_pr, b_pr))
        
        trans_r = compute_pyramid_alignment(r_pr, b_pr, ssd, refine, translate)
        trans_g = compute_pyramid_alignment(g_pr, b_pr, ssd, refine, translate)
        r = translate(r, trans_r)
        g = translate(g, trans_g)
    
    img_aligned = stack_channels(r, g, b)
    img_aligned = np.clip(img_aligned, 0, 1)
    
    # Crop a fixed portion of the aligned image to remove border artifacts
    if postcrop:
        img_aligned = crop_border(img_aligned, 0.1)
        
    # Applies the auto-contrast method, if specified
    if auto_contrast:
        img_aligned = contrast_funcs[auto_contrast](img_aligned)
    return img_aligned, trans_r, trans_g

# Performs all steps to align a given image with name img_name, including pre and post-processing
def align_image_pyramid_scale(img_name, save=False, img_dir='data/', out_dir='out/', conv_filter=True, refine=2, metric=ssd, precrop=True, postcrop=True, auto_contrast='adapt_eq', translate=translate_circ):
    name, extension = img_name.split('.')
    img = read_img(img_name, img_dir=img_dir)
    img_aligned, trans_r, trans_g = process_pyramid_scale(img, conv_filter, refine, metric, precrop, postcrop, auto_contrast, translate)
    if save:
        save_img(name + '.jpg', img_aligned, out_dir=out_dir)
    return img_aligned, trans_r, trans_g


In [107]:
### visualize channels of emir.tiff with and without edge detection
import imageio
img_name = 'emir.tif'
out_dir = 'out/alignment_tests/'
img = read_img(img_name)
r,g,b = split_into_channels(img)
r,g,b = map(crop_border, (r,g,b))
r_edge,g_edge,b_edge = map(skimage.filters.sobel, (r,g,b))
r, b, r_edge, b_edge = map(lambda x: (x*255).astype(np.uint8), (r, b, r_edge, b_edge))
imageio.imwrite(out_dir+'emir_r.jpg', r)
imageio.imwrite(out_dir+'emir_b.jpg', b)
imageio.imwrite(out_dir+'emir_r_edge.jpg', r_edge)
imageio.imwrite(out_dir+'emir_b_edge.jpg', b_edge)

In [54]:
### Functions to visualize different contrast manipulations. 
# The histogram plotting code is taken from https://scikit-image.org/docs/stable/auto_examples/color_exposure/plot_equalize.html.
def plot_img_and_hist(image, axes, bins=256):
    """Plot an image along with its histogram and cumulative histogram.

    """
    ax_img, ax_hist = axes
    ax_cdf = ax_hist.twinx()

    # Display image
    ax_img.imshow(image, cmap=plt.cm.gray)
    ax_img.set_axis_off()

    # Display histogram
    ax_hist.hist(image.ravel(), bins=bins, histtype='step', color='black')
    ax_hist.ticklabel_format(axis='y', style='scientific', scilimits=(0, 0))
    ax_hist.set_xlabel('Pixel intensity')
    ax_hist.set_xlim(0, 1)
    ax_hist.set_yticks([])

    # Display cumulative distribution
    img_cdf, bins = skimage.exposure.cumulative_distribution(image, bins)
    ax_cdf.plot(bins, img_cdf, 'r')
    ax_cdf.set_yticks([])

    return ax_img, ax_hist, ax_cdf

# Plots the effect of histogram equalization, rescaling, and adaptive histogram equalization, along with the histograms and cdf.
def plot_contrast_changes_with_histograms(img):
    fig = plt.figure(figsize=(16, 10))
    axes = np.zeros((2, 4), dtype=object)
    axes[0, 0] = fig.add_subplot(2, 4, 1)
    for i in range(1, 4):
        axes[0, i] = fig.add_subplot(2, 4, 1+i, sharex=axes[0,0], sharey=axes[0,0])
    for i in range(0, 4):
        axes[1, i] = fig.add_subplot(2, 4, 5+i)
    
    img_equalized = skimage.exposure.equalize_hist(img)
    img_rescale = skimage.exposure.rescale_intensity(img, in_range='image', out_range=(0,1))
    img_adapteq = skimage.exposure.equalize_adapthist(img, clip_limit=0.008)
    
    ax_img, ax_hist, ax_cdf = plot_img_and_hist(img, axes[:, 0])
    ax_img.set_title('Default')
    
    ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_rescale, axes[:, 1])
    ax_img.set_title('Rescaling')
    ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_equalized, axes[:, 2])
    ax_img.set_title('Histogram equalization')
    ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_adapteq, axes[:, 3])
    ax_img.set_title('Adaptive equalization')

# Plots the effect of various contrast manipulations without histograms
def plot_contrast_changes_simple(img):
    img_equalized =  skimage.exposure.equalize_hist(img)
    img_rescale = skimage.exposure.rescale_intensity(img, in_range='image', out_range=(0,1))
    img_adapteq = skimage.exposure.equalize_adapthist(img, clip_limit=0.008)
    
    rows, columns = 2, 2
    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(rows, columns, 1)
    plt.imshow(img)
    plt.title('Normal')

    fig.add_subplot(rows, columns, 2)
    plt.imshow(img_equalized)
    plt.title('img_equalized')

    fig.add_subplot(rows, columns, 3)
    plt.imshow(img_rescale)
    plt.title('img_rescale')

    fig.add_subplot(rows, columns, 4)
    plt.imshow(img_adapteq)
    plt.title('img_adapteq')

In [None]:
### Code to view single scale search results
out_dir = 'out/ssd/'
for img_name in os.listdir('data'):
    if '.tif' not in img_name:
        img = read_img(img_name)
        print(img_name)
        translate = translate_circ
        img_aligned, trans_r, trans_g = process_single_scale(img, metric=ssd, crop=True)
        # post_cr = crop_border(img_aligned, 0.1)
        # print(trans_r, trans_r)
        plt.imshow(img_aligned)
        # save_img(img_name, img_aligned, out_dir=out_dir)    


In [72]:
### Code to produce final outputs for pyramid scale search algorithm (with bells & whistles)
# To turn extensions off, set auto_contrast=None and conv_filter=False
out_dir = 'out/pyramid/'
out_file = 'translations.txt'
f = open(out_dir + out_file, "a")
for img_name in os.listdir('data'):
    img_aligned, trans_r, trans_g = align_image_pyramid_scale(img_name, save=True, out_dir=out_dir)
    mess = 'Image {0}: ({1}, {2}), ({3}, {4})\n'.format(img_name, trans_r[0], trans_r[1], trans_g[0], trans_g[1])
    f.write(mess)
    # print(trans_r, trans_g)
f.close()

In [98]:
### Code to produce final outputs for pyramid scale search algorithm (with bells & whistles)
out_dir = 'out/extra_results_no_contrast/'
out_file = 'translations.txt'
f = open(out_dir + out_file, "a")
for img_name in os.listdir('extra_data'):
    img_aligned, trans_r, trans_g = align_image_pyramid_scale(img_name, save=True, out_dir=out_dir, img_dir='extra_data/', auto_contrast=None)
    mess = 'Image {0}: ({1}, {2}), ({3}, {4})\n'.format(img_name, trans_r[0], trans_r[1], trans_g[0], trans_g[1])
    f.write(mess)
    # print(trans_r, trans_g)
f.close()