In [None]:
import datetime
import configparser
import os
import collections
import itertools

import numpy as np
import pandas as pd

import kesi
import kesi._verbose as verbose
import _common_new as common
import _fast_reciprocal_reconstructor as frr

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import cbf

In [None]:
MODEL_NAME = '4SM_CSF_3_mm'
MESH_NAME = 'uniform_normal'
CONFIG_NAME = 'comb'
DEGREE = 1
K = 9
GROUNDED_PLATE_AT = -0.088

FILENAME_PATTERN = f'FEM/solutions/paper/{MODEL_NAME}/{MESH_NAME}_{DEGREE}/{CONFIG_NAME}_sampled/{K}/{{name}}.npz'

# FEM

In [None]:
import dolfin
import FEM.fem_common as fc
import scipy.interpolate as si

In [None]:
FEM_MESH_NAME = 'uniform_coarse'
FEM_DEGREE = 3
FEM_CONFIG_GT = f'FEM/fem_configs/paper/{MODEL_NAME}/{FEM_MESH_NAME}_{FEM_DEGREE}/{CONFIG_NAME}.ini'

In [None]:
class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    def __init__(self, config):
        self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
        mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
            
            
        self.config = configparser.ConfigParser()
        self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] < GROUNDED_PLATE_AT))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
%time fem_gt = ForwardModel(FEM_CONFIG_GT)

# kernel construction

In [None]:
electrode_grid_names = [
#                         'A',
                        'B',
#                         'C',
                        'D',
#                         'E',
                        ]
electrode_names = [f'{x}_{i:02d}' for i in range(0, 12)
                   for x in electrode_grid_names]

In [None]:
class Electrode_kESI(object):
    def __init__(self, filename, decimals_tolerance=None, dx=0):
        self.filename = filename
        self.decimals_tolerance = decimals_tolerance
        self.dx = dx
        with np.load(filename) as fh:
            self._X = self.round(fh['X'])
            self._Y = self.round(fh['Y'])
            self._Z = self.round(fh['Z'])
            self.x, self.y, self.z = fh['LOCATION']
#             try:
            self.base_conductivity = fh['BASE_CONDUCTIVITY']
#             except KeyError:
#                 pass
    
    def round(self, A):
        if self.decimals_tolerance is None:
            return A
        return np.round(A, decimals=self.decimals_tolerance)

    def correction_potential(self, X, Y, Z):
        _X, IDX_X, _ = np.intersect1d(self._X, self.round(X[:, 0, 0]), return_indices=True)
        assert len(_X) == np.shape(X)[0]
        _Y, IDX_Y, _ = np.intersect1d(self._Y, self.round(Y[0, :, 0]), return_indices=True)
        assert len(_Y) == np.shape(Y)[1]
        _Z, IDX_Z, _ = np.intersect1d(self._Z, self.round(Z[0, 0, :]), return_indices=True)
        assert len(_Z) == np.shape(Z)[2]

        with np.load(self.filename) as fh:
            return fh['CORRECTION_POTENTIAL'][np.ix_(IDX_X, IDX_Y, IDX_Z)]

    def base_potential(self, X, Y, Z):
        return (0.25 / (np.pi * self.base_conductivity)
                / (self.dx * 0.15
                   + np.sqrt(np.square(X - self.x)
                             + np.square(Y - self.y)
                             + np.square(Z - self.z))))

        
class Electrode_kCSD(object):
    def __init__(self, filename):
        with np.load(filename) as fh:
            self.x, self.y, self.z = fh['LOCATION']

In [None]:
ELECTRODES = []
for name in electrode_names:
    electrode = Electrode_kESI(FILENAME_PATTERN.format(name=name))
    ELECTRODES.append({'NAME': name,
                       'X': electrode.x,
                       'Y': electrode.y,
                       'Z': electrode.z})
ELECTRODES = pd.DataFrame(ELECTRODES)

In [None]:
plt.scatter(ELECTRODES.X, ELECTRODES.Z, marker='.')
plt.gca().add_artist(plt.Circle((0,0), radius=0.090, ls=':', edgecolor=cbf.BLACK, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.085, ls=':', edgecolor=cbf.VERMILION, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.080, ls=':', edgecolor=cbf.BLUE, facecolor='none'))
plt.gca().add_artist(plt.Circle((0,0), radius=0.079, ls=':', edgecolor=cbf.PURPLE, facecolor='none'))
plt.xlim(-0.03, 0.03)
plt.ylim(0.04, 0.08)
plt.gca().set_aspect('equal')

In [None]:
set(np.diff(sorted(set(ELECTRODES.Z * 1000))))

In [None]:
ELECTRODES.Z.min()

In [None]:
XX = electrode._X
YY = electrode._Y
ZZ = electrode._Z

In [None]:
ROMBERG_K = 5

In [None]:
dx = (XX[-1] - XX[0]) / (len(XX) - 1)
SRC_R_MAX = (2**(ROMBERG_K - 1)) * dx
ROMBERG_N = 2**ROMBERG_K + 1
print(SRC_R_MAX)

In [None]:
electrodes_kesi = [Electrode_kESI(FILENAME_PATTERN.format(name=name),
                                  decimals_tolerance=16,
                                  dx=dx)
                   for name in electrode_names]
electrodes_kcsd = [Electrode_kCSD(FILENAME_PATTERN.format(name=name)) for name in electrode_names]

In [None]:
H_Y = 2e-2
H_X = 2e-2
X = XX[abs(XX) <= H_X + SRC_R_MAX + dx]
Y = YY[abs(YY) <= H_Y + SRC_R_MAX + dx]
Z = ZZ[(ZZ >= 2.5e-2 - SRC_R_MAX - dx)
       & (ZZ <= 7.9e-2)]

In [None]:
convolver = frr.ckESI_convolver([X,
                                 Y,
                                 Z],
                                [X,
                                 Y,
                                 Z])

In [None]:
sd = SRC_R_MAX / 3
def source(x, y, z):
    return common.SphericalSplineSourceKCSD(x, y, z,
#                                             [SRC_R_MAX],
#                                             [[1]],
                                             [sd, 3 * sd],
                                             [[1],
                                              [0,
                                               2.25 / sd,
                                               -1.5 / sd ** 2,
                                               0.25 / sd ** 3]],
                                             electrode.base_conductivity)

model_src = source(0, 0, 0)

In [None]:
SRC_IDX = (((convolver.SRC_Z < convolver.SRC_Z.max() - SRC_R_MAX)
            & (convolver.SRC_Z > convolver.SRC_Z.min() + SRC_R_MAX))
            & (abs(convolver.SRC_Y) <= H_Y)
            & (abs(convolver.SRC_X) <= H_X)
            & (np.square(convolver.SRC_X)
               + np.square(convolver.SRC_Y)
               + np.square(convolver.SRC_Z)
                < np.square(0.079 - SRC_R_MAX))
            )

In [None]:
SRC_IDX.sum(), SRC_IDX.shape

In [None]:
CSD_IDX = np.zeros(convolver.shape('CSD'),
                   dtype=bool)

# Kernels

Warning: no subtraction of kCSD out of the slice possible

In [None]:
from scipy.integrate import romb

ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N)) * 2 ** -ROMBERG_K

convolver_interface = frr.ConvolverInterfaceIndexed(convolver,
                                                    model_src.csd,
                                                    ROMBERG_WEIGHTS,
                                                    SRC_IDX)

kernel_constructor = frr.ckESI_kernel_constructor()

kernel_constructor.create_crosskernel = frr.ckESI_crosskernel_constructor(convolver_interface,
                                                                          CSD_IDX)
pae_kcsd = frr.PAE_kCSD_Analytical(convolver_interface,
                                   potential=model_src.potential)
pae_kesi = frr.PAE_kESI_Analytical(convolver_interface,
                                   potential=model_src.potential)

In [None]:
%%time
FWD = kernel_constructor.create_base_images_at_electrodes(electrodes_kesi,
                                                          pae_kesi)

In [None]:
%%time
PHI = kernel_constructor.create_base_images_at_electrodes(electrodes_kcsd,
                                                          pae_kcsd)

# kernel analysis

In [None]:
KERNEL = kernel_constructor.create_kernel(PHI)

EIGENVALUES, EIGENVECTORS = np.linalg.eigh(KERNEL)
EIGENVALUES, EIGENVECTORS = EIGENVALUES[::-1], EIGENVECTORS[:, ::-1]
LAMBDA = np.sqrt(EIGENVALUES)
EIGENSOURCES = np.matmul(PHI,
                         np.matmul(EIGENVECTORS,
                                   np.diag(1. / LAMBDA)))

In [None]:
KERNEL_KESI = kernel_constructor.create_kernel(FWD)
EIGENVALUES_KESI, EIGENVECTORS_KESI = np.linalg.eigh(KERNEL_KESI)
EIGENVALUES_KESI, EIGENVECTORS_KESI = EIGENVALUES_KESI[::-1], EIGENVECTORS_KESI[:, ::-1]
LAMBDA_KESI = np.sqrt(EIGENVALUES_KESI)
EIGENSOURCES_KESI = np.matmul(FWD,
                              np.matmul(EIGENVECTORS_KESI,
                                        np.diag(1. / LAMBDA_KESI)))

In [None]:
KCSD_TO_KESI_ES = np.matmul(EIGENSOURCES_KESI.T, EIGENSOURCES)

In [None]:
assert (KCSD_TO_KESI_ES.max(axis=0) > 0).all()

In [None]:
plt.plot(abs(np.diag(np.matmul(EIGENVECTORS_KESI.T, EIGENVECTORS))), marker='.')
plt.plot(abs(np.diag(KCSD_TO_KESI_ES)), marker='.')
plt.ylim(0.94, 1.00)

In [None]:
for i, KESI_PROJECTION in enumerate(KCSD_TO_KESI_ES.T):
    _idx = np.argmax(abs(KESI_PROJECTION))
    if KESI_PROJECTION[_idx] > 0:
        EIGENSOURCES[:, i] += EIGENSOURCES_KESI[:, _idx]
    else:
        EIGENSOURCES[:, i] -= EIGENSOURCES_KESI[:, _idx]

In [None]:
del EIGENSOURCES_KESI

## crosskernels

In [None]:
%%time
_SRC = np.zeros(convolver.shape('SRC'))
CROSSKERNEL = np.empty(convolver.shape('CSD') + (len(KERNEL),))

for i, _SRC[SRC_IDX] in enumerate(PHI.T):
    print(i)
    CROSSKERNEL[:, :, :, i] = convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3)
    
del _SRC

In [None]:
%%time
_SRC = np.zeros(convolver.shape('SRC'))
CROSSKERNEL_KESI = np.empty(convolver.shape('CSD') + (len(KERNEL_KESI),))

for i, _SRC[SRC_IDX] in enumerate(FWD.T):
    print(i)
    CROSSKERNEL_KESI[:, :, :, i] = convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3)
    
del _SRC

In [None]:
del PHI, FWD

# IMAGES

In [None]:
%%time
GT_CSD = []

_SRC = np.zeros(convolver.shape('SRC'))
for i, _SRC[SRC_IDX] in enumerate(EIGENSOURCES.T):
    print(i)
    GT_CSD.append(convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3))
    
del _SRC

In [None]:
%%time
IMAGE = []

for i, _CSD in enumerate(GT_CSD):
    print(i)
    _csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  _CSD,
                                  bounds_error=False,
                                  fill_value=0)
    _v = fem_gt(_csd)
    IMAGE.append(np.array(list(map(_v, ELECTRODES.X, ELECTRODES.Y, ELECTRODES.Z))))


In [None]:
del _CSD, _csd, _v

In [None]:
for i, V in enumerate(IMAGE):
    plt.figure()
    plt.title(i)
    plt.axhline(1, ls=':', color=cbf.BLACK)
    plt.axvline(i, ls=':', color=cbf.BLACK)
    plt.stem(np.matmul(V, EIGENVECTORS) / LAMBDA)
    plt.plot(np.matmul(V, EIGENVECTORS_KESI) / LAMBDA_KESI,
             marker='x')

## Images error (reconstructed)

In [None]:
%%time
IMAGE_ERRORS = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS.append(row)
    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL, CROSSKERNEL),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL, V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    print(row)
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS = pd.DataFrame(IMAGE_ERRORS)

In [None]:
IMAGE_ERRORS

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## Images error (reconstructed regularized)

In [None]:
plt.plot(EIGENVALUES)
plt.plot(EIGENVALUES_KESI)
plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(3, 17, 5 * 14 + 1)

In [None]:
es_reconstructors = {method: verbose.VerboseFFR._CrossKernelReconstructor(
                                                     kesi._engine._LinearKernelSolver(
                                                         _KERNEL),
                                                     np.matmul(np.diag(_LAMBDA),
                                                               _EIGENVECTORS.T))
                     for method, _KERNEL, _EIGENVECTORS, _LAMBDA
                     in [('kCSD', KERNEL, EIGENVECTORS, LAMBDA),
                         ('kESI', KERNEL_KESI, EIGENVECTORS_KESI, LAMBDA_KESI),
                         ]
                     }

In [None]:
%%time
IMAGE_ERRORS_CV = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS_CV.append(row)

    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL, CROSSKERNEL),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _ERRORS = common.cv(es_reconstructors[method],
                            V,
                            REGULARIZATION_PARAMETERS)
        regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]

        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL
                                          + regularization_parameter * np.identity(len(_KERNEL)),
                                          V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS_CV = pd.DataFrame(IMAGE_ERRORS_CV)

In [None]:
IMAGE_ERRORS_CV

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## compare approaches

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    assert abs(IMAGE_ERRORS[f'GT_{norm}'] - IMAGE_ERRORS_CV[f'GT_{norm}']).max() == 0

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls='-',
                 label=method)
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=':',
                 label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

In [None]:
def crude_plot_data(DATA,
                    x=None,
                    y=None,
                    z=None,
                    grid=None,
                    dpi=30,
                    cmap=cbf.bwr,
                    title=None,
                    amp=None):
    wx, wy, wz = DATA.shape
    
    if grid is None:
        ix, iy, iz = [w // 2 if a is None else a
                      for a, w in zip([x, y, z],
                                      [wx, wy, wz])]
        x, y, z = ix, iy, iz
        
    else:
        x, y, z = [g.mean() if a is None else a
                   for a, g in zip([x, y, z],
                                   grid)]
        ix, iy, iz = [np.searchsorted(g, a)
                      for a, g in zip([x, y, z],
                                   grid)]

    fig = plt.figure(figsize=((wx + wy) / dpi,
                              (wz + wy) / dpi))
    if title is not None:
        fig.suptitle(title)
    gs = plt.GridSpec(2, 2,
                      figure=fig,
                      width_ratios=[wx, wy],
                      height_ratios=[wz, wy])

    ax_xz = fig.add_subplot(gs[0, 0])
    ax_xz.set_aspect('equal')
    ax_xz.set_ylabel('Z')
    ax_xz.set_xlabel('X')

    ax_yx = fig.add_subplot(gs[1, 1])
    ax_yx.set_aspect('equal')
    ax_yx.set_ylabel('X')
    ax_yx.set_xlabel('Y')

    ax_yz = fig.add_subplot(gs[0, 1],
                            sharey=ax_xz,
                            sharex=ax_yx)
    ax_yz.set_aspect('equal')

    cax = fig.add_subplot(gs[1, 0])
    cax.set_visible(False)
#     cax.get_xaxis().set_visible(False)
#     cax.get_yaxis().set_visible(False)


    if amp is None:
        amp = abs(DATA).max()

    if grid is None:
        ax_xz.imshow(DATA[:, iy, :].T,
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     origin='lower')
        ax_yx.imshow(DATA[:, :, iz],
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     origin='lower')
        im = ax_yz.imshow(DATA[ix, :, :].T,
                          vmin=-amp,
                          vmax=amp,
                          cmap=cmap,
                          origin='lower')
    else:
        ax_xz.imshow(DATA[:, iy, :].T,
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     origin='lower',
                     extent=(grid[0].min(), grid[0].max(),
                             grid[2].min(), grid[2].max()))
        ax_yx.imshow(DATA[:, :, iz],
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     origin='lower',
                     extent=(grid[1].min(), grid[1].max(),
                             grid[0].min(), grid[0].max()))
        im = ax_yz.imshow(DATA[ix, :, :].T,
                          vmin=-amp,
                          vmax=amp,
                          cmap=cmap,
                          origin='lower',
                          extent=(grid[1].min(), grid[1].max(),
                                  grid[2].min(), grid[2].max()))
        
    ax_xz.axvline(x, ls=':', color=cbf.BLACK)
    ax_xz.axhline(z, ls=':', color=cbf.BLACK)

    ax_yx.axvline(y, ls=':', color=cbf.BLACK)
    ax_yx.axhline(x, ls=':', color=cbf.BLACK)

    ax_yz.axvline(y, ls=':', color=cbf.BLACK)
    ax_yz.axhline(z, ls=':', color=cbf.BLACK)
    fig.colorbar(im, ax=cax)

    return (fig, ((ax_xz, ax_yz),
                  (cax, ax_yx)))

In [None]:
# _ES = 1
_ES = 2
# _ES = 7

csd = {'GT': GT_CSD[_ES]}
_V = IMAGE[_ES]

plt.figure()
plt.title('CV')
plt.xscale('log')
plt.yscale('log')

for method, _KERNEL, _CROSSKERNEL in [
    ('kCSD', KERNEL, CROSSKERNEL),
    ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
    ]:
    _ERRORS = common.cv(es_reconstructors[method],
                        _V,
                        REGULARIZATION_PARAMETERS)
    regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]
    _l = plt.plot(REGULARIZATION_PARAMETERS,
                  _ERRORS,
                 label=method)
    plt.axvline(regularization_parameter,
                ls=':',
                color=_l[0].get_color())

    csd[method] = np.matmul(_CROSSKERNEL,
                            np.linalg.solve(_KERNEL
                                            + regularization_parameter * np.identity(len(_KERNEL)),
                                            _V))
plt.legend(loc='best')

In [None]:
x, y, z = 1e-3 * np.array([0, 0, 58])

for method, _CSD in csd.items():
    fig, axes = crude_plot_data(_CSD,
                                x=x, y=y, z=z,
                                grid=[c.flatten() for c in convolver.CSD_MESH])
    fig.suptitle(method)

    for ax, cx, cy, c in zip(axes[0] + axes[1][1:],
                             'XYY',
                             'ZZX',
                             [y, x, z]):
        ax.set_xlim(-0.02, 0.02)
        ax.set_ylim(0.025 if cy == 'Z' else -0.02,
                    0.09 if cy == 'Z' else 0.02)
        ax.scatter(ELECTRODES[cx], ELECTRODES[cy], marker='x', color=cbf.BLACK)
        for r, ls in [(90, '-'),
                      (86, '-.'),
                      (82, '--'),
                      (79, ':'),
                      ]:
            ax.add_artist(plt.Circle((0,0),
                                     radius=np.sqrt(np.square(r * 1e-3) - np.square(c)),
                          ls=ls,
                          edgecolor=cbf.BLACK,
                          facecolor='none'))

In [None]:
%%time
N_NOISE = 100
NOISY_IMAGE_ERRORS_CV = []
NOISY_IMAGE_ERRORS = []

norms = {'L1': lambda x: np.abs(x).mean(),
         'L2': lambda x: np.sqrt(np.square(x).mean()),
         'Linf': lambda x: np.abs(x).max(),
         }

np.random.seed(42)
NOISE = np.random.normal(scale=0.05,
                         size=(N_NOISE, len(electrodes_kcsd)))

for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row_cv = collections.defaultdict(list,
         {'ES': i,
          'GT_L1': abs(_CSD_GT).mean(),
          'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
          'GT_Linf': abs(_CSD_GT).max(),
          })

    row = row_cv.copy()
    NOISY_IMAGE_ERRORS_CV.append(row_cv)
    NOISY_IMAGE_ERRORS.append(row)

#     for _NOISE in NOISE * np.sqrt(np.square(V).sum()):
    for _NOISE in NOISE * V.std():
        NOISY_V = V + _NOISE


        for method, _KERNEL, _CROSSKERNEL in [
            ('kCSD', KERNEL, CROSSKERNEL),
            ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
            ]:
            _DIFF = np.matmul(_CROSSKERNEL,
                              np.linalg.solve(_KERNEL,
                                              NOISY_V)) - _CSD_GT
            for norm, _f in norms.items():
                row[f'ERR_{method}_{norm}'].append(_f(_DIFF))

            _ERRORS = common.cv(es_reconstructors[method],
                                NOISY_V,
                                REGULARIZATION_PARAMETERS)
            regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]

            _DIFF = np.matmul(_CROSSKERNEL,
                              np.linalg.solve(_KERNEL
                                              + regularization_parameter * np.identity(len(_KERNEL)),
                                              NOISY_V)) - _CSD_GT
            for norm, _f in norms.items():
                row_cv[f'ERR_{method}_{norm}'].append(_f(_DIFF))
                  
    for _row, method, norm in itertools.product([row, row_cv],
                                                ['kCSD', 'kESI'],
                                                norms):
        _ERRORS = _row[f'ERR_{method}_{norm}']
        for q in [25, 50,
                  100, 250, 500, 750, 900,
                  950, 975,
                  ]:
            _row[f'ERR_{method}_{norm}_{q:03d}'] = np.quantile(_ERRORS, q * 1e-3)

        _row[f'ERR_{method}_{norm}'] = norms[norm](_ERRORS)

NOISY_IMAGE_ERRORS_CV = pd.DataFrame(NOISY_IMAGE_ERRORS_CV)
NOISY_IMAGE_ERRORS = pd.DataFrame(NOISY_IMAGE_ERRORS)

del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
for norm in norms:
    plt.figure()
    plt.title(norm)
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        plt.plot(NOISY_IMAGE_ERRORS.ES,
                 NOISY_IMAGE_ERRORS[f'ERR_{method}_{norm}'] / NOISY_IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls='-',
                 label=method)
        plt.plot(NOISY_IMAGE_ERRORS_CV.ES,
                 NOISY_IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / NOISY_IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=':',
                 label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

In [None]:
BOTTOM = 250
TOP = 750
# BOTTOM = 25
# TOP = 975

for norm in norms:
    plt.figure()
    plt.title(f'Noisy ({norm}; {(TOP - BOTTOM) / 10:g}%CI)')
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        for q in [TOP, BOTTOM]:
            plt.plot(NOISY_IMAGE_ERRORS.ES,
                     NOISY_IMAGE_ERRORS[f'ERR_{method}_{norm}_{q:03d}'] / NOISY_IMAGE_ERRORS[f'GT_{norm}'],
                     color=color,
                     ls='-',
                     label=method)
            plt.plot(NOISY_IMAGE_ERRORS_CV.ES,
                     NOISY_IMAGE_ERRORS_CV[f'ERR_{method}_{norm}_{q:03d}'] / NOISY_IMAGE_ERRORS_CV[f'GT_{norm}'],
                     color=color,
                     ls=':',
                     label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')