# napari gui for manual matchings

## what this is for 
this is for stitching problems. say we have two (or more) 3d images that are supposed to be next to each other (consecutive on z), and assume some cells in the first image can be seen in the second images. then we can use these cells to find the registration btw them. 


## features
1. we can look into multiple channels (here we use 5 channels as an example)
2. steps are: select the cells spots -> learn the scaling -> learn the affine -> learn the nonrigid transformation. all these steps require the cell centers that a user would select on own. 

## hot to use:
to use this, we need to make sure the data files (tiff images) are organized in the way we want!  

this notebook therefore does two things:

1. load the data and arrange in the way we want -- this is necessary for us to load the channles correctly
2. launch napari to do the manual registration and save the results.

for how to load the results and then stitch the original images, that will be in the other notebook

# 0. setup the environment and define the functinos needed

In [3]:
from datetime import date
import tifffile
import napari
import os 
import glob
import pprint
import numpy as np
import pickle
# import helpers
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import scipy.ndimage as ndi
from skimage.transform import warp
from scipy.spatial import distance
import math

from datetime import date
today = date.today()
todaystamp = today.strftime("%m%d%y")
print(todaystamp)

import sys 
sys.path.append('../functions/')
from transformations import wahba


070924


# 1. set up the directory. 
`file_dir` example:
```
!ls -lt /Users/shuonan/project/project_1/multimodal_experiments/data_from_jan_2024/tifs/ 

total 0
drwx------@ 6 shuonanchen  staff  192 Feb 18 19:00 MERFISH_DAPI
drwx------@ 6 shuonanchen  staff  192 Feb 18 18:59 MERFISH_GCaMP
drwx------@ 6 shuonanchen  staff  192 Feb 18 18:59 MERFISH_polyA
drwx------@ 6 shuonanchen  staff  192 Feb 17 16:13 MERFISH_tdTomato_G126
drwx------@ 7 shuonanchen  staff  224 Feb 17 16:13 MERFISH_BV_G126
```

**note that `file_dir` should be the names of the channels. if you have only one channel then it should be the folder of all the tif images**
for example, say your 6 sections of exvivo are stored in here: `Users/your_name/project/experiment_data/data_from_jan_2024/images/` and each of them have a name `mouse_10_section_1`, `mouse_10_section_2`, `mouse_10_section_3`,,,   

then your `file_dir` must be `Users/your_name/project/experiment_data/data_from_jan_2024/` but not `Users/your_name/project/experiment_data/data_from_jan_2024/images/ `

In [None]:
# this is where the results will be saved. 
base_dir = '/path_to_downloadfile/napari_rez/'

# this is where the files are located. 
file_dir = '/path_to_downloadfile/tifs/'
print(file_dir)  # again this should show you a list of channel-specific folders. each folder have one tif for one section

In [4]:

all_channel_files = glob.glob(file_dir+'/*')  # this should be the list of the channel-speficif folders. 
channel_name = [f.split('/')[-1].split('_')[1] for f in all_channel_files]
print(channel_name)  # these are all the available channels. 

['tdTomato', 'BV', 'DAPI', 'GCaMP', 'polyA']


In [5]:
# specify how many (or which) channels you want to load. 
C = len(channel_name) # if loading all the channels. 
C = 1  # by default we can assume there is only one channel you are using to stitch. 

# 2. load and arrange the data

In [6]:
exvivos = []
exvivos.append(np.nan)  # this is to make it consistent -- exvivos[0] is nothing, since the section_id starts from 1.
for section_id in range(1,5):  # only looking at section 1,2,3,4. 
    all_channels_img = []
    for f in all_channel_files:  # f is the folder for this channel. 
        f_path = glob.glob(f+f'/*section{section_id}.tif*')
        assert(len(f_path)==1)
        f_path = f_path[0]
        img = tifffile.imread(f_path)
        img = img.astype(float)
        img /= img.max()
        all_channels_img.append(img)  # these should all be the same sizes. 
    exvivos.append(np.stack(all_channels_img, axis = 0))

In [7]:
len(exvivos)  # this should be 5 == 4+1 (1 is the nan part, 4 is the number of sections we are looking)

5

In [8]:
# check the size of each
[m.shape for m in exvivos[1:]]  # these are (c,z,x,y) where c is the number of channel (1), z should be number of zslices you took for each section. 

[(2, 8, 2241, 2740),
 (2, 8, 2241, 2741),
 (2, 8, 2645, 2943),
 (2, 8, 2645, 2943)]

**okay the data is ready, now we are ready to launch the GUI.**
# 3. launch GUI
- notes
    - only look a pair at a time
    - dont forget to save (`q`)
- how to use
    1. select `section_id`: if `section_id=1`, then section 1 (red) and 2 (green) will be loaded. the larger index is always the source, mapped to the smaller one. (2 is transformed to map 1).
    2. focus on BV layer, adjust the gamma and contrast. hide the ohter channels when necessary.
    3. go to the `matched` layer, add points in both green and red. Note the number of points in green and red must be the same. 
    4. do the necessary transformation until you are bored with these layer (BV), then swtich to the tomato layer (hide BV, show tomato)
    5. you can clean the selected, or add more spots based on the tomato layer. 
    6. check the console output (from this notebook) for the necessary informations. 
    7. the saved results would contain (1) all the transformations applied; and (2) transformed image. 
- keys (in the using order)
    - `s`: scaling --> learn the scale and adjust the images 
    - `t`: transform --> rigid transformation
    - `c`: clear --> clear all selection of matched points. 
    - `d`: deformable --> apply deformable for the matched points.
    - `q`: quit --> well it doesnt quit the gui, it just save the results so dont press this unless you are done. 

In [9]:
# choose your section_id
section_id = 2


In [10]:

features = {'label': np.empty(0, dtype=int)}
all_transform = []
viewer = napari.view_image(exvivos[section_id][:,0],channel_axis=0,
                           contrast_limits=(0,exvivos[section_id][:,-1].max()), blending = 'additive',
                           name=[f'sec {section_id+11}...{c}' for c in channel_name],
                           colormap='red', visible = [True,True,False,False,False],)
exvivo_layer = viewer.add_image(exvivos[section_id+1][:,-1],channel_axis=0,
                                contrast_limits=(0,exvivos[section_id+1][:,-1].max()), blending = 'additive',
                                name=[f'sec {section_id+12}...{c}' for c in channel_name],
                                colormap="green", visible = [True,True,False,False,False],)
pl_in = viewer.add_points(
    size=5, edge_width=1, edge_color='red',face_color='transparent',    
    name=f"matched section {section_id}", text='label', features=features,ndim=2,out_of_slice_display=True)
@pl_in.events.data.connect
def update_feature_default_invivo():  
    global points_layer
    no_of_points = len(pl_in.data)
    pl_in.feature_defaults['label'] = no_of_points + 1
    pl_in.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_in.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_in.text.color = 'red'
    pl_in.text.translation = np.array([-10, 0])
update_feature_default_invivo()
pl_in.mode = 'add'

pl_ex = viewer.add_points(
    size=5, edge_width=1, edge_color='green',face_color='transparent',    
    name=f"matched section {section_id+1}", text='label', features=features,ndim=2,out_of_slice_display=True)
@pl_ex.events.data.connect
def update_feature_default_exvivo():  
    global points_layer
    no_of_points = len(pl_ex.data)
    pl_ex.feature_defaults['label'] = no_of_points + 1
    pl_ex.properties["label"][0:no_of_points] = range(1, no_of_points+1)
    pl_ex.text.values[0:no_of_points] = [str(i) for i in range(1, no_of_points+1)]
    pl_ex.text.color = 'green'
    pl_ex.text.translation = np.array([-10, 0])
update_feature_default_exvivo()
pl_ex.mode = 'add'

@viewer.bind_key('s', overwrite = True)
def scale(viewer):  
    print('learning scale..')
    m_invivo = viewer.layers[f"matched section {section_id}"].data
    m_exvivo = viewer.layers[f"matched section {section_id+1}"].data 
    assert(len(m_invivo)==len(m_exvivo))    
    s_invivo = np.sqrt(np.sum((m_invivo-np.mean(m_invivo,0))**2)/len(m_invivo))
    s_exvivo = np.sqrt(np.sum((m_exvivo-np.mean(m_exvivo,0))**2)/len(m_exvivo))
    scl = s_invivo/s_exvivo
    print(f'scale difference: {scl}')
    m_exvivo_new = scl*m_exvivo    
    exvivo_scaled_small = np.array([ndi.zoom(exvivo_layer[c].data, scl, order=3) for c in range(C)])    
    all_transform.append(dict(scale=scl))
    for c in range(C):
        exvivo_layer[c].data = exvivo_scaled_small[c]
    viewer.layers[f"matched section {section_id+1}"].data = m_exvivo_new    
    

@viewer.bind_key('t', overwrite = True)
def transform(viewer):    
    print('applying rigid..')
    m_invivo = viewer.layers[f"matched section {section_id}"].data
    m_exvivo = viewer.layers[f"matched section {section_id+1}"].data
    assert(len(m_invivo)==len(m_exvivo))
    bhat = wahba(m_exvivo,m_invivo)
    offset = -(bhat[:2,:2])@bhat[-1]
    exvivo_affined = np.array([ndi.affine_transform(exvivo_layer[c].data, bhat[:2,:2],
                                                    output_shape = exvivo_layer[0].data.shape,offset = offset, order=3) for c in range(C)])
    foo = np.c_[m_exvivo, np.ones((m_exvivo.shape[0],1))]   #Nx4
    m_exvivo_update = foo@bhat
    all_transform.append(dict(bhat=bhat))
    for c in range(C):
        exvivo_layer[c].data = exvivo_affined[c]
    viewer.layers[f"matched section {section_id+1}"].data = m_exvivo_update

    
@viewer.bind_key('d', overwrite = True)
def deform(viewer):    
    print('applying deformable..')
    ksz = 200
    m_invivo = viewer.layers[f"matched section {section_id}"].data
    m_exvivo = viewer.layers[f"matched section {section_id+1}"].data
    assert(len(m_invivo)==len(m_exvivo))    
    shift = m_invivo-m_exvivo   # so the newcoords = old+shift --> vecfield will have how much shift needed to apply to the old coords
    vec_field = np.zeros(exvivo_layer[0].data.shape + (2,))  # M1,M2,2
    for p,loc in enumerate(m_exvivo):
        vec_field[int(loc[0]),int(loc[1])] = shift[p]
    for c in range(C):
        vec_field[...,c] = ndi.gaussian_filter(vec_field[...,c], ksz)
    A = np.zeros_like(m_exvivo)
    for p,loc in enumerate(m_exvivo):
        A[p] = vec_field[int(loc[0]),int(loc[1])]
    diag_step,_,_,_ = np.linalg.lstsq(A,shift,rcond=None)
    step = np.diag(diag_step)
    vec_field_total =vec_field*step;  # element wise. 
    all_transform.append(dict(vec_field_total=vec_field_total))
    mapx_base, mapy_base = np.meshgrid(np.arange(exvivo_layer[0].data.shape[0]),np.arange(exvivo_layer[0].data.shape[1]), indexing='ij')
    mapx=mapx_base-vec_field_total[:,:,0]
    mapy=mapy_base-vec_field_total[:,:,1]
    for c in range(C):
        img_de = warp(exvivo_layer[c].data,np.array((mapx,mapy)), order = 3)            
        exvivo_layer[c].data = img_de    
    
    m_exvivo_new = np.zeros_like(m_exvivo)  # POINTS
    for p,loc in enumerate(m_exvivo):
        new_s = vec_field_total[int(loc[0]),int(loc[1])]
        m_exvivo_new[p] = loc+new_s
    viewer.layers[f"matched section {section_id+1}"].data = m_exvivo_new
    
@viewer.bind_key('c')
def clear_selected(viewer):        
    viewer.layers[f"matched section {section_id}"].data = np.empty((0, 2))
    viewer.layers[f"matched section {section_id+1}"].data = np.empty((0, 2))

    
@viewer.bind_key('q')
def save_rez(viewer):     
    out_dict = dict()
    out = np.array([exvivo_layer[c].data for c in range(C)] )
    out_dict[f'transformed_section_{section_id+1}']=out
    out_dict['transformations']=all_transform    
    pickle.dump(out_dict, open(base_dir + f'section_{section_id}_{section_id+1}_{todaystamp}.pkl','wb'))
    print("saved the results!")
    

applying rigid..
learning scale..
scale difference: 0.9982306505448446
applying rigid..
learning scale..
scale difference: 1.0020945622556776
applying rigid..
applying rigid..
applying deformable..


In [11]:
todaystamp

'022123'