Please note that this tutorial is focused at reconstructing CSD timeseries
at a subset of the _CSD_ grid.  For sake of simplicity it uses kCSD
(cross)kernels only.
To learn, how to create kESI (cross)kernels please consult
_tutorial\_*\_basics\_explained.ipynb_.  To compare the reconstructed CSD
with kCSD reconstruction at all nodes of the _CSD_ grid please run
one of _tutorial\_slice\[\_basics\_explained\].ipynb_ notebooks.

# Requirements

## Memory

The code of the notebook requires at least 1.8 GB (1.7 GiB) of free memory.


## Environment

1. Anaconda Python distribution (tested with _Miniconda3-py39\_4.12.0-Linux-x86\_64.sh_, _conda v. 4.12.0_).
2. Jupyter server (see _extras/jupyter\_server.sh_ for details).
3. Anaconda environments (run _setup\_conda\_envs.sh_).

# FRR: Fast Reciprocal Reconstructor kernel construction tools

## Electrode object

The implementation of the electrode object is minimal necessary for construction of a kCSD (cross)kernel.

In [None]:
import collections

Electrode = collections.namedtuple('Electrode',
                                   ['x', 'y', 'z', 'conductivity'])

We use the same positions of electrodes as _tutorial\_slice\[\_basics\_explained\].ipynb_ notebooks.

In [None]:
CONDUCTIVITY = 0.3  # S/m

ELECTRODES_XYZ = [(0.0, 0.0, 5e-05),
                  (5e-05, 0.0, 0.00015),
                  (5e-05, -5e-05, 0.00025)]

electrodes = [Electrode(x, y, z, CONDUCTIVITY) for x, y, z in ELECTRODES_XYZ]

## Model source

We want to use CSD bases 36μm wide ($R = 18\mu{}m$).

In [None]:
from common import SphericalSplineSourceKCSD

SRC_R = 18e-6  # m

spline_nodes = [SRC_R / 3, SRC_R]
spline_polynomials = [[1],
                      [0,
                       6.75 / SRC_R,
                       -13.5 / SRC_R ** 2,
                       6.75 / SRC_R ** 3]]
model_src = SphericalSplineSourceKCSD(0, 0, 0,
                                      spline_nodes,
                                      spline_polynomials)

## Convolver object

In [None]:
import numpy as np
from kesi.kernel.constructor import Convolver

ROMBERG_K = 5
Z_MIN = 0
Z_MAX = 3e-4
XY_AMP = 1.5e-4

_h_min = SRC_R * 2**(1 - ROMBERG_K)
_X = _Y = np.linspace(-XY_AMP, XY_AMP, int(np.floor(2 * XY_AMP / _h_min)) + 1)
_Z = np.linspace(Z_MIN, Z_MAX, int(np.floor((Z_MAX - Z_MIN) / _h_min)) + 1)

_csd_grid = _pot_grid = [_X, _Y, _Z]

convolver = Convolver(_pot_grid, _csd_grid)

for _h in convolver.steps('POT'):
    assert _h >= _h_min, f'{_h} < {_h_min}'

## Convolver interface

In [None]:
from kesi.kernel.constructor import ConvolverInterfaceIndexed
from scipy.integrate import romb

ROMBERG_N = 2 ** ROMBERG_K + 1
ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N), dx=2 ** -ROMBERG_K)

SRC_MASK = ((convolver.SRC_Z > Z_MIN + SRC_R)
            & (convolver.SRC_Z < Z_MAX - SRC_R)
            & (abs(convolver.SRC_X) < XY_AMP - SRC_R)
            & (abs(convolver.SRC_Y) < XY_AMP - SRC_R))

In [None]:
print(SRC_MASK.sum())

In [None]:
convolver_interface = ConvolverInterfaceIndexed(convolver,
                                                model_src.csd,
                                                ROMBERG_WEIGHTS,
                                                SRC_MASK)

## Potential At Electrode object

### Potential At Electrode: analytical solution of the forward problem (kCSD)

In [None]:
from kesi.kernel import potential_basis_functions as pbf

In [None]:
pbf_kcsd = pbf.Analytical(convolver_interface,
                          potential=model_src.potential)

## Kernel constructor and cross-kernel constructor

In [None]:
from kesi.kernel.constructor import KernelConstructor, CrossKernelConstructor

kernel_constructor = KernelConstructor()

To calculate the cross-kernel matrix we need to select nodes of the _CSD_ grid.
We are going to visualise current source density in the cardinal planes,
thus in the boolean mask we select the closest nodes to the planes.
First we unequivocally define the planes by their intersection point.

In [None]:
cardinal_plane_x = 5e-5
cardinal_plane_y = 0
cardinal_plane_z = 1.5e-4

cardinal_plane_intersection = [cardinal_plane_x,
                               cardinal_plane_y,
                               cardinal_plane_z]

We find indices of the closest (to the intersection point) node of the _CSD_ grid.

In [None]:
cardinal_plane_indices = [np.argmin(abs(_C - _c))
                          for _C, _c in zip(convolver.CSD_GRID,
                                            cardinal_plane_intersection)]

With the indices we select the closest (to the cardinal planes) nodes of the _CSD_ grid.
<!-- We define an auxilary function `one_hot(i, n)` which returns `n`-long vector which all elements but `i`-th are `0` (and the `i`-th element is `1`). -->

In [None]:
CSD_MASK = np.zeros(convolver.csd_shape,
                    dtype=bool)
CSD_MASK[cardinal_plane_indices[0], :, :] = True
CSD_MASK[:, cardinal_plane_indices[1], :] = True
CSD_MASK[:, :, cardinal_plane_indices[2]] = True

# def one_hot(i, n):
#     return np.arange(n) == i

# _masks = [np.reshape(one_hot(_idx, n),
#                      np.where(one_hot(i, 3), -1, 1))
#           for i, (_idx, n) in enumerate(zip(cardinal_plane_indices,
#                                             convolver.csd_shape))]

# CSD_MASK = _masks[0] | _masks[1] | _masks[2]

We count the selected nodes.

In [None]:
n_csd_nodes = CSD_MASK.sum()
print(f'{n_csd_nodes} nodes of the CSD grid selected.')

We also keep the coordinates of the selected nodes for further use.

In [None]:
CSD_X, CSD_Y, CSD_Z = [A[CSD_MASK] for A in convolver_interface.meshgrid('CSD')]

We use the `CSD_MASK` to create a cross-kernel constructor.

In [None]:
kernel_constructor.create_crosskernel = CrossKernelConstructor(convolver_interface,
                                                               CSD_MASK)

To retrieve three CSD planes from CSD vector we define an auxilary function `to_planes()`.
The function uses three index arrays to select (and arrange) appropriate elements of the vector.

In [None]:
# As we reconstruct CSD at n_csd_nodes points,
# n_csd_nodes is invalid index value for the
# reconstructed CSD vector.

_CSD_IDX = np.full_like(CSD_MASK, n_csd_nodes,
                        dtype=np.int32)
_CSD_IDX[CSD_MASK] = np.arange(n_csd_nodes)

CARDINAL_PLANE_INDICES = [_CSD_IDX[cardinal_plane_indices[0], :, :].copy(),
                          _CSD_IDX[:, cardinal_plane_indices[1], :].copy(),
                          _CSD_IDX[:, :, cardinal_plane_indices[2]].copy()
                          ]
del _CSD_IDX

# We test, whether all indices are valid.

for _A in CARDINAL_PLANE_INDICES:
    assert _A.min() >= 0 and _A.max() < CSD_MASK.sum()
    
def to_planes(CSD):
    return [CSD[IDX] for IDX in CARDINAL_PLANE_INDICES]

# Reconstructor

## Construction of kernels

In [None]:
%%time
PHI = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                          pbf_kcsd)

In [None]:
KERNEL = kernel_constructor.create_kernel(PHI)

In [None]:
%%time
CROSSKERNEL = kernel_constructor.create_crosskernel(PHI)

In [None]:
del PHI  # the array is large and no longer needed

## Reconstructor object

In [None]:
from kesi._verbose import _CrossKernelReconstructor as Reconstructor
from kesi._engine import _LinearKernelSolver as KernelSolver

reconstructor = Reconstructor(KernelSolver(KERNEL),
                              CROSSKERNEL)

# Visualisation

In [None]:
import matplotlib.pyplot as plt
import cbf

class CardinalPlaneVisualisation(object):
    def __init__(self,
                 grid,
                 plane_intersection,
                 dpi=35,
                 cmap=cbf.bwr,
                 amp=None,
                 length_factor=1,
                 length_unit='$m$',
                 unit_factor=1,
                 unit=''):
        self.grid = grid
        self.plane_intersection = np.array(plane_intersection)
        self.dpi = dpi
        self.cmap = cmap
        self.amp = amp
        self.length_factor = length_factor
        self.length_unit = length_unit
        self.unit_factor = unit_factor
        self.unit = unit
    
    def start_new_image(self, title, wx, wy, wz):
        self.fig = plt.figure(figsize=((wx + wy) / self.dpi,
                                       (wz + wy) / self.dpi))
        if title is not None:
            self.fig.suptitle(title)

        gs = plt.GridSpec(2, 2,
                          figure=self.fig,
                          width_ratios=[wx, wy],
                          height_ratios=[wz, wy])

        self.ax_xz = self.fig.add_subplot(gs[0, 0])
        self.ax_xz.set_aspect('equal')
        self.ax_xz.set_ylabel(f'Z [{self.length_unit}]')
        self.ax_xz.set_xlabel(f'X [{self.length_unit}]')

        self.ax_yx = self.fig.add_subplot(gs[1, 1])
        self.ax_yx.set_aspect('equal')
        self.ax_yx.set_ylabel(f'X [{self.length_unit}]')
        self.ax_yx.set_xlabel(f'Y [{self.length_unit}]')

        self.ax_yz = self.fig.add_subplot(gs[0, 1],
                                          sharey=self.ax_xz,
                                          sharex=self.ax_yx)
        self.ax_yz.set_aspect('equal')

        self.cax = self.fig.add_subplot(gs[1, 0])
        self.cax.set_visible(False)

    def finish_image(self):
        x, y, z = self.length_factor * self.plane_intersection

        self.ax_xz.axvline(x, ls=':', color=cbf.BLACK)
        self.ax_xz.axhline(z, ls=':', color=cbf.BLACK)

        self.ax_yx.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yx.axhline(x, ls=':', color=cbf.BLACK)

        self.ax_yz.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yz.axhline(z, ls=':', color=cbf.BLACK)
        self.fig.colorbar(self.im, ax=self.cax,
                          orientation='horizontal',
                          label=self.unit)

    def _plot_planes(self, DATA_PLANES, amp):
        DATA_ZY = DATA_PLANES[0].T * self.unit_factor
        DATA_ZX = DATA_PLANES[1].T * self.unit_factor
        DATA_XY = DATA_PLANES[2] * self.unit_factor
        
        def _extent(first, second):
            _first = self.grid[first] * self.length_factor
            _second = self.grid[second] * self.length_factor
            return (_first.min(), _first.max(),
                    _second.min(), _second.max())

        self.ax_xz.imshow(DATA_ZX,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(0, 2))
        self.ax_yx.imshow(DATA_XY,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(1, 0))
        self.im = self.ax_yz.imshow(DATA_ZY,
                                    vmin=-amp * self.unit_factor,
                                    vmax=amp * self.unit_factor,
                                    cmap=self.cmap,
                                    origin='lower',
                                    extent=_extent(1, 2))

    def plot_planes(self,
                    DATA_PLANES,
                    title=None,
                    amp=None):

        DATA_YZ, DATA_XZ, DATA_XY = DATA_PLANES
        wx, wy = DATA_XY.shape
        wz = DATA_YZ.shape[1]
        assert DATA_YZ.shape[0] == wy
        assert DATA_XZ.shape[0] == wx
        assert DATA_XZ.shape[1] == wz
        
        self.start_new_image(title, wx, wy, wz)
        self._plot_planes(DATA_PLANES,
                          amp if amp is not None else max(abs(_A).max() for _A in DATA_PLANES))
        self.finish_image()

In [None]:
csd_plotter = CardinalPlaneVisualisation([_x.flatten() for _x in convolver.CSD_GRID],
                                         cardinal_plane_intersection,
                                         unit_factor=1e-12,
                                         unit='$\\frac{\\mu{}A}{mm^3}$',
                                         length_factor=1e6,
                                         length_unit='$\\mu{}m$')

# Reconstruction

Potential values (given in $\mu{}V$) are stored in a $N \times T$ matrix
(here $N = T = 3$).  Each of its $N$ rows corresponds to an electrode,
while each of its $T$ columns corresponds to a timepoint.
Values for the first timepoint were calculated from ground truth CSD
in the _tutorial\_slice.ipynb_ notebook (`GT_V` there).

In [None]:
POTENTIAL = np.array([[-126548.99282495,  -28969.86519294,    7709.33415335],
                      [-119140.53771587,   -3543.25964668,  -11594.77447688],
                      [ -73225.23867845,   15868.45761855,    5937.64717958]])

The reconstructor returns a $\tilde{\underline{N}} \times T$ matrix of CSD values.
Each of its $T$ columns corresponds to a timepoint
while each of its $\tilde{\underline{N}}$ rows corresponds to a selectrd node
of the _CSD_ grid.

In [None]:
%%time
CSD = reconstructor(POTENTIAL)

Its first column should be same as the kCSD reconstruction
in _tutorial\_slice\[\_basics\_explained\].ipynb_ notebooks.

In [None]:
csd_plotter.plot_planes(to_planes(CSD[:, 0]),
                        title='kCSD reconstruction from slice tutorial notebook')

We may also plot the remaining timepoints.

In [None]:
for i in [1, 2]:
    csd_plotter.plot_planes(to_planes(CSD[:, i]),
                            title=f'kCSD reconstruction at timepoint #{i}')

## Regularization

In [None]:
EIGENVALUES = np.linalg.eigvalsh(KERNEL)[::-1]

plt.plot(EIGENVALUES,
         marker='o')

plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(10, 20, 10 * 10 + 1)

### One-leave-out cross-validation

We can choose one regularization parameter for all timepoints.

In [None]:
from common import cv

In [None]:
%%time
CV_ERRORS = cv(reconstructor, POTENTIAL, REGULARIZATION_PARAMETERS)

In [None]:
regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(CV_ERRORS)]

In [None]:
plt.plot(REGULARIZATION_PARAMETERS,
         CV_ERRORS,
         color=cbf.BLUE)
plt.axvline(regularization_parameter,
            ls=(0, (1, 2)),
            color=cbf.BLUE)
plt.xscale('log')
plt.xlabel('regularization parameter')
plt.yscale('log')
plt.ylabel('L2 norm of cross-validation error')

As the smallest value of `REGULARIZATION_PARAMETERS` has been chosen (which is 4 orders of magnitude smaller than any of the kernel eigenvalues), it seems that cross-validation is against regularization.

We can also try to calculate a separate regularization parameter for each timepoint.

In [None]:
for i, (color, ls) in enumerate([(cbf.BLUE, '-'),
                                 (cbf.VERMILION, '--'),
                                 (cbf.GREEN, ':'),
                                ]):
    _CV_ERRORS = cv(reconstructor, POTENTIAL[:, i], REGULARIZATION_PARAMETERS)
    _regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_CV_ERRORS)]
    plt.plot(REGULARIZATION_PARAMETERS,
             _CV_ERRORS,
             color=color,
             ls=ls,
             label=f'timepoint #{i}')
    plt.axvline(_regularization_parameter,
                ls=(i, (1, 2)),
                color=color)

plt.xscale('log')
plt.xlabel('regularization parameter')
plt.yscale('log')
plt.ylabel('L2 norm of cross-validation error')
plt.legend(loc='lower right')

Once again, extreme values of regularization parameter are selected.