In [None]:
import os
import cv2
from matplotlib import pyplot as plt

# Shifting the image by a margin of pixels
import skimage.transform as trans
from scipy import signal
from PIL import Image
from scipy import stats as stat
from itertools import product

# Image Analysis
import numpy as np
from scipy.fft import fft, ifft

import glob
import os

from tifffile import imread
from matplotlib import pyplot as plt
from skimage import io, exposure, data
import numpy as np
from PIL import Image
from scipy import stats as st

from scipy.optimize import minimize
from numpy import diff
from scipy.signal import find_peaks

from skimage import exposure
from skimage.filters import unsharp_mask
# from wand.image import Image as ImageWand
from numpy.polynomial import polynomial as P
import cv2
from PIL import Image
from tqdm import tqdm

In [None]:
import numpy as np
import cv2

def save_video(filename, frames):
    # frames_rgb: (N, H, W, 3) uint8 in RGB
    h, w = frames.shape[1], frames.shape[2]
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(filename, fourcc, 15.0, (w, h), isColor=True)

    for fr in frames:
        fr_bgr = cv2.cvtColor(fr, cv2.COLOR_GRAY2BGR)
        out.write(fr_bgr)

    out.release()
    print(f"Saved {filename}")

# Registration 1 - Optimized

In [None]:
from numba import jit, prange
import numpy as np

# # Dirty Implementation of Shifting Images
# def ShiftedImage_2D(Image, XShift, YShift):
#     # Quick guard
#     if (XShift == 0 and YShift == 0):
#         return Image;

#     M = np.float32([
#     [1, 0, XShift],
#     [0, 1, YShift]
#     ]);

#     shifted = cv2.warpAffine(Image, M, (Image.shape[1], Image.shape[0]));
#     shifted_image = shifted

#     # Shift Down
#     if (YShift > 0):
#         shifted_image = shifted_image[YShift:];
#         shifted_image = np.pad(shifted_image, ((YShift, 0), (0, 0)), 'edge'); # Pad Up

#     # Shift Up
#     if (YShift < 0):
#         shifted_image = shifted_image[:shifted.shape[0] - abs(YShift)];
#         shifted_image = np.pad(shifted_image, ((0, abs(YShift)), (0, 0)), 'edge'); # Pad Down

#     # Shift Left
#     if (XShift > 0):
#         shifted_image = np.delete(shifted_image, slice(0, XShift), 1);
#         shifted_image = np.pad(shifted_image, ((0, 0), (XShift, 0)), 'edge'); # Pad Left

#     if (XShift < 0):
#         shifted_image = np.delete(shifted_image, slice(shifted.shape[1] - abs(XShift), shifted.shape[1]), 1);
#         shifted_image = np.pad(shifted_image, ((0, 0), (0, abs(XShift))), 'edge'); # Pad Right

#     return shifted_image

from numba import jit, prange
import numpy as np

@jit(nopython=True, parallel=True)
def ShiftedImage_2D_numba(Image, XShift, YShift):
    """Ultra-fast numba-compiled version."""
    if XShift == 0 and YShift == 0:
        return Image.copy()

    h, w = Image.shape[:2]
    
    if Image.ndim == 3:
        result = np.zeros((h, w, Image.shape[2]), dtype=Image.dtype)
        channels = Image.shape[2]
    else:
        result = np.zeros((h, w), dtype=Image.dtype)
        channels = 1

    # Parallel pixel-wise shifting with bounds checking
    for i in prange(h):
        for j in prange(w):
            src_i = i - YShift
            src_j = j - XShift
            
            if 0 <= src_i < h and 0 <= src_j < w:
                if Image.ndim == 3:
                    for c in range(channels):
                        result[i, j, c] = Image[src_i, src_j, c]
                else:
                    result[i, j] = Image[src_i, src_j]
            else:
                # Edge padding: find nearest valid pixel
                nearest_i = max(0, min(h - 1, src_i))
                nearest_j = max(0, min(w - 1, src_j))
                
                if Image.ndim == 3:
                    for c in range(channels):
                        result[i, j, c] = Image[nearest_i, nearest_j, c]
                else:
                    result[i, j] = Image[nearest_i, nearest_j]
    
    return result


# sum of absolute differences (SAD) metric alignment
# Optimized version
@jit(nopython=True, parallel=True)
def SAD(a, b):
    # Vectorized SAD; extremely fast
    flat_a = a.ravel()
    flat_b = b.ravel()
    total = 0.0
    for i in prange(len(flat_a)):
        total += abs(float(flat_a[i]) - float(flat_b[i]))
    return total / len(flat_a)

# We use a Tree Search Algorithm to find possible alignment
# Let Image_1 be the orginal
# Let Image_2 be the aligned
# Displacement object is our nodes, [x,y]
# Assumption, there is always a better alignment up, down, left, and right if its not the same image
def alignment_MAE(Image_1, Image_2, depth_cap):
    iterative_cap = 0;
    Best_SAD = SAD(Image_1, Image_2);
    Best_Displacement = [0,0];
    q = [];
    visited_states = [[0,0]];  # Add (0,0) displacement
    q.append(Best_Displacement); # Append (0,0) displacement

    while (iterative_cap != depth_cap and q):
        curr_state = q.pop(0);
        x = curr_state[0];
        y = curr_state[1];

        iterative_cap += 1;

        movement_arr = [
            [x, y - 1], # Up
            [x, y + 1], # Down
            [x + 1, y], # Left
            [x - 1, y], # Right
            [x - 1, y - 1], # Diagonal
            [x + 1, y + 1], # Diagonal
            [x + 1, y - 1], # Diagonal
            [x - 1, y + 1], # Diagonal
        ]

        for move in movement_arr:
            if (move not in visited_states):
                visited_states.append(move); # Marked as Visited

                # Perform shift and calculate
                new_image = ShiftedImage_2D_numba(Image_2, move[0], move[1]);
                cand_SAD = SAD(Image_1, new_image);

                if (cand_SAD < Best_SAD):
                    Best_SAD = cand_SAD;
                    Best_Displacement = move;

                    q.append(move);

                # This means we cannot find a better move.
    return Best_Displacement, Best_SAD

from numba import jit
import numpy as np

# This was a good fix for edge detection
@jit(nopython=True, parallel=True)
def compute_row_means_2d(img):
    """Custom row mean computation for 2D arrays."""
    rows, cols = img.shape
    row_means = np.empty(rows, dtype=np.float64)
    for i in prange(rows):
        total = 0.0
        for j in range(cols):
            total += float(img[i, j])
        row_means[i] = total / cols
    return row_means

@jit(nopython=True, parallel=True)
def compute_row_means_3d(img):
    """Custom row mean computation for 3D arrays."""
    rows, cols, channels = img.shape
    row_means = np.empty(rows, dtype=np.float64)
    for i in prange(rows):
        total = 0.0
        for j in range(cols):
            for k in range(channels):
                total += float(img[i, j, k])
        row_means[i] = total / (cols * channels)
    return row_means

@jit(nopython=True)
def edge_detection_numba_fixed(img):
    """
    Numba-compatible vertical edge detection function.
    
    Args:
        img: Input image (2D or 3D numpy array)
        
    Returns:
        tuple: (top_edge_row, bottom_edge_row)
    """
    # Compute row-wise brightness averages
    if img.ndim == 3:
        row_brightness = compute_row_means_3d(img)
    else:
        row_brightness = compute_row_means_2d(img)
    
    # Calculate gradient manually
    gradient = np.empty(len(row_brightness) - 1, dtype=np.float64)
    for i in range(len(gradient)):
        gradient[i] = row_brightness[i + 1] - row_brightness[i]
    
    # Suppress extreme values near edges
    height = img.shape[0]
    for i in range(len(gradient)):
        if (i <= 100 or i >= height - 100) and abs(gradient[i]) >= 150:
            gradient[i] = 0.0
    
    # Find top edge (maximum gradient in upper half)
    half_height = height // 2
    max_val = -np.inf
    top_edge = 0
    
    for i in range(min(half_height, len(gradient))):
        if gradient[i] > max_val:
            max_val = gradient[i]
            top_edge = i
    
    # Find bottom edge (minimum gradient in lower half)
    min_val = np.inf
    bottom_edge = half_height
    
    for i in range(half_height, len(gradient)):
        if gradient[i] < min_val:
            min_val = gradient[i]
            bottom_edge = i
    
    return top_edge, bottom_edge

# # @jit(nopython=True, parallel=True)
# def edge_detection_optimized(img):
#     """
#     Ultra-fast vertical edge detection using brightness gradient analysis.
    
#     Returns:
#         tuple: (top_edge_row, bottom_edge_row) positions
#     """
#     # Vectorized row-wise brightness computation
#     if img.ndim == 3:
#         row_brightness = np.mean(np.mean(img, axis=2), axis=1)
#     else:
#         row_brightness = np.mean(img, axis=1)
    
#     # Calculate first derivative efficiently
#     gradient = np.diff(row_brightness)
    
#     # Vectorized noise suppression for edge regions
#     height = img.shape[0]
#     edge_mask = (np.arange(len(gradient)) <= 100) | (np.arange(len(gradient)) >= height - 100)
#     extreme_mask = (np.abs(gradient) >= 150)
#     gradient[edge_mask & extreme_mask] = 0
    
#     # Find valid extrema with spatial constraints
#     half_height = height // 2
    
#     # Top edge: strongest positive gradient in upper half
#     upper_region = gradient[:half_height]
#     if len(upper_region) > 0 and np.max(upper_region) > 0:
#         top_edge = np.argmax(upper_region)
#     else:
#         top_edge = 0
    
#     # Bottom edge: strongest negative gradient in lower half  
#     lower_region = gradient[half_height:]
#     if len(lower_region) > 0 and np.min(lower_region) < 0:
#         bottom_edge = half_height + np.argmin(lower_region)
#     else:
#         bottom_edge = height - 1
        
#     return top_edge, bottom_edge

def remove_stage_jitter_MAE_opt(img_array, iteration_depth : int = 1000, m = False, verbose : bool = False, mcm : bool = False):
    # Add Scores path just for curiosity
    scores = []
    X_shifts = []
    Y_shifts = []
    shifted_images = []

    if img_array.ndim != 3:
        print("Give a series of grayscale images. Shape: {}".format(img_array.shape))

    base = exposure.rescale_intensity(img_array[0])

    base_top, base_bottom = edge_detection_numba_fixed(base)

    # TODO: verify
    if base.ndim == 3:
        base = base[:, :, 0] # Reduce to the 2D

    iteration = 0 # TODO: implement in more pythonic way

    for _frame in tqdm(img_array[1:]):
        iteration += 1

        template_image = exposure.rescale_intensity(_frame) # Get rid of low exposure

        template_top, template_bottom = edge_detection_numba_fixed(template_image)

        if template_image.ndim == 3:
            template_image = template_image[:, :, 0] # Reduce to the 2D

        displacement, score = alignment_MAE(base, template_image, iteration_depth)
        scores.append(score)
        # print("SCORE:", score)

        if mcm:
            displacement[0] = 0

        X_shifts.append(displacement[0])
        Y_shifts.append(int(np.mean([(base_top - template_top), (base_bottom -  template_bottom)])))
        shifted_image = ShiftedImage_2D_numba(template_image, displacement[0], int(np.mean([(base_top - template_top), (base_bottom -  template_bottom)]))) # X,Y
        shifted_images.append(shifted_image)

        # For my purposes
        # background = Image.fromarray(base)
        # overlay = Image.fromarray(shifted_image)
        #
        # new_img = Image.blend(background, overlay, 0.5)

        # print("Overlay for image to compare against jitter (PHC)", iteration, ":", filename)
        # plt.imshow(new_img)
        # plt.show()
        #
        # # Write the new image in target folder
        # cv2.imwrite(os.path.join(output_path, filename), shifted_image);

    # print ("Scores:", scores)
    # print("The X_Shifts:", X_shifts)
    # print("The Y_Shifts:", Y_shifts)

    return shifted_images, scores

In [None]:
# Test registration 1 optimized
import nd2
import time

test_dataset = nd2.imread('/Volumes/Server_Data/1-HHLN/22.20 Lauren Data/1_10_lauren_replicate_1/3-SR_1_10_6hPre-C_Plain_M9_TS_MC2.nd2', dask=True)
# test_dataset = nd2.imread(
#    '/Users/hiram/Documents/EVERYTHING/20-29 Research/22 OliveiraLab/22.12 ND2 analyzer/nd2-analyzer/SR_1_5_2h_Pre-C_3h_IPTG_After10h_05_MC.nd2',
#    dask=True)
test_pos_over_time = test_dataset[:, 0, 0].compute()

# Normalization
plt.imshow(exposure.rescale_intensity(test_pos_over_time[0]), cmap='gray')
plt.show()

frames = ((test_pos_over_time / 65535) * 255).astype(np.uint8)

begin = time.time()
aligned, scores = remove_stage_jitter_MAE_opt(frames)
end = time.time()

print(f'Total time {end - begin}')

# Save both timelapses and evaluate
save_video("original_1_opt.mp4", frames)
save_video("registered_1_opt.mp4", np.array(aligned))


# Registration 1 - Good

In [None]:
# Dirty Implementation of Shifting Images
def ShiftedImage_2D(Image, XShift, YShift):
    # Quick guard
    if (XShift == 0 and YShift == 0):
        return Image;

    M = np.float32([
    [1, 0, XShift],
    [0, 1, YShift]
    ]);

    shifted = cv2.warpAffine(Image, M, (Image.shape[1], Image.shape[0]));
    shifted_image = shifted

    # Shift Down
    if (YShift > 0):
        shifted_image = shifted_image[YShift:];
        shifted_image = np.pad(shifted_image, ((YShift, 0), (0, 0)), 'edge'); # Pad Up

    # Shift Up
    if (YShift < 0):
        shifted_image = shifted_image[:shifted.shape[0] - abs(YShift)];
        shifted_image = np.pad(shifted_image, ((0, abs(YShift)), (0, 0)), 'edge'); # Pad Down

    # Shift Left
    if (XShift > 0):
        shifted_image = np.delete(shifted_image, slice(0, XShift), 1);
        shifted_image = np.pad(shifted_image, ((0, 0), (XShift, 0)), 'edge'); # Pad Left

    if (XShift < 0):
        shifted_image = np.delete(shifted_image, slice(shifted.shape[1] - abs(XShift), shifted.shape[1]), 1);
        shifted_image = np.pad(shifted_image, ((0, 0), (0, abs(XShift))), 'edge'); # Pad Right

    return shifted_image

def ShiftedImage_3D(Image, XShift, YShift):
    # Quick guard
    if (XShift == 0 and YShift == 0):
        return Image;

    M = np.float32([
    [1, 0, XShift],
    [0, 1, YShift]
    ]);

    shifted = cv2.warpAffine(Image, M, (Image.shape[1], Image.shape[0]));
    shifted_image = shifted

    # Shift Down
    if (YShift > 0):
        shifted_image = shifted_image[YShift:];
        shifted_image = np.pad(shifted_image, ((YShift, 0), (0, 0), (0, 0)), 'constant', constant_values=(0,)); # Pad Up

    # Shift Up
    if (YShift < 0):
        shifted_image = shifted_image[:shifted.shape[0] - abs(YShift)];
        shifted_image = np.pad(shifted_image, ((0, abs(YShift)), (0, 0), (0, 0)), 'constant', constant_values=(0,)); # Pad Down

    # Shift Left
    if (XShift > 0):
        shifted_image = np.delete(shifted_image, slice(0, XShift), 1);
        shifted_image = np.pad(shifted_image, ((0, 0), (XShift, 0), (0, 0)), 'constant', constant_values=(0,)); # Pad Left

    # Shift Up
    if (XShift < 0):
        shifted_image = np.delete(shifted_image, slice(shifted.shape[1] - abs(XShift), shifted.shape[1]), 1);
        shifted_image = np.pad(shifted_image, ((0, 0), (0, abs(XShift)), (0, 0)), 'constant', constant_values=(0,)); # Pad Right

    plt.imshow(shifted_image)
    plt.show()

    return shifted_image

def SAD(A,B):
    cutA = A.ravel();
    cutB = B.ravel();
    MAE = np.sum(np.abs(np.subtract(cutA,cutB,dtype=np.float64))) / cutA.shape[0]
    return MAE

# sum of absolute differences (SAD) metric alignment, quick n dirty
# We use a Tree Search Algorithm to find possible alignment
# Let Image_1 be the orginal
# Let Image_2 be the aligned
# Displacement object is our nodes, [x,y]
# Assumption, there is always a better alignment up, down, left, and right if its not the same image
def alignment_MAE(Image_1, Image_2, depth_cap):
    iterative_cap = 0;
    Best_SAD = SAD(Image_1, Image_2);
    Best_Displacement = [0,0];
    q = [];
    visited_states = [[0,0]];  # Add (0,0) displacement
    q.append(Best_Displacement); # Append (0,0) displacement

    while (iterative_cap != depth_cap and q):
        curr_state = q.pop(0);
        x = curr_state[0];
        y = curr_state[1];

        iterative_cap += 1;

        movement_arr = [
            [x, y - 1], # Up
            [x, y + 1], # Down
            [x + 1, y], # Left
            [x - 1, y], # Right
            [x - 1, y - 1], # Diagonal
            [x + 1, y + 1], # Diagonal
            [x + 1, y - 1], # Diagonal
            [x - 1, y + 1], # Diagonal
        ]

        for move in movement_arr:
            if (move not in visited_states):
                visited_states.append(move); # Marked as Visited

                # Perform shift and calculate
                new_image = ShiftedImage_2D(Image_2, move[0], move[1]);
                cand_SAD = SAD(Image_1, new_image);

                if (cand_SAD < Best_SAD):
                    Best_SAD = cand_SAD;
                    Best_Displacement = move;

                    q.append(move);

                # This means we cannot find a better move.


    return Best_Displacement, Best_SAD

# Vec4f is (x1, y1, x2, y2)
def y_shift_emphasis(image, block_threshold, MAE_shift):
    output_img = image;

    # Need to turn into uint8 for Straight Line Detection
    img = np.uint8(image);
    lsd = cv2.createLineSegmentDetector(0);
    lines_contour = lsd.detect(img)[0];

    drawn_img = lsd.drawSegments(img,lines_contour);
#     plt.imshow(drawn_img)
#     plt.show()

    horizontal_lines = {};
    for x in lines_contour:
        for y in x:
            cand_gradient = abs(y[1] - y[3]);
            if (cand_gradient < 10):
                horizontal_lines[cand_gradient] = y;

    horz = list(horizontal_lines.values())

    top_y = np.min(horz);
    bottom_y = np.max(horz);

    return top_y, bottom_y

# Takes in image and returns the edges for top and bottom parametrically, (x,y)
# Takes in RGB image
def edge_cropping_estimation_vertical(img, m):
    main_bright = img;

    local_vertical = [];

    # Vertical Cutting
    for row in range(0, main_bright.shape[0]):
        temp_arr = [];
        for col in range(0, main_bright.shape[1]):
            temp_arr.append(np.mean(main_bright[row][col]));
        local_vertical.append(np.mean(temp_arr));

    # ================ Vertical axis squish ================
    x_vertical = list(range(1, main_bright.shape[0] + 1 ));
    y_vertical = local_vertical;

    dydx_vertical = diff(y_vertical)/diff(x_vertical);
    y_verticle_dydx = list(range(1, main_bright.shape[0]));

    for i in range(0, len(dydx_vertical)):
        # Below Crazy 150 values
        if ((dydx_vertical[i] >= 150 and i <= 100) or (dydx_vertical[i] <= -150 and i <= 100)):
            dydx_vertical[i] = 0;

        # Above Crazy 150 values
        if ((dydx_vertical[i] >= 150 and i >= (main_bright.shape[0] - 100)) or (dydx_vertical[i] <= -150 and (main_bright.shape[0] - 100))):
            dydx_vertical[i] = 0;

    top_m_derivatives_ind = np.argpartition(dydx_vertical, m)[m:];
    sorted_ind_m = sorted(top_m_derivatives_ind);
    clustered_sorted_ind_m = [];

    cluster_iter_m = 0;
    prev_m = sorted_ind_m[0];
    cluster_sum_m = 0;

    for i in range(0, len(sorted_ind_m)):
        if (i == len(sorted_ind_m) - 1):
            clustered_sorted_ind_m.append(int(cluster_sum_m / cluster_iter_m));
        # If the previous value is outside the range of the current value, i
        elif (prev_m >= (sorted_ind_m[i] + 100) or prev_m <= (sorted_ind_m[i] - 100)):
            clustered_sorted_ind_m.append(int(cluster_sum_m / cluster_iter_m));
            cluster_sum_m = sorted_ind_m[i];
            cluster_iter_m = 1;
            prev_m = sorted_ind_m[i];
        else:
            cluster_sum_m += sorted_ind_m[i];
            cluster_iter_m += 1;
            prev_m = sorted_ind_m[i];


    print("The VERTICAL DERIVATIVE:")
    plt.plot(y_verticle_dydx, dydx_vertical);
    for i in clustered_sorted_ind_m:
        plt.axvline(x = i, color = 'r');
    plt.show();

    top = clustered_sorted_ind_m[0];
    bottom = clustered_sorted_ind_m[len(clustered_sorted_ind_m) - 1];

    return top, bottom;

# Takes in image and returns the edges for top and bottom parametrically, (x,y)
# Assumes Bottom is always min and top is always max
def edge_cropping_estimation_vertical_high_low_distr(img):
    main_bright = img;

    local_vertical = [];

    # Vertical Cutting
    for row in range(0, main_bright.shape[0]):
        temp_arr = [];
        for col in range(0, main_bright.shape[1]):
            temp_arr.append(np.mean(main_bright[row][col]));
        local_vertical.append(np.mean(temp_arr));

    # ================ Vertical axis squish ================
    x_vertical = list(range(1, main_bright.shape[0] + 1 ));
    y_vertical = local_vertical;

    dydx_vertical = diff(y_vertical)/diff(x_vertical);
    y_verticle_dydx = list(range(1, main_bright.shape[0]));

    for i in range(0, len(dydx_vertical)):
        # Below Crazy 150 values
        if ((dydx_vertical[i] >= 150 and i <= 100) or (dydx_vertical[i] <= -150 and i <= 100)):
            dydx_vertical[i] = 0;

        # Above Crazy 150 values
        if ((dydx_vertical[i] >= 150 and i >= (main_bright.shape[0] - 100)) or (dydx_vertical[i] <= -150 and (main_bright.shape[0] - 100))):
            dydx_vertical[i] = 0;

    max_val = np.max(dydx_vertical)
    max_index = np.where(dydx_vertical == max_val)[0][0];
    while(max_index > (img.shape[1]/2)):
        print("Cycling max_index:", max_index)
        dydx_vertical[max_index] = 0; # Reset the value as it is not needed anymore
        max_val = np.max(dydx_vertical)
        max_index = np.where(dydx_vertical == max_val)[0][0];

    min_val = np.min(dydx_vertical)
    min_index = np.where(dydx_vertical == min_val)[0][0];
    while(min_index < (img.shape[1]/2)):
        print("Cycling min_index:", min_index)
        dydx_vertical[min_index] = 0; # Reset the value as it is not needed anymore
        min_val = np.min(dydx_vertical)
        min_index = np.where(dydx_vertical == min_val)[0][0];

    print("The VERTICAL DERIVATIVE (Pattern Distribution):")
    plt.plot(y_verticle_dydx, dydx_vertical);
    plt.axvline(x = max_index, color = 'r');
    plt.axvline(x = min_index, color = 'r');
    plt.show();

    top = max_index
    bottom = min_index

    return top, bottom;

def remove_stage_jitter_MAE(img_array, iteration_depth : int = 1000, m = False, verbose : bool = False, mcm : bool = False):
    # Add Scores path just for curiosity
    scores = []
    X_shifts = []
    Y_shifts = []
    shifted_images = []

    if img_array.ndim != 3:
        print("Give a series of grayscale images. Shape: {}".format(img_array.shape))

    base = exposure.rescale_intensity(img_array[0])

    base_top, base_bottom = edge_cropping_estimation_vertical_high_low_distr(base)
    #     base_top, base_bottom = edge_cropping_estimation_vertical(base, m);

    # TODO: verify
    if base.ndim == 3:
        base = base[:, :, 0] # Reduce to the 2D

    iteration = 0 # TODO: implement in more pythonic way

    for _frame in img_array[1:]:
        iteration += 1

        template_image = exposure.rescale_intensity(_frame) # Get rid of low exposure

        template_top, template_bottom = edge_cropping_estimation_vertical_high_low_distr(template_image)

        if template_image.ndim == 3:
            template_image = template_image[:, :, 0] # Reduce to the 2D

        displacement, score = alignment_MAE(base, template_image, iteration_depth)
        scores.append(score)
        print("SCORE:", score)

        if mcm:
            displacement[0] = 0

        X_shifts.append(displacement[0])
        Y_shifts.append(int(np.mean([(base_top - template_top), (base_bottom -  template_bottom)])))
        shifted_image = ShiftedImage_2D(template_image, displacement[0], int(np.mean([(base_top - template_top), (base_bottom -  template_bottom)]))) # X,Y
        shifted_images.append(shifted_image)

        # For my purposes
        # background = Image.fromarray(base)
        # overlay = Image.fromarray(shifted_image)
        #
        # new_img = Image.blend(background, overlay, 0.5)

        # print("Overlay for image to compare against jitter (PHC)", iteration, ":", filename)
        # plt.imshow(new_img)
        # plt.show()
        #
        # # Write the new image in target folder
        # cv2.imwrite(os.path.join(output_path, filename), shifted_image);

    print ("Scores:", scores)
    print("The X_Shifts:", X_shifts)
    print("The Y_Shifts:", Y_shifts)

    return shifted_images, scores

In [None]:
# Test registration 1
import nd2

test_dataset = nd2.imread('/Volumes/Server_Data/1-HHLN/22.20 Lauren Data/1_10_lauren_replicate_1/3-SR_1_10_6hPre-C_Plain_M9_TS_MC2.nd2', dask=True)
# test_dataset = nd2.imread(
#    '/Users/hiram/Documents/EVERYTHING/20-29 Research/22 OliveiraLab/22.12 ND2 analyzer/nd2-analyzer/SR_1_5_2h_Pre-C_3h_IPTG_After10h_05_MC.nd2',
#    dask=True)
test_pos_over_time = test_dataset[:, 0, 0].compute()

# Normalization
norm = cv2.normalize(test_pos_over_time, None, 0, 65535, cv2.NORM_MINMAX)
plt.imshow(norm[0], cmap='gray')
plt.show()
# frames: (N,H,W) uint8
frames = ((norm / 65535) * 255).astype(np.uint8)
frames_rgb = np.stack([frames] * 3, axis=-1)  # (N,H,W,3)
aligned, scores = remove_stage_jitter_MAE(frames)

# Save both timelapses and evaluate
save_video("original_1.mp4", frames)
save_video("registered_1.mp4", np.array(aligned))

# Registration 2 - Not good

In [None]:
# Dirty Implementation of Shifting Images
def ShiftedImage_2D(Image, XShift, YShift):
    # Quick guard
    if (XShift == 0 and YShift == 0):
        return Image

    M = np.float32([
        [1, 0, XShift],
        [0, 1, YShift]
    ])

    shifted = cv2.warpAffine(Image, M, (Image.shape[1], Image.shape[0]))
    shifted_image = shifted

    # Shift Down
    if (YShift > 0):
        shifted_image = shifted_image[YShift:]
        shifted_image = np.pad(shifted_image, ((YShift, 0), (0, 0)), 'edge')  # Pad Up

    # Shift Up
    if (YShift < 0):
        shifted_image = shifted_image[:shifted.shape[0] - abs(YShift)]
        shifted_image = np.pad(shifted_image, ((0, abs(YShift)), (0, 0)), 'edge')  # Pad Down

    # Shift Left
    if (XShift > 0):
        shifted_image = np.delete(shifted_image, slice(0, XShift), 1)
        shifted_image = np.pad(shifted_image, ((0, 0), (XShift, 0)), 'edge')  # Pad Left

    if (XShift < 0):
        shifted_image = np.delete(shifted_image, slice(shifted.shape[1] - abs(XShift), shifted.shape[1]), 1);
        shifted_image = np.pad(shifted_image, ((0, 0), (0, abs(XShift))), 'edge');  # Pad Right

    return shifted_image

# sum of absolute differences (SAD) metric alignment, quick n dirty
def SAD(A, B):
    cutA = A.ravel();
    cutB = B.ravel();
    MAE = np.sum(np.abs(np.subtract(cutA, cutB, dtype=np.float64))) / cutA.shape[0]
    return MAE

# We use a Tree Search Algorithm to find possible alignment
# Let Image_1 be the orginal
# Let Image_2 be the aligned
# Displacement object is our nodes, [x,y]
# Assumption, there is always a better alignment up, down, left, and right if its not the same image
def alignment_MAE(Image_1, Image_2, depth_cap):
    iterative_cap = 0
    Best_SAD = SAD(Image_1, Image_2)
    Best_Displacement = [0, 0]
    q = []
    visited_states = [[0, 0]]  # Add (0,0) displacement
    q.append(Best_Displacement)  # Append (0,0) displacement

    while (iterative_cap != depth_cap and q):
        curr_state = q.pop(0);
        x = curr_state[0];
        y = curr_state[1];

        iterative_cap += 1;

        movement_arr = [
            [x, y - 1],  # Up
            [x, y + 1],  # Down
            [x + 1, y],  # Left
            [x - 1, y],  # Right
            [x - 1, y - 1],  # Diagonal
            [x + 1, y + 1],  # Diagonal
            [x + 1, y - 1],  # Diagonal
            [x - 1, y + 1],  # Diagonal
        ]

        for move in movement_arr:
            if (move not in visited_states):
                visited_states.append(move)  # Marked as Visited

                # Perform shift and calculate
                new_image = ShiftedImage_2D(Image_2, move[0], move[1])
                cand_SAD = SAD(Image_1, new_image)

                if (cand_SAD < Best_SAD):
                    Best_SAD = cand_SAD
                    Best_Displacement = move

                    q.append(move)

                # This means we cannot find a better move.

    return Best_Displacement, Best_SAD

# Vec4f is (x1, y1, x2, y2)
def y_shift_emphasis(image, block_threshold, MAE_shift, plot_debug=False):
    output_img = image

    # Need to turn into uint8 for Straight Line Detection
    img = np.uint8(image)
    lsd = cv2.createLineSegmentDetector(0)
    lines_contour = lsd.detect(img)[0]

    if plot_debug:
        drawn_img = lsd.drawSegments(img, lines_contour);
        plt.imshow(drawn_img, cmap='gray')
        plt.title('Drawn Image')
        plt.show()

    horizontal_lines = {};
    for x in lines_contour:
        for y in x:
            cand_gradient = abs(y[1] - y[3])
            if (cand_gradient < 10):
                horizontal_lines[cand_gradient] = y

    horz = list(horizontal_lines.values())

    top_y = np.min(horz);
    bottom_y = np.max(horz);

    return top_y, bottom_y

def remove_stage_jitter_MAE(input_images, iteration_depth, plot_debug=False):
    # Add Scores path just for curiosity
    scores = []
    result = []

    # # Create Output folder
    # if (not os.path.exists(output_path)):
    #     os.makedirs(output_path);

    # # Get training image files list:
    # image_name_arr = glob.glob(os.path.join(source_path, "*.png")) + glob.glob(os.path.join(source_path, "*.tif"));
    # image_name_arr_sorted = sorted(image_name_arr, key = lambda x:x[48:57]);

    # base_image = os.path.basename(image_name_arr_sorted[0]);
    # base = cv2.imread(os.path.join(source_path, base_image), cv2.IMREAD_ANYDEPTH)

    base = input_images[0]

    if base.ndim == 3:
        base = base[:, :, 0]  # Reduce to the 2D

    base_top, base_bottom = y_shift_emphasis(base, 15, 0)
    iteration = 0

    for img in input_images[1:]:
        iteration += 1
        template_image = img

        if template_image.ndim == 3:
            template_image = template_image[:, :, 0]  # Reduce to the 2D

        template_top, template_bottom = y_shift_emphasis(template_image, 15, 0)

        displacement, score = alignment_MAE(base, template_image, iteration_depth)
        scores.append(score)
        shifted_image = ShiftedImage_2D(template_image, displacement[0], int(np.mean(
            [(base_top - template_top), (base_bottom - template_bottom)])))  # X,Y

        # For my purposes
        background = Image.fromarray(np.uint8(base))
        overlay = Image.fromarray(np.uint8(shifted_image))
        new_img = Image.blend(background, overlay, 0.5)

        if plot_debug:
            print("Overlay to show frame jitter")
            print("Overlay for image", iteration)
            plt.imshow(new_img, cmap='gray')
            plt.title('New Image')
            # plt.imshow(shifted_image, cmap='gray')
            plt.show()

        # Write the new image in target folder
        shifted_image = exposure.rescale_intensity(shifted_image)  # Get rid of low exposure
        result.append(shifted_image)

    print("Scores:", scores)

    return result, scores

In [None]:
import matplotlib.pyplot as plt
# Load images
import nd2

test_dataset = nd2.imread(
    '/Users/hiram/Documents/EVERYTHING/20-29 Research/22 OliveiraLab/22.12 ND2 analyzer/nd2-analyzer/SR_1_5_2h_Pre-C_3h_IPTG_After10h_05_MC.nd2',
    dask=True)
test_pos_over_time = test_dataset[:, 0, 0].compute()

In [None]:
plt.imshow(test_pos_over_time[0], cmap='gray')

In [None]:
# Normalization
norm = cv2.normalize(test_pos_over_time, None, 0, 65535, cv2.NORM_MINMAX)
plt.imshow(norm[0], cmap='gray')
plt.show()
print('putaaa')
# frames: (N,H,W) uint8
frames = ((norm / 65535) * 255).astype(np.uint8)
frames_rgb = np.stack([frames] * 3, axis=-1)  # (N,H,W,3)
aligned, scores = remove_stage_jitter_MAE(frames, 1000)

In [None]:
# Save both timelapses and evaluate
save_video("original.mp4", frames)
save_video("registered.mp4", np.array(aligned))

## Register with skimage (not good)

In [None]:
from skimage.registration import phase_cross_correlation


def register_timestack(images):
    registered = [images[0]]  # First frame as reference
    shifts = [(0, 0)]

    for i in range(1, len(images)):
        shift, error, phase_diff = phase_cross_correlation(
            images[0], images[i],
            upsample_factor=100  # for sub-pixel precision
        )
        registered.append(np.roll(images[i], shift.astype(int), axis=(0, 1)))
        shifts.append(shift)

    return registered, shifts

In [None]:
import matplotlib.pyplot as plt


# frames: a NumPy array of shape (N, H, W) or (N, H, W, 3)
def show_frames_matplotlib(frames, delay=0.1):
    plt.ion()  # interactive mode on
    fig, ax = plt.subplots()
    for frame in frames:
        ax.clear()
        if frame.ndim == 2:
            ax.imshow(frame, cmap='gray', vmin=0, vmax=255)
        else:
            ax.imshow(frame)
        ax.axis('off')
        display(fig)
        plt.pause(delay)
    plt.ioff()
    plt.show()

In [None]:

import numpy as np
import cv2

# Normalization
norm = cv2.normalize(test_pos_over_time.compute(), None, 0, 65535, cv2.NORM_MINMAX)
# plt.imshow(norm[0], cmap='gray')
# frames: (N,H,W) uint8
frames = ((norm / 65535) * 255).astype(np.uint8)
frames_rgb = np.stack([frames] * 3, axis=-1)  # (N,H,W,3)

# frames_rgb: (N, H, W, 3) uint8 in RGB
h, w = frames_rgb.shape[1], frames_rgb.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("timelapse.mp4", fourcc, 15.0, (w, h), isColor=True)

for fr in frames_rgb:
    fr_bgr = cv2.cvtColor(fr, cv2.COLOR_RGB2BGR)
    out.write(fr_bgr)

out.release()
print("Saved timelapse.mp4")


In [None]:
frames_rgb.shape

In [None]:
# Do the registration
registered, shifts = register_timestack(frames)

In [None]:

import numpy as np
import cv2

# # frames: (N,H,W) uint8
# frames = ((test_pos_over_time / 65535) * 255).compute().astype(np.uint8)
# frames_rgb = np.stack([frames]*3, axis=-1)  # (N,H,W,3)

# frames_rgb: (N, H, W, 3) uint8 in RGB
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("registered.mp4", fourcc, 15.0, (w, h), isColor=True)

for fr in registered:
    fr_bgr = cv2.cvtColor(fr, cv2.COLOR_GRAY2BGR)
    out.write(fr_bgr)

out.release()
print("Saved timelapse.mp4")
