# [Aligning Big Brains and Atlases](https://biop.github.io/ijp-imagetoatlas/) in Python

This notebook demos the use of ABBA with python. In particular it shows how to use the [BrainGlobe API](https://github.com/brainglobe) with ABBA.

Technically speaking, this notebook relies on [JPype](https://github.com/jpype-project/jpype) and [PyImageJ](https://github.com/imagej/pyimagej) to make the bridge between Python and Java. Surprisingly, there's no almost no functionality loss...

For this notebook to run, an atlas will need to be downloaded, as well as some sample dataset.

Note that ABBA works for aligning serial sections. If you want to register a nice 3D dataset, you'd rather have a look at [brainreg](https://github.com/brainglobe/brainreg).


## What happens in this notebook ?

We start by downloading an atlas from brainglobe, then start ABBA, download a few serial sections example, register them.


In [None]:
#Import brainglobe atlas api

from bg_atlasapi import show_atlases
import numpy as np
show_atlases()

In [None]:
# Download the atlas of your choice taken from the list above, here the 25 micron resolution mouse allen atlas

from bg_atlasapi.bg_atlas import BrainGlobeAtlas
atlas = BrainGlobeAtlas("allen_mouse_25um")
print(atlas.atlas_name)

In [None]:
# If you want to see the reference volume in napari (try the 3D viewer, it's awesome)
# import napari
# viewer = napari.view_image(atlas.reference)

In [None]:
# Starting PyImageJ, some of these dependencies may be autodiscovered via transitive dependencies, not sure

imagej_core_dep = 'net.imagej:imagej:2.3.0'
imagej_legacy_dep = 'net.imagej:imagej-legacy:0.38.1'
abba_dep = 'ch.epfl.biop:ImageToAtlasRegister:0.2.4'

deps_pack = [imagej_core_dep, imagej_legacy_dep, abba_dep]

In [None]:
# Starts ImageJ, show UI
import imagej
ij = imagej.init(deps_pack, headless=False)
ij.ui().showUI()

In [None]:
# Importing necessary classes from Java for the next cell, the hard one

from scyjava import jimport

import jpype
import jpype.imports
from jpype.types import *
from jpype import JImplements, JOverride

BdvFunctions = jimport('bdv.util.BdvFunctions')
Atlas = jimport('ch.epfl.biop.atlas.struct.Atlas')
AtlasMap = jimport('ch.epfl.biop.atlas.struct.AtlasMap')
AtlasOntology = jimport('ch.epfl.biop.atlas.struct.AtlasOntology')
AtlasNode = jimport('ch.epfl.biop.atlas.struct.AtlasNode')
AtlasHelper = jimport('ch.epfl.biop.atlas.struct.AtlasHelper')
AffineTransform3D = jimport('net.imglib2.realtransform.AffineTransform3D')
ArrayList = jimport('java.util.ArrayList')
BdvOptions = jimport('bdv.util.BdvOptions')

The cell below translates the brainglobe api internal structure into the one used by ABBA. They are very similar in the sense that an atlas is the combination of a label image, structural images and an ontology that go from the root to leaves.

JPype allows to implement Java interfaces with python objects.

Here the four interfaces that needs to be implemented from the BrainGlobe API are:
* [Atlas](https://github.com/BIOP/ijp-atlas/blob/ebd31918ab6f4e52b46d5ab14a08e7d826602ed8/src/main/java/ch/epfl/biop/atlas/struct/Atlas.java)
* [AtlasMap](https://github.com/BIOP/ijp-atlas/blob/ebd31918ab6f4e52b46d5ab14a08e7d826602ed8/src/main/java/ch/epfl/biop/atlas/struct/AtlasMap.java)
* [AtlasOntology](https://github.com/BIOP/ijp-atlas/blob/ebd31918ab6f4e52b46d5ab14a08e7d826602ed8/src/main/java/ch/epfl/biop/atlas/struct/AtlasOntology.java)
* [AtlasNode](https://github.com/BIOP/ijp-atlas/blob/ebd31918ab6f4e52b46d5ab14a08e7d826602ed8/src/main/java/ch/epfl/biop/atlas/struct/AtlasNode.java)

In [None]:
# The core of this compatibility layer : 
# - translates BrainGlobe API to ch.epfl.biop.atlas.struct interfaces through JPype

@JImplements(AtlasMap)
class BrainGlobeMap(object):
    
    def __init__(self, bg_atlas):
        #this function is called way too many times if I put here the content 
        # of initialize... and I don't know why
        # that's why there's this initialize function
        self.atlas = bg_atlas
    
    @JOverride
    def setDataSource(self, dataSource):
        self.dataSource = dataSource

    @JOverride
    def initialize(self, atlasName):
        self.atlasName = str(atlasName)
        
        atlas_resolution_in__mm = JDouble(min(self.atlas.metadata['resolution'])/1000.0)
        
        vox_x_mm = self.atlas.metadata['resolution'][0] / 1000.0
        vox_y_mm = self.atlas.metadata['resolution'][1] / 1000.0
        vox_z_mm = self.atlas.metadata['resolution'][2] / 1000.0
        
        affine_transform = AffineTransform3D()
        affine_transform.scale(JDouble(vox_x_mm), JDouble(vox_y_mm), JDouble(vox_z_mm))
        
        # Convert 
        bss = BdvFunctions.show(ij.py.to_java(self.atlas.reference),JString(self.atlas.atlas_name+'_reference'), BdvOptions.options().sourceTransform(affine_transform))
        reference_sac = bss.getSources().get(0)
        bss.getBdvHandle().close()
        
        bss = BdvFunctions.show(ij.py.to_java(self.atlas.hemispheres),JString(self.atlas.atlas_name+'_hemispheres'), BdvOptions.options().sourceTransform(affine_transform))
        left_right_sac = bss.getSources().get(0)
        bss.getBdvHandle().close()
        
        bss = BdvFunctions.show(ij.py.to_java(self.atlas.annotation),JString(self.atlas.atlas_name+'_annotation'), BdvOptions.options().sourceTransform(affine_transform))
        self.annotation_sac = bss.getSources().get(0)
        bss.getBdvHandle().close()
        
        image_keys = ArrayList()
        image_keys.add(JString('reference'))
        image_keys.add(JString('X'))
        image_keys.add(JString('Y'))
        image_keys.add(JString('Z'))
        image_keys.add(JString('Left Right'))
        
        structural_images = {
            'reference':  reference_sac,
            'X': AtlasHelper.getCoordinateSac(0,JString('X')),
            'Y': AtlasHelper.getCoordinateSac(1,JString('Y')),
            'Z': AtlasHelper.getCoordinateSac(2,JString('Z')),
            'Left Right': left_right_sac
        } #return Map<String,SourceAndConverter>
        
        self.atlas_resolution_in__mm = atlas_resolution_in__mm
        self.affine_transform = affine_transform
        self.image_keys = image_keys
        self.structural_images = structural_images
        self.maxReference = JDouble(np.max(atlas.reference)*2)

    @JOverride
    def getDataSource(self):
        return self.dataSource #return URL
        
    @JOverride
    def getStructuralImages(self):
        return self.structural_images
    
    @JOverride
    def getImagesKeys(self):
        return self.image_keys

    @JOverride
    def getLabelImage(self):
        return self.annotation_sac #SourceAndConverter

    @JOverride
    def getAtlasPrecisionInMillimeter(self):
        return self.atlas_resolution_in__mm

    @JOverride
    def getCoronalTransform(self):
        return AffineTransform3D()
    
    @JOverride
    def getImageMax(self, key):
         return self.maxReference #double
    
    @JOverride        
    def labelRight(self):
        return JInt(1)

    @JOverride
    def labelLeft(self):
        return JInt(2)
        
@JImplements(AtlasOntology)
class BrainGlobeOntology(object):
    
    def __init__(self, bg_atlas):
        self.atlas = bg_atlas
        
    @JOverride
    def getName(self):
        return JString(self.atlas.atlas_name)
    
    @JOverride
    def initialize(self):
        self.root_node = BrainGlobeAtlasNode(self.atlas, self.atlas.structures.tree.root, None)
        self.idToAtlasNodeMap = AtlasHelper.buildIdToAtlasNodeMap(self.root_node)

    @JOverride
    def setDataSource(self, dataSource):
        self.dataSource = dataSource

    @JOverride
    def getDataSource(self):
        return self.dataSource #return URL

    @JOverride
    def getRoot(self):  
        return self.root_node #return AtlasNode
    
    @JOverride
    def getNodeFromId(self, index):
        return self.idToAtlasNodeMap.get(index) #return AtlasNode 
    
    @JOverride
    def getNamingProperty(self):
        return self.namingProperty
    
    @JOverride
    def setNamingProperty(self, namingProperty):
        self.namingProperty = namingProperty
    
@JImplements(AtlasNode)
class BrainGlobeAtlasNode(object):

    def __init__(self, bg_atlas, index, parent_node):
        self.atlas = bg_atlas
        self.id = index
        self.parent_node = parent_node
        children_nodes = []
        for child in bg_atlas.structures.tree.children(index):
            childNode = BrainGlobeAtlasNode(atlas, child.identifier, self)
            children_nodes.append(childNode)
        self.children_nodes = ArrayList(children_nodes)
        self.namingKey = JString('acronym')
    
    @JOverride
    def getId(self):
        return JInt(self.id)
    
    @JOverride
    def getColor(self):
        val = JInt[4]
        rgb = self.data().get('rgb_triplet')
        return val

    @JOverride
    def data(self):
        dict_ori = self.atlas.structures[self.id]
        string_dict = {}
        for key in dict_ori.keys():
            try:
                string_dict[key] = JString(str(dict_ori[key]))
            except Exception:
                pass   
        return string_dict# self.atlas.structures[self.id] #string_dict #self.atlas.structures[self.id] # issue with map
    
    @JOverride
    def parent(self):
        return self.parent_node

    @JOverride
    def children(self):
        return self.children_nodes
    
    @JOverride
    def toString(self):
        return self.data().get(self.namingKey)

    
@JImplements(Atlas)
class BrainGlobeAtlas(object):

    def __init__(self, bg_atlas):
        self.atlas = bg_atlas
        
    @JOverride
    def getMap(self):
        return self.bg_atlasmap
        
    @JOverride
    def getOntology(self):
        return self.bg_ontology
    
    @JOverride
    def initialize(self, mapURL, ontologyURL):
        self.bg_ontology = BrainGlobeOntology(self.atlas)
        self.bg_ontology.initialize()
        self.bg_ontology.setNamingProperty(JString('acronym'))
        self.bg_atlasmap = BrainGlobeMap(self.atlas)
        self.bg_atlasmap.initialize(self.atlas.atlas_name)
        self.dois = ArrayList()
        self.dois.add(JString('doi1')) #TODO
        self.dois.add(JString('doi2'))
        
    @JOverride
    def getDOIs(self):
        return self.dois

    @JOverride
    def getURL(self):
        return JString('BrainGlobe Atlas URL...')  
    
    @JOverride
    def getName(self):
        return JString(self.atlas.atlas_name)
    
    @JOverride
    def toString(self):
        return self.getName()

In [None]:
# Makes the atlas object
convertedAtlas = BrainGlobeAtlas(atlas)
convertedAtlas.initialize(None, None)

In [None]:
# Puts it in the scijava ObjectService for automatic discovery if necessary
ij.object().addObject(convertedAtlas)

In [None]:
# Starts ABBA

# .. but before : logger, please shut up
DebugTools = jimport('loci.common.DebugTools')
DebugTools.enableLogging('INFO')

# Ok, let's start ABBA and its BDV view (it's also possible to start it without any GUI, 
# or even to build another GUI with a Napari view, why not ?)

ABBABdvStartCommand = jimport('ch.epfl.biop.atlas.aligner.gui.bdv.ABBABdvStartCommand') # Command import
ij.command().run(ABBABdvStartCommand, True, 'ba', convertedAtlas, 'slicing_mode', 'coronal') # Starts it with the converted brainglobe atlas in the coronal orientation

## Download serial sections examples

Download sections 30, 40, 50 from the zenodo repository: https://zenodo.org/record/4715656#.Ybe-8Fko_iE (around 100Mb per section...)

Files are put in the current repository, under the `images` folder. If files have already been downloaded, the download will be skipped.

In [None]:
import os
from bg_atlasapi import utils
from pathlib import Path
cwd = os.getcwd() # gets current path

utils.check_internet_connection()
base_zenodo_url = 'https://zenodo.org/record/4715656/'

basePath = cwd+'/images/'

def downloadIfNecessary(section_name):
    outputPath = Path(basePath+section_name)
    if not outputPath.exists():
        url = 'https://zenodo.org/record/4715656/files/'+section_name+'?download=1'
        utils.retrieve_over_http(url, outputPath)
    
downloadIfNecessary('S30.ome.tiff') #https://zenodo.org/record/4715656/files/S30.ome.tiff?download=1
downloadIfNecessary('S40.ome.tiff') #https://zenodo.org/record/4715656/files/S40.ome.tiff?download=1
downloadIfNecessary('S50.ome.tiff') #https://zenodo.org/record/4715656/files/S50.ome.tiff?download=1


In [None]:
# Let's get the multipositioner object 
MultiSlicePositioner = jimport('ch.epfl.biop.atlas.aligner.MultiSlicePositioner')

# There's only one multipositioner instance in the object service
# https://javadoc.scijava.org/SciJava/org/scijava/object/ObjectService.html
mp = ij.object().getObjects(MultiSlicePositioner).get(0)


In [None]:
# Let's import the files using Bio-Formats.
# The list of all commands is accessible here:
# https://github.com/BIOP/ijp-imagetoatlas/tree/master/src/main/java/ch/epfl/biop/atlas/aligner/command

ImportImageCommand = jimport('ch.epfl.biop.atlas.aligner.command.ImportImageCommand')

# Here we want to import images: check
# https://github.com/BIOP/ijp-imagetoatlas/blob/master/src/main/java/ch/epfl/biop/atlas/aligner/command/ImportImageCommand.java

File = jimport('java.io.File')

file_s30 = File(basePath+'S30.ome.tiff')
file_s40 = File(basePath+'S40.ome.tiff')
file_s50 = File(basePath+'S50.ome.tiff')

FileArray = JArray(File)
files = FileArray(3)

files[0] = file_s30
files[1] = file_s40
files[2] = file_s50

# Any missing input parameter will lead to a popup window asking the missing argument to the user
ij.command().run(ImportImageCommand, True,\
                 "files", files,\
                 "mp", mp,\
                 "split_rgb_channels", False,\
                 "slice_axis_initial", 5.0,\
                 "increment_between_slices", 0.04\
                )



In [None]:
mp.selectSlice(mp.getSlices()) # select all slices

In [None]:
mp.getReslicedAtlas().setRotateY(0.05) # Small correction in Y slicing

In [None]:
mp.deselectSlice(mp.getSlices()) # deselect all

In [None]:
mp.selectSlice(mp.getSlices().get(2)) # select the last slice

In [None]:
# Gets the bigdataviewer view. First let's get the class
BdvMultislicePositionerView = jimport('ch.epfl.biop.atlas.aligner.gui.bdv.BdvMultislicePositionerView')

In [None]:
# view = ij.object().getObjects(BdvMultislicePositionerView).get(0) # Only one BigDataViewer view
# TODO : use fix in newer version to access the view through the object service

In [None]:
# The slices are always sorted from small z to high z. To keep track of who's who, reference them before moving them
slice30 = mp.getSlices().get(0) 
slice40 = mp.getSlices().get(1)
slice50 = mp.getSlices().get(2)

In [None]:
mp.moveSlice(slice50,9.5)

In [None]:
mp.moveSlice(slice40,8.2)

In [None]:
mp.moveSlice(slice30,7.5)

In [None]:
# Simple actions are accessible through mp.whatever, but most actions are executed on selected slices
# Almost all actions are executed asynchronously

# For a registration : let's select all slices
mp.selectSlice(mp.getSlices()) # select all

In [None]:
# Let's run an affine registration on the green slice channel and on the reference atlas channel
# elastix needs to be setup, see https://biop.github.io/ijp-imagetoatlas/installation.html
RegistrationElastixAffineCommand = jimport('ch.epfl.biop.atlas.aligner.command.RegistrationElastixAffineCommand')

ij.command().run(RegistrationElastixAffineCommand, True,
                 "mp", mp,\
                 "pixel_size_micrometer", 40,\
                 "show_imageplus_registration_result", False,\
                 "background_offset_value_moving",0,\
                 "atlas_image_channel",0,\
                 "slice_image_channel",1) # second channel, 0-based


In [None]:
# Let's try spline
RegistrationElastixSplineCommand = jimport('ch.epfl.biop.atlas.aligner.command.RegistrationElastixSplineCommand')

ij.command().run(RegistrationElastixSplineCommand, True,
                 "mp", mp,\
                 "nb_control_points_x", 12,\
                 "pixel_size_micrometer", 20,\
                 "show_imageplus_registration_result", False,\
                 "background_offset_value_moving",0,\
                 "atlas_image_channel",0,\
                 "slice_image_channel",1) # second channel, 0-based


In [None]:
# Let's wait for all registration to finish
mp.waitForTasks()