# DK52 Detection-based Alignment Experiment

In [1]:
import json
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import torch

torch.set_default_dtype(torch.float64)

## Prepare Input Data

### Atlas COMs

In [2]:
from scipy import ndimage

def get_atlas_coms(
    atlas_box_size=(1000, 1000, 300),
    atlas_box_scales=(10, 10, 20),
    atlas_raw_scale=10
):
    atlas_box_size = np.array(atlas_box_size)
    atlas_box_scales = np.array(atlas_box_scales)
    atlas_box_center = atlas_box_size / 2

    atlas_dir = Path('/net/birdstore/Active_Atlas_Data/data_root/atlas_data/atlasV7')
    origin_dir = atlas_dir / 'origin'
    volume_dir = atlas_dir / 'structure'

    atlas_coms = {}

    for origin_file, volume_file in zip(sorted(origin_dir.iterdir()), sorted(volume_dir.iterdir())):
        assert origin_file.stem == volume_file.stem
        name = origin_file.stem

        origin = np.loadtxt(origin_file)

        volume = np.load(volume_file)
        volume = np.rot90(volume, axes=(0, 1))
        volume = np.flip(volume, axis=0)

        # computer volume center of mass in raw array coordinates
        com = (origin + ndimage.measurements.center_of_mass(volume))

        # transform into the atlas box coordinates that neuroglancer assumes
        com = atlas_box_center + com * atlas_raw_scale / atlas_box_scales

        atlas_coms[name] = com

    return atlas_coms

atlas_coms = get_atlas_coms()
atlas_coms

{'10N_L': array([703.3414018 , 547.92857128, 135.63021318]),
 '10N_R': array([703.3414018 , 547.92857128, 164.36978682]),
 '12N': array([697.22688379, 567.09882511, 149.25      ]),
 '3N_L': array([364.55527451, 391.50561214, 143.23580737]),
 '3N_R': array([364.55527451, 391.50561214, 156.76419263]),
 '4N_L': array([399.44078717, 396.10405799, 137.62875327]),
 '4N_R': array([399.44078717, 396.10405799, 162.37124673]),
 '5N_L': array([456.79919285, 502.07948465,  81.90055473]),
 '5N_R': array([456.79919285, 502.07948465, 218.09944527]),
 '6N_L': array([517.15970035, 512.49218793, 132.31348358]),
 '6N_R': array([517.15970035, 512.49218793, 167.68651642]),
 '7N_L': array([529.63289179, 614.42425173,  92.59598246]),
 '7N_R': array([529.63289179, 614.42425173, 207.40401754]),
 '7n_L': array([494.89181904, 540.30703027,  96.03164647]),
 '7n_R': array([494.89181904, 540.30703027, 203.96835353]),
 'AP': array([702.30623016, 525.65250533, 151.        ]),
 'Amb_L': array([612.43978309, 603.825796

### Brain COMs (detection)

In [3]:
with open('DK52_detection_with_noise.pkl', 'rb') as f:
    detection_list = pickle.load(f)
detection_map = {}
brain_coms = {}
# sids = []
# score_maps = []
# score_map_sizes = []
# score_map_scales = []
# max_scores = []
for detection in detection_list:
    detection_map[detection['id']] = detection
    brain_coms[detection['id']] = np.array(detection['center'])

#     sids.append(structure['id'])
#     score_map = structure['score']
#     score_maps.append(score_map)
#     score_map_sizes.append(score_map.shape)
#     score_map_scales.append(np.array(structure['scale']))
#     max_scores.append(score_map.max())
# score_maps = np.array(score_maps)
# score_map_sizes = np.array(score_map_sizes)
# score_map_scales = np.array(score_map_scales)
# max_scores = np.array(max_scores)
brain_coms

{'10N_L': array([45737, 17865,   253]),
 '10N_R': array([45237, 17399,   219]),
 '12N': array([46654, 18792,   249]),
 '3N_L': array([36076, 14529,   230]),
 '3N_R': array([36199, 14801,   225]),
 '4N_L': array([34831, 16623,   180]),
 '4N_R': array([37779, 14777,   233]),
 '5N_L': array([38733, 20351,   172]),
 '5N_R': array([38819, 19394,   315]),
 '6N_L': array([40600, 19468,   255]),
 '6N_R': array([40743, 19653,   214]),
 '7N_L': array([42239, 26123,   184]),
 '7N_R': array([34295, 18803,   193]),
 'AP': array([46314, 17280,   236]),
 'Amb_L': array([44456, 23431,   149]),
 'Amb_R': array([44863, 21010,   308]),
 'DC_L': array([44072, 18564,   106]),
 'DC_R': array([42405, 15158,   356]),
 'IC': array([34528,  9665,   228]),
 'LC_L': array([39564, 16854,   208]),
 'LC_R': array([40645, 16549,   288]),
 'LRt_L': array([49883, 23887,   210]),
 'LRt_R': array([48479, 23556,   307]),
 'PBG_L': array([35437, 16642,   147]),
 'PBG_R': array([36209, 15054,   336]),
 'Pn_L': array([36332,

### Brain COMs (manual)

In [4]:
with open('dk52_coms_man.json', 'r') as f:
    brain_coms_man = json.load(f)
for k in brain_coms_man.keys():
    brain_coms_man[k] = np.array(brain_coms_man[k])
brain_coms_man

{'10N_L': array([45542, 17755,   218]),
 '10N_R': array([45714, 17685,   247]),
 '12N': array([45596, 18540,   234]),
 '3N_L': array([34961, 16285,   232]),
 '3N_R': array([34749, 16231,   243]),
 '4N_L': array([36247, 16030,   223]),
 '4N_R': array([36247, 16056,   247]),
 '5N_L': array([38765, 19548,   172]),
 '5N_R': array([38914, 18314,   309]),
 '6N_L': array([40686, 19124,   257]),
 '6N_R': array([40599, 19425,   220]),
 '7N_L': array([44309, 20809,   307]),
 '7N_R': array([42465, 22042,   309]),
 'AP': array([46190, 16965,   230]),
 'Amb_L': array([44094, 21914,   172]),
 'DC_L': array([41986, 17520,   130]),
 'DC_R': array([42227, 16113,   343]),
 'IC': array([37176,  9511,   222]),
 'LC_L': array([39808, 16939,   188]),
 'LC_R': array([39934, 16233,   278]),
 'LRt_L': array([47708, 22042,   183]),
 'LRt_R': array([47409, 21332,   296]),
 'PBG_L': array([35560, 15975,   138]),
 'PBG_R': array([35814, 15440,   330]),
 'Pn_L': array([35794, 24118,   220]),
 'Pn_R': array([35751, 

### Scales

In [5]:
atlas_scale = np.diag([10, 10, 20])
atlas_scale

array([[10,  0,  0],
       [ 0, 10,  0],
       [ 0,  0, 20]])

In [6]:
brain_scale = np.diag([0.325, 0.325, 20])
brain_scale

array([[ 0.325,  0.   ,  0.   ],
       [ 0.   ,  0.325,  0.   ],
       [ 0.   ,  0.   , 20.   ]])

## Define Transformation

In [7]:
from torch.nn import Module
from torch.nn import Parameter

class LandmarkAffineTransform(Module):
    """Affine transform for landmark."""

    def __init__(self):
        super().__init__()
        self._linear_matrix = Parameter(torch.eye(3))
        self._translation = Parameter(torch.zeros(3))

    def init_guess(self, fixed_landmark, moving_landmark):
        """Make an initial guess of parameters."""
        fixed_landmark = np.array(fixed_landmark)
        moving_landmark = np.array(moving_landmark)
        r, t = umeyama(moving_landmark.T, fixed_landmark.T)
        r = r.T
        t = t.T[0]
        self._linear_matrix = Parameter(torch.tensor(r))
        self._translation = Parameter(torch.tensor(t))
        return r, t

    def transform_numpy(self, moving):
        """Apply transform to the moving object and get numpy result."""
        return self.forward(torch.tensor(moving)).detach().numpy()

    def get_linear_matrix(self):
        """Get the linear transform matrix."""
        return self._linear_matrix.detach().clone().numpy()

    def get_translation(self):
        """Get the translation vector."""
        return self._translation.detach().clone().numpy()

    def forward(self, moving_landmark):
        """Forward the model."""
        return moving_landmark @ self._linear_matrix + self._translation

def umeyama(src, dst, with_scaling=True):
    """The Umeyama algorithm to register landmarks with rigid transform.
    See the paper "Least-squares estimation of transformation parameters
    between two point patterns".
    """
    src = np.array(src)
    dst = np.array(dst)
    assert src.shape == dst.shape
    assert len(src.shape) == 2
    dim, n_pts = src.shape

    src_mean = np.mean(src, axis=1).reshape(-1, 1)
    dst_mean = np.mean(dst, axis=1).reshape(-1, 1)

    src_demean = src - src_mean
    dst_demean = dst - dst_mean

    u, s, vh = np.linalg.svd(dst_demean @ src_demean.T / n_pts)

    # deal with reflection
    e = np.ones(dim)
    if np.linalg.det(u) * np.linalg.det(vh) < 0:
        print("reflection detected")
        e[-1] = -1

    r = u @ np.diag(e) @ vh

    if with_scaling:
        src_var = (src_demean ** 2).sum(axis=0).mean()
        c = sum(s * e) / src_var
        r *= c

    t = dst_mean - r @ src_mean

    return r, t

## Define Loss

In [8]:
from torch.nn.functional import grid_sample

def get_scores(fix_landmarks, mov_landmarks, score_maps, score_map_scales):
    score_maps = torch.tensor(score_maps)
    n, d, h, w = score_maps.shape
    score_map_size = torch.tensor([d - 1, h - 1, w - 1])
    score_map_scales = torch.tensor(score_map_scales)

    grid = (mov_landmarks - fix_landmarks) / score_map_scales
    grid = grid / (score_map_size / 2) # normalize to [-1, 1]

    scores = grid_sample(
        score_maps.view(n, 1, d, h, w),
        grid.view(n, 1, 1, 1, 3),
        mode='bilinear',
        padding_mode='zeros',
        align_corners=True
    )
    return scores.flatten()

def score_map_loss(fix_landmarks, mov_landmarks, score_maps, score_map_scales):
    scores = get_scores(fix_landmarks, mov_landmarks, score_maps, score_map_scales)
    return -scores.sum()

## Define Registration

In [9]:
def registrate(
    atlas_landmarks, atlas_scale,
    brain_landmarks, brain_scale,
    score_maps, score_map_scales,
    n_iter_affine=4000,
    n_iter_score=200
):
    dtype = torch.float64

    atlas_landmarks = torch.tensor(atlas_landmarks, dtype=dtype)
    brain_landmarks = torch.tensor(brain_landmarks, dtype=dtype)
    atlas_scale = torch.tensor(atlas_scale, dtype=dtype)
    brain_scale = torch.tensor(brain_scale, dtype=dtype)

    transform = LandmarkAffineTransform()
    transform.init_guess(
        brain_landmarks @ brain_scale,
        atlas_landmarks @ atlas_scale
    )
    r0 = transform.get_linear_matrix()
    t0 = transform.get_translation()
    
    optimizer = torch.optim.Adam(transform.parameters())

    print('Optimizing with MSEloss')
    loss_fun = torch.nn.MSELoss()
    print_step = n_iter_affine // 10
    for i_iter in range(n_iter_affine):

        def closure():
            optimizer.zero_grad()
            fix_landmarks = brain_landmarks @ brain_scale
            mov_landmarks = transform(atlas_landmarks @ atlas_scale)
            loss = loss_fun(fix_landmarks, mov_landmarks)
            loss.backward()
            if i_iter % print_step == 0:
                print(i_iter, loss)
            return loss

        optimizer.step(closure)
    r1 = transform.get_linear_matrix()
    t1 = transform.get_translation()

    print('Optimizing with score_map_loss')
    loss_fun = score_map_loss
    print_step = n_iter_score // 10
    for i_iter in range(n_iter_score):

        def closure():
            optimizer.zero_grad()
            fix_landmarks = brain_landmarks
            mov_landmarks = transform(atlas_landmarks @ atlas_scale) @ torch.inverse(brain_scale)
            loss = loss_fun(fix_landmarks, mov_landmarks, score_maps, score_map_scales)
            loss.backward()
            if i_iter % print_step == 0:
                print(i_iter, loss)
            return loss

        optimizer.step(closure)
    r2 = transform.get_linear_matrix()
    t2 = transform.get_translation()

    return r0, t0, r1, t1, r2, t2

## Assemble Pipeline

In [10]:
def align_and_diagnose(
    atlas_coms, atlas_scale,
    brain_coms, brain_scale,
    brain_coms_man,
    detection_map,
    selection=None
):
    common_structures = set(atlas_coms.keys()) & set(brain_coms.keys()) & set(brain_coms_man.keys())
    common_structures = sorted(common_structures)
    
    # Registrate with selected structures

    if selection is None:
        selection = common_structures
    
    atlas_landmarks = np.array([atlas_coms[name] for name in selection])
    brain_landmarks = np.array([brain_coms[name] for name in selection])
    
    score_maps = np.array([detection_map[name]['score'] for name in selection])
    score_map_scales = np.array([detection_map[name]['scale'] for name in selection])

    r0, t0, r1, t1, r2, t2 = registrate(
        atlas_landmarks, atlas_scale,
        brain_landmarks, brain_scale,
        score_maps, score_map_scales
    )
    
    # Diagnose with all structures

    selection = common_structures
    
    atlas_landmarks = np.array([atlas_coms[name] for name in selection])
    brain_landmarks = np.array([brain_coms[name] for name in selection])
    brain_landmarks_man = np.array([brain_coms_man[name] for name in selection])
    
    score_maps = np.array([detection_map[name]['score'] for name in selection])
    score_map_scales = np.array([detection_map[name]['scale'] for name in selection])
    
    # Extract diagnostic information

    brain_coms_phys = brain_landmarks @ brain_scale
    brain_coms_man_phys = brain_landmarks_man @ brain_scale
    atlas_coms0_phys = atlas_landmarks @ atlas_scale @ r0 + t0
    atlas_coms1_phys = atlas_landmarks @ atlas_scale @ r1 + t1
    atlas_coms2_phys = atlas_landmarks @ atlas_scale @ r2 + t2
    
    score0 = get_scores(
        brain_coms_phys @ np.linalg.inv(brain_scale),
        atlas_coms0_phys @ np.linalg.inv(brain_scale),
        score_maps,
        score_map_scales
    )
    score1 = get_scores(
        brain_coms_phys @ np.linalg.inv(brain_scale),
        atlas_coms1_phys @ np.linalg.inv(brain_scale),
        score_maps,
        score_map_scales
    )
    score2 = get_scores(
        brain_coms_phys @ np.linalg.inv(brain_scale),
        atlas_coms2_phys @ np.linalg.inv(brain_scale),
        score_maps,
        score_map_scales
    )
    
    def compute_dist(pos1, pos2):
        return np.sqrt(np.square(pos2 - pos1).sum(axis=-1))

    dist0_det = compute_dist(brain_coms_phys, atlas_coms0_phys)
    dist1_det = compute_dist(brain_coms_phys, atlas_coms1_phys)
    dist2_det = compute_dist(brain_coms_phys, atlas_coms2_phys)
    dist0_man = compute_dist(brain_coms_man_phys, atlas_coms0_phys)
    dist1_man = compute_dist(brain_coms_man_phys, atlas_coms1_phys)
    dist2_man = compute_dist(brain_coms_man_phys, atlas_coms2_phys)
    
    max_scores = [score_map.max() for score_map in score_maps]
    score_map_sizes = np.array([score_map.shape for score_map in score_maps])
    score_map_size_phys = (score_map_sizes - 1) * score_map_scales @ brain_scale
    
    diag = {}
    diag['name'] = selection
    diag['max_score'] = max_scores
    diag['dx'] = score_map_size_phys.T[0]
    diag['dy'] = score_map_size_phys.T[1]
    diag['dz'] = score_map_size_phys.T[2]

    diag['atlas_rig_dist_man'] = dist0_man
    diag['atlas_rig_dist_det'] = dist0_det
    diag['atlas_rig_score'] = score0

    diag['atlas_aff_dist_man'] = dist1_man
    diag['atlas_aff_dist_det'] = dist1_det
    diag['atlas_aff_score'] = score1

    diag['atlas_det_dist_man'] = dist0_man
    diag['atlas_det_dist_det'] = dist0_det
    diag['atlas_det_score'] = score0
    diag = pd.DataFrame(diag)
    
    return diag, [r0, t0, r1, t1, r2, t2]

## Run Experiments

### Baseline

In [11]:
diag, transforms = align_and_diagnose(
    atlas_coms, atlas_scale,
    brain_coms, brain_scale,
    brain_coms_man,
    detection_map,
    selection=None
)
diag.to_csv('dk52_com_diag_baseline.csv')
diag.sort_values('max_score', ascending=False)

Optimizing with MSEloss
0 tensor(347250.6903, grad_fn=<MseLossBackward>)


  Variable._execution_engine.run_backward(


400 tensor(330615.2910, grad_fn=<MseLossBackward>)
800 tensor(330605.0419, grad_fn=<MseLossBackward>)
1200 tensor(330595.3107, grad_fn=<MseLossBackward>)
1600 tensor(330585.7069, grad_fn=<MseLossBackward>)
2000 tensor(330576.2090, grad_fn=<MseLossBackward>)
2400 tensor(330566.7603, grad_fn=<MseLossBackward>)
2800 tensor(330557.3433, grad_fn=<MseLossBackward>)
3200 tensor(330547.9482, grad_fn=<MseLossBackward>)
3600 tensor(330538.5690, grad_fn=<MseLossBackward>)
Optimizing with score_map_loss
0 tensor(-3.5247, grad_fn=<NegBackward>)
20 tensor(-4.5699, grad_fn=<NegBackward>)
40 tensor(-8.1178, grad_fn=<NegBackward>)
60 tensor(-8.6702, grad_fn=<NegBackward>)
80 tensor(-8.9471, grad_fn=<NegBackward>)
100 tensor(-8.9915, grad_fn=<NegBackward>)
120 tensor(-9.0611, grad_fn=<NegBackward>)
140 tensor(-9.0644, grad_fn=<NegBackward>)
160 tensor(-9.0905, grad_fn=<NegBackward>)
180 tensor(-9.0473, grad_fn=<NegBackward>)


Unnamed: 0,name,max_score,dx,dy,dz,atlas_rig_dist_man,atlas_rig_dist_det,atlas_rig_score,atlas_aff_dist_man,atlas_aff_dist_det,atlas_aff_score,atlas_det_dist_man,atlas_det_dist_det,atlas_det_score
10,6N_R,4.724561,594.75,594.75,600.0,811.269317,934.877806,0.0,788.782543,913.048188,0.0,811.269317,934.877806,0.0
5,4N_L,4.629459,594.75,594.75,600.0,332.180325,951.464056,0.0,327.774597,924.224928,0.0,332.180325,951.464056,0.0
9,6N_L,4.519583,594.75,594.75,600.0,679.123038,629.366303,0.0,669.286785,613.361061,0.0,679.123038,629.366303,0.0
27,SNC_L,4.469259,594.75,594.75,600.0,142.70113,152.566546,-0.727826,412.469873,411.126386,2.105887,142.70113,152.566546,-0.727826
6,4N_R,4.453532,594.75,594.75,600.0,266.238388,899.861889,0.0,243.07094,865.240729,0.0,266.238388,899.861889,0.0
23,PBG_R,4.16253,594.75,594.75,600.0,103.386935,317.734067,0.96048,228.192088,388.022701,0.277834,103.386935,317.734067,0.96048
0,10N_L,4.044719,594.75,594.75,600.0,643.48443,670.891055,0.0,480.558765,531.967005,0.0,643.48443,670.891055,0.0
25,Pn_R,4.022889,653.25,653.25,600.0,228.236774,1587.71898,0.0,369.046021,1386.131611,0.0,228.236774,1587.71898,0.0
1,10N_R,3.997256,594.75,594.75,600.0,565.634806,1074.097548,0.0,380.594623,937.222989,0.0,565.634806,1074.097548,0.0
38,VCA_R,3.988726,789.75,789.75,600.0,178.326092,1210.954057,0.0,26.984538,1277.456592,0.0,178.326092,1210.954057,0.0


### Kui's Good List

In [12]:
kui_list = ['5N_L', '12N', '6N_L', '6N_R', 'LC_L', 'LC_R', 'PBG_R', 'AP', 'Amb_L', 'VLL_L']
kui_list

['5N_L',
 '12N',
 '6N_L',
 '6N_R',
 'LC_L',
 'LC_R',
 'PBG_R',
 'AP',
 'Amb_L',
 'VLL_L']

In [13]:
diag, transforms = align_and_diagnose(
    atlas_coms, atlas_scale,
    brain_coms, brain_scale,
    brain_coms_man,
    detection_map,
    selection=kui_list
)
diag.to_csv('dk52_com_diag_kui_list.csv')
diag.sort_values('max_score', ascending=False)

Optimizing with MSEloss
0 tensor(84453.0993, grad_fn=<MseLossBackward>)
400 tensor(75433.7292, grad_fn=<MseLossBackward>)
800 tensor(75427.8097, grad_fn=<MseLossBackward>)
1200 tensor(75422.2241, grad_fn=<MseLossBackward>)
1600 tensor(75416.7348, grad_fn=<MseLossBackward>)
2000 tensor(75411.2847, grad_fn=<MseLossBackward>)
2400 tensor(75405.8982, grad_fn=<MseLossBackward>)
2800 tensor(75400.4909, grad_fn=<MseLossBackward>)
3200 tensor(75395.0536, grad_fn=<MseLossBackward>)
3600 tensor(75389.6315, grad_fn=<MseLossBackward>)
Optimizing with score_map_loss
0 tensor(5.3633, grad_fn=<NegBackward>)
20 tensor(0.8080, grad_fn=<NegBackward>)
40 tensor(-0.3125, grad_fn=<NegBackward>)
60 tensor(-1.2131, grad_fn=<NegBackward>)
80 tensor(-1.4526, grad_fn=<NegBackward>)
100 tensor(-1.8005, grad_fn=<NegBackward>)
120 tensor(-1.9337, grad_fn=<NegBackward>)
140 tensor(-2.0081, grad_fn=<NegBackward>)
160 tensor(-2.0425, grad_fn=<NegBackward>)
180 tensor(-2.0730, grad_fn=<NegBackward>)


Unnamed: 0,name,max_score,dx,dy,dz,atlas_rig_dist_man,atlas_rig_dist_det,atlas_rig_score,atlas_aff_dist_man,atlas_aff_dist_det,atlas_aff_score,atlas_det_dist_man,atlas_det_dist_det,atlas_det_score
10,6N_R,4.724561,594.75,594.75,600.0,764.191061,889.923287,0.0,730.832224,853.820385,0.0,764.191061,889.923287,0.0
5,4N_L,4.629459,594.75,594.75,600.0,147.671547,1021.781732,0.0,143.047858,1010.827402,0.0,147.671547,1021.781732,0.0
9,6N_L,4.519583,594.75,594.75,600.0,736.319374,701.461143,0.0,781.969377,742.0561,0.0,736.319374,701.461143,0.0
27,SNC_L,4.469259,594.75,594.75,600.0,254.435118,422.578016,-1.451316,321.980634,52.194812,1.183642,254.435118,422.578016,-1.451316
6,4N_R,4.453532,594.75,594.75,600.0,125.735939,774.993278,0.0,111.96663,718.882467,0.0,125.735939,774.993278,0.0
23,PBG_R,4.16253,594.75,594.75,600.0,191.618643,61.230554,2.983016,311.493271,118.865617,0.032735,191.618643,61.230554,2.983016
0,10N_L,4.044719,594.75,594.75,600.0,607.865088,857.573013,0.0,484.439774,659.056652,0.0,607.865088,857.573013,0.0
25,Pn_R,4.022889,653.25,653.25,600.0,281.120439,1603.775606,0.0,584.523068,1716.894177,0.0,281.120439,1603.775606,0.0
1,10N_R,3.997256,594.75,594.75,600.0,567.88174,963.531151,0.0,449.272061,951.406637,0.0,567.88174,963.531151,0.0
38,VCA_R,3.988726,789.75,789.75,600.0,293.787881,1283.955365,0.0,344.940101,1393.815812,0.0,293.787881,1283.955365,0.0


### Top Detection Score List

In [14]:
k_top = 12
common_structures = set(atlas_coms.keys()) & set(brain_coms.keys()) & set(brain_coms_man.keys())
common_structures = sorted(common_structures)
max_scores = np.array([detection_map[name]['score'].max() for name in common_structures])
indices = max_scores.argsort()[-k_top:]
top_score_list = [common_structures[i] for i in indices]
top_score_list

['PBG_L',
 'RtTg',
 'VCA_R',
 '10N_R',
 'Pn_R',
 '10N_L',
 'PBG_R',
 '4N_R',
 'SNC_L',
 '6N_L',
 '4N_L',
 '6N_R']

In [15]:
diag, transforms = align_and_diagnose(
    atlas_coms, atlas_scale,
    brain_coms, brain_scale,
    brain_coms_man,
    detection_map,
    selection=top_score_list
)
diag.to_csv('dk52_com_diag_top_score_list.csv')
diag.sort_values('max_score', ascending=False)

Optimizing with MSEloss
0 tensor(244884.0640, grad_fn=<MseLossBackward>)
400 tensor(205770.1516, grad_fn=<MseLossBackward>)
800 tensor(205740.9131, grad_fn=<MseLossBackward>)
1200 tensor(205719.5532, grad_fn=<MseLossBackward>)
1600 tensor(205698.5911, grad_fn=<MseLossBackward>)
2000 tensor(205677.8190, grad_fn=<MseLossBackward>)
2400 tensor(205657.1549, grad_fn=<MseLossBackward>)
2800 tensor(205636.5586, grad_fn=<MseLossBackward>)
3200 tensor(205616.0762, grad_fn=<MseLossBackward>)
3600 tensor(205595.4952, grad_fn=<MseLossBackward>)
Optimizing with score_map_loss
0 tensor(-0.5013, grad_fn=<NegBackward>)
20 tensor(-0.7158, grad_fn=<NegBackward>)
40 tensor(-1.0923, grad_fn=<NegBackward>)
60 tensor(-1.1198, grad_fn=<NegBackward>)
80 tensor(-1.1528, grad_fn=<NegBackward>)
100 tensor(-1.1652, grad_fn=<NegBackward>)
120 tensor(-1.1784, grad_fn=<NegBackward>)
140 tensor(-1.1823, grad_fn=<NegBackward>)
160 tensor(-1.1873, grad_fn=<NegBackward>)
180 tensor(-1.1861, grad_fn=<NegBackward>)


Unnamed: 0,name,max_score,dx,dy,dz,atlas_rig_dist_man,atlas_rig_dist_det,atlas_rig_score,atlas_aff_dist_man,atlas_aff_dist_det,atlas_aff_score,atlas_det_dist_man,atlas_det_dist_det,atlas_det_score
10,6N_R,4.724561,594.75,594.75,600.0,710.876086,837.501486,0.0,699.325295,828.585402,0.0,710.876086,837.501486,0.0
5,4N_L,4.629459,594.75,594.75,600.0,478.046679,1018.238177,0.0,494.914156,994.72131,0.0,478.046679,1018.238177,0.0
9,6N_L,4.519583,594.75,594.75,600.0,739.363125,691.832645,0.0,697.933438,656.550732,0.0,739.363125,691.832645,0.0
27,SNC_L,4.469259,594.75,594.75,600.0,187.984841,352.867696,0.0,653.624785,840.530224,0.0,187.984841,352.867696,0.0
6,4N_R,4.453532,594.75,594.75,600.0,426.579773,1060.267192,0.0,477.926244,1090.863768,0.0,426.579773,1060.267192,0.0
23,PBG_R,4.16253,594.75,594.75,600.0,281.089167,448.330451,0.0,511.758661,682.987593,0.0,281.089167,448.330451,0.0
0,10N_L,4.044719,594.75,594.75,600.0,318.199467,715.725258,0.0,183.51695,862.030499,0.0,318.199467,715.725258,0.0
25,Pn_R,4.022889,653.25,653.25,600.0,282.792787,1470.157989,0.0,638.599971,916.055074,0.0,282.792787,1470.157989,0.0
1,10N_R,3.997256,594.75,594.75,600.0,270.242638,730.834629,0.0,221.013756,378.838238,0.0,270.242638,730.834629,0.0
38,VCA_R,3.988726,789.75,789.75,600.0,243.992136,1125.126387,0.0,465.286394,944.606944,0.0,243.992136,1125.126387,0.0
