In [1]:
# poc registration (class definition)

import numpy as np
from scipy.fft import fft2, ifft2, fftshift
from scipy.ndimage import shift

def poc_reg(image_series, mag, cut, target):
    """
    Register a series of images using Phase Only Correlation (POC).

    Parameters:
    image_series : np.ndarray
        3D numpy array of the images to be registered (height x width x num_images)
    mag : int
        Magnification factor for the upsampling of the correlation.
    cut : int
        Number of pixels to cut from the borders.
    target : np.ndarray, optional
        Target image for the registration. If not provided, the mean image will be used.

    Returns:
    regged_image_series : np.ndarray
        Registered image series.
    dif_y : np.ndarray
        Displacement along y axis for each image.
    dif_x : np.ndarray
        Displacement along x axis for each image.
    """

    debug = False
    debug_itr = []

    # Subset the image series (cut the borders)
    if cut == 0:
     subset = image_series  # 全体を選択
    else:
     subset = image_series[cut:-cut, cut:-cut, :]

    # Compute the target/mean image
    if target is not None:
        mean_image = target
    else:
        mean_image = np.mean(subset, axis=2)

    #plt.imshow(mean_image)
    # Image dimensions
    w_row, w_col, n_im = subset.shape

    # 2D Hanning window
    wx, wy = np.meshgrid(
        0.5 - 0.5 * np.cos(2 * np.pi / mean_image.shape[1] * np.linspace(0, mean_image.shape[1], mean_image.shape[1])),
        0.5 - 0.5 * np.cos(2 * np.pi / mean_image.shape[0] * np.linspace(0, mean_image.shape[0], mean_image.shape[0]))
    )
    mean_image *= wx * wy

    # FFT of the mean image
    ft_mean = fft2(mean_image)

    dif_y = np.zeros(n_im)
    dif_x = np.zeros(n_im)

    def process_image(i):
        nonlocal debug
        if i in debug_itr:
            debug = True
        else:
            debug = False

        temp = subset[:, :, i].astype(float)
#        temp *= wx * wy
        ft_temp = fft2(temp)

        # Phase-only cross-power spectrum
        factorial = ft_temp * np.conj(ft_mean) / np.sqrt(np.abs(ft_mean) * np.abs(ft_temp))

        if mag != 1:
            zeros_ft = np.zeros_like(factorial)
            factorial = np.vstack((
                np.hstack((factorial[:w_row//2, :w_col//2], np.tile(zeros_ft[:w_row//2, :w_col], (1, mag-1)), factorial[:w_row//2, w_col//2:])),
                np.tile(zeros_ft, (mag-1, mag)),
                np.hstack((factorial[w_row//2:, :w_col//2], np.tile(zeros_ft[w_row//2:, :w_col], (1, mag-1)), factorial[w_row//2:, w_col//2:]))
            ))

        # Inverse FFT and correlation peak detection
        c = fftshift(np.real(ifft2(factorial)))
        c1 = np.max(c, axis=0)
        i1 = np.argmax(c, axis=0)
        i2 = np.argmax(c1)

        # Sub-pixel displacement calculation
        val_y = (c[i1[i2]-1, i2] - c[i1[i2]+1, i2]) / (2 * (c[i1[i2]-1, i2] + c[i1[i2]+1, i2] - 2 * c[i1[i2], i2])) + i1[i2]
        val_x = (c[i1[i2], i2-1] - c[i1[i2], i2+1]) / (2 * (c[i1[i2], i2-1] + c[i1[i2], i2+1] - 2 * c[i1[i2], i2])) + i2

        dif_y[i] = -(val_y - (w_row // 2 * mag + 1)) / mag
        dif_x[i] = -(val_x - (w_col // 2 * mag + 1)) / mag

        # Handle large displacements
        if abs(dif_y[i]) > 2000:
            dif_y[i] = np.finfo(float).eps
            print(f"A displacement > 100 pixels was detected in frame {i} and was assumed as eps")
        # if abs(dif_x[i]) > 2000:
        #     dif_x[i] = np.finfo(float).eps
        #     print(f"A displacement > 100 pixels was detected in frame {i} and was assumed as eps")

        return dif_y[i], dif_x[i]

    # Use multiprocessing to parallelize the process
    for i in range(n_im):
     dif_y[i], dif_x[i] = process_image(i)

    dif_y, dif_x = np.array(dif_y), np.array(dif_x)

    # Apply the translation to the image series
    #regged_image_series = np.array([shift(image_series[:, :, i], (dif_y[i], dif_x[i])) for i in range(n_im)])

    #return regged_image_series.astype(np.uint16), dif_y, dif_x
    return dif_y, dif_x


In [2]:
import reset
from pathlib import Path
import numpy as np
import tifffile

In [3]:
# Assign variables from reset
animal_loc = reset.animal_loc
animal_id = reset.animal_id
hemi = reset.hemi
frame_len = reset.frame_len

In [4]:
# Define paths for input and output images
image_loc = animal_loc / "padding" / hemi
save_loc = animal_loc / "zmerged"

# Ensure save location exists
save_loc.mkdir(parents=True, exist_ok=True)

image_path_list = [f"{image_loc}/{animal_id}_{hemi}_{n}.tif" for n in range(1,3)]

In [5]:
image_path1 = image_path_list[0]
targetimage = tifffile.imread(image_path1).astype('float64')
meantargetimage = np.mean(targetimage, axis=0)

In [6]:
image_path2 = image_path_list[1]
print(image_path2)
# .tifファイルを読み込み、3次元のnumpy配列に変換
image_series = tifffile.imread(image_path2)
image_series = np.transpose(image_series, (1, 2, 0))
# image_seriesは高さ x 幅 x フレーム数 の3次元numpy配列になります
print(image_series.shape)

/Volumes/BaffaloSSDPUTU3C1TB/rbak_data/rbak006/padding/l/rbak006_l_2.tif
(3072, 2160, 62)


In [7]:
dif_y, dif_x =poc_reg(image_series,1,0,meantargetimage)

In [8]:
w_row, w_col, n_im = image_series.shape
regged_image_series = np.array([shift(image_series[:, :, i], (dif_y[i], dif_x[i])) for i in range(n_im)])

In [21]:
approx_start = 30
approx_end = 40

In [None]:
corr_list = []
for fn1 in range(approx_start,approx_end):
    for fn2 in range(0,frame_len):
        corr = abs(np.corrcoef(targetimage[fn1],regged_image_series[fn2])[0,1])
        print(corr)
        corr_list.append(corr)

corr_max = np.argmax(corr_list)
print(corr_max)
f1_max = corr_max // (approx_end-approx_start) + approx_start
f2_max = corr_max % (frame_len)

poc_reg_y = dif_y[f2_max]
poc_reg_x = dif_x[f2_max]

new_dif_y = np.full(frame_len, poc_reg_y)
new_dif_x = np.full(frame_len, poc_reg_x)

  c /= stddev[:, None]
  c /= stddev[None, :]


[1.         0.70876956]
[1.         0.70876956]
[1.         0.70876956]
[1.         0.70876956]
[1.         0.70876956]
[1.         0.70876956]


KeyboardInterrupt: 

In [33]:
print(max(corr_list))

0.8150362574288988


In [34]:
newtargetimage = targetimage[f1_max]

IndexError: index 85 is out of bounds for axis 0 with size 62

In [25]:
dif_y, dif_x =poc_reg(image_series,1,0,newtargetimage)

In [26]:
poc_reg_y = dif_y[f2_max]
poc_reg_x = dif_x[f2_max]

new_dif_y = np.full(frame_len, poc_reg_y)
new_dif_x = np.full(frame_len, poc_reg_x)

In [27]:
print(f1_max, f2_max, poc_reg_y, poc_reg_x)

49 12 1.0824398720928912 -96.40751420981019


In [28]:
w_row, w_col, n_im = image_series.shape
new_regged_image_series = np.array([shift(image_series[:, :, i], (new_dif_y[i], new_dif_x[i])) for i in range(n_im)])

In [29]:
im1 = targetimage[:f1_max]
new_regged_image_series = new_regged_image_series[f2_max:]

In [30]:
img = np.concatenate((im1, new_regged_image_series), axis = 0)

In [19]:
# SAVE
img = img.astype('uint8')
save_path = save_loc / f"{animal_id}_{hemi}.tif"
tifffile.imwrite(save_path,img)