Please note that this tutorial is focused at reconstructing CSD timeseries
at a subset of the _CSD_ grid.  For sake of simplicity it uses kCSD
(cross)kernels only.
To learn, how to create kESI (cross)kernels please consult
`tutorial_*_basics_explained.ipynb`.  To compare the reconstructed CSD
with kCSD reconstruction at all nodes of the _CSD_ grid please run
one of `tutorial_slice[_basics_explained].ipynb` notebooks.

# Requirements

## Memory

The code of the notebook requires at least 1.8 GB (1.7 GiB) of free memory.


## Environment

1. Anaconda Python distribution (tested with `Miniconda3-py39_4.12.0-Linux-x86_64.sh`, _conda v. 4.12.0_).
2. Jupyter server (see `extras/jupyter_server.sh` for details).
3. Anaconda environments (run `setup_conda_envs.sh`).

# Kernel construction tools

## Electrode object

The implementation of the electrode object is minimal necessary for construction of a kCSD (cross)kernel.

In [None]:
import collections

Electrode = collections.namedtuple('Electrode',
                                   ['x', 'y', 'z', 'conductivity'])

We use the same positions of electrodes as `tutorial_slice[_basics_explained].ipynb` notebooks.

In [None]:
CONDUCTIVITY = 0.3  # S/m

ELECTRODES_XYZ = [(0.0, 0.0, 5e-05),
                  (5e-05, 0.0, 0.00015),
                  (5e-05, -5e-05, 0.00025)]

electrodes = [Electrode(x, y, z, CONDUCTIVITY) for x, y, z in ELECTRODES_XYZ]

## Model source

We want to use CSD bases 36μm wide ($R = 18\mu{}m$).

In [None]:
from kesi.common import SphericalSplineSourceKCSD

SRC_R = 18e-6  # m

spline_nodes = [SRC_R / 3, SRC_R]
spline_polynomials = [[1],
                      [0,
                       6.75 / SRC_R,
                       -13.5 / SRC_R ** 2,
                       6.75 / SRC_R ** 3]]
model_src = SphericalSplineSourceKCSD(0, 0, 0,
                                      spline_nodes,
                                      spline_polynomials)

## Convolver object

In [None]:
import numpy as np
from kesi.kernel.constructor import Convolver

ROMBERG_K = 5
Z_MIN = 0
Z_MAX = 3e-4
XY_AMP = 1.5e-4

_h_min = SRC_R * 2**(1 - ROMBERG_K)
_X = _Y = np.linspace(-XY_AMP, XY_AMP, int(np.floor(2 * XY_AMP / _h_min)) + 1)
_Z = np.linspace(Z_MIN, Z_MAX, int(np.floor((Z_MAX - Z_MIN) / _h_min)) + 1)

_csd_grid = _pot_grid = [_X, _Y, _Z]

convolver = Convolver(_pot_grid, _csd_grid)

for _c, _h in zip("XYZ", convolver.steps("POT")):
    assert _h >= _h_min, f"{_c}:\t{_h} < {_h_min}"
    if _h >= 2 * _h_min:
        print(f"You can reduce number of nodes of quadrature for {_c} dimension")

## Convolver interface

In [None]:
from kesi.kernel.constructor import ConvolverInterfaceIndexed
from scipy.integrate import romb

ROMBERG_N = 2 ** ROMBERG_K + 1
ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N), dx=2 ** -ROMBERG_K)

SRC_MASK = ((convolver.SRC_Z > Z_MIN + SRC_R)
            & (convolver.SRC_Z < Z_MAX - SRC_R)
            & (abs(convolver.SRC_X) < XY_AMP - SRC_R)
            & (abs(convolver.SRC_Y) < XY_AMP - SRC_R))

In [None]:
print(SRC_MASK.sum())

In [None]:
convolver_interface = ConvolverInterfaceIndexed(convolver,
                                                model_src.csd,
                                                ROMBERG_WEIGHTS,
                                                SRC_MASK)

## Potential Basis Functions

In [None]:
from kesi.kernel.potential_basis_functions import Analytical as PBF

In [None]:
pbf = PBF(convolver_interface,
          potential=model_src.potential)

## Kernel constructor and cross-kernel constructor

In [None]:
from kesi.kernel.constructor import KernelConstructor, CrossKernelConstructor

kernel_constructor = KernelConstructor()

### Cross-kernel for reconstruction in coordinate line

As we want to plot 2D reconstruction (Z vs T), in the boolean mask we
select the closest nodes of the _CSD_ grid to given coordinate line.

In [None]:
coordinate_x = 25e-6
coordinate_y = -25e-6

coordinate_line = [coordinate_x,
                   coordinate_y]

We find indices of the node of the _CSD_ grid closest to the coordinate line in terms of Manhattan distance.

In [None]:
indices_of_coordinates = [np.argmin(abs(_C - _c))
                          for _C, _c in zip(convolver.CSD_GRID,
                                            coordinate_line)]

With the indices we select the closest (to the coordinate planes) nodes of the _CSD_ grid.
<!-- We define an auxilary function `one_hot(i, n)` which returns `n`-long vector which all elements but `i`-th are `0` (and the `i`-th element is `1`). -->

In [None]:
_CSD_MASK = np.zeros(convolver.csd_shape,
                     dtype=bool)
_CSD_MASK[indices_of_coordinates[0], indices_of_coordinates[1], :] = True

kernel_constructor.crosskernel = CrossKernelConstructor(convolver_interface,
                                                        _CSD_MASK)

del _CSD_MASK

# Reconstructor

## Construction of kernels

In [None]:
%%time
B = kernel_constructor.potential_basis_functions_at_electrodes(electrodes,
                                                               pbf)

In [None]:
KERNEL = kernel_constructor.kernel(B)

In [None]:
%%time
CROSSKERNEL = kernel_constructor.crosskernel(B)

In [None]:
del B  # the array is large and no longer needed

## Reconstructors

In [None]:
from kesi._verbose import _CrossKernelReconstructor as Reconstructor
from kesi._engine import _LinearKernelSolver as KernelSolver


reconstructor = Reconstructor(KernelSolver(KERNEL),
                              CROSSKERNEL)

# Visualisation

In [None]:
import matplotlib.pyplot as plt

import cbf

# Reconstruction

## Reconstruction of timeseries

Potential values (given in $\mu{}V$) are stored in a $N \times T$ matrix
`POTENTIAL_TIMESERIES`.
Each of its $N$ rows corresponds to an electrode, while each of its $T$
columns corresponds to a timepoint.  Potentials are sum of time-modulated
components.  The components are proportional to kCSD eigensources
forward-modelled in the slice model (see `tutorial_slice[_basics_explained].ipynb`
notebooks for details).  Note that last two components are modulated periodically,
while the first component is strongly time-localized.

In [None]:
COMPONENTS = [[126548.99283768, -193132.43450924,  102791.12206456],
              [119140.53772061,  -23621.73043093, -154596.99314908],
              [ 73225.23872045,  105789.71746181,   79168.62902192]]

T = 2048
T_START = 0
T_END = 2

TIME = np.linspace(T_START, T_END, T)


class Spikes(object):
    time_constant = 0.05

    def __init__(self, *times):
        self.times = times
        
    def __call__(self, TIME):
        return sum(self.alpha((TIME - t) / self.time_constant - 1)
                   for t in self.times)

    def alpha(self, TIME):
        return np.where(TIME < 0, 0, TIME * np.exp(-TIME))
    

POTENTIAL_TIMESERIES = np.matmul(COMPONENTS,
                                 [Spikes(1.20, 1.25, 1.40, 1.45)(TIME),
                                  0.10 * np.sin(TIME * 3 * np.pi),
                                  0.05 * np.cos(TIME * 6 * np.pi)])

The reconstructor returns a $\tilde{N} \times T$ matrix of CSD values.
Each of its $T$ columns corresponds to a timepoint
while each of its $\tilde{N}$ rows corresponds to a selected node
of the _CSD_ grid.

In [None]:
%%time
CSD_TIMESERIES = reconstructor(POTENTIAL_TIMESERIES)

In [None]:
dpi = 150 # 35
cmap = cbf.bwr
unit_factor = 1e-12
unit = '$\\frac{\\mu{}A}{mm^3}$'
length_factor = 1e6
length_unit = '$\\mu{}m$'
time_factor = 1000
time_unit = '$ms$'

_T, _Z = np.meshgrid(TIME * time_factor,
                     convolver.CSD_GRID[2].flatten() * length_factor)

_amp = abs(CSD_TIMESERIES * unit_factor).max()

plt.figure(figsize=(tuple(_x / dpi for _x in CSD_TIMESERIES.shape[::-1])))
plt.ylabel(f"Z [{length_unit}]")
plt.xlabel(f"time [{time_unit}]")

plt.contourf(_T, _Z, CSD_TIMESERIES * unit_factor,
             256,
             vmin=-_amp,
             vmax=_amp,
             cmap=cmap)

plt.colorbar(label=unit)

## Regularization

In [None]:
EIGENVALUES = np.linalg.eigvalsh(KERNEL)[::-1]

plt.plot(EIGENVALUES,
         marker='o')

plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(10, 20, 10 * 10 + 1)

### One-leave-out cross-validation

We can choose one regularization parameter for all timepoints.

In [None]:
from kesi.common import cv

In [None]:
%%time
CV_ERRORS = cv(reconstructor, POTENTIAL_TIMESERIES, REGULARIZATION_PARAMETERS)

In [None]:
regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(CV_ERRORS)]

In [None]:
plt.plot(REGULARIZATION_PARAMETERS,
         CV_ERRORS,
         color=cbf.BLUE)
plt.axvline(regularization_parameter,
            ls=(0, (1, 2)),
            color=cbf.BLUE)
plt.xscale('log')
plt.xlabel('regularization parameter')
plt.yscale('log')
plt.ylabel('L2 norm of cross-validation error')

As the smallest value of `REGULARIZATION_PARAMETERS` has been chosen (which is 4 orders of magnitude smaller than any of the kernel eigenvalues), it seems that cross-validation is against regularization.

We can also try to calculate a separate regularization parameter for each timepoint.

In [None]:
%%time
CV_ERRORS_SERIES = np.transpose([cv(reconstructor, _V, REGULARIZATION_PARAMETERS)
                                for _V in POTENTIAL_TIMESERIES.T])

In [None]:
REGULARIZATION_PARAMETERS_SERIES = REGULARIZATION_PARAMETERS[CV_ERRORS_SERIES.argmin(axis=0)]

In [None]:
from matplotlib.colors import LogNorm

_vmin = CV_ERRORS_SERIES.min()
_vmax = CV_ERRORS_SERIES.max()

levels = np.logspace(np.log10(_vmin),
                     np.log10(_vmax),
                     256)

_T, _R = np.meshgrid(TIME * time_factor,
                     REGULARIZATION_PARAMETERS)

fig = plt.figure(figsize=(16, 10))
gs = plt.GridSpec(2, 2,
                  figure=fig,
                  width_ratios=[12, 1],
                  height_ratios=[1, 2],
                  wspace=0)
ax_rp = fig.add_subplot(gs[0, 0])
ax_cb = fig.add_subplot(gs[1, 1])
ax_err = fig.add_subplot(gs[1, 0])

ax_cb.set_visible(False)


_contourf = ax_err.contourf(_T, _R, CV_ERRORS_SERIES,
                            levels,
                            norm=LogNorm(vmin=_vmin, vmax=_vmax),
                            cmap=cbf.wo)

for (ax, ls) in [(ax_rp, "-"),
                 (ax_err, "--")]:
    ax.plot(TIME * time_factor, REGULARIZATION_PARAMETERS_SERIES,
            ls=ls,
            color=cbf.BLUE)

for ax in [ax_rp, ax_err]:
    ax.axhline(EIGENVALUES.min(),
               ls=":",
               color=cbf.GREEN)
    ax.axhline(EIGENVALUES.max(),
               ls=":",
               color=cbf.GREEN)
    
    ax.set_xlim(T_START * time_factor, T_END * time_factor)
    ax.set_ylim(REGULARIZATION_PARAMETERS[0], REGULARIZATION_PARAMETERS[-1])
    ax.set_yscale("log")
    ax.set_ylabel("regularization parameter")

ax_err.set_xlabel(f"time [{time_unit}]")
ax_rp.spines[['right', 'top']].set_visible(False)
ax_rp.set_title("CV-selected regularization parameter")


fig.colorbar(_contourf, label="CV error", ax=ax_cb)