# brain_to_brain_inverse_registration
The purpose of this notebook is to demonstrate how to transform clearmap2-detected points from raw space to atlas space when elastix fails during the usual transformation steps. Sometimes elastix will fail during either the affine or bspline (primarily this one) step giving an error like: 
```
Description: itk::ERROR: AdvancedMattesMutualInformationMetric(0x20e2330): Too many samples map outside moving image buffer: 0 / 10000


Error occurred during actual registration.
```
If this happens, there may still be a way to transform raw points from this brain to atlas space. One way to do this is use a different brain for which the transformations between raw space and atlas space have already been obtained. In this notebook, we will refer to this brain as the **working brain**. The trick is to find the transformation between the brain for which elastix is failing (we will call this the **faulty brain**) and the working brain, which is much less likely to fail. Let's say for the faulty brain we have a 642 nm signal channel whose raw space points we want to transform to the atlas, and we also have a 488 nm autofluorescence channel. The transformation steps to obtain faulty brain's raw points in atlas space look like:

```

642_faulty ->  488_faulty -> 488_working  -> atlas

```
where `488_faulty -> 488_working` is the only new transformation we should need. Note that because we are transforming points we actually need the inverse of this transformation, so:
- moving: 488_working
- fixed: 488_faulty
Likewise for all of the other arrows. Once we obtain these transformations with elastix, you can run this notebook to apply the transformations to the raw space points to obtain them in atlas space.   

This notebook assumes you have run ClearMap2 to obtain the points and have run the ClearMap2 pipeline: https://github.com/PrincetonUniversity/lightsheet_helper_scripts/tree/master/clearmap2/pipeline/cfos_cell_detection_pipeline

If you have not, then you will need to do some modification of the paths below. The general idea of the notebook should work regardless of how you have your files organized. 

In [None]:
import tifffile, glob, os, sys, json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import multiprocessing as mp
%matplotlib inline

In [None]:
sys.path.append('/jukebox/wang/ahoag/ClearMap2') # change to wherever you have clearmap2 cloned
import ClearMap.IO.Workspace as wsp
import ClearMap.IO.IO as io
import ClearMap.ParallelProcessing.BlockProcessing as bp
import ClearMap.ImageProcessing.Experts.Cells as cells
import ClearMap.Settings as settings
import ClearMap.Alignment.Resampling as res
import ClearMap.Alignment.Elastix as elx   

## Attempt to align points from one brain to another then to atlas
In this example, we will be working with Lightserv request: zimmerman_06, sample name: zimmerman_06-352 whose elastix command failed during the inverse registration step of the atlas -> 488 channel. The elastix transformations for sample_name: zimmerman_06-354 in this same request worked, so we will be using that as the "working brain". 

In [None]:
request_name = "zimmerman_06"
faulty_brain = "zimmerman_06-352"
working_brain = "zimmerman_06-354"
output_rootpath = "/jukebox/witten/Chris/data/clearmap2" # set to rootpath of clearmap2 pipeline output files

In [None]:
faulty_clearmap_dir = os.path.join(output_rootpath,request_name,faulty_brain,'imaging_request_1','rawdata','resolution_3.6x')

In [None]:
ws_faulty = wsp.Workspace('CellMap', directory=faulty_clearmap_dir);
ws_faulty.debug=False
ws_faulty.info()

In [None]:
ch488_downsized_dir = os.path.join(faulty_clearmap_dir,'Ex_488_Em_0_downsized')
ch488_downsized_file = os.path.join(ch488_downsized_dir,'downsized_for_atlas_ch488.tif')
ch488_downsized_vol = tifffile.imread(ch488_downsized_file)
ch642_downsized_dir = os.path.join(faulty_clearmap_dir,'Ex_642_Em_2_downsized')
ch642_downsized_file = os.path.join(ch642_downsized_dir,'downsized_for_atlas_ch642.tif')
ch642_downsized_vol = tifffile.imread(ch642_downsized_file)

In [None]:
working_clearmap_dir = os.path.join(output_rootpath,request_name,working_brain,'imaging_request_1','rawdata','resolution_3.6x')

In [None]:
ws_working = wsp.Workspace('CellMap', directory=working_clearmap_dir);
ws_working.debug=False
ws_working.info()

In [None]:
ch488_downsized_dir_working = os.path.join(working_clearmap_dir,'Ex_488_Em_0_downsized')
ch488_downsized_file_working = os.path.join(ch488_downsized_dir_working,'downsized_for_atlas_ch488.tif')
ch488_downsized_vol_working = tifffile.imread(ch488_downsized_file_working)

In [None]:
cells_source_faulty = ws_faulty.source('cells', postfix='raw')
coordinates_faulty = np.hstack([cells_source_faulty[c][:,None] for c in 'xyz']);

In [None]:
# Verify that raw cells are in the right locations and we understand the x,y,z layout of these clearmap files
zplane=1200 # just pick a z plane where we know there will be cells
zplane_depth = 3
minplane = max(zplane-zplane_depth,0)
maxplane = zplane+zplane_depth
zplane_range = range(minplane,maxplane)
this_plane_coords = np.array([coord for coord in coordinates_faulty if coord[-1] in zplane_range])
xs = this_plane_coords[:,0]
ys = this_plane_coords[:,1]
fig,axes = plt.subplots(figsize=(15,8),nrows=1,ncols=2,sharex=True,sharey=True)
ax_tissue = axes[0]
stitched_z_plane = ws_faulty.source('stitched')[:,:,zplane] 
ax_tissue.imshow(stitched_z_plane,vmin=0,vmax=200,cmap='viridis')
ax_both=axes[1]
ax_both.imshow(stitched_z_plane,vmin=0,vmax=200,cmap='viridis')
ax_both.scatter(ys,xs,s=50,facecolors='none',edgecolors='r') # [::10] shows every 10th cell. I do this because there are so many cells that showing them all is a mess

## Resample to faulty 642 downsized space

In [None]:
coordinates_raw_swapped_axes = np.zeros_like(coordinates_faulty)
coordinates_raw_swapped_axes[:,0] = coordinates_faulty[:,2]
coordinates_raw_swapped_axes[:,1] = coordinates_faulty[:,1]
coordinates_raw_swapped_axes[:,2] = coordinates_faulty[:,0]

In [None]:
coordinates_resampled = res.resample_points(
    coordinates_raw_swapped_axes, sink=None, orientation=None,
    source_shape=io.shape(ws_faulty.filename('stitched'))[::-1],
    sink_shape=io.shape(ch642_downsized_file))

In [None]:
io.shape(ch642_downsized_file) # x,y,z where z are horizontal planes

In [None]:
coordinates_resampled

In [None]:
# visually verify that the points are still in the right locations
zplane=270 # the equivalent to z=2000 raw data that we used
zplane_depth = 1
minplane = max(zplane-zplane_depth,0)
maxplane = zplane+zplane_depth
zplane_range = range(minplane,maxplane)
this_plane_coords_resampled = np.array([coord for coord in coordinates_resampled if int(coord[0]) in zplane_range])
ys = this_plane_coords_resampled[:,1]
zs = this_plane_coords_resampled[:,2]
fig,axes = plt.subplots(figsize=(15,8),nrows=1,ncols=2,sharex=True,sharey=True)
ax_tissue = axes[0]
resampled_z_plane = ch642_downsized_vol[:,:,zplane]
# resampled_z_plane_fixaxes = np.swapaxes(resampled_z_plane,0,1)
ax_tissue.imshow(resampled_z_plane,vmin=0,vmax=200,cmap='viridis')
ax_both=axes[1]
ax_both.imshow(resampled_z_plane,vmin=0,vmax=200,cmap='viridis')
ax_both.scatter(ys,zs,s=50,facecolors='none',edgecolors='r')

This still looks good.

## Transform from faulty 642 downsized to faulty 488 downsized

In [None]:
# Transform cell coordinates from 642-space to 488-space
elastix_inverse_dir = os.path.join(faulty_clearmap_dir,"elastix_inverse_transform","488_to_642")
coordinates_aligned_to_488 = elx.transform_points(
        coordinates_resampled, sink=None,
        transform_directory=elastix_inverse_dir,
        temp_file='/tmp/elastix_input_pipeline.bin',
        result_directory='/tmp/elastix_output_pipeline')


In [None]:
coordinates_aligned_to_488

In [None]:
# Verify that raw cells are in the right locations and we understand the x,y,z layout of these clearmap files
zplane=270 # the equivalent to z=2000 raw data that we used
zplane_depth = 1
minplane = max(zplane-zplane_depth,0)
maxplane = zplane+zplane_depth
zplane_range = range(minplane,maxplane)
this_plane_coords_aligned_to_488 = np.array([coord for coord in coordinates_aligned_to_488 if int(coord[0]) in zplane_range])
ys = this_plane_coords_aligned_to_488[:,1]
zs = this_plane_coords_aligned_to_488[:,2]

fig,axes = plt.subplots(figsize=(15,8),nrows=1,ncols=2,sharex=True,sharey=True)
ax_tissue = axes[0]
resampled_z_plane = ch488_downsized_vol[:,:,zplane]
# resampled_z_plane_fixaxes = np.swapaxes(resampled_z_plane,0,1)
ax_tissue.imshow(resampled_z_plane,vmin=0,vmax=1200,cmap='viridis')
ax_both=axes[1]
ax_both.imshow(resampled_z_plane,vmin=0,vmax=1200,cmap='viridis')
ax_both.scatter(ys,zs,s=50,facecolors='none',edgecolors='r')

Still looks good. 

## Transformation from faulty 488 downsized to working 488 downsized
This is the part that is different. Here, we need to transform between two different brains, then use the 488 -> atlas transformation from the second brain to get us all the way to atlas space. 

In [None]:
# Transform cell coordinates from 488-space (faulty) to 488-space (working)
elastix_inverse_dir_working = os.path.join(faulty_clearmap_dir,f"elastix_inverse_transform_488_to_{working_brain}_488")
coordinates_aligned_to_488_working = elx.transform_points(
        coordinates_aligned_to_488, sink=None,
        transform_directory=elastix_inverse_dir_working,
        temp_file='/tmp/elastix_input_pipeline.bin',
        result_directory='/tmp/elastix_output_pipeline')

In [None]:
coordinates_aligned_to_488_working

In [None]:
zplane=270 # the equivalent to z=2000 raw data that we used
zplane_depth = 1
minplane = max(zplane-zplane_depth,0)
maxplane = zplane+zplane_depth
zplane_range = range(minplane,maxplane)
this_plane_coords_aligned_to_488_working = np.array(
    [coord for coord in coordinates_aligned_to_488_working if int(coord[0]) in zplane_range])
ys = this_plane_coords_aligned_to_488_working[:,1]
zs = this_plane_coords_aligned_to_488_working[:,2]

fig,axes = plt.subplots(figsize=(15,8),nrows=1,ncols=2,sharex=True,sharey=True)
ax_tissue = axes[0]
resampled_z_plane = ch488_downsized_vol_working[:,:,zplane]
# resampled_z_plane_fixaxes = np.swapaxes(resampled_z_plane,0,1)
ax_tissue.imshow(resampled_z_plane,vmin=0,vmax=1200,cmap='viridis')
ax_both=axes[1]
ax_both.imshow(resampled_z_plane,vmin=0,vmax=1200,cmap='viridis')
ax_both.scatter(ys,zs,s=50,facecolors='none',edgecolors='r')

These don't need to overlap with the brain itself since they are coordinates from another brain warped to this space. They just need to show that they are in the right space, which they are.

## Final step: Transform working 488 downsized -> Princeton Mouse Atlas
This is the final step and the result will be the cell coordinates of the original faulty brain aligned to the Princeton Mouse Atlas. The transform_directory here is the inverse transform directory from the working brain, i.e. the location where the results from transforming atlas -> downsized 488 working are. 

In [None]:
# Transform cell coordinates from working 488-space to atlas
elastix_inverse_dir_working_atlas = os.path.join(working_clearmap_dir,f"elastix_inverse_transform")
# elastix_inverse_dir_working_atlas
coordinates_aligned_to_atlas = elx.transform_points(
        coordinates_aligned_to_488_working, sink=None,
        transform_directory=elastix_inverse_dir_working_atlas,
        temp_file='/tmp/elastix_input_pipeline.bin',
        result_directory='/tmp/elastix_output_pipeline')

In [None]:
# Load the princeton mouse atlas tissue volume
pma_file = '/jukebox/LightSheetTransfer/atlas/sagittal_atlas_20um_iso.tif'
pma_vol = tifffile.imread(pma_file)

In [None]:
# Verify that cells are still in the right locations 
zplane=150
zplane_depth = 1
minplane = max(zplane-zplane_depth,0)
maxplane = zplane+zplane_depth
zplane_range = range(minplane,maxplane)
this_plane_coords_aligned_to_atlas = np.array(
    [coord for coord in coordinates_aligned_to_atlas if int(coord[0]) in zplane_range])
ys = this_plane_coords_aligned_to_atlas[:,1]
zs = this_plane_coords_aligned_to_atlas[:,2]

fig,axes = plt.subplots(figsize=(15,8),nrows=1,ncols=2,sharex=True,sharey=True)
ax_tissue = axes[0]
atlas_z_plane = pma_vol[:,:,zplane]
# atlas_z_plane_fixaxes = np.swapaxes(atlas_z_plane,0,1)
ax_tissue.imshow(atlas_z_plane,vmin=0,vmax=500,cmap='viridis')
ax_both=axes[1]
ax_both.imshow(atlas_z_plane,vmin=0,vmax=500,cmap='viridis')
ax_both.scatter(ys,zs,s=50,facecolors='none',edgecolors='r')
ax_tissue.set_xlim(0,atlas_z_plane.shape[1])
ax_tissue.set_ylim(0,atlas_z_plane.shape[0])
plt.suptitle(f"Points from {faulty_brain} aligned \nvia {working_brain} transforms to atlas",
            fontsize=24)
plt.tight_layout()
savename = f'./{faulty_brain}_points_aligned_via_{working_brain}_to_atlas.png'
plt.savefig(savename)
print(f"Saved {savename}")

Still looks good, just need to mask out points outside of the volume, which we do anyway. Ready to save these points. 

In [None]:
# Load atlas files
size_intensity = np.hstack([cells_source_faulty[c][:,None] for c in ['size','background']])

eroded_atlas_file = '/jukebox/LightSheetTransfer/atlas/annotation_sagittal_atlas_20um_16bit_hierarch_labels_60um_edge_80um_vent_erosion.tif'
segment_props_file = '/jukebox/LightSheetTransfer/atlas/PMA_16bit_hierarch_labels_segment_properties_info'
ontology_json_file = '/jukebox/LightSheetTransfer/atlas/PMA_ontology.json'

eroded_atlas_vol = np.array(tifffile.imread(eroded_atlas_file)).astype('uint16')
atlas_segments = np.unique(eroded_atlas_vol)
atlas_segments = np.array([x for x in atlas_segments if x!=0])

with open(segment_props_file,'r') as infile:
    segment_props_dict = json.load(infile)

with open(ontology_json_file,'r') as infile:
    ontology_dict = json.load(infile)

# Record the brain region ID where a cell is detected, 0 if not in a region
cell_regions = np.empty([len(coordinates_aligned_to_atlas), 1], dtype=int)
xyz = np.asarray([(int(X[0]), int(X[1]), int(X[2])) for X in coordinates_aligned_to_atlas])
for idx, val in enumerate(xyz):
    try:
        ID = eroded_atlas_vol[val[2],val[1],val[0]]
        cell_regions[idx] = ID
    except Exception as e:
        cell_regions[idx] = 0
        pass

# Add brain region ID to transformed cell array 
cells_to_save = np.hstack((coordinates_aligned_to_atlas,size_intensity,cell_regions))
header = ['x','y','z','size','intensity','region']
dtypes = [int, int, int, int, float, int]
dt = {'names' : header, 'formats' : dtypes}
output_array = np.zeros(len(cells_to_save), dtype=dt)
for i,h in enumerate(header):
    output_array[h] = cells_to_save[:,i]
# Remove cells that are outside the atlas
output_array = np.delete(output_array,np.argwhere(cell_regions==0))

# Save registered cells to cells_transformed_to_atlas.npy
savename = ws_faulty.filename('cells',postfix='transformed_to_atlas')
io.write(savename,output_array)
print(f'Saving registered cell detection results to: {savename}')
print()