# An automated pipeline to create an atlas of in-situ hybridization gene expression data in the adult marmoset brain (ISBI 2023, Poon, C., et al.)

## 3D registration code

In [None]:
import os
import sys
sys.path.append("/disk/soft/SLURM/slurm/")
from slurm import slurm_tools
import numpy as np
import nibabel as nib
import json
import matplotlib.pyplot as plt
import glob
from PIL import Image
import cv2
from skimage.transform import resize

%load_ext autoreload
%autoreload 2
from IPython.display import clear_output
import matplotlib as mpl
dpi_default = mpl.rcParams['figure.dpi']

def json_read(filename):
    if (os.path.isfile(filename)):
        with open(filename) as data_file:
            data = json.load(data_file)
        isfile=True;
    else:
        data={}
        isfile=False;
    return data, isfile

### 1. Align gene images by each RGB channel to the last 2D transform, save as niftis

In [None]:
## 1. aligns the RGB gene niis wrt the last 2D transform, then outputs each aligned RGB image

def align_save_gene_rgbs(all_cmp, mid):
    print('------in 1 align_save_gene_rgbs')
    jobids = []
    for gene in all_cmp:

        if  "gene" in gene:

            print('gene:',gene)

            # original niftis
            niif_ = ch_workdir+"/"+mid+"/"+gene+"/img_*.nii.gz"
            files = glob.glob(niif_)
            files.sort()
            #print('niif_ 1: {}, num files: {}'.format(niif_, len(files)))


            # aligned genes
            niif_ = ch_workdir+"/"+mid+"/out/"+gene+"/2/img/img_*.nii.gz"
            files2 = glob.glob(niif_)
            files2.sort()
            #print('niif_ 2: {}, num files2: {}'.format(niif_, len(files2)))

            # inverse transforms
            h5f_ = ch_workdir+"/"+mid+"/out/"+gene+"/2/trafo/*InverseComposite.h5"
            files3 = glob.glob(h5f_)
            files3.sort()
            jobids = []
            #print('h5f_: {}, num files3: {}'.format(h5f_, len(files3)))

            os.makedirs(workdir+mid+"/out/"+gene+"/2/img/",exist_ok=True)


            for f1,f2,f3,indx in zip(files,files2,files3,range(len(files3))):

                for r in ["R","G","B"]:
                    print(r)
                    move = f1.replace("img_","img"+r+"_")  # original
                    ref = f2                               # aligned
                    out = f2.replace("img_","img"+r+"_")#.replace("charissa","skibbe")    # does not yet exist
                    #print('move: {}, ref: {}, out: {}'.format(move, ref, out))

                    cores = "2"
                    logfolder = "/".join(out.split("/")[:-1])+"/log/"  # char changed this
                    os.makedirs(logfolder,exist_ok=True)

                    POSTTRAFO = f3.replace("InverseComposite.h5","")
                    #print('logfolder: {}, POSTTRAFO: {}'.format(logfolder, POSTTRAFO))

                    SLURM_commands = ["MOVE="+move+" REF="+ref+" CORES="+cores+" OUT1="+out+"  POSTTRAFO="+POSTTRAFO+" bash apply_2D.sh;"]
                    #SLURM_commands = ["MOVE="+move+" REF="+ref+" CORES="+cores+" OUT1="+OUT1+"  POSTTRAFO="+POSTTRAFO+"  bash apply_avg3D_02.sh;"]


                    jobid, success = slurm_tools.slurm_submit(SLURM_commands,
                        name = "reg3D",
                        output = logfolder+'/log_2D'+str(indx)+'.out',
                        mem = '8GB',
                        cores = cores,
                        partition ="ish-adult", #"bigmem",  
                        )      

                    if not success:
                            print("could not submit jobs")
                            print(format(jobid))
                            print(format(jobids))
                            slurm_tools.killall(jobids)
                    jobids.append(int(jobid))
                slurm_tools.wait_for_jobs(jobids)
                clear_output(wait=True)
                print('in 1 align_save_gene_rgbs:'+gene)

### 2. Make 3D sagittal stacks of each gene in each RGB channel, with padding in Z-plane

In [None]:
## 2. make 3D sagittal stacks of each gene (R,B,G), with z-margin (literal z, not z wrt sagittal)
# ie the directionality of the slices will be the same as the template backlit image
# prep for registration

def make_sagittal_stacks(all_cmp,mid):
    print('------in 2 make_sagittal_stacks')

    for cmp in all_cmp:

        if "backlit" in cmp or "gene" in cmp:
        #if  "gene" in cmp:  # this was uncommented!!!

            # set state
            state = 10 if "backlit" in cmp or "seg" in cmp else 2

            print('cmp: {}, state: {}'.format(cmp,state))

            # idir is actually the same dir as the 2D registration
            if "seg" in cmp:
                idir = workdir+mid+"/out_seg/"+str(cmp)+"/"+str(1)+'/img/'
            elif "backlit" in cmp:
                idir = workdir+mid+"/out/"+str(cmp)+"/"+str(10)+'/img/'
            else:
                idir = workdir+mid+"/out/"+str(cmp)+"/"+str(state)+'/img/'
            
            # if 'backlit' in compartment name, then use the single grayscale image, else use R,G,B
            ilist = ["img_"] if state == 10 else  ["img_","imgR_","imgG_","imgB_"]

            for il in ilist:
                i_files = glob.glob(idir+"/"+il+"*.nii.gz")
                
                i_files.sort()
                print('num files: ',len(i_files)) 
                print('il: ',il)
                print('idir: ',idir)

                z_margin = 20

                for fn,fi in zip(i_files,range(len(i_files))):
                    tmp = nib.load(fn).get_fdata(dtype=np.float32)  
                    if fi == 0:
                        n_files = len(i_files)

                        shape = [tmp.shape[1],n_files+2*z_margin,tmp.shape[0]]  

                        img3D = np.zeros(shape,dtype=np.uint8)
                        print('cmp: {}, n_files: {}, shape: {}, img3D: {}, tmp.max: {}'.format(cmp, n_files, shape, img3D.shape, tmp.max()))


                    np.clip(tmp,0,1,tmp) 

                    # reverse order, rotate images to be in the same direction as template backlit
                    img3D[:,n_files-fi-1+z_margin,:] = np.rot90(tmp*255,k=3)


                if state != 10:  #"backlit" not in cmp:
                    assert(img3D[:,:z_margin,:].max()==0)           
                    assert(img3D[:,z_margin+n_files:,:].max()==0)
                    img3D[:,:z_margin,:] = img3D[:,z_margin,None,:]  
                    img3D[:,z_margin+n_files:,:] = img3D[:,z_margin+n_files-1,None,:]  
                
                if "seg" in cmp:
                    odir3 = workdir+mid+"/out_seg/"+str(cmp)+"/3D/"
                else:
                    odir3 = workdir+mid+"/out/"+str(cmp)+"/3D/"
                os.makedirs(odir3,exist_ok=True)
                
                print('odir3: {}'.format(odir3))

                mat = np.identity(4)
                scale = 2.24  
                mat[1,1] = 0.6 / scale 
                mat[0,0] = 0.1 / scale 
                mat[2,2] = 0.1 / scale  

                new_image = nib.Nifti1Image(img3D, affine=mat)
                new_image.header["cal_min"] = 0
                new_image.header["cal_max"] = 1
                new_image.header["scl_slope"] = 1/255.0
                new_image.header["scl_inter"] = 0
                nib.save(new_image,odir3+il.replace("_","")+".nii.gz") 

    clear_output(wait=True)


### 3. Align backlit images (specific to each marmoset) to the BMCA backlit template

In [None]:
## 3. align backlit with reference backlit brain, save as horizontal stack -- img_2_STPT.nii.gz

def align_bl_2_bmcabl(mid):
    print('------in 3 align_bl_2_bmcabl:')
    jobids = []
    cores = "80"
    odir3 = workdir+mid+"/out/01_backlit/3D/" 
    move = odir3+"/img.nii.gz"
    ref = workdir + "../ref/new_ref_bl.nii"
    OUT1 = odir3+"/img_2_STPT.nii.gz"
    POSTTRAFO = odir3+"/trafo_01_"  #
    logfolder = odir3+"/log_3D_bl2bmcabl/"
    os.makedirs(logfolder,exist_ok=True)
    os.makedirs(odir3,exist_ok=True)

    print('odir3: {}, move: {}, ref: {}, OUT1: {}, POSTTRAFO: {}, logfolder: {}'.format(odir3, move, ref, OUT1, POSTTRAFO, logfolder))

    SLURM_commands = ["MOVE="+move+" REF="+ref+" CORES="+cores+" OUT1="+OUT1+"  POSTTRAFO="+POSTTRAFO+"  bash regme_avg3D_01.sh;"]


    jobid, success = slurm_tools.slurm_submit(SLURM_commands,
        name = "reg3D_bl2bmcabl",
        output = logfolder+'/log_01.out',
        mem = '32GB',
        cores = cores,
        partition = "ish-adult",  #"bigmem"
        )      

    print("job number: "+jobid)

    if not success:
            print("could not submit jobs")
            print(format(jobid))
            print(format(jobids))
            slurm_tools.killall(jobids)
    jobids.append(int(jobid))
    # break

    #print("state "+str(state))
    slurm_tools.wait_for_jobs(jobids)  # char commented out


### 4. Make a backlit template that is larger in the XY plane and smaller in the Z plane

In [None]:
## 4. make backlit template that is larger in xy plane and smaller in z plane: new_ref_25mu.nii 
# only need to make this once

if True:
    ref_org = nib.load(ref)  # "ref/new_ref_bl.nii"
    ref_org_d = ref_org.get_fdata(dtype=np.float32)
    ref_org_s = ref_org.shape  # (256, 356, 230)
    dummy = np.zeros([ref_org_s[0]*4,ref_org_s[1]//8,ref_org_s[2]*4])  # 1024,44,920, (x,z,y)
    for a in range(ref_org_s[1]//8):  # 44
        #print(a)
        dummy[:,a,:] = resize(ref_org_d[:,a*8,:],[ref_org_s[0]*4,ref_org_s[2]*4])  # larger in xy planes and smaller in z plane

    dummy = 255*dummy / dummy.max()
    affine = ref_org.affine*1
    print(affine)
    affine[0,0]/=4.0
    affine[1,1]*=8.0
    affine[2,2]/=4.0
    print(affine)
    new_image = nib.Nifti1Image(dummy.astype(np.uint8), affine=affine)
    new_image.header["cal_min"] = 0
    new_image.header["cal_max"] = 1
    new_image.header["scl_slope"] = 1/255.0
    new_image.header["scl_inter"] = 0
    nib.save(new_image,workdir + "../ref/new_ref_25mu.nii")  # open in 3D Slicer

### 5. Align each gene by each RGB channel to BMCA-aligned backlit images

In [None]:
## 5. align each gene (move: sagittal stacks in 2) to the template backlit, save as niftis
# output _2_STPT.nii.gz has repeating anterior and posterior slices

def align_gene_2_bmcabl(raw_cmp,seg_cmp,mid):
    print('------in 4 align_gene_2_bmcabl:')
    #workdir_in = "/disk/charissa/ISH_reg_pipeline/data/"

    jobids = []
    
    ref = workdir + "../ref/new_ref_bl.nii"
    odir3 = workdir+mid+"/out/01_backlit/3D/" 
    POSTTRAFO = odir3+"/trafo_01_"
    
    for rcmp in raw_cmp:
        if "_gene" in rcmp:
            #SLURM_commands = []
            for r in ["img","imgR","imgG","imgB"]:
                SLURM_commands = []
                move = workdir+mid+"/out/"+rcmp+"/3D/"+r+".nii.gz"         # from sagittal stack (from 2)
                OUT1 = workdir+mid+"/out/"+rcmp+"/3D/"+r+"_2_STPT.nii.gz"
                logfolder = workdir+mid+"/out/"+rcmp+"/log_3D_gene2bmcabl/"
                os.makedirs(logfolder,exist_ok=True)
                cores = "2"
                SLURM_commands += ["MOVE="+move+" REF="+ref+" CORES="+cores+" OUT1="+OUT1+"  POSTTRAFO="+POSTTRAFO+"  bash apply_avg3D_01.sh;"]
                #print('move: {}, ref: {}, OUT1: {}, POSTTRAFO: {}, logfolder: {}'.format(move, ref, OUT1, POSTTRAFO, logfolder))

            jobid, success = slurm_tools.slurm_submit(SLURM_commands,
                name = "reg3D_rawgene2bmcabl",
                output = logfolder+'/log_01.out',
                mem = '8GB',
                cores = cores,
                partition = "ish-adult", #"bigmem",
                )      

            print("job number: "+jobid)

            if not success:
                    print("could not submit jobs")
                    print(format(jobid))
                    print(format(jobids))
                    slurm_tools.killall(jobids)
            jobids.append(int(jobid))

    slurm_tools.wait_for_jobs(jobids)
    clear_output(wait=True)

    for scmp in seg_cmp:
        if "_gene_seg" in scmp:
            #SLURM_commands = []
            for s in ["img"]:
                SLURM_commands = []
                move = workdir+mid+"/out_seg/"+scmp+"/3D/"+s+".nii.gz"
                OUT1 = workdir+mid+"/out_seg/"+scmp+"/3D/"+s+"_2_STPT.nii.gz"
                logfolder = workdir+mid+"/out_seg/"+scmp+"/log_3D_gene2bmcabl/"
                os.makedirs(logfolder,exist_ok=True)
                cores = "2"
                SLURM_commands += ["MOVE="+move+" REF="+ref+" CORES="+cores+" OUT1="+OUT1+"  POSTTRAFO="+POSTTRAFO+"  bash apply_avg3D_01.sh;"]
                #print('move: {}, ref: {}, OUT1: {}, POSTTRAFO: {}, logfolder: {}'.format(move, ref, OUT1, POSTTRAFO, logfolder))

            jobid, success = slurm_tools.slurm_submit(SLURM_commands,
                name = "reg3D_seggene2bmcabl",
                output = logfolder+'/log_01.out',
                mem = '8GB',
                cores = cores,
                partition = "ish-adult", #"bigmem",
                )      

            print("job number: "+jobid)

            if not success:
                    print("could not submit jobs")
                    print(format(jobid))
                    print(format(jobids))
                    slurm_tools.killall(jobids)
            jobids.append(int(jobid))

### 6. Save aligned backlit and gene niftis 

In [None]:
## 6. make backlit and gene stacks individually as R/G/B and also the full RGB
# aligned gene stacks are aligned to a lower res backlit template, because the original gene images are lower res
# to be specific, they are 8-bit, and there are fewer z-slices than the template
# so there is no reason to save them as higher bit

def save_bl_gene_niis(raw_cmp,seg_cmp,mid):
    print('------in 5 save_bl_gene_niis:')
    ref = nib.load('/disk/charissa/ISH_reg_pipeline/ref/new_ref_bl.nii').get_fdata(dtype=np.float32)
    z_mask = ref.max(axis=(0,2))  # (356,) consisting of 0. or 255.

    ref2 = nib.load('/disk/charissa/ISH_reg_pipeline/ref/new_ref_25mu.nii').get_fdata(dtype=np.float32)
    z_mask2 = ref2.max(axis=(0,2))  # (44,) consisting of 0. or 1.

    all_out_3D = workdir+mid+"/out/3D/low_res/"
    os.makedirs(all_out_3D,exist_ok=True)


    for rcmp in raw_cmp:
        #pass
        if "_backlit" in rcmp or "_gene" in rcmp and "seg" not in rcmp:
        #if "_gene" in cmp:    
            ilist = ["img"] if "_backlit" in rcmp else  ["imgR","imgG","imgB"]

            for il in ilist:
            #if False:

                OUT1 = workdir+mid+"/out/"+rcmp+"/3D/"+il+"_2_STPT.nii.gz"
                #print('saving il: {} as OUT1: {}'.format(il, OUT1))
                gimg_ = nib.load(OUT1)
                gimg = gimg_.get_fdata(dtype=np.float32)  # (256, 356, 230)
                #data = np.asanyarray(gimg.dataobj)
                gimg[:,z_mask==0,:] = 0  # wherever the mask is black, in z, make it also black
                if il == "imgR":
                    shape = gimg.shape
                    data8 = np.zeros(shape,dtype=np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')]))  # Data-type with fields R,G,B, each being a uint8:
                    #imgRGB = np.zeros([shape[0],shape[1]])
                if il in ["imgR","imgG","imgB"]:
                    data8[il.replace("img","")] = gimg*255


                new_image = nib.Nifti1Image((gimg*255).astype(np.uint8), affine=gimg_.affine)
                new_image.header["cal_min"] = 0
                new_image.header["cal_max"] = 1
                new_image.header["scl_slope"] = 1/255.0
                new_image.header["scl_inter"] = 0
                nib.save(new_image,OUT1.replace(".nii.gz","_8bit.nii.gz")) 
                nib.save(new_image,all_out_3D+"/"+rcmp+"_"+il+".nii.gz") 

            if "imgB" in ilist:
                new_image = nib.Nifti1Image(data8, affine=gimg_.affine)
                #new_image.header["cal_min"] = 0
                #new_image.header["cal_max"] = 1
                #new_image.header["scl_slope"] = 1/255.0
                #new_image.header["scl_inter"] = 0
                nib.save(new_image,all_out_3D+"/"+rcmp+"_RGB.nii.gz") 
                print('saving RGB as '+all_out_3D+"/"+rcmp+"_RGB.nii.gz")

    for scmp in seg_cmp:
        if "_gene_seg" in scmp:
        #if "_gene" in cmp:    
            ilist = ["img"]

            for il in ilist:
            #if False:

                OUT1 = workdir+mid+"/out_seg/"+scmp+"/3D/"+il+"_2_STPT.nii.gz"
                #print('saving il: {} as OUT1: {}'.format(il, OUT1))
                gimg_ = nib.load(OUT1)
                gimg = gimg_.get_fdata(dtype=np.float32)  # (256, 356, 230)
                #data = np.asanyarray(gimg.dataobj)
                gimg[:,z_mask==0,:] = 0  # wherever the mask is black, in z, make it also black


                new_image = nib.Nifti1Image((gimg*255).astype(np.uint8), affine=gimg_.affine)
                #new_image.header["cal_min"] = 0
                #new_image.header["cal_max"] = 1
                #new_image.header["scl_slope"] = 1/255.0
                #new_image.header["scl_inter"] = 0
                nib.save(new_image,OUT1.replace(".nii.gz","_8bit.nii.gz")) 
                nib.save(new_image,all_out_3D+"/"+scmp+"_"+il+".nii.gz") 
                print('saving seg as '+all_out_3D+"/"+scmp+"_"+il+".nii.gz")

# Calling all the functions and setting some global path variables #
* where to find what
* read in some basic info from meta (incorrect!!)

In [None]:
db = '/disk/charissa/shimogori_adult/'
workdir = '/disk/charissa/ISH_reg_pipeline/data/'

marm_ls = ['R08_0470']
#['R04_0239','R08_0450','R08_0457','R08_0470','R08_0478','R08_0585','R08_0679','R08_0688','R08_0730']


mu_bl = 3.12
mu_bf = 28.72

print("mu bl:",mu_bl)
print("mu bf:",mu_bf)


for d in range(len(marm_ls)):
    mid = marm_ls[d]
    print(mid)
    
    
    jdata,success = json_read(db+mid+'/meta/01_backlit_000.json')
    jdata,success = json_read(db+mid+'/meta/blockface_new/0108.json')
    
    raw_cmp = glob.glob(workdir+mid+"/out/*")
    raw_cmp = [r.split("/")[-1] for r in raw_cmp]
    seg_cmp = glob.glob(workdir+mid+"/out_seg/*")
    seg_cmp = [s.split("/")[-1] for s in seg_cmp]
    all_cmp = raw_cmp+seg_cmp    
    

    os.makedirs(workdir+mid+"/bf/",exist_ok=True)

    sdir = db+mid+"/img2d/"

    align_save_gene_rgbs(all_cmp, mid)  #1
    make_sagittal_stacks(all_cmp,mid)    #2
    align_bl_2_bmcabl(mid)               #3
    align_gene_2_bmcabl(raw_cmp,seg_cmp,mid)    #5
    save_bl_gene_niis(raw_cmp,seg_cmp,mid)      #6