In [None]:
import sys
from pathlib import Path
import numpy as np
from scipy.ndimage import center_of_mass
import warnings
warnings.filterwarnings("error")
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import numpy as np

PIPELINE_ROOT = Path('../src').resolve().parent.parent
sys.path.append(PIPELINE_ROOT.as_posix())
print(PIPELINE_ROOT)

from atlas.scripts.brain_structure_manager import BrainStructureManager
from library.utilities.algorithm import umeyama

In [None]:
animal = 'DK55'
brain = BrainStructureManager(animal)
brain.fixed_brain = BrainStructureManager('Allen')

In [None]:
moving_coms = brain.get_coms()
#del moving_coms['Sp5C_L']
#del moving_coms['RtTg']
fixed_coms = brain.fixed_brain.get_coms(annotator_id=1)
common_keys = fixed_coms.keys() & moving_coms.keys()
brain_regions = sorted(moving_coms.keys())

fixed_points = np.array([fixed_coms[s] for s in brain_regions if s in common_keys])
moving_points = np.array([moving_coms[s] for s in brain_regions if s in common_keys])

In [None]:
fixed_point_dict = {s:fixed_coms[s] for s in brain_regions if s in common_keys}
moving_point_dict = {s:moving_coms[s] for s in brain_regions if s in common_keys}

In [None]:
print(fixed_points.shape)
print('Mean')
print(np.mean(fixed_points, axis=0))
print(np.mean(moving_points, axis=0))
print('Min')
print(np.min(fixed_points, axis=0))
print(np.min(moving_points, axis=0))
print('Max')
print(np.max(fixed_points, axis=0))
print(np.max(moving_points, axis=0))
#fixed_points - moving_points

In [None]:
def brain_to_atlas_transform(brain_coord, r, t):
    brain_coord = np.array(brain_coord).reshape(3, 1) # Convert to a column vector
    atlas_coord = r @ brain_coord + t
    return atlas_coord.T[0] # Convert back to a row vector

def calculate_distance(com1, com2):
    return (np.linalg.norm(com1 - com2))

def plot_point_sets_3d(point_sets):
    df = pd.DataFrame()
    for data, label in point_sets:
        df_cur = pd.DataFrame(data.T, columns=['x', 'y', 'z'])
        df_cur['label'] = label
        #df = df.concat(df_cur, ignore_index=True)
        df = pd.concat([df, df_cur], axis= 0)
    
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='label')
    return fig

In [None]:
plot_point_sets_3d([
    (moving_points.T, 'unaligned moving centers'),
    (fixed_points.T, 'Allen centers')
])

In [None]:
r, t = umeyama(moving_points.T, fixed_points.T)

In [None]:
t

In [None]:
reg_points = r @ moving_points.T + t

In [None]:
print('Mean')
print(np.mean(fixed_points, axis=0))
print(np.mean(reg_points, axis=1))
print('Min')
print(np.min(fixed_points, axis=0))
print(np.min(reg_points, axis=1))
print('Max')
print(np.max(fixed_points, axis=0))
print(np.max(reg_points, axis=1))

In [None]:
distances = []
for structure in common_keys:
    (x,y,z) = fixed_point_dict[structure]
    fixed_point = np.array([x,y,z])    
    moving_point = np.array(moving_point_dict[structure])
    reg_point = brain_to_atlas_transform(moving_point, r, t)
    d = calculate_distance(fixed_point, reg_point)
    distances.append(d)
    print(f'{structure} distance={round(d,2)}')

In [None]:
len(distances), round(np.mean(distances)), round(min(distances)), round(max(distances)), round(np.sum(distances))

In [None]:
# MD594 (33, 436, 111, 2838, 14391)
# MD589 (33, 429, 107, 2922, 14146)
# MD589 (31, 276, 53, 818, 8554) # with no RtTG and SPc
# MD589 (33, 276, 77, 529, 9098)

In [None]:
plot_point_sets_3d([
    (reg_points, 'registered moving centers'),
    (fixed_points.T, 'Allen centers')    
])