# Alignment Error Visualization

This notebook collects COM data from the database and tries to quantify some alignment errors. The main results are shown in the plots at the end of the notebook.

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
from collections import OrderedDict
from IPython.display import HTML
from itertools import combinations

PIPELINE_ROOT = Path('./').absolute().parents[1]
PIPELINE_ROOT = PIPELINE_ROOT.as_posix()
sys.path.append(PIPELINE_ROOT)
print(PIPELINE_ROOT)


/home/eddyod/programming/pipeline/src


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from library.controller.sql_controller import SqlController
from library.image_manipulation.filelocation_manager import FileLocationManager
from library.atlas.atlas_utilities import apply_affine_transform, get_affine_transformation, list_coms, \
    compute_affine_transformation, compute_affine_transformation_centroid, get_umeyama
from library.atlas.brain_structure_manager import BrainStructureManager
from library.utilities.utilities_process import M_UM_SCALE, SCALING_FACTOR, random_string, \
read_image, write_image


XGBoost Version: 2.1.4


In [4]:
def absolute_sum(l):
    la = np.array(l)
    nabs = (np.array(la*la))
    return np.sum(la, axis=0)

def sum_square_com(com):
    ss = np.sqrt(sum([s**2 for s in com]))
    return ss

def convert_com(com):
    scales = np.array([0.452*32, 0.452*32, 20])
    return com * scales


def fetch_coms(animal):
    coms = {}
    dirpath = f'/net/birdstore/Active_Atlas_Data/data_root/atlas_data/{animal}/com'
    files = sorted(os.listdir(dirpath))
    for file in files:
        structure = Path(file).stem
        filepath = os.path.join(dirpath, file)
        com = np.loadtxt(filepath)
        coms[structure] = com
    return coms

In [5]:
moving_name = 'AtlasV8'
fixed_name = 'Allen'
moving_all = fetch_coms(moving_name)
fixed_all = list_coms(fixed_name)
common_keys = list(moving_all.keys() & fixed_all.keys())
moving_common = np.array([moving_all[s] for s in common_keys])
fixed_common = np.array([fixed_all[s] for s in common_keys])
print(f'{moving_name} len={len(moving_all.keys())}')
print(f'{fixed_name} len={len(fixed_all.keys())}')

print(len(common_keys))

AtlasV8 len=53
Allen len=37
37


In [6]:
# these keys have very high errors: 'RtTg', 'AP'
midbrain_keys = {
            "3N_L",
            "3N_R",
            "4N_L",
            "4N_R",
            "IC",
            "PBG_L",
            "PBG_R",
            "SC",
            "SNC_L",
            "SNC_R",
            "SNR_L",
            "SNR_R",
        }
bad_keys = ('RtTg', 'AP', '3N_L', '3N_R')
bad_keys = ('RtTg', 'AP')
#bad_keys = ('RtTg',)
#bad_keys = ()
good_keys = set(common_keys) - set(bad_keys)
print(f'#good_keys={len(good_keys)}')

#good_keys=35


In [7]:
moving_src = np.array([moving_all[s] for s in good_keys])
fixed_src = np.array([fixed_all[s] for s in good_keys])
transformation_matrix = compute_affine_transformation(moving_src, fixed_src)
#A, t, transformation_matrix = compute_affine_transformation_centroid(moving_src, fixed_src)
#transformation_matrix = get_umeyama(moving_src, fixed_src, scaling=True)
print(np.round(transformation_matrix, 2))

[[ 9.70000e-01 -4.00000e-02  1.00000e-02 -1.33169e+03]
 [ 1.30000e-01  1.23000e+00  7.00000e-02 -4.36247e+03]
 [ 1.00000e-02 -2.00000e-02  1.15000e+00  6.45840e+02]
 [ 0.00000e+00  0.00000e+00  0.00000e+00  1.00000e+00]]


In [8]:
df_list = []
error = []
transformed_dict = {}
for structure in common_keys:
    moving0 = np.array(moving_all[structure])
    fixed0 = np.array(fixed_all[structure]) 
    transformed = apply_affine_transform(moving0, transformation_matrix)
    transformed = [x for x in transformed]
    difference = [a - b for a, b in zip(transformed, fixed0)]
    ss = sum_square_com(difference)
    row = [structure, np.round(moving0), np.round(fixed0), 
           np.round(transformed), np.round(difference), ss]
    df_list.append(row)
    error.append(ss)
    transformed_dict[structure] = transformed
print('RMS', sum(error)/len(df_list))
# MD589 to Allen RMS 260.0211852431133
# MD585 to Allen RMS 263.314352291951
# MD594 to Allen RMS 250.79820210419254

RMS 237.06805950085774


In [9]:
structure = 'SC'
print(f'{moving_name} {structure} {np.round(np.array(moving_all[structure]))}')
print(f'{fixed_name} {structure} {np.round(np.array(fixed_all[structure]))}')
print(f'{moving_name} transformed {structure} {np.round(np.array(transformed_dict[structure]))}')

AtlasV8 SC [10949.  3924.  4399.]
Allen SC [9140. 2388. 5692.]
AtlasV8 transformed SC [9209. 2235. 5682.]


In [10]:
columns = ['structure', moving_name, fixed_name, 'transformed', 'difference', 'sumsquares']
df = pd.DataFrame(df_list, columns=columns)
df.index.name = 'Index'
df = df.round(2)
df.sort_values(by=['sumsquares'], inplace=True)
HTML(df.to_html(index=False))

structure,AtlasV8,Allen,transformed,difference,sumsquares
LRt_L,"[14380.0, 7519.0, 3300.0]","[12344.0, 6991.0, 4393.0]","[12380.0, 7042.0, 4362.0]","[35.0, 52.0, -31.0]",69.81
5N_R,"[11981.0, 6184.0, 5809.0]","[10193.0, 5286.0, 7293.0]","[10136.0, 5262.0, 7261.0]","[-57.0, -24.0, -32.0]",69.86
SNC_R,"[10093.0, 6235.0, 5686.0]","[8356.0, 5123.0, 7092.0]","[8298.0, 5067.0, 7108.0]","[-58.0, -55.0, 16.0]",81.78
PBG_R,"[11244.0, 5053.0, 6299.0]","[9401.0, 3846.0, 7833.0]","[9473.0, 3804.0, 7843.0]","[73.0, -42.0, 10.0]",84.5
5N_L,"[12015.0, 6306.0, 3088.0]","[10193.0, 5287.0, 4092.0]","[10128.0, 5219.0, 4130.0]","[-65.0, -67.0, 39.0]",101.26
IC,"[12075.0, 3840.0, 4365.0]","[10400.0, 2325.0, 5675.0]","[10306.0, 2277.0, 5651.0]","[-94.0, -48.0, -24.0]",108.21
LRt_R,"[14325.0, 7269.0, 5549.0]","[12344.0, 6991.0, 6993.0]","[12366.0, 6889.0, 6952.0]","[22.0, -102.0, -40.0]",111.77
VLL_L,"[11308.0, 6446.0, 2936.0]","[9464.0, 5176.0, 3890.0]","[9433.0, 5288.0, 3949.0]","[-31.0, 113.0, 59.0]",131.02
6N_L,"[12585.0, 6276.0, 4057.0]","[10771.0, 5215.0, 5291.0]","[10697.0, 5327.0, 5248.0]","[-75.0, 113.0, -43.0]",141.56
3N_L,"[11014.0, 5264.0, 4317.0]","[9102.0, 3794.0, 5521.0]","[9215.0, 3891.0, 5560.0]","[113.0, 96.0, 39.0]",153.4


In [11]:
l = [-2,3,4]
print(l)
l2 = [j**2 for j in l]
print(l2)
l3 = sum(l2)
print(l3)
l4 = np.sqrt(l3)
print(l4)

[-2, 3, 4]
[4, 9, 16]
29
5.385164807134504
