# Alignment Error Visualization

This notebook collects COM data from the database and tries to quantify some alignment errors. The main results are shown in the plots at the end of the notebook.

In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import SimpleITK as sitk
import matplotlib.pyplot as plt

PIPELINE_ROOT = Path('./').absolute().parents[1]
PIPELINE_ROOT = PIPELINE_ROOT.as_posix()
sys.path.append(PIPELINE_ROOT)
print(PIPELINE_ROOT)

In [None]:
from library.controller.sql_controller import SqlController
from library.image_manipulation.filelocation_manager import FileLocationManager
from library.atlas.atlas_utilities import affine_transform_point, get_affine_transformation, \
    fetch_coms, list_coms, compute_affine_transformation, affine_transform_volume
from library.atlas.brain_structure_manager import BrainStructureManager
from library.utilities.utilities_process import M_UM_SCALE, SCALING_FACTOR, random_string, \
    read_image, write_image


In [None]:
def sum_square_com(com):
    ss = np.sqrt(sum([s**2 for s in com]))
    return ss

def get_coms(animal, scaling_factor=1):
    """
    Fetches the COMs from disk. The data is stored in micrometers.
    """
    
    coms = {}
    dirpath = f'/net/birdstore/Active_Atlas_Data/data_root/atlas_data/{animal}/com'
    if not os.path.exists(dirpath):
        return coms
    files = sorted(os.listdir(dirpath))
    for file in files:
        structure = Path(file).stem
        filepath = os.path.join(dirpath, file)
        com = np.loadtxt(filepath)
        com /= scaling_factor 
        coms[structure] = com
    return coms


In [None]:
moving_name = 'AtlasV8'
fixed_name = 'Allen'
moving_all = get_coms(moving_name, scaling_factor=1)
fixed_all = get_coms(fixed_name, scaling_factor=1)
common_keys = list(moving_all.keys() & fixed_all.keys())
bad_keys = ('10N_L','10N_R')
good_keys = ['SC', 'IC']
good_keys.extend(['AP'])
good_keys.extend(['3N_L','3N_R'])
good_keys.extend(['4N_L','4N_R'])
good_keys.extend(['5N_L','5N_R'])
good_keys.extend(['6N_L','6N_R'])
good_keys.extend(['7N_L','7N_R'])
#good_keys.extend(['7n_L','7n_R']) # upped the rms a bit
good_keys.extend(['Amb_L', 'Amb_R'])
good_keys.extend(['DC_L', 'DC_R'])
#good_keys.extend(['LC_L', 'LC_R'])
good_keys.extend(['LRt_L', 'LRt_R'])
#good_keys.extend(['PBG_L','PBG_R'])
good_keys.extend(['SNC_L','SNC_R'])
#good_keys.extend(['SNR_L','SNR_R']) # upped a bit
good_keys.extend(['Sp5C_L','Sp5C_R']) # improved 5
good_keys.extend(['Sp5I_L','Sp5I_R']) # improved 4
good_keys.extend(['Sp5O_L','Sp5O_R']) # improved 1
good_keys.extend(['VLL_L','VLL_R']) # improved 2
common_keys = set(good_keys) - set(bad_keys)
#good_keys.extend(['',''])

#good_keys = common_keys
moving_src = np.array([moving_all[s] for s in good_keys])
fixed_src = np.array([fixed_all[s] for s in good_keys])
print(len(good_keys))
transformation_matrix = compute_affine_transformation(moving_src, fixed_src)
print(np.round(moving_src[1]))
print(np.round(fixed_src[1]))

df_list = []
error = []
transformed_dict = {}
for structure in common_keys:    
    moving0 = np.array(moving_all[structure])
    fixed0 = np.array(fixed_all[structure]) 
    transformed = affine_transform_point(moving0, transformation_matrix)
    difference = [a - b for a, b in zip(transformed, fixed0)]
    #diff_moving_fixed = [a - b for a, b in zip(moving0, fixed0)]
    
    ss = sum_square_com(difference)
    row = [structure, np.round(moving0), np.round(fixed0), 
           np.round(transformed), np.round(difference), ss]
    #print(row)
    df_list.append(row)
    error.append(ss)
    transformed_dict[structure] = transformed
rms = sum(error)/len(df_list)
print(f'RMS: {rms} observations: {len(df_list)}')

In [None]:
# MD589 to Allen RMS 260.0211852431133
# MD585 to Allen RMS 263.314352291951
# MD594 to Allen RMS 250.79820210419254
# AtlasV8 DB to Allen RMS: 237.06805950085737 observations: 37
# MD585 to MD594 152.06606097021333 observations: 51
# MD585 to Allen 263.31435 observations: 37
# MD589 to Allen 260.02 observations: 37
# MD594 to Allen 250.79 observations: 37

In [None]:
#transformation_matrix = np.hstack([transformation_matrix, t])
#transformation_matrix = np.vstack([transformation_matrix, np.array([0, 0, 0, 1])])
#print(transformation_matrix)
structure = 'SC'
try:
    com = moving_all[structure]
except KeyError:
    structure = common_keys[0]
    com = moving_all[structure]
#comtfm = np.array([824.6051918494063, 80.83004570523167, 363.4390121956811])
transformed_structure = affine_transform_point(com, transformation_matrix)

print(f'{moving_name} {structure} non trans {np.round(np.array(com))}')
print(f'{moving_name} {structure} apply trans {np.round(transformed_structure/1)}')
print(f'{fixed_name} {structure} {np.round(np.array(fixed_all[structure]))}')
diff = transformed_structure - fixed_all[structure]
#comdiff = comtfm - fixed_all[structure]

print(f'{moving_name}->{fixed_name} error {structure} {diff}')
#print(f'{moving_name}->{fixed_name} tfm error {structure} {comdiff}')

In [None]:
columns = ['structure', moving_name, fixed_name, 'transformed', 'difference', 'sumsquares']
df = pd.DataFrame(df_list, columns=columns)
df.index.name = 'Index'
df = df.round(4)
df.sort_values(by=['sumsquares'], inplace=True)
#df.to_csv('/home/eddyod/programming/pipeline/docs/sphinx/source/_static/results.csv', index=False)
df.head(50)
#20	3N_R	[1079.0, 531.0, 485.0]	[910.0, 380.0, 17.0]	[873.0, 358.0, 298.0]	[-37.0, -21.0, 281.0]	284.3851
#7	4N_L	[1135.0, 529.0, 423.0]	[959.0, 378.0, 544.0]	[923.0, 362.0, 356.0]	[-36.0, -16.0, -188.0]	192.0327

In [None]:
plt.figure(figsize=(30, 5)) 
plt.axhline(y=rms, linestyle='--', linewidth=2, color='red', label='Mean')
plt.text(0, rms, f"Mean RMS={round(rms,2)}")
plt.bar(df['structure'], df['sumsquares'])

In [None]:
fields_we_are_changing = ['active', 'scene_number']
changed_fields = ['copy_copy', 'active']
changed_fields.remove('copy_copy')

# Using sets for efficient intersection
if set(fields_we_are_changing).intersection(changed_fields) == set():
    print("empy set")
else:
    print(f"found {set(fields_we_are_changing).intersection(changed_fields)}")
print(changed_fields)