In [1]:
import numpy as np
import os, torch, copy, scipy, skimage, json
from tqdm import trange, tqdm
import nibabel as nib
from PIL import Image

import matplotlib.pyplot as plt

from skimage.registration import phase_cross_correlation


from datetime import datetime
OVERLAP_R = 0.2



In [2]:
ptag='pair4' 
btag='220904_L35D719P5_topro_brn2_ctip2_4x_0_108na_50sw_11hdf_4z_20ov_21-49-38'
# btag='220902_L35D719P3_topro_brn2_ctip2_4x_11hdf_0_108na_50sw_4z_20ov_16-41-36'

zratio = 2.5/4
overlap_r = OVERLAP_R

save_path = f'/cajal/ACMUSERS/ziquanw/Lightsheet/stitch_by_ptreg/{ptag}/{btag.split("_")[1]}'
result_path = f'/cajal/ACMUSERS/ziquanw/Lightsheet/results/P4/{ptag}/{btag}'
tile_loc = np.array([[int(fn[8:10]), int(fn[-3:-1])] for fn in os.listdir(result_path)])
ncol, nrow = tile_loc.max(0)+1
# print(tile_loc, nrow, ncol)
assert len(tile_loc) == nrow*ncol, f'tile of raw data is not complete, tile location: {tile_loc}'

ls_image_root = f'/lichtman/Felix/Lightsheet/P4/{ptag}/{btag}'
fn_ = ls_image_root.split('_')[-1]+'_'+'_'.join(btag.split('_')[1:-1])
root = result_path + '/UltraII[%02d x %02d]'
stack_names = [f for f in os.listdir(root % (0, 0)) if f.endswith('instance_center.zip')]
neighbor = [[-1, 0], [0, -1], [-1, -1], [1, 0], [0, 1], [1, 1], [1, -1], [-1, 1]]
os.makedirs(f'{save_path}/NIS_tranform', exist_ok=True)

In [3]:
for stack_name in stack_names:
    print(stack_name)
    meta_name = stack_name.replace('instance_center', 'seg_meta')
    zstart = int(stack_name.split('zmin')[1].split('_')[0])
    zstart = int(zstart*zratio)
    seg_shape = torch.load(f'{root % (0, 0)}/{meta_name}')
    zend = int(zstart + seg_shape[0].item()*zratio)

zstart = 0 
print(zstart, zend)
tile_lt_loc = {
    f'{i}-{j}': [i*seg_shape[1]*(1-overlap_r), j*seg_shape[2]*(1-overlap_r)] for i in range(ncol) for j in range(nrow)
}
tile_lt_loc

L35D719P3_NIScpp_results_zmin0_instance_center.zip
L35D719P3_NIScpp_results_zmin958_instance_center.zip
0 1048


{'0-0': [tensor(0.), tensor(0.)],
 '0-1': [tensor(0.), tensor(1605.6000)],
 '0-2': [tensor(0.), tensor(3211.2000)],
 '0-3': [tensor(0.), tensor(4816.8003)],
 '0-4': [tensor(0.), tensor(6422.3999)],
 '1-0': [tensor(1903.2001), tensor(0.)],
 '1-1': [tensor(1903.2001), tensor(1605.6000)],
 '1-2': [tensor(1903.2001), tensor(3211.2000)],
 '1-3': [tensor(1903.2001), tensor(4816.8003)],
 '1-4': [tensor(1903.2001), tensor(6422.3999)],
 '2-0': [tensor(3806.4001), tensor(0.)],
 '2-1': [tensor(3806.4001), tensor(1605.6000)],
 '2-2': [tensor(3806.4001), tensor(3211.2000)],
 '2-3': [tensor(3806.4001), tensor(4816.8003)],
 '2-4': [tensor(3806.4001), tensor(6422.3999)],
 '3-0': [tensor(5709.6001), tensor(0.)],
 '3-1': [tensor(5709.6001), tensor(1605.6000)],
 '3-2': [tensor(5709.6001), tensor(3211.2000)],
 '3-3': [tensor(5709.6001), tensor(4816.8003)],
 '3-4': [tensor(5709.6001), tensor(6422.3999)]}

In [4]:
row_order = [nrow//2] + [(-1)**i*(i//2+1)+nrow//2 for i in range(nrow-1)]
row_order = [i for i in row_order if i in list(range(nrow))]
for i in range(nrow): 
    if i not in row_order: row_order.append(i)
col_order = [ncol//2] + [(-1)**i*(i//2+1)+ncol//2 for i in range(ncol-1)]
col_order = [i for i in col_order if i in list(range(ncol))]
for i in range(ncol): 
    if i not in col_order: col_order.append(i)

print(col_order, row_order)


[2, 3, 1, 0] [2, 3, 1, 4, 0]


In [5]:
def get_tile_stack_overlap_area(fn, zrange, overlap_r):
    stack_overlaps = [[], [], [], []]
    for zi in zrange:
        tile_img = Image.open(fn % zi)
        tile_img = np.asarray(tile_img)
        h, w = tile_img.shape
        t = tile_img[:int(h*overlap_r)+1]
        b = tile_img[int(h*(1-overlap_r)):]
        l = tile_img[:, :int(w*overlap_r)+1]
        r = tile_img[:, int(w*(1-overlap_r)):]
        stack_overlaps[0].append(t)
        stack_overlaps[1].append(b)
        stack_overlaps[2].append(l)
        stack_overlaps[3].append(r)
    
    stack_overlaps = [
        np.stack(stack_overlaps[0]),
        np.stack(stack_overlaps[1]),
        np.stack(stack_overlaps[2]),
        np.stack(stack_overlaps[3]),
    ]
    return stack_overlaps
    

In [6]:
def shift_image(image, shift):
    s = int(shift[0])
    if s < 0:
        image = torch.cat([image[-s:], torch.zeros(-s, image.shape[1], image.shape[2], dtype=image.dtype)])
    elif s > 0:
        image = torch.cat([torch.zeros(s, image.shape[1], image.shape[2], dtype=image.dtype), image[:-s]])
    
    s = int(shift[1])
    if s < 0:
        image = torch.cat([image[:, -s:], torch.zeros(image.shape[0], -s, image.shape[2], dtype=image.dtype)], 1)
    elif s > 0:
        image = torch.cat([torch.zeros(image.shape[0], s, image.shape[2], dtype=image.dtype), image[:, :-s]], 1)
        
    s = int(shift[2])
    if s < 0:
        image = torch.cat([image[:, :, -s:], torch.zeros(image.shape[0], image.shape[1], -s, dtype=image.dtype)], 2)
    elif s > 0:
        image = torch.cat([torch.zeros(image.shape[0], image.shape[1], s, dtype=image.dtype), image[:, :, :-s]], 2)
        
    return image
    

In [7]:
'''
One example
'''
# overlap_r = 0.2
# i = 0
# j = 0
# fn = f'{ls_image_root}/{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z%04d.ome.tif'
# registered_tile = {f'{i}-{j}': get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)}

# i = 1
# j = 0
# fn = f'{ls_image_root}/{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z%04d.ome.tif'
# registered_tile[f'{i}-{j}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)

'\nOne example\n'

In [8]:
# print(registered_tile['0-0'][2].shape, registered_tile['1-0'][3].shape)
# print(registered_tile['0-0'][3].shape, registered_tile['1-0'][2].shape)

In [9]:
'''
One example
'''

# down_r = [1, 1, 1]

# image = registered_tile['0-0'][1].astype(float)
# offset_image = registered_tile['1-0'][0].astype(float)
# print('original shape', image.shape, offset_image.shape)
# image = torch.nn.functional.interpolate(torch.from_numpy(image)[None, None], scale_factor=down_r)[0,0].numpy()
# offset_image = torch.nn.functional.interpolate(torch.from_numpy(offset_image)[None, None], scale_factor=down_r)[0,0].numpy()
# print('downsampled shape', image.shape, offset_image.shape)

# time = datetime.now()
# # pixel precision first
# shift, error, diffphase = phase_cross_correlation(image, offset_image)

# print(f'Time consume {datetime.now()-time}. Detected pixel offset: {[s/r for s, r in zip(shift, down_r)]}', error, diffphase)

# offset_image = shift_image(torch.from_numpy(offset_image), shift).numpy()
# print('shifted shape', image.shape, offset_image.shape)



'\nOne example\n'

In [10]:
'''
One example
'''
# vis_slice = 200

# fig = plt.figure(figsize=(24, 9))
# ax1 = plt.subplot(1, 3, 1)
# ax2 = plt.subplot(1, 3, 2, sharex=ax1, sharey=ax1)
# ax3 = plt.subplot(1, 3, 3)

# ax1.imshow(image[:, vis_slice], cmap='gray')
# ax1.set_axis_off()
# ax1.set_title('Reference image')

# ax2.imshow(offset_image[:, vis_slice], cmap='gray')
# ax2.set_axis_off()
# ax2.set_title('Offset image')

# overlay = (image + offset_image)/2
# ax3.imshow(overlay[:, vis_slice], cmap='gray')
# ax3.set_axis_off()
# ax3.set_title("Overlay")

# plt.show()


'\nOne example\n'

In [11]:
# '''
# Average multiple shifts 
# '''
# def pcc_all_neighbor_tile(cur_ij, tile_overlap, tform_stack_coarse, tformed_tile_lt_loc, imax, jmax):
# #     neighbor = [[-1, 0], [0, -1], [1, 0], [0, 1]]
#     neighbor = [[1, 0], [0, 1]]
#     down_r = 0.5
#     i, j = cur_ij
#     print(datetime.now(), f'start read stack ({i},{j})')
#     if f'{i}-{j}' not in tile_overlap:
#         fn = f'{ls_image_root}/{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z%04d.ome.tif'
#         tile_overlap[f'{i}-{j}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)

#     reference_images = tile_overlap[f'{i}-{j}']
#     print(datetime.now(), f'done read stack ({i},{j})')
#     for pi, pj in neighbor:
#         ## skip repeated pcc
# #         if i+pi == pre_ij[0] and j+pj == pre_ij[1]: continue
#         if i+pi >= imax or i+pi < 0: continue
#         if j+pj >= jmax or j+pj < 0: continue
#         print(datetime.now(), f'start read stack ({i+pi},{j+pj})')
#         if f'{i+pi}-{j+pj}' not in tile_overlap:
#             fn = f'{ls_image_root}/{fn_}_UltraII[{(i+pi):02d} x {(j+pj):02d}]_C01_xyz-Table Z%04d.ome.tif'
#             tile_overlap[f'{i+pi}-{j+pj}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)
    
#         moving_image = tile_overlap[f'{i+pi}-{j+pj}']
#         print(datetime.now(), f'done read stack ({i+pi},{j+pj})')
#         if pi < 0:
#             moving_image = moving_image[1].astype(float)
#             reference_image = reference_images[0].astype(float)
#         elif pi > 0:
#             moving_image = moving_image[0].astype(float)
#             reference_image = reference_images[1].astype(float)
#         elif pj < 0:
#             moving_image = moving_image[3].astype(float)
#             reference_image = reference_images[2].astype(float)
#         elif pj > 0:
#             moving_image = moving_image[2].astype(float)
#             reference_image = reference_images[3].astype(float)

#         moving_image = torch.nn.functional.interpolate(torch.from_numpy(moving_image)[None, None], scale_factor=down_r)[0,0].numpy()
#         reference_image = torch.nn.functional.interpolate(torch.from_numpy(reference_image)[None, None], scale_factor=down_r)[0,0].numpy()
#         print(datetime.now(), f"start phase cross correlation mov ({i+pi},{j+pj}) to ({i},{j})")
#         shift, error, diffphase = phase_cross_correlation(reference_image, moving_image, overlap_ratio=0.8)
#         shift = [s/down_r for s in shift]

#         tz, tx, ty = shift
#         ## plus with tgt shifts
#         if len(tform_stack_coarse[f'{i}-{j}'][0]) > 0:
#             tz = tz + np.mean(tform_stack_coarse[f'{i}-{j}'][0])
#             tx = tx + np.mean(tform_stack_coarse[f'{i}-{j}'][1])
#             ty = ty + np.mean(tform_stack_coarse[f'{i}-{j}'][2])
#         shift = [tz, tx, ty]
#         print(datetime.now(), f"done phase cross correlation, shift: {shift}, error: {error:.6f}, dphase: {diffphase:.6f}")

#         ## it might connects multiple tiles    
#         tform_stack_coarse[f'{i+pi}-{j+pj}'][0].append(tz)
#         tform_stack_coarse[f'{i+pi}-{j+pj}'][1].append(tx)
#         tform_stack_coarse[f'{i+pi}-{j+pj}'][2].append(ty)


#     for pi, pj in neighbor:
#         ## skip repeated pcc to prevent infinite loop
# #         if i+pi == pre_ij[0] and j+pj == pre_ij[1]: continue
#         if i+pi >= imax or i+pi < 0: continue
#         if j+pj >= jmax or j+pj < 0: continue
#         tform_stack_coarse, _, tile_overlap = pcc_all_neighbor_tile([i+pi, j+pj], tile_overlap, tform_stack_coarse, None, imax, jmax)
    
#     if tformed_tile_lt_loc is not None:
#         for k in tform_stack_coarse:
#             tz, tx, ty = tform_stack_coarse[k]
#             tz = np.mean(tz)
#             tx = np.mean(tx)
#             ty = np.mean(ty)
#             tformed_tile_lt_loc[0][k][0] = tformed_tile_lt_loc[0][k][0] + tx
#             tformed_tile_lt_loc[0][k][1] = tformed_tile_lt_loc[0][k][1] + ty
    
#     return tform_stack_coarse, tformed_tile_lt_loc, tile_overlap

# # Init inputs


# tform_stack_coarse = {f'{i}-{j}': [[], [], []] for i in range(ncol) for j in range(nrow)}
# tformed_tile_lt_loc = {zstart: copy.deepcopy(tile_lt_loc)}

# # Start at imaging starting point (0,0)
# i, j = 0, 0
# tform_stack_coarse[f'{i}-{j}'] = [[0], [0], [0]]
# # tile_overlap = {}
# tform_stack_coarse, tformed_tile_lt_loc, tile_overlap = pcc_all_neighbor_tile([i, j], tile_overlap, tform_stack_coarse, tformed_tile_lt_loc, ncol, nrow)
# print(tform_stack_coarse)

In [12]:
tile_overlap = {}
for i in range(ncol):
    for j in trange(nrow):
        if f'{i}-{j}' not in tile_overlap:
            fn = f'{ls_image_root}/{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z%04d.ome.tif'
            tile_overlap[f'{i}-{j}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)


100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:32<00:00, 66.54s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:38<00:00, 67.78s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:27<00:00, 65.43s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:17<00:00, 63.58s/it]


In [13]:
tile_overlap.keys()

dict_keys(['0-0', '0-1', '0-2', '0-3', '0-4', '1-0', '1-1', '1-2', '1-3', '1-4', '2-0', '2-1', '2-2', '2-3', '2-4', '3-0', '3-1', '3-2', '3-3', '3-4'])

In [14]:
'''
Coarse reg by 3D PCC
'''
tform_stack_coarse_colrow = {}
neighbor_row = [[0, -1], [0, 1]]
neighbor_col = [[-1, 0], [1, 0]]
down_r = 0.5
for i in range(ncol):
    j = 0
    tform_stack_coarse_colrow[f'{i}-{j}'] = [[0], [0], [0]]
    for j in range(1, nrow):
        if f'{i}-{j}' not in tile_overlap:
            fn = f'{ls_image_root}/{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z%04d.ome.tif'
            tile_overlap[f'{i}-{j}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)

        moving_image = tile_overlap[f'{i}-{j}']
        if f'{i}-{j}' not in tform_stack_coarse_colrow:
            tform_stack_coarse_colrow[f'{i}-{j}'] = [[], [], []]
        for pi, pj in neighbor_row:
            if f'{i+pi}-{j+pj}' in tform_stack_coarse_colrow: 
                if f'{i+pi}-{j+pj}' not in tile_overlap:
                    fn = f'{ls_image_root}/{fn_}_UltraII[{(i+pi):02d} x {(j+pj):02d}]_C01_xyz-Table Z%04d.ome.tif'
                    tile_overlap[f'{i+pi}-{j+pj}'] = get_tile_stack_overlap_area(fn, range(zstart, zend), overlap_r)

                reference_image = tile_overlap[f'{i+pi}-{j+pj}']
                if pj < 0:
                    moving_image = moving_image[2].astype(float)
                    reference_image = reference_image[3].astype(float)
                elif pj > 0:
                    moving_image = moving_image[3].astype(float)
                    reference_image = reference_image[2].astype(float)
                
                moving_image = torch.nn.functional.interpolate(torch.from_numpy(moving_image)[None, None], scale_factor=down_r)[0,0].numpy()
                reference_image = torch.nn.functional.interpolate(torch.from_numpy(reference_image)[None, None], scale_factor=down_r)[0,0].numpy()
                print(datetime.now(), f"start phase cross correlation mov ({i},{j}) to ({i+pi},{j+pj})")
                shift, error, diffphase = phase_cross_correlation(reference_image, moving_image, overlap_ratio=0.8)
                shift = [s/down_r for s in shift]

                tz, tx, ty = shift
                assert len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0]) == 1, len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0])
                assert len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1]) == 1, len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1])
                assert len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2]) == 1, len(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2])
#                 tz = tz + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0])
#                 tx = tx + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1])
#                 ty = ty + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2])
#                 tz = tz + tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0][0]
#                 tx = tx + tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1][0]
#                 ty = ty + tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2][0]
                shift = [tz, tx, ty]
                print(datetime.now(), f"done phase cross correlation, shift: {shift}, error: {error:.6f}, dphase: {diffphase:.6f}")
                tform_stack_coarse_colrow[f'{i}-{j}'][0].append(tz)
                tform_stack_coarse_colrow[f'{i}-{j}'][1].append(tx)
                tform_stack_coarse_colrow[f'{i}-{j}'][2].append(ty)

#                 for imgi, img in enumerate(tile_overlap[f'{i}-{j}']):
#                     tile_overlap[f'{i}-{j}'][imgi] = shift_image(torch.from_numpy(img.astype(float)), shift).numpy()
                    
                break
print(tform_stack_coarse_colrow)             
                
for j in range(nrow):
    for i in range(1, ncol):
        moving_image = tile_overlap[f'{i}-{j}']
        assert f'{i}-{j}' in tform_stack_coarse_colrow, f'{i}-{j}'
        for pi, pj in neighbor_col:
            if f'{i+pi}-{j+pj}' in tform_stack_coarse_colrow: 
                reference_image = tile_overlap[f'{i+pi}-{j+pj}']
                if pi < 0:
                    moving_image = moving_image[0].astype(float)
                    reference_image = reference_image[1].astype(float)
                elif pi > 0:
                    moving_image = moving_image[1].astype(float)
                    reference_image = reference_image[0].astype(float)
                
                moving_image = torch.nn.functional.interpolate(torch.from_numpy(moving_image)[None, None], scale_factor=down_r)[0,0].numpy()
                reference_image = torch.nn.functional.interpolate(torch.from_numpy(reference_image)[None, None], scale_factor=down_r)[0,0].numpy()
                print(datetime.now(), f"start phase cross correlation mov ({i},{j}) to ({i+pi},{j+pj})")
                shift, error, diffphase = phase_cross_correlation(reference_image, moving_image, overlap_ratio=0.8)
                shift = [s/down_r for s in shift]

                tz, tx, ty = shift
#                 print(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0])
#                 tz = tz + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0])
#                 tx = tx + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1])
#                 ty = ty + np.mean(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2])
#                 tz = tz + sum(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][0])
#                 tx = tx + sum(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][1])
#                 ty = ty + sum(tform_stack_coarse_colrow[f'{i+pi}-{j+pj}'][2])
                shift = [tz, tx, ty]
                print(datetime.now(), f"done phase cross correlation, shift: {shift}, error: {error:.6f}, dphase: {diffphase:.6f}")
                tform_stack_coarse_colrow[f'{i}-{j}'][0].append(tz)
                tform_stack_coarse_colrow[f'{i}-{j}'][1].append(tx)
                tform_stack_coarse_colrow[f'{i}-{j}'][2].append(ty)
                
                break

print(tform_stack_coarse_colrow)

2024-07-20 11:54:10.780035 start phase cross correlation mov (0,1) to (0,0)
2024-07-20 11:54:48.491996 done phase cross correlation, shift: [-2.0, 0.0, -20.0], error: 1.000000, dphase: 0.000000
2024-07-20 11:54:52.643842 start phase cross correlation mov (0,2) to (0,1)
2024-07-20 11:55:18.579352 done phase cross correlation, shift: [-2.0, 4.0, -24.0], error: 1.000000, dphase: 0.000000
2024-07-20 11:55:23.586519 start phase cross correlation mov (0,3) to (0,2)
2024-07-20 11:55:49.901431 done phase cross correlation, shift: [-2.0, -6.0, -10.0], error: 1.000000, dphase: 0.000000
2024-07-20 11:55:54.864749 start phase cross correlation mov (0,4) to (0,3)
2024-07-20 11:56:20.865912 done phase cross correlation, shift: [0.0, -8.0, 2.0], error: 1.000000, dphase: -0.000000
2024-07-20 11:56:25.771668 start phase cross correlation mov (1,1) to (1,0)
2024-07-20 11:56:51.791124 done phase cross correlation, shift: [0.0, 0.0, -16.0], error: 1.000000, dphase: 0.000000
2024-07-20 11:56:56.706638 star

In [15]:
tform_stack_coarse_colrow

{'0-0': [[0], [0], [0]],
 '0-1': [[-2.0], [0.0], [-20.0]],
 '0-2': [[-2.0], [4.0], [-24.0]],
 '0-3': [[-2.0], [-6.0], [-10.0]],
 '0-4': [[0.0], [-8.0], [2.0]],
 '1-0': [[0, 2.0], [0, -22.0], [0, -14.0]],
 '1-1': [[0.0, 2.0], [0.0, -24.0], [-16.0, -8.0]],
 '1-2': [[0.0, 0.0], [4.0, -16.0], [-26.0, 0.0]],
 '1-3': [[0.0, 4.0], [-6.0, -26.0], [-12.0, 8.0]],
 '1-4': [[0.0, 2.0], [-6.0, -22.0], [0.0, 10.0]],
 '2-0': [[0, 2.0], [0, -10.0], [0, -8.0]],
 '2-1': [[0.0, 2.0], [2.0, -12.0], [-14.0, -2.0]],
 '2-2': [[0.0, 4.0], [6.0, -14.0], [-26.0, 0.0]],
 '2-3': [[0.0, 4.0], [-4.0, -18.0], [-10.0, 6.0]],
 '2-4': [[0.0, 2.0], [-6.0, -18.0], [0.0, 10.0]],
 '3-0': [[0, 2.0], [0, -8.0], [0, -6.0]],
 '3-1': [[0.0, 2.0], [2.0, -10.0], [-14.0, -4.0]],
 '3-2': [[0.0, 2.0], [4.0, -14.0], [-24.0, 2.0]],
 '3-3': [[0.0, 2.0], [-4.0, -14.0], [-12.0, 4.0]],
 '3-4': [[0.0, 2.0], [-6.0, -12.0], [0.0, 10.0]]}

In [16]:
np.array([[0], [0], [0]]).shape

(3, 1)

In [48]:
def fuse_colrow_tform(tform_colrow):
    tform = {}
    ## accumulate col
    rows = []
    cols = []
    for k in tform_colrow:
        i, j = k.split('-')
        i, j = int(i), int(j)
        tform[k] = np.array(tform_colrow[k])[:, 0]
        tform[k] = tform[k] + (tform[f'{i}-{j-1}'] if j>0 else 0)
        if i not in rows: rows.append(i)
        if j not in cols: cols.append(j)

    ## average and accumulate row
    avg = {}
    for i in rows:
        if i>0: 
            avg[i] = []
            for j in cols:
                k = f'{i}-{j}'
                avg[i].append(np.array(tform_colrow[k])[:, 1])

            avg[i] = np.stack(avg[i]).mean(0) + avg[i-1]
        else:
            avg[i] = 0
            
    for i in rows:
        for j in cols:
            k = f'{i}-{j}'
            tform[k] = tform[k] + avg[i]
    
    for k in tform:
        tform[k] = [float(t) for t in tform[k]]
        
    return tform

In [49]:
zstart = 0
tformed_tile_lt_loc = {zstart: copy.deepcopy(tile_lt_loc)}
tform_stack_coarse = fuse_colrow_tform(tform_stack_coarse_colrow)
for k in tform_stack_coarse_colrow:
#     tz, tx, ty = np.array(tform_stack_coarse_colrow[k])[:, -1]
#     tz, tx, ty = np.array(tform_stack_coarse_colrow[k]).sum(1)
#     tz, tx, ty = np.array(tform_stack_coarse_colrow[k]).mean(1)
    tz, tx, ty = tform_stack_coarse[k]
#     tform_stack_coarse[k] = [float(tz), float(tx), float(ty)]
    tformed_tile_lt_loc[0][k][0] = tformed_tile_lt_loc[0][k][0] + tx
    tformed_tile_lt_loc[0][k][1] = tformed_tile_lt_loc[0][k][1] + ty
    print(k, tformed_tile_lt_loc[zstart][k], tform_stack_coarse[k])


0-0 [tensor(0.), tensor(0.)] [0.0, 0.0, 0.0]
0-1 [tensor(0.), tensor(1585.6000)] [-2.0, 0.0, -20.0]
0-2 [tensor(4.), tensor(3167.2000)] [-4.0, 4.0, -44.0]
0-3 [tensor(-2.), tensor(4762.8003)] [-6.0, -2.0, -54.0]
0-4 [tensor(-10.), tensor(6370.3999)] [-6.0, -10.0, -52.0]
1-0 [tensor(1881.2001), tensor(-0.8000)] [2.0, -22.0, -0.8]
1-1 [tensor(1881.2001), tensor(1588.7999)] [2.0, -22.0, -16.8]
1-2 [tensor(1885.2001), tensor(3168.3999)] [2.0, -18.0, -42.8]
1-3 [tensor(1879.2001), tensor(4762.0005)] [2.0, -24.0, -54.8]
1-4 [tensor(1873.2001), tensor(6367.6001)] [2.0, -30.0, -54.8]
2-0 [tensor(3770.0002), tensor(0.4000)] [4.8, -36.4, 0.3999999999999999]
2-1 [tensor(3772.0002), tensor(1592.)] [4.8, -34.4, -13.6]
2-2 [tensor(3778.0002), tensor(3171.5999)] [4.8, -28.4, -39.6]
2-3 [tensor(3774.0002), tensor(4767.2002)] [4.8, -32.4, -49.6]
2-4 [tensor(3768.0002), tensor(6372.7998)] [4.8, -38.4, -49.6]
3-0 [tensor(5661.6001), tensor(1.6000)] [6.8, -48.0, 1.5999999999999999]
3-1 [tensor(5663.6001),

In [50]:

with open(f'{save_path}/NIS_tranform/{stack_name.replace("instance_center.zip", "tform_coarse.json")}', 'w', encoding='utf-8') as f:
    json.dump(tform_stack_coarse, f, ensure_ascii=False, indent=4)

tform_stack_coarse

{'0-0': [0.0, 0.0, 0.0],
 '0-1': [-2.0, 0.0, -20.0],
 '0-2': [-4.0, 4.0, -44.0],
 '0-3': [-6.0, -2.0, -54.0],
 '0-4': [-6.0, -10.0, -52.0],
 '1-0': [2.0, -22.0, -0.8],
 '1-1': [2.0, -22.0, -16.8],
 '1-2': [2.0, -18.0, -42.8],
 '1-3': [2.0, -24.0, -54.8],
 '1-4': [2.0, -30.0, -54.8],
 '2-0': [4.8, -36.4, 0.3999999999999999],
 '2-1': [4.8, -34.4, -13.6],
 '2-2': [4.8, -28.4, -39.6],
 '2-3': [4.8, -32.4, -49.6],
 '2-4': [4.8, -38.4, -49.6],
 '3-0': [6.8, -48.0, 1.5999999999999999],
 '3-1': [6.8, -46.0, -12.4],
 '3-2': [6.8, -42.0, -36.4],
 '3-3': [6.8, -46.0, -48.4],
 '3-4': [6.8, -52.0, -48.4]}

In [51]:
tform_stack_coarse = json.load(open(f'{save_path}/NIS_tranform/{stack_name.replace("instance_center.zip", "tform_coarse.json")}', 'r', encoding='utf-8'))

tform_stack_coarse

{'0-0': [0.0, 0.0, 0.0],
 '0-1': [-2.0, 0.0, -20.0],
 '0-2': [-4.0, 4.0, -44.0],
 '0-3': [-6.0, -2.0, -54.0],
 '0-4': [-6.0, -10.0, -52.0],
 '1-0': [2.0, -22.0, -0.8],
 '1-1': [2.0, -22.0, -16.8],
 '1-2': [2.0, -18.0, -42.8],
 '1-3': [2.0, -24.0, -54.8],
 '1-4': [2.0, -30.0, -54.8],
 '2-0': [4.8, -36.4, 0.3999999999999999],
 '2-1': [4.8, -34.4, -13.6],
 '2-2': [4.8, -28.4, -39.6],
 '2-3': [4.8, -32.4, -49.6],
 '2-4': [4.8, -38.4, -49.6],
 '3-0': [6.8, -48.0, 1.5999999999999999],
 '3-1': [6.8, -46.0, -12.4],
 '3-2': [6.8, -42.0, -36.4],
 '3-3': [6.8, -46.0, -48.4],
 '3-4': [6.8, -52.0, -48.4]}

In [52]:
'''
Apply coarse to img
'''
for k in tqdm(tile_overlap, desc='Apply coarse to img'):
    for i, img in enumerate(tile_overlap[k]):
        tile_overlap[k][i] = shift_image(torch.from_numpy(img.astype(float)), tform_stack_coarse[k]).numpy()

tile_overlap_coarse = tile_overlap

Apply coarse to img: 100%|██████████████████████████████████████████████████████████████| 20/20 [08:58<00:00, 26.91s/it]


In [53]:
tile_overlap = {}

In [54]:
for k in tile_overlap_coarse:
    print([i.shape for i in tile_overlap_coarse[k]])

[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]
[(1048, 476, 2007), (1048, 476, 2007), (1048, 2379, 402), (1048, 2379, 402)]

In [55]:
'''
Refine 2D by PCC
'''
tform_stack_refine_colrow = [{} for _ in range(zstart, zend)]
neighbor_row = [[0, -1], [0, 1]]
neighbor_col = [[-1, 0], [1, 0]]
down_r = 1
max_shift = 0.05
for i in range(ncol):
    j = 0
    for zi in range(zstart, zend):
        tform_stack_refine_colrow[zi][f'{i}-{j}'] = [[0], [0]]
    
    for j in range(1, nrow):
        moving_image = tile_overlap_coarse[f'{i}-{j}']
        assert moving_image[0].shape[0] == zend-zstart, moving_image[0].shape[0]
        pre_shift = [0, 0]
        for zi in trange(moving_image[0].shape[0]-1, -1, -1, desc=f"mov ({i},{j})"):
#         for zi in trange(moving_image[0].shape[0], desc=f"mov ({i},{j})"):
#             if moving_image[0][zi].max() < 600: continue
#             if zi == 900: break
            if f'{i}-{j}' not in tform_stack_refine_colrow[zi]:
                tform_stack_refine_colrow[zi][f'{i}-{j}'] = [[], []]
                
            for pi, pj in neighbor_row:
                if f'{i+pi}-{j+pj}' in tform_stack_refine_colrow[zi]: 
                    reference_image = tile_overlap_coarse[f'{i+pi}-{j+pj}']
                    
                    if pj < 0:
                        moving_image2d = moving_image[2]
                        reference_image2d = reference_image[3]
                    elif pj > 0:
                        moving_image2d = moving_image[3]
                        reference_image2d = reference_image[2]
                    
                    reference_image2d = reference_image2d[zi].astype(float)
                    moving_image2d = moving_image2d[zi].astype(float)
                    moving_image2d = torch.nn.functional.interpolate(torch.from_numpy(moving_image2d)[None, None], scale_factor=down_r)[0,0].numpy()
                    reference_image2d = torch.nn.functional.interpolate(torch.from_numpy(reference_image2d)[None, None], scale_factor=down_r)[0,0].numpy()
#                     print(datetime.now(), f"start phase cross correlation mov ({i},{j}) to ({i+pi},{j+pj})")
#                     print(reference_image2d.shape, moving_image2d.shape)
                    shift, error, diffphase = phase_cross_correlation(reference_image2d, moving_image2d, overlap_ratio=0.99)
                    if shift[0] > reference_image2d.shape[0]*max_shift or shift[1] > reference_image2d.shape[1]*max_shift:
                        shift = pre_shift
                    else:
                        shift = [s/down_r for s in shift]

                    tx, ty = shift
#                     tx = tx + tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][0][0]
#                     ty = ty + tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][1][0]
                    shift = [tx, ty]
#                     print(datetime.now(), f"done phase cross correlation, shift: {shift}, error: {error:.6f}, dphase: {diffphase:.6f}")
                    tform_stack_refine_colrow[zi][f'{i}-{j}'][0].append(tx)
                    tform_stack_refine_colrow[zi][f'{i}-{j}'][1].append(ty)
                    pre_shift = shift
                    break
                    
# print(tform_stack_refine_colrow)             
                
for j in range(nrow):
    for i in range(1, ncol):
        moving_image = tile_overlap_coarse[f'{i}-{j}']
        for zi in trange(moving_image[0].shape[0]-1, -1, -1, desc=f"mov ({i},{j})"):
#         for zi in trange(moving_image[0].shape[0], desc=f"mov ({i},{j})"):
#             if moving_image[0][zi].max() < 700: continue
#             if zi == 900: break
            pre_shift = [0, 0]
            assert f'{i}-{j}' in tform_stack_refine_colrow[zi], f'{i}-{j}'
            for pi, pj in neighbor_col:
                if f'{i+pi}-{j+pj}' in tform_stack_refine_colrow[zi]: 
                    reference_image = tile_overlap_coarse[f'{i+pi}-{j+pj}']
                    if pi < 0:
                        moving_image2d = moving_image[0]
                        reference_image2d = reference_image[1]
                    elif pi > 0:
                        moving_image2d = moving_image[1]
                        reference_image2d = reference_image[0]

                    reference_image2d = reference_image2d[zi].astype(float)
                    moving_image2d = moving_image2d[zi].astype(float)
                    moving_image2d = torch.nn.functional.interpolate(torch.from_numpy(moving_image2d)[None, None], scale_factor=down_r)[0,0].numpy()
                    reference_image2d = torch.nn.functional.interpolate(torch.from_numpy(reference_image2d)[None, None], scale_factor=down_r)[0,0].numpy()
#                     print(datetime.now(), f"start phase cross correlation mov ({i},{j}) to ({i+pi},{j+pj})")
                    shift, error, diffphase = phase_cross_correlation(reference_image2d, moving_image2d, overlap_ratio=0.99)
                    if shift[0] > reference_image2d.shape[0]*max_shift or shift[1] > reference_image2d.shape[1]*max_shift:
                        shift = pre_shift
                    else:
                        shift = [s/down_r for s in shift]

                    tx, ty = shift
#                     tx = tx + np.mean(tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][0])
#                     ty = ty + np.mean(tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][1])
#                     tx = tx + sum(tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][0])
#                     ty = ty + sum(tform_stack_refine_colrow[zi][f'{i+pi}-{j+pj}'][1])
                    shift = [tx, ty]
#                     print(datetime.now(), f"done phase cross correlation, shift: {shift}, error: {error:.6f}, dphase: {diffphase:.6f}")
                    tform_stack_refine_colrow[zi][f'{i}-{j}'][0].append(tx)
                    tform_stack_refine_colrow[zi][f'{i}-{j}'][1].append(ty)
                    pre_shift = shift
                    break

# print(tform_stack_coarse_colrow)

  (src_amp * target_amp)
mov (0,1): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:21<00:00,  7.40it/s]
mov (0,2): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:22<00:00,  7.38it/s]
mov (0,3): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:24<00:00,  7.27it/s]
mov (0,4): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:22<00:00,  7.37it/s]
mov (1,1): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:20<00:00,  7.48it/s]
mov (1,2): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:21<00:00,  7.40it/s]
mov (1,3): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:22<00:00,  7.38it/s]
mov (1,4): 100%|████████████████████████████████████████████████████████████████████| 1048/1048 [02:22<00:00,  7.36it/s]
mov (2,

In [56]:
len(tformed_tile_lt_loc)

1

In [57]:
# tformed_tile_lt_loc = {zstart: copy.deepcopy(tile_lt_loc)}
tform_stack_refine = [{} for _ in range(len(tform_stack_refine_colrow))]
tformed_tile_lt_loc_refine = {zi: copy.deepcopy(tformed_tile_lt_loc[0]) for zi in range(zstart, zend)}
for zi in range(moving_image[0].shape[0]-1, -1, -1):
#     if zi == 900: break
# for zi, tform_stack_refine_colrow_zi in enumerate(tform_stack_refine_colrow):
    tform_stack_refine[zi] = fuse_colrow_tform(tform_stack_refine_colrow[zi])
    for k in tform_stack_refine_colrow[zi]:
#         tx, ty = np.array(tform_stack_refine_colrow[zi][k])[:, -1]
#         tx, ty = np.array(tform_stack_refine_colrow[zi][k]).sum(1)
#         tx, ty = np.array(tform_stack_refine_colrow[zi][k]).mean(1)
#         tform_stack_refine[zi][k] = [tform_stack_coarse[k][0], float(tx), float(ty)]
        tx, ty = tform_stack_refine[zi][k]
        tformed_tile_lt_loc_refine[zi][k][0] = tformed_tile_lt_loc_refine[zi][k][0] + tx
        tformed_tile_lt_loc_refine[zi][k][1] = tformed_tile_lt_loc_refine[zi][k][1] + ty
#         print(k, tformed_tile_lt_loc_refine[zi][k], tform_stack_refine[zi][k])


In [None]:

with open(f'{save_path}/NIS_tranform/{stack_name.replace("instance_center.zip", "tform_refine.json")}', 'w', encoding='utf-8') as f:
    json.dump(tform_stack_refine, f, ensure_ascii=False, indent=4)

tform_stack_refine[950]

In [59]:

class image_stitch_QCer:
    def __init__(self, seg_shape, overlap_w, overlap_h):
        x_grad1 = torch.cat([torch.arange(overlap_w).float() / overlap_w, 
            torch.ones(seg_shape[1]-2*overlap_w),
            torch.arange(overlap_w, 0, -1) / overlap_w])

        y_grad1 = torch.cat([torch.arange(overlap_h).float() / overlap_h, 
            torch.ones(seg_shape[2]-2*overlap_h),
            torch.arange(overlap_h, 0, -1) / overlap_h])

        self.mov_adaptive_grad = x_grad1[:, None] * y_grad1[None, :]

        x_grad2 = torch.cat([torch.arange(overlap_w, 0, -1) / overlap_w, 
            torch.zeros(seg_shape[1]-2*overlap_w),
            torch.arange(overlap_w).float() / overlap_w])

        y_grad2 = torch.cat([torch.arange(overlap_h, 0, -1) / overlap_h, 
            torch.zeros(seg_shape[2]-2*overlap_h),
            torch.arange(overlap_h).float() / overlap_h])

        self.tgt_adaptive_grad = x_grad2[:, None] * y_grad1[None, :] + x_grad1[:, None] * y_grad2[None, :] + x_grad2[:, None] * y_grad2[None, :]

    def adaptive_image_stitch(self, mov_overlap, tgt_overlap, overlap_mask):
        overlap_coord = np.where(overlap_mask)
        if len(overlap_coord) == 3:
            _, ox, oy = overlap_coord
        else:
            ox, oy = overlap_coord
            
        mov = self.mov_adaptive_grad[ox, oy].numpy() * mov_overlap
        tgt = self.tgt_adaptive_grad[ox, oy].numpy() * tgt_overlap
        return (mov + tgt)
        

In [62]:
tform_stack_coarse

{'0-0': [0.0, 0.0, 0.0],
 '0-1': [-2.0, 0.0, -20.0],
 '0-2': [-4.0, 4.0, -44.0],
 '0-3': [-6.0, -2.0, -54.0],
 '0-4': [-6.0, -10.0, -52.0],
 '1-0': [2.0, -22.0, -0.8],
 '1-1': [2.0, -22.0, -16.8],
 '1-2': [2.0, -18.0, -42.8],
 '1-3': [2.0, -24.0, -54.8],
 '1-4': [2.0, -30.0, -54.8],
 '2-0': [4.8, -36.4, 0.3999999999999999],
 '2-1': [4.8, -34.4, -13.6],
 '2-2': [4.8, -28.4, -39.6],
 '2-3': [4.8, -32.4, -49.6],
 '2-4': [4.8, -38.4, -49.6],
 '3-0': [6.8, -48.0, 1.5999999999999999],
 '3-1': [6.8, -46.0, -12.4],
 '3-2': [6.8, -42.0, -36.4],
 '3-3': [6.8, -46.0, -48.4],
 '3-4': [6.8, -52.0, -48.4]}

In [66]:
'''
apply stitch to LS image for QC
'''
tform_xy_max = [0.05*seg_shape[1], 0.05*seg_shape[2]]

tile_w = int(seg_shape[1].item()*(1-overlap_r))
tile_h = int(seg_shape[2].item()*(1-overlap_r))
overlap_w = int(seg_shape[1].item()*overlap_r)
overlap_h = int(seg_shape[2].item()*overlap_r)
max_pixel = None
image_stitcher = image_stitch_QCer(seg_shape, overlap_w, overlap_h)
os.makedirs(f'{save_path}/LS_image_stitched', exist_ok=True)
pre_startx, pre_starty = {}, {}
for zi in tqdm(tformed_tile_lt_loc_refine, desc='Apply transformation'):
# for zi in trange(len(tformed_tile_lt_loc_refine)-1, -1, -1, desc='Apply transformation'):
#     if zi > 450: continue
#     if zi == 900: break
    valid_wsi = True
    for ijstr in tformed_tile_lt_loc_refine[zi]:
        i, j = ijstr.split('-')
        i, j = int(i), int(j)
        tz = tform_stack_coarse[f'{i}-{j}'][0]
        tz = int(np.around(tz))
        wsii = zi - tz
        valid_wsi = wsii>=0 and wsii<len(tformed_tile_lt_loc_refine)
        if not valid_wsi: break
        
    if not valid_wsi: continue
    wsi = np.zeros((tile_w*ncol+overlap_w+1, tile_h*nrow+overlap_h+1), dtype=np.float32)
    for ijstr in tformed_tile_lt_loc_refine[zi]:
        if ijstr not in pre_startx:
            pre_startx[ijstr] = None
            pre_starty[ijstr] = None
        
        i, j = ijstr.split('-')
        i, j = int(i), int(j)
            
        tz = tform_stack_coarse[f'{i}-{j}'][0]
        tz = int(np.around(tz))
        wsii = zi - tz

        fn = f'{fn_}_UltraII[{i:02d} x {j:02d}]_C01_xyz-Table Z{wsii:04d}.ome.tif'
        tile_img = Image.open(f'{ls_image_root}/{fn}')
        tile_img_ = np.asarray(tile_img)
        max_pix = np.percentile(tile_img_, 99)
        if max_pixel is None: max_pixel = max_pix
        max_pixel = max(max_pix, max_pixel)
        tile_img = (tile_img_-tile_img_.min())/(max_pixel-tile_img_.min())
        
        if abs(tform_stack_refine[wsii][ijstr][0]) > tform_xy_max[0] or abs(tform_stack_refine[wsii][ijstr][1]) > tform_xy_max[1]:
            if pre_startx[ijstr] is not None: 
                startx, starty = pre_startx[ijstr], pre_starty[ijstr]
            else:
                for zii in range(wsii, len(tform_stack_refine)):
                    if abs(tform_stack_refine[zii][ijstr][0]) <= tform_xy_max[0] and abs(tform_stack_refine[zii][ijstr][1]) <= tform_xy_max[1]:
                        startx, starty = tformed_tile_lt_loc_refine[zii + tz][ijstr]
                        break
        else:
            startx, starty = tformed_tile_lt_loc_refine[zi][ijstr]
        
        pre_startx[ijstr] = startx
        pre_starty[ijstr] = starty

        startx, endx = int(startx), tile_img.shape[0] + int(startx)
        starty, endy = int(starty), tile_img.shape[1] + int(starty)
        if startx < 0:
            tile_img = tile_img[-startx:]
            startx = 0
        if starty < 0:
            tile_img = tile_img[:, -starty:]
            starty = 0
        if endx > wsi.shape[0]: 
            tile_img = tile_img[:-(endx-wsi.shape[0])]
            endx = wsi.shape[0]

        if endy > wsi.shape[1]: 
            tile_img = tile_img[:, :-(endy-wsi.shape[1])]
            endy = wsi.shape[1]

        cur_tile = wsi[startx:endx, starty:endy]
        fg_mask = cur_tile>0
        if fg_mask.any():
            tile_img[fg_mask] = image_stitcher.adaptive_image_stitch(tile_img[fg_mask], cur_tile[fg_mask], fg_mask)

        wsi[startx:endx, starty:endy] = tile_img
    
    save_fn = f'{save_path}/LS_image_stitched/{btag.split("_")[1]}_TOPRO_C01_Z{zi:04d}.ptreg_stitch.tif'
    Image.fromarray(wsi).save(save_fn)

        
    

Apply transformation: 100%|███████████████████████████████████████████████████████| 1048/1048 [1:38:37<00:00,  5.65s/it]


In [None]:
tform_stack_refine[zi][ijstr]