In [2]:
from datetime import date
import tifffile
import napari
import os,glob
import pprint
import numpy as np
import pickle
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import scipy.ndimage as ndi
# from skimage.transform import warp


from datetime import date
today = date.today()
todaystamp = today.strftime("%m%d%y")
print(todaystamp)


def wahba(X,Y):
    ''' 3d version'''
    X0=X-np.mean(X,axis=0)    
    Y0=Y-np.mean(Y,axis=0)    
    U, _, Vt = np.linalg.svd(X0.T@Y0)
    V = Vt.T
    M = np.eye(3)
    M[-1,:] = np.array((0,0,np.linalg.det(U)*np.linalg.det(V)))
    R = U@M@V.T
    T = np.mean(Y-X@R, axis = 0)
    return(np.r_[R,T[None,:]])  # return 4x3 matrix (to transform from first input to second inpupt)

100123


# load data

In [25]:
invivo_f = '/Users/shuonan/Dropbox/project/multimodal_2023/xinxindata/G128_data/G128_invivos.pkl'
exvivo_f = '/Users/shuonan/Dropbox/project/multimodal_2023/xinxindata/G128_data/G128_exvivos.pkl'

invivo = pickle.load(open(invivo_f,'rb'))
 #144um, because each z is 1.5 um!

In [31]:
exvivo.shape, invivo.shape

((2, 72, 1147, 1464), (2, 222, 768, 1035))

# specify save directory 

In [32]:
base_dir='/Users/shuonan/Downloads/'

# run 

In [33]:
channel_name = ["BV", 'tomatos']
C = len(channel_name)

features = {'label': np.empty(0, dtype=int)}
all_transform = []
# foo_l = []
# foo_l.append(cent_ex)

viewer = napari.view_image(invivo,channel_axis=0,name=[f'invivo {c}' for c in channel_name],
                           colormap='red', visible = [True,False],)
ex_layer = viewer.add_image(exvivo,channel_axis=0,name=[f'exvivo {c}' for c in channel_name],
                            colormap="green", visible = [True,False],)
pl_in = viewer.add_points(size=10, edge_width=1, edge_color='red',face_color='transparent',name=f"invivo points", text='label', features=features,ndim=3,out_of_slice_display=True)

@pl_in.events.data.connect
def update_feature_default_invivo():  
    global points_layer
    no_of_points = len(pl_in.data)
    pl_in.feature_defaults['label'] = no_of_points + 1
    pl_in.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_in.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_in.text.color = 'red'
    pl_in.text.translation = np.array([-10, 0])
update_feature_default_invivo()
pl_in.mode = 'add'

pl_ex = viewer.add_points(
    size=10, edge_width=1, edge_color='green',face_color='transparent',name=f"exvivo points", text='label', features=features,ndim=3,out_of_slice_display=True)
@pl_ex.events.data.connect
def update_feature_default_exvivo():  
    global points_layer
    no_of_points = len(pl_ex.data)
    pl_ex.feature_defaults['label'] = no_of_points + 1
    pl_ex.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_ex.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_ex.text.color = 'green'
    pl_ex.text.translation = np.array([-10, 0])
update_feature_default_exvivo()
pl_ex.mode = 'add'


@viewer.bind_key('s', overwrite = True)
def scale(viewer):  
    print('learning scale..')
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data 
    assert(len(m_invivo)==len(m_exvivo))    
    s_invivo = np.sqrt(np.sum((m_invivo-np.mean(m_invivo,0))**2)/len(m_invivo))
    s_exvivo = np.sqrt(np.sum((m_exvivo-np.mean(m_exvivo,0))**2)/len(m_exvivo))
    scl = s_invivo/s_exvivo
    print(f'scale difference: {scl}')
    m_exvivo_new = scl*m_exvivo    
#     foo_l[-1] = scl*foo_l[-1]
    exvivo_scaled_small = np.array([ndi.zoom(ex_layer[c].data, scl, order=3) for c in range(C)])    
    all_transform.append(dict(scale=scl))
    for c in range(C):
        ex_layer[c].data = exvivo_scaled_small[c]
    viewer.layers["exvivo points"].data = m_exvivo_new    
    

    
@viewer.bind_key('t', overwrite = True)
def transform(viewer):    
    viewer.update_console(locals())
    print('applying rigid..')
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data 
    assert(len(m_invivo)==len(m_exvivo))
    bhat = wahba(m_exvivo,m_invivo)
    offset = -(bhat[:3,:3])@bhat[-1]
    exvivo_affined = np.array([ndi.affine_transform(ex_layer[c].data, bhat[:3,:3],
                                                    output_shape = viewer.layers['invivo tomatos'].data.shape,offset = offset, order=3) for c in range(C)])
    foo = np.c_[m_exvivo, np.ones((m_exvivo.shape[0],1))]   #Nx4
    m_exvivo_new = foo@bhat
#     foo_l[-1] = np.c_[foo_l[-1], np.ones((foo_l[-1].shape[0],1))]@bhat
    all_transform.append(dict(bhat=bhat))
    for c in range(C):
        ex_layer[c].data = exvivo_affined[c]
    viewer.layers["exvivo points"].data = m_exvivo_new
    print("done transform!")
    

    

    
@viewer.bind_key('d', overwrite = True)
def deform(viewer):    
    print('applying deformable..')
    ksz = 200
    m_invivo = viewer.layers[f"matched section {section_id}"].data
    m_exvivo = viewer.layers[f"matched section {section_id+1}"].data
    assert(len(m_invivo)==len(m_exvivo))    
    shift = m_invivo-m_exvivo   # so the newcoords = old+shift --> vecfield will have how much shift needed to apply to the old coords
    vec_field = np.zeros(exvivo_layer[0].data.shape + (2,))  # M1,M2,2
    for p,loc in enumerate(m_exvivo):
        vec_field[int(loc[0]),int(loc[1])] = shift[p]
    for c in range(2):
        vec_field[...,c] = ndi.gaussian_filter(vec_field[...,c], ksz)
    A = np.zeros_like(m_exvivo)
    for p,loc in enumerate(m_exvivo):
        A[p] = vec_field[int(loc[0]),int(loc[1])]
    diag_step,_,_,_ = np.linalg.lstsq(A,shift,rcond=None)
    step = np.diag(diag_step)
    vec_field_total =vec_field*step;  # element wise. 
    all_transform.append(dict(vec_field_total=vec_field_total))
    mapx_base, mapy_base = np.meshgrid(np.arange(exvivo_layer[0].data.shape[0]),np.arange(exvivo_layer[0].data.shape[1]), indexing='ij')
    mapx=mapx_base-vec_field_total[:,:,0]
    mapy=mapy_base-vec_field_total[:,:,1]
    for c in range(C):
        img_de = warp(exvivo_layer[c].data,np.array((mapx,mapy)), order = 3)            
        exvivo_layer[c].data = img_de    
    
    m_exvivo_new = np.zeros_like(m_exvivo)  # POINTS
    for p,loc in enumerate(m_exvivo):
        new_s = vec_field_total[int(loc[0]),int(loc[1])]
        m_exvivo_new[p] = loc+new_s
    viewer.layers[f"matched section {section_id+1}"].data = m_exvivo_new
        
# @viewer.bind_key('f', overwrite=True)
# def correct_coords(viewer):
#     ex_roi_cent = foo_l[-1]
#     in_roi_cent = cent_in
#     dst_ex = distance.cdist(foo_l[-1], viewer.layers['exvivo points'].data)
#     dst_in = distance.cdist(cent_in, viewer.layers['invivo points'].data)
#     m_ex_label = np.argmin(dst_ex,0)
#     m_in_label = np.argmin(dst_in,0)
#     m_ex_loc_corrected = foo_l[-1][m_ex_label]
#     m_in_loc_corrected = cent_in[m_in_label]
#     viewer.layers['exvivo points'].data = m_ex_loc_corrected
#     viewer.layers['invivo points'].data = m_in_loc_corrected
#     print('corrected the exvivo and invivo selected points coordinates!')
    
    
@viewer.bind_key('c',overwrite = True)
def clear_selected(viewer):        
    viewer.layers[f"invivo points"].data = np.empty((0, 3))
    viewer.layers[f"exvivo points"].data = np.empty((0, 3))

@viewer.bind_key('q', overwrite = True)
def save_rez(viewer):     
    out_dict = dict()
    out_dict['transformed']=np.array([ex_layer[c].data for c in range(C)] )
    out_dict['transformations']=all_transform    
    
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data 
    assert(len(m_invivo)==len(m_exvivo))
    out_dict['pcd_invivo'] = m_invivo
    out_dict['pcd_exvivo'] = m_exvivo
    pickle.dump(out_dict, open(base_dir + f'transformed_{todaystamp}.pkl','wb'))
    print("saved the results!")
    

applying rigid..
done transform!
