In [None]:
import sys
from pathlib import Path
import numpy as np
from scipy.ndimage import center_of_mass
#import warnings
#warnings.filterwarnings("error")
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import numpy as np

PIPELINE_ROOT = Path('../src').resolve().parent.parent
sys.path.append(PIPELINE_ROOT.as_posix())
print(PIPELINE_ROOT)

from library.registration.brain_structure_manager import BrainStructureManager
from library.utilities.algorithm import umeyama

In [None]:
def brain_to_atlas_transform(brain_coord, r, t):
    brain_coord = np.array(brain_coord).reshape(3, 1) # Convert to a column vector
    atlas_coord = r @ brain_coord + t
    return atlas_coord.T[0] # Convert back to a row vector

def calculate_distance(com1, com2):
    return (np.linalg.norm(com1 - com2))

def plot_point_sets_3d(point_sets):
    df = pd.DataFrame()
    for data, label in point_sets:
        df_cur = pd.DataFrame(data.T, columns=['x', 'y', 'z'])
        df_cur['label'] = label
        #df = df.concat(df_cur, ignore_index=True)
        df = pd.concat([df, df_cur], axis= 0)
    
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='label')
    return fig

In [None]:
animal = 'Atlas'
brain = BrainStructureManager(animal)
brain.fixed_brain = BrainStructureManager('Allen')

In [None]:
moving_coms = brain.get_coms(annotator_id=1)
#del moving_coms['Sp5C_L']
#del moving_coms['RtTg']
fixed_coms = brain.fixed_brain.get_coms(annotator_id=1)
common_keys = fixed_coms.keys() & moving_coms.keys()
brain_regions = sorted(moving_coms.keys())
common_keys = ['SC','IC','PBG_L','PBG_R','3N_L','3N_R','4N_L','4N_R','SNR_L','SNR_R','VLL_L','VLL_R']
fixed_points = np.array([fixed_coms[s] for s in brain_regions if s in common_keys])
moving_points = np.array([moving_coms[s] for s in brain_regions if s in common_keys])

In [None]:
fixed_point_dict = {s:fixed_coms[s] for s in brain_regions if s in common_keys}
moving_point_dict = {s:moving_coms[s] for s in brain_regions if s in common_keys}

In [None]:
fixed_point_dict

In [None]:
distances = []
for structure in common_keys:
    (x,y,z) = fixed_point_dict[structure]
    fixed_point = np.array([x,y,z])    
    moving_point = np.array(moving_point_dict[structure])
    d = calculate_distance(fixed_point, moving_point)
    distances.append(d)
    moving_point = np.round(moving_point/25)
    print(f'{structure} COM={moving_point} distance={round(d,2)}')

In [None]:
"""
SC COM=[357.  93. 229.] distance=309.4
IC COM=[407.  78. 224.] distance=290.55
PBG_L COM=[383. 136. 143.] distance=464.15
PBG_R COM=[382. 141. 311.] distance=357.12
3N_L COM=[368. 152. 223.] distance=100.88
3N_R COM=[368. 154. 234.] distance=106.66
4N_L COM=[386. 149. 215.] distance=95.34
4N_R COM=[386. 151. 237.] distance=64.32
SNR_L COM=[357. 193. 165.] distance=597.31
SNR_R COM=[355. 193. 294.] distance=559.41
VLL_L COM=[392. 172. 148.] distance=944.63
VLL_R COM=[391. 181. 303.] distance=721.04

SC COM=[356.  92. 229.] distance=327.88
IC COM=[406.  77. 225.] distance=316.61
PBG_L COM=[382. 135. 143.] distance=479.08
PBG_R COM=[381. 141. 312.] distance=356.87
3N_L COM=[366. 152. 223.] distance=82.6
3N_R COM=[367. 153. 234.] distance=74.17
4N_L COM=[385. 148. 216.] distance=83.52
4N_R COM=[385. 150. 238.] distance=42.41
SNR_L COM=[356. 193. 166.] distance=589.6
SNR_R COM=[354. 192. 294.] distance=546.13
VLL_L COM=[390. 172. 149.] distance=934.07
VLL_R COM=[390. 180. 303.] distance=730.59

SC COM=[364.  90. 230.] distance=153.14
IC COM=[416.  86. 226.] distance=57.8
PBG_L COM=[378. 144. 141.] distance=241.66
PBG_R COM=[376. 147. 316.] distance=192.25
3N_L COM=[362. 153. 224.] distance=105.12
3N_R COM=[363. 154. 235.] distance=60.78
4N_L COM=[382. 153. 216.] distance=76.64
4N_R COM=[380. 154. 239.] distance=117.63
SNR_L COM=[341. 195. 165.] distance=322.87
SNR_R COM=[338. 195. 296.] distance=307.14
VLL_L COM=[382. 179. 148.] distance=722.29
VLL_R COM=[377. 187. 307.] distance=516.72
"""

In [None]:
print(fixed_points.shape)
print('Mean')
print(np.mean(fixed_points, axis=0))
print(np.mean(moving_points, axis=0))
print('Min')
print(np.min(fixed_points, axis=0))
print(np.min(moving_points, axis=0))
print('Max')
print(np.max(fixed_points, axis=0))
print(np.max(moving_points, axis=0))
#fixed_points - moving_points

In [None]:
plot_point_sets_3d([
    (moving_points.T, 'unaligned moving centers'),
    (fixed_points.T, 'Allen centers')
])

In [None]:
r, t = umeyama(moving_points.T, fixed_points.T)

In [None]:
t

In [None]:
reg_points = r @ moving_points.T + t

In [None]:
print('Mean')
print(np.mean(fixed_points, axis=0))
print(np.mean(reg_points, axis=1))
print('Min')
print(np.min(fixed_points, axis=0))
print(np.min(reg_points, axis=1))
print('Max')
print(np.max(fixed_points, axis=0))
print(np.max(reg_points, axis=1))

In [None]:
reg_point_dict = {s:moving_coms[s] for s in brain_regions if s in common_keys}

In [None]:
distances = []
for structure in common_keys:
    (x,y,z) = fixed_point_dict[structure]
    fixed_point = np.array([x,y,z])    
    moving_point = np.array(moving_point_dict[structure])
    reg_point = brain_to_atlas_transform(moving_point, r, t)
    d = calculate_distance(fixed_point, reg_point)
    distances.append(d)
    reg_point = np.round(reg_point / 25)
    print(f'{structure} COM={reg_point} distance={round(d,2)}')

In [None]:
len(distances), round(np.mean(distances)), round(min(distances)), round(max(distances)), round(np.sum(distances))

In [None]:
# MD594 (33, 436, 111, 2838, 14391)
# MD589 (33, 429, 107, 2922, 14146)
# MD589 (31, 276, 53, 818, 8554) # with no RtTG and SPc
# MD589 (33, 276, 77, 529, 9098)

In [None]:
plot_point_sets_3d([
    (reg_points, 'registered moving centers'),
    (fixed_points.T, 'Allen centers')    
])