In [11]:
import reset
from pathlib import Path
import numpy as np
import scipy as sp
import tifffile
import math
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [12]:
# Assign variables from reset
animal_loc = reset.animal_loc
animal_id = reset.animal_id
hemi = reset.hemi
divisor = reset.divisor
down_x = reset.down_x
down_y = reset.down_y
grid = reset.grid
frame_len = reset.frame_len
max_pos = reset.max_pos

In [13]:
# Create list of locations as strings
pos_list = [str(n+1) for n in range(max_pos)]

# Define paths for input and output images
image_loc = animal_loc / "xmerged" / hemi
save_loc = animal_loc / "ymerged" / hemi

# Ensure save location exists
save_loc.mkdir(parents=True, exist_ok=True)

image_path_list = [f"{image_loc}/{animal_id}_{hemi}_{pos}.tif" for pos in pos_list]

In [14]:
height = down_y

frame_height = 3 * down_y
frame_width = 3 * down_x

# Initialize frames for top, middle, and bottom sections
frame = np.empty((frame_len, frame_height, frame_width))

In [15]:
def find_brightest_frame(path_num1, path_num2, path_num3):
    global max_frame
    """
    Find the brightest frame across three image sequences.
    Returns the index of the brightest frame and the average brightness for each image.
    """
    average_brightness = []
    for path_num in [path_num1, path_num2, path_num3]:
        brightness_list = []
        for fn in range(frame_len):
            im = tifffile.imread(image_path_list[path_num - 1])
            im_f = im[fn, :, :]
            brightness_list.append(np.average(im_f))
        
        max_brightness = max(brightness_list)
        max_frame_index = np.argmax(brightness_list)
        average_brightness.append((max_brightness, max_frame_index))

    # Determine the maximum brightness and corresponding frame
    max_brightness_values, max_frame_indices = zip(*average_brightness)
    max_frame = max_frame_indices[np.argmax(max_brightness_values)]
    
    print(f"Brightest frame across sections: {max_frame}")
    return max_frame

In [16]:
def y_connector(path_num1, path_num2, rect_y, rect_y_move, rect_y_start):
    global start_y
    # Load and extract frames from both images
    im1 = tifffile.imread(image_path_list[path_num1 - 1])
    im1_f = im1[max_frame, :, :]
    del im1

    im2 = tifffile.imread(image_path_list[path_num2 - 1])
    im2_f = im2[max_frame, :, :]
    del im2

    # Calculate correlations over specified range
    cor = [
        np.corrcoef(
            im1_f[:rect_y, :].ravel(),
            im2_f[i:i + rect_y, :].ravel()
        )[0, 1]
        for i in range(rect_y_start, rect_y_start + rect_y_move)
    ]

    start_y = np.argmax(cor) + rect_y

# Main function to align three sections vertically
def nine_sections(f1, f2, f3, frame_name, rect_y, rect_y_move, rect_y_start):
    global target1_start_y, target2_start_y

    y_connector(f1, f2, rect_y, rect_y_move, rect_y_start)
    target1_start_y = start_y
    y_connector(f2, f3, rect_y, rect_y_move, rect_y_start)
    target2_start_y = start_y

    print("First section target frame and start y:", max_frame, target1_start_y)
    print("Second section target frame and start y:", max_frame, target2_start_y)

def top(f1, frame_name = frame):
    im1 = tifffile.imread(image_path_list[f1 - 1])
    width1 = im1.shape[2]
    for fn in range(frame_len):
        frame_name[fn, :height, :width1] = im1[fn,:,:]
    del im1

def middle(f2, start_y1, frame_name=frame):
    im2 = tifffile.imread(image_path_list[f2 - 1])
    width2 = im2.shape[2]
    for fn in range(frame_len):
        # Compute overlapping section and average
        overlap_start = height - start_y1
        overlap = (frame_name[fn, overlap_start:height, :width2] + im2[fn, :start_y1, :]) / 2
        frame_name[fn, overlap_start:height, :width2] = overlap
        frame_name[fn, height:2 * height - start_y1, :width2] = im2[fn, start_y1:, :]
    del im2

def bottom(f3, start_y1, start_y2, frame_name=frame):
    im3 = tifffile.imread(image_path_list[f3 - 1])
    width3 = im3.shape[2]
    for fn in range(frame_len):
        # Compute overlapping section and average
        overlap_start = 2 * height - start_y1 - start_y2
        overlap = (frame_name[fn, overlap_start:2 * height - start_y1, :width3] + im3[fn, :start_y2, :]) / 2
        frame_name[fn, overlap_start:2 * height - start_y1, :width3] = overlap
        frame_name[fn, 2 * height - start_y1:3 * height - start_y1 - start_y2, :width3] = im3[fn, start_y2:, :]
    del im3

In [17]:
#Adjust as deeded
rect_y = 20
rect_y_move = 200
rect_y_start = down_y - rect_y - rect_y_move

In [None]:
find_brightest_frame(1, 2, 3)
# Run alignment on specified sections
nine_sections(1, 2, 3, frame, rect_y, rect_y_move, rect_y_start)

In [8]:
top(1)
middle(2, 75)

In [10]:
bottom(3, target1_start_y, 260)

In [None]:
fig = go.Figure(data=go.Heatmap(z=frame[max_frame], colorscale="Viridis"))
fig.update_layout(xaxis=dict(scaleanchor="y"), yaxis=dict(scaleanchor="x"))
fig.show()

In [None]:
# Save the aligned image
save_path = save_loc / f"{animal_id}_{hemi}_1.tif"
tifffile.imwrite(save_path, frame.astype('uint8'))
print(f"Aligned image saved to {save_path}")

In [None]:
find_brightest_frame(4, 5, 6)
# Run alignment on specified sections
nine_sections(4, 5, 6, frame, rect_y, rect_y_move, rect_y_start)

In [18]:
top(4)
middle(5, 0)

In [15]:
bottom(6, 175, 230)

In [None]:
fig = go.Figure(data=go.Heatmap(z=frame[max_frame], colorscale="Viridis"))
fig.update_layout(xaxis=dict(scaleanchor="y"), yaxis=dict(scaleanchor="x"))
fig.show()

In [19]:
# Save the aligned image
save_path = save_loc / f"{animal_id}_{hemi}_2.tif"
tifffile.imwrite(save_path, frame.astype('uint8'))
print(f"Aligned image saved to {save_path}")

Aligned image saved to /Volumes/BaffaloSSDPUTU3C1TB/rbak_data/rbak004/ymerged/l/rbak004_l_2.tif
