Please note that this tutorial is focused at reconstructing CSD timeseries
at a subset of the _CSD_ grid.  For sake of simplicity it uses kCSD
(cross)kernels only.
To learn, how to create kESI (cross)kernels please consult
`tutorial_*_basics_explained.ipynb`.  To compare the reconstructed CSD
with kCSD reconstruction at all nodes of the _CSD_ grid please run
one of `tutorial_slice[_basics_explained].ipynb` notebooks.

# Requirements

## Memory

The code of the notebook requires at least 1.8 GB (1.7 GiB) of free memory.


## Environment

1. Anaconda Python distribution (tested with `Miniconda3-py39_4.12.0-Linux-x86_64.sh`, _conda v. 4.12.0_).
2. Jupyter server (see `extras/jupyter_server.sh` for details).
3. Anaconda environments (run `setup_conda_envs.sh`).

# FRR: Fast Reciprocal Reconstructor kernel construction tools

## Electrode object

The implementation of the electrode object is minimal necessary for construction of a kCSD (cross)kernel.

In [None]:
import collections

Electrode = collections.namedtuple('Electrode',
                                   ['x', 'y', 'z', 'conductivity'])

We use the same positions of electrodes as `tutorial_slice[_basics_explained].ipynb` notebooks.

In [None]:
CONDUCTIVITY = 0.3  # S/m

ELECTRODES_XYZ = [(0.0, 0.0, 5e-05),
                  (5e-05, 0.0, 0.00015),
                  (5e-05, -5e-05, 0.00025)]

electrodes = [Electrode(x, y, z, CONDUCTIVITY) for x, y, z in ELECTRODES_XYZ]

## Model source

We want to use CSD bases 36μm wide ($R = 18\mu{}m$).

In [None]:
from kesi.common import SphericalSplineSourceKCSD

SRC_R = 18e-6  # m

spline_nodes = [SRC_R / 3, SRC_R]
spline_polynomials = [[1],
                      [0,
                       6.75 / SRC_R,
                       -13.5 / SRC_R ** 2,
                       6.75 / SRC_R ** 3]]
model_src = SphericalSplineSourceKCSD(0, 0, 0,
                                      spline_nodes,
                                      spline_polynomials)

## Convolver object

In [None]:
import numpy as np
from kesi.kernel.constructor import Convolver

ROMBERG_K = 5
Z_MIN = 0
Z_MAX = 3e-4
XY_AMP = 1.5e-4

_h_min = SRC_R * 2**(1 - ROMBERG_K)
_X = _Y = np.linspace(-XY_AMP, XY_AMP, int(np.floor(2 * XY_AMP / _h_min)) + 1)
_Z = np.linspace(Z_MIN, Z_MAX, int(np.floor((Z_MAX - Z_MIN) / _h_min)) + 1)

_csd_grid = _pot_grid = [_X, _Y, _Z]

convolver = Convolver(_pot_grid, _csd_grid)

for _h in convolver.steps('POT'):
    assert _h >= _h_min, f'{_h} < {_h_min}'

## Convolver interface

In [None]:
from kesi.kernel.constructor import ConvolverInterfaceIndexed
from scipy.integrate import romb

ROMBERG_N = 2 ** ROMBERG_K + 1
ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N), dx=2 ** -ROMBERG_K)

SRC_MASK = ((convolver.SRC_Z > Z_MIN + SRC_R)
            & (convolver.SRC_Z < Z_MAX - SRC_R)
            & (abs(convolver.SRC_X) < XY_AMP - SRC_R)
            & (abs(convolver.SRC_Y) < XY_AMP - SRC_R))

In [None]:
print(SRC_MASK.sum())

In [None]:
convolver_interface = ConvolverInterfaceIndexed(convolver,
                                                model_src.csd,
                                                ROMBERG_WEIGHTS,
                                                SRC_MASK)

## Potential At Electrode object

### Potential At Electrode: analytical solution of the forward problem (kCSD)

In [None]:
from kesi.kernel import potential_basis_functions as pbf

In [None]:
pbf_kcsd = pbf.Analytical(convolver_interface,
                          potential=model_src.potential)

## Kernel constructor and cross-kernel constructor

In [None]:
from kesi.kernel.constructor import KernelConstructor, CrossKernelConstructor

kernel_constructor = KernelConstructor()

### Cross-kernel for reconstruction in coordinate planes

To calculate the cross-kernel matrix we need to select nodes of the _CSD_ grid.
We are going to visualise current source density in the coordinate planes,
thus in the boolean mask we select the closest nodes to the planes.
First we unequivocally define the planes by their intersection point.

In [None]:
coordinate_x = 25e-6
coordinate_y = -25e-6
coordinate_z = 150e-6

intersection = [coordinate_x,
                coordinate_y,
                coordinate_z]

We find indices of the node of the _CSD_ grid closest to the intersection in terms of Manhattan distance.

In [None]:
indices_of_coordinates = [np.argmin(abs(_C - _c))
                          for _C, _c in zip(convolver.CSD_GRID,
                                            intersection)]

With the indices we select the closest (to the coordinate planes) nodes of the _CSD_ grid.
<!-- We define an auxilary function `one_hot(i, n)` which returns `n`-long vector which all elements but `i`-th are `0` (and the `i`-th element is `1`). -->

In [None]:
CSD_MASK_CP = np.zeros(convolver.csd_shape,
                       dtype=bool)
CSD_MASK_CP[indices_of_coordinates[0], :, :] = True
CSD_MASK_CP[:, indices_of_coordinates[1], :] = True
CSD_MASK_CP[:, :, indices_of_coordinates[2]] = True

We count the selected nodes.

In [None]:
n_csd_nodes = CSD_MASK_CP.sum()
print(f'{n_csd_nodes} nodes of the CSD grid selected (coordinate planes).')

We use the `CSD_MASK_CP` to create a cross-kernel constructor.

In [None]:
kernel_constructor.crosskernel_cp = CrossKernelConstructor(convolver_interface,
                                                           CSD_MASK_CP)

To retrieve three CSD planes from CSD vector we define an auxilary function `to_planes()`.
The function uses three index arrays to select (and arrange) appropriate elements of the vector.

In [None]:
# As we reconstruct CSD at n_csd_nodes points,
# n_csd_nodes is invalid index value for the
# reconstructed CSD vector.

_CSD_IDX = np.full_like(CSD_MASK_CP, n_csd_nodes,
                        dtype=np.int32)
_CSD_IDX[CSD_MASK_CP] = np.arange(n_csd_nodes)

COORDINATE_PLANE_INDICES = [_CSD_IDX[indices_of_coordinates[0], :, :].copy(),
                            _CSD_IDX[:, indices_of_coordinates[1], :].copy(),
                            _CSD_IDX[:, :, indices_of_coordinates[2]].copy()
                            ]
del _CSD_IDX

# We test, whether all indices are valid.

for _A in COORDINATE_PLANE_INDICES:
    assert _A.min() >= 0 and _A.max() < CSD_MASK_CP.sum()
    
def to_planes(CSD):
    return [CSD[IDX] for IDX in COORDINATE_PLANE_INDICES]

### Cross-kernel for reconstruction in cordinate line

For the sake of simplicity we repeat all the steps above but definition of auxilary function, as the vector  format of reconstruction fits our purposes.

In [None]:
_CSD_MASK = np.zeros(convolver.csd_shape,
                     dtype=bool)
_CSD_MASK[indices_of_coordinates[0], indices_of_coordinates[1], :] = True

kernel_constructor.crosskernel_cl = CrossKernelConstructor(convolver_interface,
                                                           _CSD_MASK)

del _CSD_MASK

# Reconstructor

## Construction of kernels

In [None]:
%%time
B = kernel_constructor.potential_basis_functions_at_electrodes(electrodes,
                                                               pbf_kcsd)

In [None]:
KERNEL = kernel_constructor.kernel(B)

In [None]:
%%time
CROSSKERNEL_CP = kernel_constructor.crosskernel_cp(B)

In [None]:
%%time
CROSSKERNEL_CL = kernel_constructor.crosskernel_cl(B)

In [None]:
del B  # the array is large and no longer needed

## Reconstructors

In [None]:
from kesi._verbose import _CrossKernelReconstructor as Reconstructor
from kesi._engine import _LinearKernelSolver as KernelSolver

kernel_solver = KernelSolver(KERNEL)
reconstructor_cp = Reconstructor(kernel_solver,
                                 CROSSKERNEL_CP)
reconstructor_cl = Reconstructor(kernel_solver,
                                 CROSSKERNEL_CL)

# Visualisation

In [None]:
import matplotlib.pyplot as plt
import cbf


class CoordinatePlanesVisualisation(object):
    def __init__(self,
                 grid,
                 plane_intersection,
                 dpi=35,
                 cmap=cbf.bwr,
                 amp=None,
                 length_factor=1,
                 length_unit='$m$',
                 unit_factor=1,
                 unit=''):
        self.grid = grid
        self.plane_intersection = np.array(plane_intersection)
        self.dpi = dpi
        self.cmap = cmap
        self.amp = amp
        self.length_factor = length_factor
        self.length_unit = length_unit
        self.unit_factor = unit_factor
        self.unit = unit
    
    def start_new_image(self, title, wx, wy, wz):
        self.fig = plt.figure(figsize=((wx + wy) / self.dpi,
                                       (wz + wy) / self.dpi))
        if title is not None:
            self.fig.suptitle(title)

        gs = plt.GridSpec(2, 2,
                          figure=self.fig,
                          width_ratios=[wx, wy],
                          height_ratios=[wz, wy])

        self.ax_xz = self.fig.add_subplot(gs[0, 0])
        self.ax_xz.set_aspect('equal')
        self.ax_xz.set_ylabel(f'Z [{self.length_unit}]')
        self.ax_xz.set_xlabel(f'X [{self.length_unit}]')

        self.ax_yx = self.fig.add_subplot(gs[1, 1])
        self.ax_yx.set_aspect('equal')
        self.ax_yx.set_ylabel(f'X [{self.length_unit}]')
        self.ax_yx.set_xlabel(f'Y [{self.length_unit}]')

        self.ax_yz = self.fig.add_subplot(gs[0, 1],
                                          sharey=self.ax_xz,
                                          sharex=self.ax_yx)
        self.ax_yz.set_aspect('equal')

        self.cax = self.fig.add_subplot(gs[1, 0])
        self.cax.set_visible(False)

    def finish_image(self):
        x, y, z = self.length_factor * self.plane_intersection

        self.ax_xz.axvline(x, ls=':', color=cbf.BLACK)
        self.ax_xz.axhline(z, ls=':', color=cbf.BLACK)

        self.ax_yx.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yx.axhline(x, ls=':', color=cbf.BLACK)

        self.ax_yz.axvline(y, ls=':', color=cbf.BLACK)
        self.ax_yz.axhline(z, ls=':', color=cbf.BLACK)
        self.fig.colorbar(self.im, ax=self.cax,
                          orientation='horizontal',
                          label=self.unit)

    def _plot_planes(self, DATA_PLANES, amp):
        DATA_ZY = DATA_PLANES[0].T * self.unit_factor
        DATA_ZX = DATA_PLANES[1].T * self.unit_factor
        DATA_XY = DATA_PLANES[2] * self.unit_factor
        
        def _extent(first, second):
            _first = self.grid[first] * self.length_factor
            _second = self.grid[second] * self.length_factor
            return (_first.min(), _first.max(),
                    _second.min(), _second.max())

        self.ax_xz.imshow(DATA_ZX,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(0, 2))
        self.ax_yx.imshow(DATA_XY,
                          vmin=-amp * self.unit_factor,
                          vmax=amp * self.unit_factor,
                          cmap=self.cmap,
                          origin='lower',
                          extent=_extent(1, 0))
        self.im = self.ax_yz.imshow(DATA_ZY,
                                    vmin=-amp * self.unit_factor,
                                    vmax=amp * self.unit_factor,
                                    cmap=self.cmap,
                                    origin='lower',
                                    extent=_extent(1, 2))

    def plot_planes(self,
                    DATA_PLANES,
                    title=None,
                    amp=None):

        DATA_YZ, DATA_XZ, DATA_XY = DATA_PLANES
        wx, wy = DATA_XY.shape
        wz = DATA_YZ.shape[1]
        assert DATA_YZ.shape[0] == wy
        assert DATA_XZ.shape[0] == wx
        assert DATA_XZ.shape[1] == wz
        
        self.start_new_image(title, wx, wy, wz)
        self._plot_planes(DATA_PLANES,
                          amp if amp is not None else max(abs(_A).max() for _A in DATA_PLANES))
        self.finish_image()

In [None]:
csd_plotter = CoordinatePlanesVisualisation([_x.flatten() for _x in convolver.CSD_GRID],
                                            intersection,
                                            unit_factor=1e-12,
                                            unit='$\\frac{\\mu{}A}{mm^3}$',
                                            length_factor=1e6,
                                            length_unit='$\\mu{}m$')

# Reconstruction

## Reconstruction in coordinate planes

Potential values (given in $\mu{}V$) are stored in a vector `POTENTIALS`.
Each of its $N$ elements were calculated from ground truth CSD
in the `tutorial_slice.ipynb` notebook (`GT_V` therein).

In [None]:
POTENTIALS = [-126548.99283768,
              -119140.53772061,
              -73225.23872045,
              ]

As potential input was a vector, the reconstructor returns
a vector of CSD values.  Each of its $\tilde{\underline{N}}$
elements corresponds to a selected node of the _CSD_ grid.

In [None]:
%%time
CSD_CP = reconstructor_cp(POTENTIALS)

It should be same as the kCSD reconstruction in
`tutorial_slice[_basics_explained].ipynb` notebooks.

In [None]:
csd_plotter.plot_planes(to_planes(CSD_CP),
                        title='kCSD reconstruction from slice tutorial notebook')

## Reconstruction of timeseries

Potential values (given in $\mu{}V$) are stored in a $N \times T$ matrix
`TIMESERIES`.
Each of its $N$ rows corresponds to an electrode, while each of its $T$
columns corresponds to a timepoint.  Potentials are sum of time-modulated
components.  Note, that the first component is negative of `POTENTIALS`.

In [None]:
COMPONENTS = [[126548.99283768, -193132.43450924,  102791.12206456],
              [119140.53772061,  -23621.73043093, -154596.99314908],
              [ 73225.23872045,  105789.71746181,   79168.62902192]]

T = 2048
T_START = 0
T_END = 2

TIME = np.linspace(T_START, T_END, T)

class spikes(object):
    time_constant = 0.05

    def __init__(self, *times):
        self.times = times
        
    def __call__(self, TIME):
        return sum(self.alpha((TIME - t) / self.time_constant)
                   for t in self.times)

    def alpha(self, TIME):
        return np.where(TIME < 0, 0, TIME * np.exp(-TIME))
    

POTENTIAL_TIMESERIES = np.matmul(COMPONENTS,
                                 [spikes(1.15, 1.25, 1.35, 1.45)(TIME),
                                  0.10 * np.sin(TIME * 3 * np.pi),
                                  0.05 * np.cos(TIME * 6 * np.pi)])

The reconstructor returns a $\tilde{\underline{N}} \times T$ matrix of CSD values.
Each of its $T$ columns corresponds to a timepoint
while each of its $\tilde{\underline{N}}$ rows corresponds to a selected node
of the _CSD_ grid.

In [None]:
%%time
CSD_TIMESERIES = reconstructor_cl(POTENTIAL_TIMESERIES)

In [None]:
dpi = 150 # 35
cmap = cbf.bwr
unit_factor = 1e-12
unit = '$\\frac{\\mu{}A}{mm^3}$'
length_factor = 1e6
length_unit = '$\\mu{}m$'
time_factor = 1000
time_unit = '$ms$'

_Z = convolver.CSD_GRID[2].flatten() * length_factor

_amp = abs(CSD_TIMESERIES).max()

plt.figure(figsize=(tuple(_x / dpi for _x in CSD_TIMESERIES.shape[::-1])))
plt.ylabel(f"Z [{length_unit}]")
plt.xlabel(f"time [{time_unit}]")

plt.axhline(intersection[2] * length_factor,
            ls=':',
            color=cbf.BLACK)
plt.imshow(CSD_TIMESERIES * unit_factor,
           vmin=-_amp * unit_factor,
           vmax=_amp * unit_factor,
           cmap=cmap,
           origin='lower',
           extent=(T_START * time_factor, T_END * time_factor, _Z.min(), _Z.max()))

plt.colorbar(label=unit)

## Regularization

In [None]:
EIGENVALUES = np.linalg.eigvalsh(KERNEL)[::-1]

plt.plot(EIGENVALUES,
         marker='o')

plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(10, 20, 10 * 10 + 1)

### One-leave-out cross-validation

We can choose one regularization parameter for all timepoints.

In [None]:
from kesi.common import cv

In [None]:
%%time
CV_ERRORS = cv(kernel_solver, POTENTIAL_TIMESERIES, REGULARIZATION_PARAMETERS)

Note that, as cross-validation is cross-kernel independent, we use the `kernel_solver` object instead of a reconstructor.

In [None]:
regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(CV_ERRORS)]

In [None]:
plt.plot(REGULARIZATION_PARAMETERS,
         CV_ERRORS,
         color=cbf.BLUE)
plt.axvline(regularization_parameter,
            ls=(0, (1, 2)),
            color=cbf.BLUE)
plt.xscale('log')
plt.xlabel('regularization parameter')
plt.yscale('log')
plt.ylabel('L2 norm of cross-validation error')

As the smallest value of `REGULARIZATION_PARAMETERS` has been chosen (which is 4 orders of magnitude smaller than any of the kernel eigenvalues), it seems that cross-validation is against regularization.

We can also try to calculate a separate regularization parameter for each timepoint.

In [None]:
%%time
_CV_ERRORS = np.transpose([cv(kernel_solver, _V, REGULARIZATION_PARAMETERS)
                           for _V in POTENTIAL_TIMESERIES.T])

In [None]:
_REGULARIZATION_PARAMETERS = REGULARIZATION_PARAMETERS[_CV_ERRORS.argmin(axis=0)]

In [None]:
from matplotlib import cm
from matplotlib.colors import LogNorm

# _vmin = 10 ** np.floor(np.log10(_CV_ERRORS.min()))
# _vmax = 10 ** np.ceil(np.log10(_CV_ERRORS.max()))
_vmin = _CV_ERRORS.min()
_vmax = _CV_ERRORS.max()

levels = np.logspace(np.log10(_vmin),
                     np.log10(_vmax),
                     256)

_T, _R = np.meshgrid(TIME * time_factor,
                     REGULARIZATION_PARAMETERS)

plt.figure(figsize=(12, 5))

plt.contourf(_T, _R, _CV_ERRORS,
             levels,
             norm=LogNorm(vmin=_vmin, vmax=_vmax),
             cmap=cm.copper)
plt.plot(TIME * time_factor, _REGULARIZATION_PARAMETERS,
         ls="--",
         color=cbf.BLUE)
plt.axhline(EIGENVALUES.min(),
            ls=":",
            color=cbf.GREEN)
plt.axhline(EIGENVALUES.max(),
            ls=":",
            color=cbf.GREEN)
plt.xlim(T_START * time_factor, T_END * time_factor)
plt.ylim(1 / 1.2 * REGULARIZATION_PARAMETERS[0], 1.2 * REGULARIZATION_PARAMETERS[-1])
plt.yscale("log")
plt.xlabel(f"time [{time_unit}]")
plt.ylabel("regularization parameter")

plt.colorbar(label="CV error")