# stitch the sections based on the tranformation learned on napari
- this notebook accounts for the OOM problems 

In [39]:
import pickle
import numpy as np
import glob, sys, os
import tifffile
import scipy.ndimage as ndi
from skimage.transform import warp
import matplotlib.pyplot as plt

def pad_2d_OLD(a,out_shape):
    assert(len(a.shape)==2)
    assert(len(a.shape)==len(out_shape))
    if np.product((np.array(out_shape)-np.array(a.shape))>0)==0:
        out = np.zeros(out_shape)
        expand_axes = np.where((np.array(out_shape)-np.array(a.shape))>0)[0]
        if expand_axes==0:            
            out[:a.shape[0],:] = a[:,out_shape[1]]
        if expand_axes==1:            
            out[:, :a.shape[1]] = a[out_shape[0],:]
    else:
        out = np.zeros(out_shape)
        out[:a.shape[0], :a.shape[1]] = a
    return out


def pad_2d(a,o_shape):
    assert(len(a.shape)==2)
    out = np.zeros(o_shape)
    if ((a.shape-np.array(o_shape))>=0).all():  # crop on both axis
        out = a[:o_shape[0],:o_shape[1]]
    elif ((a.shape-np.array(o_shape))<0).all():  # pad with zero on both axis    
        out[:a.shape[0],:a.shape[1]] = a
    elif a.shape[0]>=o_shape[0]: # crop on x axis. 
        assert(a.shape[1]<o_shape[1])
        out = np.zeros(o_shape)
        out[:,:a.shape[1]] = a[:o_shape[0]]
    elif a.shape[0]<o_shape[0]: # crop on y axis. 
        assert(a.shape[1]>=o_shape[1])
        out[:a.shape[0],:] = a[:,:o_shape[1]]
    return out

# set the path

In [None]:
# this is where the files are located. 
file_dir = '/home/ubuntu/largevolume2/massimo_xinxin/data/G128/exvivo/' 

# this is where the results pkl files were saved from napari. 
base_dir = '/home/ubuntu/largevolume2/massimo_xinxin/napari_rez/G128/MERFISH_transformation/'


# this is where the results will be saved. 
base_dir = '/pathtodownloadfile/napari_rez/'

# this is where the files are located. 
file_dir = '/pathtodownloadfile/tifs/'


# this is where the output will be saved (default is the same as base_dir). 
out_dir = '/pathtodownloadfile/napari_rez/'

isExist = os.path.exists(out_dir)
if not isExist:
    os.makedirs(out_dir)
    print("The new directory is created!")

# set the number of sections 

In [None]:
section_range=[1,2,3]
section_range=list(np.arange(1,13))

all_files = glob.glob(file_dir+'/*')
channel_name = [f.split('/')[-1].split('_')[1] for f in all_files]
print(channel_name)
C = len(channel_name)

# set the output shape

In [None]:
# need to get the biggest shape i guess
f=all_files[0]
shape_list = []
for f_path in glob.glob(f+f'/*section*.tif*'):        
    shape_list.append(np.array((tifffile.imread(f_path)).shape))
out_shape = np.max(np.array(shape_list),0)
print(out_shape)

# load all the transformation (from napari results)¶

In [None]:
T = dict()
for section_id in section_range[:-1]:
    rez_dir = base_dir + f'/section_{section_id}_{section_id+1}_021923.pkl'
    rez = pickle.load(open(rez_dir, 'rb'))
    T[section_id+1] = rez['transformations']
for v in T.values():
    assert(list(v[-1].keys())[0]=='vec_field_total')    

## function that applies `T`s

In [None]:
def trans(source,T_dict,o_shape):  # original when source image is of size (c,z,x,y)
    '''o_shape = (source.shape[1],) + exvivos[section_id][0].shape[1:] -- # (z,x,y)'''
    assert(len(source.shape)==4)  # this is (c,z,x,y) where c==2 for now. z can be 8,16,24....
    all_vec_f_3 = np.zeros(tuple(o_shape) + (3,))
    B = np.eye(3)
    for l in T_dict:
        for k,v in l.items():
            if k=='bhat':
                B = B@((np.c_[v, np.array((0,0,1))]))
            if k=='scale':
                B[:,:2] *= v  
            if k=='vec_field_total':
                vf = np.stack([pad_2d(v[...,c],out_shape = o_shape[1:]) for c in range(2)], -1)
                vf_3 = np.concatenate((np.zeros(vf.shape[:-1])[...,None], vf), 2)
                all_vec_f_3 += vf_3
    R_3 = np.eye(3); R_3[1:,1:] = (np.linalg.inv(B[:2,:2])).T
    offset_3 = np.zeros(3); offset_3[1:] = -B[-1,:-1]@np.linalg.inv(B[:2,:2])
    print('running rigid..')
    transformed_all = np.array([ndi.affine_transform(source[c].astype('float32'), R_3, offset = offset_3,
                                    output_shape = o_shape, order=3) for c in range(2)])
    
    mapz_base, mapx_base, mapy_base = np.meshgrid(np.arange(o_shape[0]),np.arange(o_shape[1]), np.arange(o_shape[2]),indexing='ij')
    mapz=mapz_base-all_vec_f_3[:,:,:,0]
    mapx=mapx_base-all_vec_f_3[:,:,:,1]
    mapy=mapy_base-all_vec_f_3[:,:,:,2]
    print('running deformable..')
    deformed_all = np.array([warp(transformed_all[c],np.array((mapz,mapx,mapy)), order = 3) for c in range(2)])
    return(deformed_all)

# now apply the transformation to all the sections

In [None]:
for section_id in np.arange(2,section_range[-1]+1):  #[4,3,2]
    print(f'...running for sec{section_id}...')
    all_channels_img = []
    for c in range(C):
        f = all_files[c]
        f_path = glob.glob(f+f'/*section{section_id}.tif*')
        assert(len(f_path)==1)
        img = tifffile.imread(f_path[0])        
        img = img.astype(float)
        img /= img.max()
        all_channels_img.append(img)
    source = np.stack(all_channels_img, axis = 0)
    print(source.shape)
    foo_p = source.copy()
    for sec_id_p in np.arange(2,section_id+1)[::-1]:
        print(f'...applying T({sec_id_p})...')
        foo = foo_p.copy()
        foo_p = trans(foo, T[sec_id_p], out_shape)
        plt.figure(figsize =(20,10))
        plt.subplot(1,2,1)
        plt.imshow(foo[0][4])
        plt.subplot(1,2,2)
        plt.imshow(foo_p[0][4])
    out = foo_p
    print(out.shape)
    pickle.dump(out, open(f'/home/ubuntu/largevolume2/massimo_xinxin/pkl/G126_exvivos/tranformed_sec{section_id}.pkl','wb'))


# load all transformed resutls

In [None]:
data_dict = dict()
for section_id in section_range:
    print(section_id)
    if section_id == 1:
        all_channels_img = []
        for c in range(C):
            f = all_files[c]
            f_path = glob.glob(f+f'/*section{section_id}.tif*')
            assert(len(f_path)==1)
            img = tifffile.imread(f_path[0])        
            img = img.astype(float)
            img /= img.max()
            img_foo = img.copy()
            for z in range(z.shape[0]):
                img[z]=pad_2d(img_foo[z],out_shape = out_shape[1:])
            all_channels_img.append(img)
        data_dict[f'sec{section_id}'] =  np.stack(all_channels_img, axis = 0)        
        print(data_dict[f'sec{section_id}'].shape)
    else:
        f_path = out_dir + f'/tranformed_sec{section_id}.pkl'
        data_dict[f'sec{section_id}'] = pickle.load(open(f_path,'rb'))
        print(data_dict[f'sec{section_id}'].shape)

### check teh results

In [None]:
c = 0
plt.figure(figsize = (20,20))
for i in range(1,5):
    img = data_dict[f'sec{i}'][c][4].astype(float)
    plt.subplot(2,2,i)
    plt.imshow(img, vmin = np.quantile(img, 0.5), vmax = np.quantile(img, 0.9999))
    plt.title(f'sec{i}')
plt.tight_layout()    

### stitch. 

In [None]:
vol_out_p = data_dict['sec1']
for section_id in section_range[1:]:    
    new_snow = data_dict[f'sec{section_id}'] #
    vol_out_p = np.concatenate((new_snow,vol_out_p),axis = 1)#[:,::-1]
print(vol_out_p.shape)  # the order is sec4:1-8, sec3:1-8,sec2:1-8,sec1:1-8