Test kESI vs kCSD in four sphere model geometry, where forward FEM simulation and kESI assumptions matches.

The comparison is handicapped - which means that GT CSDs are kCSD eigensources, which makes that method advantaged.

In [None]:
import configparser
import os
import collections
import itertools

import numpy as np
import pandas as pd

import kesi
import kesi._verbose as verbose
import _common_new as common
import _fast_reciprocal_reconstructor as frr

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import cbf

In [None]:
COMPARISON = 'handicapped'

INV_GEOMETRY = 'four_spheres_csf_3_mm'
INV_MESH = 'normal'
INV_DEGREE = 2

FWD_GEOMETRY = INV_GEOMETRY
FWD_MESH = 'coarse'
FWD_DEGREE = 3

In [None]:
K = 9

FILENAME_PATTERN = f'FEM/solutions/paper/{INV_GEOMETRY}/{INV_MESH}/{INV_DEGREE}/sampled/{K}/{{name}}.npz'

In [None]:
FEM_MESH = f'FEM/meshes/meshes/{FWD_GEOMETRY}_plain/{FWD_MESH}.xdmf'
FEM_CONFIG = f'FEM/model_properties/{FWD_GEOMETRY}.ini'

# FEM

In [None]:
import dolfin
import FEM.fem_common as fc
import scipy.interpolate as si

In [None]:
GROUNDED_PLATE_AT = -0.088


class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    def __init__(self, mesh, degree, config):
        self.fm = fc.FunctionManager(mesh, degree, 'CG')
        self.config = configparser.ConfigParser()
        self.config.read(config)
        
        mesh_filename = mesh[:-5]

#     def __init__(self, config):
#         self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
#         mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
          
#         self.config = configparser.ConfigParser()
#         self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] < GROUNDED_PLATE_AT))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
%time fem_gt = ForwardModel(FEM_MESH, FWD_DEGREE, FEM_CONFIG)

# kernel construction

In [None]:
electrode_grid_names = [
#                         'A',
                        'B',
#                         'C',
                        'D',
#                         'E',
                        ]
electrode_names = [f'{x}_{i:02d}' for i in range(0, 12)
                   for x in electrode_grid_names]

In [None]:
class Electrode(object):
    def __init__(self, filename):
        """
        Parameters
        ----------
        
        filename : str
            Path to the sampled correction potential.
        """
        self.filename = filename
        with np.load(filename) as fh:
            self.SAMPLING_GRID = [fh[c] for c in 'XYZ']
            self.x, self.y, self.z = fh['LOCATION']
            self.base_conductivity = fh['BASE_CONDUCTIVITY']

    def correction_leadfield(self, X, Y, Z):
        """
        Correction of the leadfield of the electrode
        for violation of kCSD assumptions
        
        Parameters
        ----------
        X, Y, Z : np.array
            Coordinate matrices of the same shape.
        """
        with np.load(self.filename) as fh:
            return self._correction_leadfield(fh['CORRECTION_POTENTIAL'],
                                              [X, Y, Z])

    def _correction_leadfield(self, SAMPLES, XYZ):
        # if XYZ points are in nodes of the sampling grid,
        # no time-consuming interpolation is necessary
        return SAMPLES[self._sampling_grid_indices(XYZ)]

    def _sampling_grid_indices(self, XYZ):
        return tuple(np.searchsorted(GRID, COORD)
                     for GRID, COORD in zip(self.SAMPLING_GRID, XYZ))

In [None]:
electrodes = [Electrode(FILENAME_PATTERN.format(name=name))
              for name in electrode_names]

In [None]:
ELECTRODES = []
for name, electrode in zip(electrode_names, electrodes):
    ELECTRODES.append({'NAME': name,
                       'X': electrode.x,
                       'Y': electrode.y,
                       'Z': electrode.z})
ELECTRODES = pd.DataFrame(ELECTRODES)

In [None]:
plt.scatter(ELECTRODES.X, ELECTRODES.Z, marker='.')
plt.gca().add_artist(plt.Circle((0,0), radius=0.090, ls=':', edgecolor=cbf.BLACK, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.086, ls=':', edgecolor=cbf.VERMILION, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.082, ls=':', edgecolor=cbf.BLUE, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.079, ls=':', edgecolor=cbf.PURPLE, facecolor='none'))
plt.xlim(-0.03, 0.03)
plt.ylim(0.04, 0.08)
plt.gca().set_aspect('equal')

In [None]:
set(np.diff(sorted(set(ELECTRODES.Z * 1000))))

In [None]:
ELECTRODES.Z.min()

In [None]:
XX, YY, ZZ = electrode.SAMPLING_GRID

In [None]:
ROMBERG_K = 5

In [None]:
dx = (XX[-1] - XX[0]) / (len(XX) - 1)
SRC_R_MAX = (2**(ROMBERG_K - 1)) * dx
ROMBERG_N = 2**ROMBERG_K + 1
print(SRC_R_MAX)

In [None]:
H_Y = 2e-2
H_X = 2e-2
X = XX[abs(XX) <= H_X + SRC_R_MAX + dx]
Y = YY[abs(YY) <= H_Y + SRC_R_MAX + dx]
Z = ZZ[(ZZ >= 2.5e-2 - SRC_R_MAX - dx)
       & (ZZ <= 7.9e-2)]

In [None]:
convolver = frr.Convolver([X, Y, Z],
                          [X, Y, Z])

In [None]:
sd = SRC_R_MAX / 3
def source(x, y, z):
    return common.SphericalSplineSourceKCSD(x, y, z,
#                                             [SRC_R_MAX],
#                                             [[1]],
                                             [sd, 3 * sd],
                                             [[1],
                                              [0,
                                               2.25 / sd,
                                               -1.5 / sd ** 2,
                                               0.25 / sd ** 3]],
                                             electrode.base_conductivity)

model_src = source(0, 0, 0)

In [None]:
SRC_MASK = (((convolver.SRC_Z < convolver.SRC_Z.max() - SRC_R_MAX)
             & (convolver.SRC_Z > convolver.SRC_Z.min() + SRC_R_MAX))
            & (abs(convolver.SRC_Y) <= H_Y)
            & (abs(convolver.SRC_X) <= H_X)
            & (np.square(convolver.SRC_X)
               + np.square(convolver.SRC_Y)
               + np.square(convolver.SRC_Z)
                < np.square(0.079 - SRC_R_MAX))
            )

In [None]:
SRC_MASK.sum(), SRC_MASK.shape

In [None]:
CSD_MASK = np.ones(convolver.shape('CSD'),
                   dtype=bool)

# Kernels

In [None]:
from scipy.integrate import romb

ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N)) * 2 ** -ROMBERG_K

convolver_interface = frr.ConvolverInterfaceIndexed(convolver,
                                                    model_src.csd,
                                                    ROMBERG_WEIGHTS,
                                                    SRC_MASK)

kernel_constructor = frr.KernelConstructor()
kernel_constructor.create_crosskernel = frr.CrossKernelConstructor(convolver_interface,
                                                                   CSD_MASK)

pae_kcsd = frr.PAE_Analytical(convolver_interface,
                              potential=model_src.potential)
pae_kesi = frr.PAE_AnalyticalCorrectedNumerically(convolver_interface,
                                                  potential=model_src.potential)

Warning: no subtraction of kCSD out of the slice possible

In [None]:
%%time
PHI_KCSD = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                               pae_kcsd)

In [None]:
%%time
PHI_KESI = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                               pae_kesi)

# kernel analysis

In [None]:
KERNEL_KCSD = kernel_constructor.create_kernel(PHI_KCSD)

EIGENVALUES_KCSD, EIGENVECTORS_KCSD = np.linalg.eigh(KERNEL_KCSD)
EIGENVALUES_KCSD, EIGENVECTORS_KCSD = EIGENVALUES_KCSD[::-1], EIGENVECTORS_KCSD[:, ::-1]
LAMBDA_KCSD = np.sqrt(EIGENVALUES_KCSD)
EIGENSOURCES_KCSD = np.matmul(PHI_KCSD,
                              np.matmul(EIGENVECTORS_KCSD,
                                        np.diag(1. / LAMBDA_KCSD)))

In [None]:
KERNEL_KESI = kernel_constructor.create_kernel(PHI_KESI)

EIGENVALUES_KESI, EIGENVECTORS_KESI = np.linalg.eigh(KERNEL_KESI)
EIGENVALUES_KESI, EIGENVECTORS_KESI = EIGENVALUES_KESI[::-1], EIGENVECTORS_KESI[:, ::-1]
LAMBDA_KESI = np.sqrt(EIGENVALUES_KESI)
# EIGENSOURCES_KESI = np.matmul(FWD,
#                               np.matmul(EIGENVECTORS_KESI,
#                                         np.diag(1. / LAMBDA_KESI)))

In [None]:
# KCSD_TO_KESI_ES = np.matmul(EIGENSOURCES_KESI.T, EIGENSOURCES)

In [None]:
# assert (KCSD_TO_KESI_ES.max(axis=0) > 0).all()

In [None]:
# plt.plot(abs(np.diag(np.matmul(EIGENVECTORS_KESI.T, EIGENVECTORS))), marker='.')
# plt.plot(abs(np.diag(KCSD_TO_KESI_ES)), marker='.')
# plt.ylim(0.94, 1.00)

In [None]:
# for i, KESI_PROJECTION in enumerate(KCSD_TO_KESI_ES.T):
#     _idx = np.argmax(abs(KESI_PROJECTION))
#     if KESI_PROJECTION[_idx] > 0:
#         EIGENSOURCES[:, i] += EIGENSOURCES_KESI[:, _idx]
#     else:
#         EIGENSOURCES[:, i] -= EIGENSOURCES_KESI[:, _idx]

In [None]:
# del EIGENSOURCES_KESI

## crosskernels

In [None]:
%%time
CROSSKERNEL_KCSD = kernel_constructor.create_crosskernel(PHI_KCSD).reshape(convolver.shape('CSD') + (-1,))

In [None]:
%%time
CROSSKERNEL_KESI = kernel_constructor.create_crosskernel(PHI_KESI).reshape(convolver.shape('CSD') + (-1,))

In [None]:
del PHI_KCSD, PHI_KESI

# IMAGES

In [None]:
%%time
GT_CSD = []

_SRC = np.zeros(convolver.shape('SRC'))
for i, _SRC[SRC_MASK] in enumerate(EIGENSOURCES_KCSD.T):
    print(i)
    GT_CSD.append(convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3))
    
del _SRC

In [None]:
%%time
IMAGE = []

for i, _CSD in enumerate(GT_CSD):
    print(i)
    _csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  _CSD,
                                  bounds_error=False,
                                  fill_value=0)
    _v = fem_gt(_csd)
    IMAGE.append(np.array(list(map(_v, ELECTRODES.X, ELECTRODES.Y, ELECTRODES.Z))))

In [None]:
del _CSD, _csd, _v

In [None]:
%%time

FILENAME = f'git_paper_4SM__two_paralel_linear_electrodes_images_INV_{INV_GEOMETRY}_{INV_MESH}_{INV_DEGREE}_FWD_{FWD_GEOMETRY}_{FWD_MESH}_{FWD_DEGREE}_{COMPARISON}.npz'

kwargs = {attr: getattr(convolver, attr) for attr in map('CSD_{}'.format, 'XYZ')}
kwargs.update((f'POT_{attr}', ELECTRODES[attr]) for attr in 'XYZ')

np.savez_compressed(FILENAME,
                    CSD=GT_CSD,
                    POT=IMAGE,
                    **kwargs)

In [None]:
for i, V in enumerate(IMAGE):
    plt.figure()
    plt.title(i)
    plt.axhline(1, ls=':', color=cbf.BLACK)
    plt.axvline(i, ls=':', color=cbf.BLACK)
    plt.stem(np.matmul(V, EIGENVECTORS_KCSD) / LAMBDA_KCSD)
    plt.plot(np.matmul(V, EIGENVECTORS_KESI) / LAMBDA_KESI)

## Reconstruction errors

In [None]:
norms = {'L1': lambda x: np.abs(x).mean(),
         'L2': lambda x: np.sqrt(np.square(x).mean()),
         'Linf': lambda x: np.abs(x).max(),
         }

def add_norms_to_dict(d, key_template, DATA):
    for name, norm in norms.items():
        d[key_template.format(name)] = norm(DATA)

def reconstruct(_KERNEL, _CROSSKERNEL, V, _rp=0):
    return np.matmul(_CROSSKERNEL,
                     np.linalg.solve(_KERNEL
                                     + _rp * np.identity(len(_KERNEL)),
                                     V))

In [None]:
plt.plot(EIGENVALUES_KCSD)
plt.plot(EIGENVALUES_KESI)
plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(3, 17, 5 * 14 + 1)

In [None]:
es_reconstructors = {method: verbose.VerboseFFR._CrossKernelReconstructor(
                                                     kesi._engine._LinearKernelSolver(
                                                         _KERNEL),
                                                     np.matmul(np.diag(_LAMBDA),
                                                               _EIGENVECTORS.T))
                     for method, _KERNEL, _EIGENVECTORS, _LAMBDA
                     in [('kCSD', KERNEL_KCSD, EIGENVECTORS_KCSD, LAMBDA_KCSD),
                         ('kESI', KERNEL_KESI, EIGENVECTORS_KESI, LAMBDA_KESI),
                         ]
                     }

In [None]:
%%time
IMAGE_ERRORS = []
IMAGE_ERRORS_CV = []

for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i}
    add_norms_to_dict(row, 'GT_{}', _CSD_GT)
    row_cv = row.copy()
    IMAGE_ERRORS.append(row)
    IMAGE_ERRORS_CV.append(row_cv)


    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL_KCSD, CROSSKERNEL_KCSD),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _ERRORS = common.cv(es_reconstructors[method],
                            V,
                            REGULARIZATION_PARAMETERS)
        for _row, _rp in [(row, 0),
                          (row_cv, REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)])]:
            add_norms_to_dict(_row,
                              f'ERR_{method}_{{}}',
                              reconstruct(_KERNEL, _CROSSKERNEL, V, _rp) - _CSD_GT)
    
del _CSD_GT, _ERRORS, _CROSSKERNEL, _KERNEL

IMAGE_ERRORS = pd.DataFrame(IMAGE_ERRORS)
IMAGE_ERRORS_CV = pd.DataFrame(IMAGE_ERRORS_CV)

## Reconstruction error plots (no regularization)

In [None]:
for norm in norms:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## Reconstruction error plots (regularization)

In [None]:
for norm in norms:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## compare approaches

In [None]:
for norm in norms:
    plt.figure()
    plt.title(norm)
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls='-',
                 label=method)
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=':',
                 label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

In [None]:
dx * 1e3, SRC_R_MAX * 1e3

In [None]:
class CardinalPlaneVisualisation(object):
    SPHERE_RADII = [0.079, 0.082, 0.086, 0.090]
    SPHERE_RADII_GT = [0.079, 0.080, 0.085, 0.090]

    def __init__(self,
                 grid,
                 plane_intersection,
                 dpi=17,
                 cmap=cbf.bwr,
                 amp=None,
                 length_factor=1,
                 length_unit='$m$',
                 unit_factor=1,
                 unit=''):
        self.grid = grid
        self.plane_intersection = np.array(plane_intersection)
        self.indices = [np.searchsorted(g, a)
                        for a, g in zip(plane_intersection,
                                        grid)]
        self.dpi = dpi
        self.cmap = cmap
        self.amp = amp
        self.length_factor = length_factor
        self.length_unit = length_unit
        self.unit_factor = unit_factor
        self.unit = unit

    def start_new_image(self, title, wx, wy, wz):
        self.fig = plt.figure(figsize=((wx + wy) / self.dpi,
                                       (wz + wy) / self.dpi))
        if title is not None:
            self.fig.suptitle(title)

        gs = plt.GridSpec(2, 2,
                          figure=self.fig,
                          width_ratios=[wx, wy],
                          height_ratios=[wz, wy])

        self.ax_xz = self.fig.add_subplot(gs[0, 0])
        self.ax_xz.set_aspect('equal')
        self.ax_xz.set_ylabel(f'Z [{self.length_unit}]')
        self.ax_xz.set_xlabel(f'X [{self.length_unit}]')

        self.ax_yx = self.fig.add_subplot(gs[1, 1])
        self.ax_yx.set_aspect('equal')
        self.ax_yx.set_ylabel(f'X [{self.length_unit}]')
        self.ax_yx.set_xlabel(f'Y [{self.length_unit}]')

        self.ax_yz = self.fig.add_subplot(gs[0, 1],
                                          sharey=self.ax_xz,
                                          sharex=self.ax_yx)
        self.ax_yz.set_aspect('equal')

        self.cax = self.fig.add_subplot(gs[1, 0])
        self.cax.set_visible(False)

    def finish_image(self):
        x, y, z = self.length_factor * self.plane_intersection

        self.ax_xz.axvline(x, ls=':', color=cbf.BLACK)
        self.ax_xz.axhline(z, ls=':', color=cbf.BLACK)

        self.ax_yx.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yx.axhline(x, ls=':', color=cbf.BLACK)

        self.ax_yz.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yz.axhline(z, ls=':', color=cbf.BLACK)
        self.fig.colorbar(self.im, ax=self.cax,
                          orientation='horizontal',
                          label=self.unit)

    def plot_volume(self, DATA, title=None, amp=None):
        self.start_new_image(title, *DATA.shape)
        ix, iy, iz = self.indices
        self._plot_planes([DATA[ix:ix+1, :, :],
                           DATA[:, iy:iy+1, :],
                           DATA[:, :, iz:iz+1],
                           ],
                           amp if amp is not None else abs(DATA).max())
        self.finish_image()

    def _plot_planes(self, DATA_PLANES, amp):
        DATA_ZY = DATA_PLANES[0][0, :, :].T * self.unit_factor
        DATA_ZX = DATA_PLANES[1][:, 0, :].T * self.unit_factor
        DATA_XY = DATA_PLANES[2][:, :, 0] * self.unit_factor
        
        def _extent(first, second):
            _first = self.grid[first] * self.length_factor
            _second = self.grid[second] * self.length_factor
            return (_first.min(), _first.max(),
                    _second.min(), _second.max())

        self.ax_xz.imshow(DATA_ZX,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(0, 2))
        self.ax_yx.imshow(DATA_XY,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(1, 0))
        self.im = self.ax_yz.imshow(DATA_ZY,
                                    vmin=-amp * self.unit_factor,
                                    vmax=amp * self.unit_factor,
                                    cmap=self.cmap,
                                    origin='lower',
                                    extent=_extent(1, 2))

    def plot_planes(self,
                    DATA_PLANES,
                    title=None,
                    amp=None):

        DATA_YZ, DATA_XZ, DATA_XY = DATA_PLANES
        wx, wy, _ = DATA_XY.shape
        wz = DATA_YZ.shape[2]
        assert DATA_YZ.shape[1] == wy
        assert DATA_XZ.shape[0] == wx
        assert DATA_XZ.shape[2] == wz
        
        self.start_new_image(title, wx, wy, wz)
        self._plot_planes(DATA_PLANES,
                          amp if amp is not None else max(abs(_A).max() for _A in DATA_PLANES))
        self.finish_image()

    def compare_with_gt(self, GT, CSD, title=''):
        ERROR = CSD - GT
        error_L2 = np.sqrt(np.square(ERROR).sum() / np.square(GT_CSD).sum())
        amp = max(abs(CSD).max(),
                  abs(GT).max(),
                  abs(ERROR).max())
        self.plot_volume(GT,
                         title='GT CSD',
                         amp=amp)
        self._add_spheres(self.SPHERE_RADII_GT)
        self.plot_volume(CSD,
                         title=f'{title} reconstruction',
                         amp=amp)
        self._add_spheres(self.SPHERE_RADII)
        self.plot_volume(ERROR,
                         title=f'{title} error (GT normalized L2 norm: {error_L2:.2g})',
                         amp=amp)

    def _add_spheres(self, sphere_radii):
        for c, ax in zip(self.plane_intersection,
                         [self.ax_yz,
                          self.ax_xz,
                          self.ax_yx]):
            for r2 in np.square(sphere_radii):
                self._plot_circle(ax, np.sqrt(r2 - np.square(c)))

    def _plot_circle(self, ax, r):
        ax.add_artist(plt.Circle((0, 0), r * self.length_factor,
                                facecolor='none',
                                edgecolor=cbf.BLACK,
                                linestyle=':'))

    @property
    def PLANES_XYZ(self):
        return [[[c] if i == j else A for j, A in enumerate(self.grid)]
                for i, c in enumerate(self.plane_intersection)]

In [None]:
csd_plotter = CardinalPlaneVisualisation([_x.flatten() for _x in convolver.CSD_GRID],
                                         [0, 0, 0.058],
                                         unit_factor=1e-3,
                                         unit='$\\frac{pA}{mm^3}$',
                                         length_factor=1e3,
                                         length_unit='$mm$')

In [None]:
# _ES = 1
_ES = 2
# _ES = 7

csd = {'GT': GT_CSD[_ES]}
_V = IMAGE[_ES]

plt.figure()
plt.title('CV')
plt.xscale('log')
plt.yscale('log')

for method, _KERNEL, _CROSSKERNEL in [
    ('kCSD', KERNEL_KCSD, CROSSKERNEL_KCSD),
    ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
    ]:
    _ERRORS = common.cv(es_reconstructors[method],
                        _V,
                        REGULARIZATION_PARAMETERS)
    regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]
    _l = plt.plot(REGULARIZATION_PARAMETERS,
                  _ERRORS,
                 label=method)
    plt.axvline(regularization_parameter,
                ls=':',
                color=_l[0].get_color())

    csd[method] = reconstruct(_KERNEL, _CROSSKERNEL, _V, regularization_parameter)
plt.legend(loc='best')

In [None]:
# x, y, z = 1e-3 * np.array([0, 0, 58])

for method, _CSD in csd.items():
    csd_plotter.plot_volume(_CSD, method)
    csd_plotter._add_spheres(csd_plotter.SPHERE_RADII)

    csd_plotter.ax_xz.set_xlim(-20, 20)
    csd_plotter.ax_yx.set_xlim(-20, 20)
    csd_plotter.ax_yz.set_xlim(-20, 20)
    
    csd_plotter.ax_xz.set_ylim(25, 90)
    csd_plotter.ax_yx.set_ylim(-20, 20)
    csd_plotter.ax_yz.set_ylim(25, 90)
    
    csd_plotter.ax_xz.scatter(ELECTRODES.X * 1e3, ELECTRODES.Z * 1e3, marker='x', color=cbf.BLACK)
    csd_plotter.ax_yx.scatter(ELECTRODES.Y * 1e3, ELECTRODES.X * 1e3, marker='x', color=cbf.BLACK)
    csd_plotter.ax_yz.scatter(ELECTRODES.Y * 1e3, ELECTRODES.Z * 1e3, marker='x', color=cbf.BLACK)

#     fig, axes = crude_plot_data(_CSD,
#                                 x=x, y=y, z=z,
#                                 grid=[c.flatten() for c in convolver.CSD_MESH])
#     fig.suptitle(method)

#     for ax, cx, cy, c in zip(axes[0] + axes[1][1:],
#                              'XYY',
#                              'ZZX',
#                              [y, x, z]):
#         ax.set_xlim(-0.02, 0.02)
#         ax.set_ylim(0.025 if cy == 'Z' else -0.02,
#                     0.09 if cy == 'Z' else 0.02)
#         ax.scatter(ELECTRODES[cx], ELECTRODES[cy], marker='x', color=cbf.BLACK)
#         for r, ls in [(90, '-'),
#                       (86, '-.'),
#                       (82, '--'),
#                       (79, ':'),
#                       ]:
#             ax.add_artist(plt.Circle((0,0),
#                                      radius=np.sqrt(np.square(r * 1e-3) - np.square(c)),
#                           ls=ls,
#                           edgecolor=cbf.BLACK,
#                           facecolor='none'))

In [None]:
for _ES, (_GT, _V) in enumerate(zip(GT_CSD, IMAGE), start=1):
    print(_ES)

    csd = {f'ES #{_ES} GT': _GT}
    
#     plt.figure()
#     plt.title('CV')
#     plt.xscale('log')
#     plt.yscale('log')

    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL_KCSD, CROSSKERNEL_KCSD),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        print(_ES, method)
        _ERRORS = common.cv(es_reconstructors[method],
                            _V,
                            REGULARIZATION_PARAMETERS)
        regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]
#         _l = plt.plot(REGULARIZATION_PARAMETERS,
#                       _ERRORS,
#                      label=method)
#         plt.axvline(regularization_parameter,
#                     ls=':',
#                     color=_l[0].get_color())

        csd[f'ES #{_ES} {method}'] = reconstruct(_KERNEL, _CROSSKERNEL, _V, regularization_parameter)
#     plt.legend(loc='best')

    for method, _CSD in csd.items():
        csd_plotter.plot_volume(_CSD, method)
        csd_plotter._add_spheres(csd_plotter.SPHERE_RADII)

        csd_plotter.ax_xz.set_xlim(-20, 20)
        csd_plotter.ax_yx.set_xlim(-20, 20)
        csd_plotter.ax_yz.set_xlim(-20, 20)

        csd_plotter.ax_xz.set_ylim(25, 90)
        csd_plotter.ax_yx.set_ylim(-20, 20)
        csd_plotter.ax_yz.set_ylim(25, 90)

        csd_plotter.ax_xz.scatter(ELECTRODES.X * 1e3, ELECTRODES.Z * 1e3, marker='x', color=cbf.BLACK)
        csd_plotter.ax_yx.scatter(ELECTRODES.Y * 1e3, ELECTRODES.X * 1e3, marker='x', color=cbf.BLACK)
        csd_plotter.ax_yz.scatter(ELECTRODES.Y * 1e3, ELECTRODES.Z * 1e3, marker='x', color=cbf.BLACK)