# invivo to exvivo mapping, using napari. 
- key hint:
    - s: scaling --> learn the scale and adjust the images
    - t: transform --> rigid transformation
    - r: reverse the previous transfrom
    - c: clear --> clear all selection of matched points.
    - d: deformable --> apply deformable for the matched points.
    - q: quit --> well it doesnt quit the gui, it just save the results so dont press this unless you are done. dont forget to save!

In [6]:

from datetime import date
import tifffile
import napari, pprint, sys
import numpy as np
import pickle
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import scipy.ndimage as ndi
from datetime import date
today = date.today()
todaystamp = today.strftime("%m%d%y")
print(todaystamp)
sys.path.append('./functions/')
import transformations

070924


In [None]:
datadir= '/Users/shuonan/Dropbox/project/multimodal_2023/xinxindata/G130/'
base_dir='/Users/shuonan/Dropbox/project/multimodal_2023/xinxindata/G130/transformed/'

In [44]:
invivo_f = datadir + 'invivo_2xds.pkl'
exvivo_f = datadir + 'exvivo_2xds.pkl'

invivo = pickle.load(open(invivo_f,'rb'))
exvivo = pickle.load(open(exvivo_f,'rb'))[:,::-1]  #144um, because each z is 1.5 um!

invivo[invivo<=0]=0
exvivo[exvivo<=0]=0
print(invivo.shape, exvivo.shape)

(2, 269, 784, 733) (2, 48, 1517, 1618)


In [66]:
channel_name = ["BV", 'tomatos']
C = len(channel_name)

features = {'label': np.empty(0, dtype=int)}
all_transform = []
viewer = napari.view_image(exvivo,channel_axis=0,
                            name=[f'exvivo {c}' for c in channel_name],
                            colormap="green", 
                            visible = [True,False],
                           contrast_limits=[(0,exvivo[0].max()*1),(0,exvivo[1].max()*1)],
                           blending = 'additive',
                           gamma = .5,)
in_layer = viewer.add_image(invivo,channel_axis=0,
                           name=[f'invivo {c}' for c in channel_name],
                           colormap='red', 
                           visible = [True,False],
                          contrast_limits=[(0,invivo[0].max()*1),(0,invivo[1].max()*1)],
                           blending = 'additive',
                           gamma = .5,)
pl_in = viewer.add_points(size=10, edge_width=1, edge_color='red',face_color='transparent',name=f"invivo points", text='label', features=features,ndim=3,out_of_slice_display=True)

@pl_in.events.data.connect
def update_feature_default_invivo():  
    global points_layer
    no_of_points = len(pl_in.data)
    pl_in.feature_defaults['label'] = no_of_points + 1
    pl_in.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_in.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_in.text.color = 'red'
    pl_in.text.translation = np.array([-10, 0])
update_feature_default_invivo()
pl_in.mode = 'add'

pl_ex = viewer.add_points(
    size=10, edge_width=1, edge_color='green',face_color='transparent',name=f"exvivo points", text='label', features=features,ndim=3,out_of_slice_display=True)
@pl_ex.events.data.connect
def update_feature_default_exvivo():  
    global points_layer
    no_of_points = len(pl_ex.data)
    pl_ex.feature_defaults['label'] = no_of_points + 1
    pl_ex.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_ex.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_ex.text.color = 'green'
    pl_ex.text.translation = np.array([-10, 0])
update_feature_default_exvivo()
pl_ex.mode = 'add'


@viewer.bind_key('s', overwrite = True)
def scale(viewer):  
    print('learning scale..')
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data
    assert len(m_invivo)==len(m_exvivo), f"{len(m_invivo), len(m_exvivo)}"
    s_invivo = np.sqrt(np.sum((m_invivo-np.mean(m_invivo,0))**2)/len(m_invivo))
    s_exvivo = np.sqrt(np.sum((m_exvivo-np.mean(m_exvivo,0))**2)/len(m_exvivo))
    scl = s_exvivo/s_invivo
    print(f'scale difference (ex/in): {scl}')
    m_invivo_new = scl*m_invivo    
    invivo_scaled_small = np.array([ndi.zoom(in_layer[c].data, scl, order=1) for c in range(C)])    
    all_transform.append(dict(scale=scl))
    for c in range(C):
        in_layer[c].data = invivo_scaled_small[c]
    viewer.layers["invivo points"].data = m_invivo_new    
    

    
@viewer.bind_key('t', overwrite = True)
def transform(viewer):    
    viewer.update_console(locals())
    print('applying rigid..')
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data 
    assert len(m_invivo)==len(m_exvivo), f"{len(m_invivo), len(m_exvivo)}"
    bhat = transformations.wahba(m_invivo,m_exvivo)
    offset = -(bhat[:3,:3])@bhat[-1]
    invivo_affined = np.array([ndi.affine_transform(in_layer[c].data, bhat[:3,:3],
                                                    output_shape = viewer.layers['exvivo tomatos'].data.shape,offset = offset, order=1) for c in range(C)])
    foo = np.c_[m_invivo, np.ones((m_invivo.shape[0],1))]   #Nx4
    m_invivo_new = foo@bhat
    all_transform.append(dict(bhat=bhat))
    for c in range(C):
        in_layer[c].data = invivo_affined[c]
    viewer.layers["invivo points"].data = m_invivo_new
    print("done transform!")
    
@viewer.bind_key('r', overwrite = True)
def inv_transform(viewer):
    print('inverting the process......')
    global all_transform
    viewer.update_console(locals())
    assert list(all_transform[-1].keys())[0]=='bhat', 'can only invert affine transformations!'
    lastbhat = all_transform[-1]['bhat']
    m_invivo = viewer.layers[f"invivo points"].data    
    invivo_inv_affined = np.array([ndi.affine_transform(in_layer[c].data, lastbhat[:3,:3].T,
                                                    output_shape = viewer.layers['exvivo tomatos'].data.shape,
                                                    offset = lastbhat[-1], order=1) for c in range(C)])    
    for c in range(C):
        in_layer[c].data = invivo_inv_affined[c]
    m_invivo_new = (np.linalg.inv(lastbhat[:3,:3].T)@m_invivo.T - np.linalg.inv(lastbhat[:3,:3].T)@lastbhat[-1][:,None]).T
    viewer.layers["invivo points"].data = m_invivo_new
    all_transform=all_transform[:-1]   
    print('inverted the transformation!')
    
    
        
    
@viewer.bind_key('c',overwrite = True)
def clear_selected(viewer):        
    viewer.layers[f"invivo points"].data = np.empty((0, 3))
    viewer.layers[f"exvivo points"].data = np.empty((0, 3))

@viewer.bind_key('q', overwrite = True)
def save_rez(viewer):     
    out_dict = dict()
    out_dict['transformed']=np.array([in_layer[c].data for c in range(C)] )
    out_dict['transformations']=all_transform    
    
    m_invivo = viewer.layers[f"invivo points"].data
    m_exvivo = viewer.layers[f"exvivo points"].data 
    assert len(m_invivo)==len(m_exvivo), f"{len(m_invivo), len(m_exvivo)}"
    out_dict['pcd_invivo'] = m_invivo
    out_dict['pcd_exvivo'] = m_exvivo
    pickle.dump(out_dict, open(base_dir + f'transformed_{todaystamp}.pkl','wb'))
    print("saved the results!")
    

learning scale..
scale difference (ex/in): 0.9785990833839613
applying rigid..
done transform!
applying rigid..
done transform!
learning scale..
scale difference (ex/in): 1.0006087714288114
applying rigid..
done transform!
saved the results!
