In [None]:
%matplotlib inline

import configparser

import numpy as np
import scipy.integrate as si

import matplotlib.pyplot as plt
import cbf

import kesi
import kesi._verbose as verbose
import _fast_reciprocal_reconstructor as frr
import common

In [None]:
MODEL = '4SM_CSF_3_mm'
MESH = 'uniform_coarse'
DEGREE = 3
BRAIN_R = 0.079

# Note

Mind that to fit 4SM electrodes are centered in `[4.7467814  4.04361142 4.40882958] 1e-3`.

In [None]:
ELECTRODE_ORDER = [('AMG', '1'), ('AMG', '2'), ('AMG', '3'), ('AMG', '4'), ('AMG', '5'), ('AMG', '6'), ('AMG', '7'), ('AMG', '8'), ('HH', '1'), ('HH', '2'), ('HH', '3'), ('HH', '4'), ('HH', '5'), ('HH', '6'), ('HH', '7'), ('HH', '8'), ('HP', '1'), ('HP', '2'), ('HP', '3'), ('HP', '4'), ('HP', '5'), ('HP', '6'), ('HP', '7'), ('HP', '8'), ('PC', '1'), ('PC', '2'), ('PC', '3'), ('PC', '4'), ('PC', '5'), ('PC', '6'), ('PC', '7'), ('PC', '8'), ('PM', '1'), ('PM', '2'), ('PM', '3'), ('PM', '4'), ('PM', '5'), ('PM', '6'), ('STG', '1'), ('STG', '2'), ('STG', '3'), ('STG', '4'), ('STG', '5'), ('STG', '6'), ('STG', '7'), ('STG', '8'), ('TP', '1'), ('TP', '2'), ('TP', '3'), ('TP', '4'), ('TP', '5'), ('TP', '6'), ('TP', '7'), ('TP', '8')]

In [None]:
class Electrode(object):
    def __init__(self, filename, decimals_tolerance=None, dx=0):
        """
        Parameters
        ----------
        
        filename : str
            Path to the sampled correction potential.
            
        decimals_tolerance : int
            Precision of coordinate comparison
            in the `.correction_potential()` method.
            
        dx : float
            Integration step used to calculate a regularization
            parameter of the `.base_potential()` method.
        """
        self.filename = filename
        self.decimals_tolerance = decimals_tolerance
        self.dx = dx
        with np.load(filename) as fh:
            self._X = self.round(fh['X'])
            self._Y = self.round(fh['Y'])
            self._Z = self.round(fh['Z'])
            self.x, self.y, self.z = fh['LOCATION']
            self.base_conductivity = fh['BASE_CONDUCTIVITY']

    @property
    def _epsilon(self):
        """
        Regularization parameter of the `.base_potential()` method.
        
        Note
        ----
        
        The 0.15 factor choice has been based on a toy numerical experiment.
        Further, more rigorous experiments are definitely recommended.
        """
        return 0.15 * self.dx
    
    def round(self, A):
        if self.decimals_tolerance is None:
            return A
        return np.round(A, decimals=self.decimals_tolerance)

    def correction_potential(self, X, Y, Z):
        """
        Parameters
        ----------
        X, Y, Z : np.array
            Coordinate matrices with matrix indexing.
            Coordinates are expected to be - respectively -
            from `._X`, `._Y` and `._Z` attributes.
            May be obtained with
            `X, Y, Z = np.meshgrid(..., indexing='ij')`.
        """
        _X, IDX_X, _ = np.intersect1d(self._X, self.round(X[:, 0, 0]), return_indices=True)
        assert len(_X) == np.shape(X)[0]
        _Y, IDX_Y, _ = np.intersect1d(self._Y, self.round(Y[0, :, 0]), return_indices=True)
        assert len(_Y) == np.shape(Y)[1]
        _Z, IDX_Z, _ = np.intersect1d(self._Z, self.round(Z[0, 0, :]), return_indices=True)
        assert len(_Z) == np.shape(Z)[2]

        with np.load(self.filename) as fh:
            return fh['CORRECTION_POTENTIAL'][np.ix_(IDX_X, IDX_Y, IDX_Z)]

    def base_potential(self, X, Y, Z):
        return (0.25 / (np.pi * self.base_conductivity)
                / (self._epsilon
                   + np.sqrt(np.square(X - self.x)
                             + np.square(Y - self.y)
                             + np.square(Z - self.z))))

In [None]:
electrodes = [Electrode(f'FEM/solutions/paper/{MODEL}/{MESH}_{DEGREE}/epi_wroclaw_sampled/9/{group}_{n}.npz')
              for group, n in ELECTRODE_ORDER]

In [None]:
_X, _Y, _Z = [getattr(electrodes[0], f'_{c}') for c in 'XYZ']

In [None]:
_ix_csd = np.searchsorted(_X * 1e3, [16.15, 85.6])
_iy_csd = np.searchsorted(_Y * 1e3, [-24.04, 53.51])
_iz_csd = np.searchsorted(_Z * 1e3, [-32.2, 40.65])

In [None]:
dx = max((_A[-1] - _A[0]) / (len(_A) - 1) for _A in [_X, _Y, _Z])

In [None]:
dx

In [None]:
ROMBERG_K = 6
ROMBERG_N = 2 ** ROMBERG_K + 1
ROMBERG_WEIGHTS = si.romb(np.identity(ROMBERG_N)) * 2 ** -ROMBERG_K

SRC_R_MAX = dx * 2 ** (ROMBERG_K - 1)

spline_nodes = [SRC_R_MAX / 3, SRC_R_MAX]
spline_polynomials = [[1],
                      [0,
                       6.75 / SRC_R_MAX,
                       -13.5 / SRC_R_MAX ** 2,
                       6.75 / SRC_R_MAX ** 3]]
model_src = common.SphericalSplineSourceKCSD(0, 0, 0,
                                             spline_nodes,
                                             spline_polynomials,
                                             electrodes[0].base_conductivity)

print(SRC_R_MAX)

In [None]:
convolver = frr.ckESI_convolver([_X[_X >= _X[_ix_csd[0] - 1] - SRC_R_MAX],
                                 _Y[(_Y >= _Y[_iy_csd[0] - 1] - SRC_R_MAX)
                                    & (_Y <= _Y[_iy_csd[1] + 1] + SRC_R_MAX)],
                                 _Z[(_Z >= _Z[_iz_csd[0] - 1] - SRC_R_MAX)
                                    & (_Z <= _Z[_iz_csd[1] + 1] + SRC_R_MAX)],
                                 ],
                                [_X[_ix_csd[0] - 1:_ix_csd[1] + 1],
                                 _Y[_iy_csd[0] - 1:_iy_csd[1] + 1],
                                 _Z[_iz_csd[0] - 1:_iz_csd[1] + 1],
                                 ])

In [None]:
SRC_IDX = np.sqrt(sum(np.square(getattr(convolver, f'SRC_{c}')) for c in 'XYZ')) <= BRAIN_R - SRC_R_MAX

In [None]:
SRC_IDX.sum()

In [None]:
convolver_interface = frr.ConvolverInterfaceIndexed(convolver,
                                                    model_src.csd,
                                                    ROMBERG_WEIGHTS,
                                                    SRC_IDX)

In [None]:
kernel_constructor = frr.ckESI_kernel_constructor()

In [None]:
CSD_IDX = np.ones(convolver.shape('CSD'),
                  dtype=bool)

In [None]:
CSD_IDX.shape

In [None]:
kernel_constructor.create_crosskernel = frr.ckESI_crosskernel_constructor(convolver_interface,
                                                                          CSD_IDX)

In [None]:
potential_at_electrode = frr.PAE_kESI_Analytical(convolver_interface,
                                                 potential=model_src.potential)

In [None]:
%%time
Φ = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                        potential_at_electrode)

In [None]:
KERNEL = kernel_constructor.create_kernel(Φ)

In [None]:
EIGENVALUES, EIGENVECTORS = np.linalg.eigh(KERNEL)
EIGENVALUES, EIGENVECTORS = EIGENVALUES[::-1], EIGENVECTORS[:, ::-1]

In [None]:
EIGENSOURCES = np.matmul(Φ,
                         np.matmul(EIGENVECTORS,
                                   np.diag(1. / np.sqrt(EIGENVALUES))))

In [None]:
%%time
CROSSKERNEL = kernel_constructor.create_crosskernel(Φ)

In [None]:
del Φ

In [None]:
%%time
EIGENSOURCES_CSD = kernel_constructor.create_crosskernel(EIGENSOURCES)

In [None]:
del EIGENSOURCES

In [None]:
CROSSKERNEL_ES = np.matmul(np.diag(np.sqrt(EIGENVALUES)),
                           EIGENVECTORS.T)

In [None]:
_TMP = np.empty(convolver.shape('CSD') + CROSSKERNEL.shape[1:])

for i, _COL in enumerate(CROSSKERNEL.T):
    _TMP[:, :, :, i] = _COL.reshape(convolver.shape('CSD'))

CROSSKERNEL = np.swapaxes(_TMP, 0, 1)
del _TMP

In [None]:
_TMP = np.empty(convolver.shape('CSD') + EIGENSOURCES_CSD.shape[1:])

for i, _COL in enumerate(EIGENSOURCES_CSD.T):
    _TMP[:, :, :, i] = _COL.reshape(convolver.shape('CSD'))

EIGENSOURCES_CSD = np.swapaxes(_TMP, 0, 1)
del _TMP
print(EIGENSOURCES_CSD.max(), EIGENSOURCES_CSD.min())

In [None]:
CSD_SPACE = np.array([np.swapaxes(A, 0, 1)
                      for A in np.meshgrid(*convolver.CSD_MESH,
                                           indexing='ij')])

In [None]:
kernel_solver = kesi._engine._LinearKernelSolver(KERNEL)

In [None]:
plt.plot(np.linalg.eigvals(KERNEL))
plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(0, 15, 10 * 15 + 1)

In [None]:
DATA = np.load('FEM/solutions/NOT_SOLUTIONS/epi_wroclaw/lfp_napad.npy')

In [None]:
DATA.shape

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :100], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :1000], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :2500], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :5000], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :6000], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :7500], REGULARIZATION_PARAMETERS)

In [None]:
%%time
ERRORS = common.cv(kernel_solver, DATA[:, :10000], REGULARIZATION_PARAMETERS)

In [None]:
DATA.shape[1] / 5000 * 23 / 3600

In [None]:
%%time
_ERRORS = []
# _ERRORS_OLD = []
buffer_size = 5000
for i in range(int(np.ceil(DATA.shape[1] / buffer_size))):
    TMP = DATA[:, i * buffer_size:(i + 1) * buffer_size]
#     _ERRORS_OLD.append(TMP.shape[1] * np.array(common.cv(kernel_solver, TMP, REGULARIZATION_PARAMETERS)))
    _ERRORS.append(TMP.shape[1] * np.square(common.cv(kernel_solver, TMP, REGULARIZATION_PARAMETERS)))
#     assert len(_ERRORS[-1]) == len(REGULARIZATION_PARAMETERS)

# ERRORS_OLD = np.sum(_ERRORS_OLD, axis=0) / DATA.shape[1]
ERRORS = np.sqrt(np.sum(_ERRORS, axis=0) / DATA.shape[1])

In [None]:
regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(ERRORS)]

In [None]:
plt.plot(REGULARIZATION_PARAMETERS, ERRORS)
# plt.plot(REGULARIZATION_PARAMETERS, ERRORS_OLD)
plt.xscale('log')
plt.yscale('log')
plt.axvline(regularization_parameter)

In [None]:
plt.plot(REGULARIZATION_PARAMETERS, np.transpose(_ERRORS))
# plt.plot(REGULARIZATION_PARAMETERS, ERRORS_OLD)
plt.xscale('log')
plt.yscale('log')
plt.axvline(regularization_parameter)

In [None]:
%%time
BETAS = kernel_solver(DATA, regularization_parameter)

In [None]:
ORIGINAL_CSD_SPACE = [np.array([-24.04      , -19.47817609, -14.91635218, -10.35452826,
                                -5.79270435,  -1.23088044,   3.33094347,   7.89276738,
                                12.45459129,  17.01641521,  21.57823912,  26.14006303,
                                30.70188694,  35.26371085,  39.82553477,  44.38735868,
                                48.94918259,  53.5110065 ]),
                      np.array([16.15      , 20.78009875, 25.41019751, 30.04029626, 34.67039502,
                                39.30049377, 43.93059253, 48.56069128, 53.19079004, 57.82088879,
                                62.45098755, 67.0810863 , 71.71118506, 76.34128381, 80.97138257,
                                85.60148132]),
                      np.array([-32.2       , -27.64703526, -23.09407053, -18.54110579,
                                -13.98814105,  -9.43517632,  -4.88221158,  -0.32924684,
                                4.22371789,   8.77668263,  13.32964737,  17.8826121 ,
                                22.43557684,  26.98854158,  31.54150631,  36.09447105,
                                40.64743579]),
                      ] # Y, X, Z

In [None]:
from scipy.interpolate import RegularGridInterpolator

In [None]:
INTERPOLATED_CROSSKERNEL = []

for i in range(len(KERNEL)):
    interpolator = RegularGridInterpolator([CSD_SPACE[1, :, 0, 0] * 1e3,
                                            CSD_SPACE[0, 0, :, 0] * 1e3,
                                            CSD_SPACE[2, 0, 0, :] * 1e3,
                                            ],
                                           CROSSKERNEL[:, :, :, i],
                                           bounds_error=False,
                                           fill_value=0)
    INTERPOLATED_CROSSKERNEL.append(
        interpolator(np.stack(np.meshgrid(*ORIGINAL_CSD_SPACE, indexing='ij'),
                              axis=-1)))
INTERPOLATED_CROSSKERNEL = np.stack(INTERPOLATED_CROSSKERNEL, axis=-1)

In [None]:
INTERPOLATED_CROSSKERNEL.shape

In [None]:
INTERPOLATED_CROSSKERNEL.dtype

In [None]:
print(CROSSKERNEL.max(), CROSSKERNEL.min())

In [None]:
print(BETAS.max(), BETAS.min())

In [None]:
np.finfo(np.float16), np.finfo(np.float16).tiny

In [None]:
np.finfo(np.float32), np.finfo(np.float32).tiny

In [None]:
np.savez_compressed('git_stereotactic_epi_wroclaw_kESI_4SM.npz',
                    csd_space=CSD_SPACE * 1e3,
                    _kernel=KERNEL,
                    inflation_matrix=CROSSKERNEL.astype(np.float32),
                    _inflation_matrix=INTERPOLATED_CROSSKERNEL.astype(np.float32),
                    compressed_matrix=BETAS.astype(np.float32))

In [None]:
INTERPOLATED_ES_CSD = []

for i in range(len(KERNEL)):
    interpolator = RegularGridInterpolator([CSD_SPACE[1, :, 0, 0] * 1e3,
                                            CSD_SPACE[0, 0, :, 0] * 1e3,
                                            CSD_SPACE[2, 0, 0, :] * 1e3,
                                            ],
                                           EIGENSOURCES_CSD[:, :, :, i],
                                           bounds_error=False,
                                           fill_value=0)
    INTERPOLATED_ES_CSD.append(
        interpolator(np.stack(np.meshgrid(*ORIGINAL_CSD_SPACE, indexing='ij'),
                              axis=-1)))
INTERPOLATED_ES_CSD = np.stack(INTERPOLATED_ES_CSD, axis=-1)

In [None]:
EIGENSOURCES_CSD.max(), EIGENSOURCES_CSD.min()

In [None]:
abs(np.matmul(CROSSKERNEL_ES, BETAS)).max()

In [None]:
np.isnan(INTERPOLATED_ES_CSD).any(), np.isinf(INTERPOLATED_ES_CSD).any()

In [None]:
np.savez_compressed('git_stereotactic_epi_wroclaw_kESI_4SM.eigensources.npz',
                    csd_space=CSD_SPACE * 1e3,
                    inflation_matrix=EIGENSOURCES_CSD.astype(np.float32),
                    _inflation_matrix=INTERPOLATED_ES_CSD.astype(np.float32),
                    compressed_matrix=np.matmul(CROSSKERNEL_ES, BETAS).astype(np.float32))