In [None]:
import sys
from pathlib import Path
import numpy as np
from scipy.ndimage import center_of_mass
#import warnings
#warnings.filterwarnings("error")
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import numpy as np

PIPELINE_ROOT = Path('../src').resolve().parent.parent
sys.path.append(PIPELINE_ROOT.as_posix())
print(PIPELINE_ROOT)

from library.registration.brain_structure_manager import BrainStructureManager
from library.registration.algorithm import umeyama

In [None]:
def brain_to_atlas_transform(brain_coord, r, t):
    brain_coord = np.array(brain_coord).reshape(3, 1) # Convert to a column vector
    atlas_coord = r @ brain_coord + t
    return atlas_coord.T[0] # Convert back to a row vector

def calculate_distance(com1, com2):
    return (np.linalg.norm(com1 - com2))

def plot_point_sets_3d(point_sets):
    df = pd.DataFrame()
    for data, label in point_sets:
        df_cur = pd.DataFrame(data.T, columns=['x', 'y', 'z'])
        df_cur['label'] = label
        #df = df.concat(df_cur, ignore_index=True)
        df = pd.concat([df, df_cur], axis= 0)
    
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='label')
    return fig

In [None]:
animal = 'Atlas'
brain = BrainStructureManager(animal)
brain.fixed_brain = BrainStructureManager('Allen')

In [None]:
atlas_coms = brain.get_coms(annotator_id=1)
allen_coms = brain.fixed_brain.get_coms(annotator_id=1)
common_keys = allen_coms.keys() & atlas_coms.keys()
brain_regions = sorted(atlas_coms.keys())
allen_points = np.array([allen_coms[s] for s in brain_regions if s in brain.midbrain_keys])
atlas_points = np.array([atlas_coms[s] for s in brain_regions if s in brain.midbrain_keys])

In [None]:
allen_point_dict = {s:allen_coms[s] for s in brain_regions if s in common_keys}
atlas_point_dict = {s:atlas_coms[s] for s in brain_regions if s in common_keys}

In [None]:
atlas_coms['SC']

In [None]:
allen_coms['SC']

In [None]:
distances = []
for structure in brain.midbrain_keys:
    (x,y,z) = allen_point_dict[structure]
    allen_point = np.array([x,y,z])    
    atlas_point = np.array(atlas_point_dict[structure])
    #print(atlas_point, allen_point)
    d = calculate_distance(allen_point, atlas_point)
    distances.append(d)
    atlas_point = np.round(atlas_point/25)
    print(f'{structure} Cdistance from Allen={round(d,2)} micrometers')
print()
print(f'n={len(distances)}, min={min(distances)} max={max(distances)}, mean={np.mean(distances)}')

In [None]:
plot_point_sets_3d([
    (atlas_points.T, 'unaligned moving centers'),
    (allen_points.T, 'Allen centers')
])

In [None]:
r, t = umeyama(atlas_points.T, allen_points.T)

In [None]:
t

In [None]:
reg_points = r @ atlas_points.T + t

In [None]:
print('Mean')
print(np.mean(allen_points, axis=0))
print(np.mean(reg_points, axis=1))
print('Min')
print(np.min(allen_points, axis=0))
print(np.min(reg_points, axis=1))
print('Max')
print(np.max(allen_points, axis=0))
print(np.max(reg_points, axis=1))

In [None]:
reg_point_dict = {s:atlas_coms[s] for s in brain_regions if s in brain.midbrain_keys}

In [None]:
distances = []
for structure in common_keys:
    (x,y,z) = allen_point_dict[structure]
    allen_point = np.array([x,y,z])    
    atlas_point = np.array(atlas_point_dict[structure])
    reg_point = brain_to_atlas_transform(atlas_point, r, t)
    d = calculate_distance(allen_point, reg_point)
    distances.append(d)
    reg_point = np.round(reg_point / 25)
    print(f'{structure} COM={reg_point} distance={round(d,2)}')
print()
len(distances), round(np.mean(distances)), round(min(distances)), round(max(distances))

In [None]:
"""
add trans in loop from MD -> MD589
Pn_R COM=[385. 239. 255.] distance=628.58
SNR_L COM=[364. 192. 172.] distance=807.9
PBG_R COM=[379. 147. 310.] distance=208.75
Pn_L COM=[382. 237. 207.] distance=745.91
4N_R COM=[387. 151. 238.] distance=95.07
3N_L COM=[366. 153. 225.] distance=125.3
PBG_L COM=[381. 147. 149.] distance=278.35
3N_R COM=[375. 151. 239.] distance=305.8
SNR_R COM=[350. 198. 289.] distance=403.99
SC COM=[365.  94. 230.] distance=790.67
IC COM=[415.  80. 226.] distance=134.63
4N_L COM=[346. 129. 189.] distance=1299.23

(12, 485, 95, 1299)

add trans in loop from MD -> allen
SNR_L COM=[361. 194. 164.] distance=653.93
3N_R COM=[369. 152. 234.] distance=116.48
4N_R COM=[379. 149. 237.] distance=132.96
Pn_R COM=[385. 243. 253.] distance=724.89
IC COM=[409.  77. 224.] distance=284.73
SNR_R COM=[352. 194. 295.] distance=493.45
SC COM=[365.  90. 230.] distance=689.5
3N_L COM=[362. 151. 224.] distance=101.12
PBG_R COM=[376. 141. 313.] distance=309.93
Pn_L COM=[382. 241. 203.] distance=736.1
PBG_L COM=[376. 138. 141.] distance=392.79
4N_L COM=[379. 147. 215.] distance=164.75

(12, 400, 101, 736)

only transformation is on this page
4N_R COM=[350. 132. 213.] distance=1151.52
Pn_R COM=[384. 237. 257.] distance=565.89
3N_L COM=[369. 155. 229.] distance=250.5
SC COM=[370. 101. 234.] distance=988.87
PBG_R COM=[384. 151. 310.] distance=222.8
Pn_L COM=[382. 236. 209.] distance=773.35
PBG_L COM=[384. 146. 155.] distance=418.68
IC COM=[419.  86. 227.] distance=94.65
4N_L COM=[351. 130. 193.] distance=1140.79
SNR_R COM=[356. 199. 291.] distance=512.06
SNR_L COM=[370. 192. 176.] distance=955.02
3N_R COM=[376. 154. 239.] distance=326.34

(12, 617, 95, 1152)
"""

In [None]:
plot_point_sets_3d([
    (reg_points, 'registered moving centers'),
    (allen_points.T, 'Allen centers')    
])