In [1]:
import os,sys
from collections import defaultdict
import json
import numpy as np
import pickle
from skimage import io
import matplotlib
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
from scipy.ndimage.morphology import distance_transform_edt
%load_ext autoreload
%autoreload 2

In [7]:
PATH = '/home/eddyod/programming/pipeline_utility/src'

sys.path.append(PATH)
from lib.file_location import DATA_PATH, ROOT_DIR

from lib.sqlcontroller import SqlController
from lib.utilities_atlas import load_original_volume_all_known_structures_v3, get_centroid_3d, \
    load_alignment_results_v3, transform_points, average_location, \
    convert_to_original_name, name_unsided_to_color, paired_structures, \
    convert_to_left_name, convert_to_right_name, load_original_volume_v2, save_alignment_results_v3, \
    convert_transform_forms, transform_volume_v4, volume_to_polydata, singular_structures, \
    MESH_DIR, average_shape, convert_to_surround_name, mirror_volume_v2, save_original_volume, \
    save_mesh_stl, get_surround_volume_v2, transform_volume_v4, high_contrast_colors, \
    plot_centroid_means_and_covars_3d, all_known_structures_sided, load_data, \
    get_instance_mesh_filepath, images_to_volume_v2, find_contour_points, load_cropbox_v2, \
    load_mean_shape, \
    display_volume_sections, get_structure_mean_positions_filepath
from lib.atlas_aligner import Aligner
from lib.utilities_alignment import convert_resolution_string_to_um
atlas_name = 'atlasV8'
ATLAS_PATH = os.path.join(DATA_PATH, 'atlas_data', atlas_name)

In [8]:
fixed_brain_name = 'MD589'
sqlController = SqlController(fixed_brain_name)
structures = sqlController.get_structures_list()
structures.remove("R")
moving_brain_names = ['MD585', 'MD594']
resolution = '10.0um'
resolution_um = 10.0
structure_centroids_all_brains_um_wrt_fixed = []
fixed_brain_spec = {'name': fixed_brain_name, 'vol_type': 'annotationAsScore', 'resolution': resolution}

In [9]:
fixed_brain = load_original_volume_all_known_structures_v3(stack_spec=fixed_brain_spec, structures=structures, 
                                                           in_bbox_wrt='wholebrain')
fixed_brain_structure_centroids = get_centroid_3d(fixed_brain)
fixed_brain_structure_centroids_um = {s: c * resolution_um for s, c in fixed_brain_structure_centroids.items()}
structure_centroids_all_brains_um_wrt_fixed.append(fixed_brain_structure_centroids_um)

10N_L /net/birdstore/Active_Atlas_Data/data_root/CSHL_volumes/MD589/10.0um_annotationAsScoreVolume/10N_L.npy
10N_L /net/birdstore/Active_Atlas_Data/data_root/CSHL_volumes/MD589/10.0um_annotationAsScoreVolume/10N_L_origin_wrt_wholebrain.txt


TypeError: get_domain_origin() got an unexpected keyword argument 'stack'

## Compute instance centroids

In [None]:
for brain_m in moving_brain_names:
    moving_brain_spec = {'name': brain_m, 'vol_type': 'annotationAsScore', 'resolution': resolution}
    print('Brain', moving_brain_spec)
    moving_brain = load_original_volume_all_known_structures_v3(stack_spec=moving_brain_spec, 
                                                                structures=structures, in_bbox_wrt='wholebrain')
    alignment_spec = dict(stack_m=moving_brain_spec, stack_f=fixed_brain_spec, warp_setting=109)
    moving_brain_structure_centroids_input_resol = get_centroid_3d(moving_brain)
    # Load registration.
    # Alignment results fp: os.path.join(reg_root_dir, alignment_spec['stack_m']['name'], warp_basename, warp_basename + '_' + what + '.' + ext)
    transform_parameters_moving_brain_to_fixed_brain = load_alignment_results_v3(alignment_spec=alignment_spec, what='parameters')
    # Transform moving brains into alignment with the fixed brain.
    transformed_moving_brain_structure_centroids_input_resol_wrt_fixed = \
    dict(list(zip(list(moving_brain_structure_centroids_input_resol.keys()),
                  transform_points(pts=list(moving_brain_structure_centroids_input_resol.values()),
                                   transform=transform_parameters_moving_brain_to_fixed_brain))))

    transformed_moving_brain_structure_centroids_um_wrt_fixed = \
        {s: c * resolution_um for s, c in
        list(transformed_moving_brain_structure_centroids_input_resol_wrt_fixed.items())}

    structure_centroids_all_brains_um_wrt_fixed.append(transformed_moving_brain_structure_centroids_um_wrt_fixed)


In [None]:
structure_centroids_all_brains_um_grouped_by_structure_wrt_fixed = defaultdict(list)
for sc in structure_centroids_all_brains_um_wrt_fixed:
    for k, c in sc.items():
        structure_centroids_all_brains_um_grouped_by_structure_wrt_fixed[k].append(c)
structure_centroids_all_brains_um_grouped_by_structure_wrt_fixed.default_factory = None

In [None]:
nominal_centroids, \
instance_centroids_wrt_canonicalAtlasSpace_um, \
canonical_center_wrt_fixed_um, \
canonical_normal, \
transform_matrix_to_canonicalAtlasSpace_um = \
average_location(structure_centroids_all_brains_um_grouped_by_structure_wrt_fixed)

In [None]:
filepath = os.path.join(ATLAS_PATH, '1um_meanPositions.pkl')
with open(filepath, 'wb') as f:
    pickle.dump(nominal_centroids, f)

In [None]:
# Note that all shapes have voxel resolution matching input resolution (10.0 micron).
for structure in tqdm(structures):
    # for structure in all_known_structures:
    # Load instance volumes.
    instance_volumes = []
    instance_source = []

    for brain_name in [fixed_brain_name] + moving_brain_names:
        brain_spec = {'name': brain_name, 'vol_type': 'annotationAsScore', 'resolution': resolution}
       
        if '_L' in structure:
            left_instance_vol, _ = load_original_volume_v2(stack_spec=brain_spec,
                                                           structure=structure,
                                                           return_origin_instead_of_bbox=True,
                                                           crop_to_minimal=True)
            instance_volumes.append(left_instance_vol[..., ::-1])  # if left, mirror
            instance_source.append((brain_name, 'L'))
        
        else:
            right_instance_vol, _ = load_original_volume_v2(stack_spec=brain_spec,
                                                            structure=structure,
                                                            return_origin_instead_of_bbox=True,
                                                            crop_to_minimal=True)
            instance_volumes.append(right_instance_vol)  # if right, do not mirror
            instance_source.append((brain_name, 'R'))


   # Use the first instance as registration target.
    # Register every other instance to the first instance.
    template_instance_volume = instance_volumes[0]
    template_instance_centroid_wrt_templateOrigin = get_centroid_3d(template_instance_volume).astype(np.int16)
    template_instance_wrt_templateCentroid = (template_instance_volume, - template_instance_centroid_wrt_templateOrigin)
    aligned_moving_instance_wrt_templateCentroid_all_instances = []

    for i in range(1, len(instance_volumes)):
        # Compute transform.
        moving_instance_volume = instance_volumes[i]
        aligner = Aligner({0: template_instance_wrt_templateCentroid},
                          {0: (moving_instance_volume, np.array((0,0,0)))},
                          labelIndexMap_m2f={0:0},
                         verbose=False)
        aligner.set_centroid(centroid_m='structure_centroid', centroid_f='structure_centroid')
        aligner.compute_gradient(smooth_first=True)
        lr = 1.
        ### max_iter_num was originally 100 and 1000
        _, _ = aligner.optimize(tf_type='rigid',
                                history_len=100,
                                max_iter_num=2 if structure in ['SC', 'IC'] else 3,
                                grad_computation_sample_number=None,
                                full_lr=np.array([lr, lr, lr, 0.1, 0.1, 0.1]),
                                terminate_thresh_trans=.01)


        # Transform instances.
        T = convert_transform_forms(aligner=aligner, out_form=(3, 4), select_best='max_value')
        aligned_moving_instance_volume, aligned_moving_instance_origin_wrt_templateCentroid = \
            transform_volume_v4(volume=(moving_instance_volume, (0, 0, 0)), transform=T,
                                return_origin_instead_of_bbox=True)
        aligned_moving_instance_wrt_templateCentroid = (
        aligned_moving_instance_volume, aligned_moving_instance_origin_wrt_templateCentroid)
        aligned_moving_instance_wrt_templateCentroid_all_instances.append(aligned_moving_instance_wrt_templateCentroid)

    # Generate meshes for each instance.
    volume_origin_list = [template_instance_wrt_templateCentroid] + aligned_moving_instance_wrt_templateCentroid_all_instances
    instance_mesh_wrt_templateCentroid_all_instances = [volume_to_polydata(volume, num_simplify_iter=3, smooth=True)
        for volume, o in volume_origin_list]


    if structure == 'IC' or structure == 'SC':
        # IC and SC boundaries are particularly jagged, so do a larger value smoothing.
        sigma = 5.
    else:
        sigma = 2.


    mean_shape_wrt_templateCentroid = \
        average_shape(volume_origin_list=volume_origin_list, force_symmetric=(structure in singular_structures),
                      sigma=sigma,
                      )


    wall_level = .5
    surround_distance_um = 200.

    # Save mean shape.
    filename = f'{structure}.npy'
    filepath =  os.path.join(ATLAS_PATH, 'structure', filename)
    np.save(filepath, np.ascontiguousarray(mean_shape_wrt_templateCentroid[0]))

    filename = f'{structure}.txt'
    filepath = os.path.join(ATLAS_PATH, 'origin', filename)
    np.savetxt(filepath, mean_shape_wrt_templateCentroid[1])


# Combine standard shapes with standard centroid locations

In [None]:
atlas_resolution = '10.0um'
atlas_resolution_um = 10.0
filepath = os.path.join(ATLAS_PATH, '1um_meanPositions.pkl')
nominal_centroids = pickle.load( open(filepath, "rb" ) )
nominal_centroids_10um = {s: c / atlas_resolution_um for s, c in nominal_centroids.items()}
mean_shapes = {name_u: load_mean_shape(atlas_name=atlas_name, structure=name_u, resolution=atlas_resolution) 
                    for name_u in structures}

In [None]:
for structure in tqdm(structures):

    if '_L' in structure:
        mean_shape = mirror_volume_v2(volume=mean_shapes[structure], 
                                      centroid_wrt_origin=-mean_shapes[structure][1],
                                      new_centroid=nominal_centroids_10um[structure])
    else:
        mean_shape = (mean_shapes[structure][0], 
                        mean_shapes[structure][1] + nominal_centroids_10um[structure])
        
    volume = mean_shape[0]
    origin = mean_shape[1]
    
    # save origin, this is also the important one
    filename = f'{structure}.txt'
    filepath = os.path.join(ATLAS_PATH, 'origin', filename)
    np.savetxt(filepath, origin)

    # Save volume with stated level. This is the important one
    filename = f'{structure}.npy'
    filepath = os.path.join(ATLAS_PATH, 'structure', filename)
    np.save(filepath, volume)
    
    # mesh
    aligned_volume = (mean_shape[0] >= 0.9, mean_shape[1])
    aligned_structure = volume_to_polydata(volume=aligned_volume,
                           num_simplify_iter=3, smooth=False,
                           return_vertex_face_list=False)
    filepath = os.path.join(ATLAS_PATH, 'mesh', f'{structure}.stl')
    save_mesh_stl(aligned_structure, filepath)
