In [None]:
import datetime
import configparser
import os

import numpy as np
import pandas as pd

import kesi
import kesi._verbose as verbose
import _common_new as common
import _fast_reciprocal_reconstructor as frr

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import cbf

In [None]:
MODEL_NAME = 'finite_slice'
MESH_NAME = 'composite'
CONFIG_NAME = 'grid_3d'
DEGREE = 3
K = 9

FILENAME_PATTERN = f'FEM/solutions/paper/{MODEL_NAME}/{MESH_NAME}_{DEGREE}/{CONFIG_NAME}_sampled/{K}/{{name}}.npz'

In [None]:
BRAIN_CONDUCTIVITY = 0.3  # S / m
SALINE_CONDUCTIVITY = 1.5  # S / m
H = 3e-4  # m

# SD_GT = H / 16
# RESOLUTION = 32

# FEM

In [None]:
import dolfin
import FEM.fem_common as fc
import scipy.interpolate as si

In [None]:
class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    def __init__(self, config):
        self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
        mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
            
            
        self.config = configparser.ConfigParser()
        self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] > 0))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
FEM_CONFIG_GT = 'FEM/fem_configs/paper/finite_slice/fine_3/fem.ini'
# FEM_CONFIG_GT = 'FEM/fem_configs/paper/finite_slice/composite_3/grid_3d.ini'

In [None]:
%time fem_gt = ForwardModel(FEM_CONFIG_GT)

# kernel construction

In [None]:
electrode_grid_names = [
#                         'minus24',
                        'minus12',
#                         '0',
                        '12',
#                         '24',
                        ]
electrode_names = [f'{x}_0_{i}' for i in [
                                          6,
                                          12,
                                          18,
                                          24,
                                          30,
                                          36,
                                          42,
                                          48,
                                          54,
                                          60,
                                          ]
                   for x in electrode_grid_names]

In [None]:
class Electrode_kESI(object):
    def __init__(self, filename, decimals_tolerance=None, dx=0):
        self.filename = filename
        self.decimals_tolerance = decimals_tolerance
        self.dx = dx
        with np.load(filename) as fh:
            self._X = self.round(fh['X'])
            self._Y = self.round(fh['Y'])
            self._Z = self.round(fh['Z'])
            self.x, self.y, self.z = fh['LOCATION']
#             try:
            self.base_conductivity = fh['BASE_CONDUCTIVITY']
#             except KeyError:
#                 pass
    
    def round(self, A):
        if self.decimals_tolerance is None:
            return A
        return np.round(A, decimals=self.decimals_tolerance)

    def correction_potential(self, X, Y, Z):
        _X, IDX_X, _ = np.intersect1d(self._X, self.round(X[:, 0, 0]), return_indices=True)
        assert len(_X) == np.shape(X)[0]
        _Y, IDX_Y, _ = np.intersect1d(self._Y, self.round(Y[0, :, 0]), return_indices=True)
        assert len(_Y) == np.shape(Y)[1]
        _Z, IDX_Z, _ = np.intersect1d(self._Z, self.round(Z[0, 0, :]), return_indices=True)
        assert len(_Z) == np.shape(Z)[2]

        with np.load(self.filename) as fh:
            return fh['CORRECTION_POTENTIAL'][np.ix_(IDX_X, IDX_Y, IDX_Z)]

    def base_potential(self, X, Y, Z):
        return (0.25 / (np.pi * self.base_conductivity)
                / (self.dx * 0.15
                   + np.sqrt(np.square(X - self.x)
                             + np.square(Y - self.y)
                             + np.square(Z - self.z))))

        
class Electrode_kCSD(object):
    def __init__(self, filename):
        with np.load(filename) as fh:
            self.x, self.y, self.z = fh['LOCATION']

In [None]:
ELECTRODES = []
for name in electrode_names:
    electrode = Electrode_kESI(FILENAME_PATTERN.format(name=name))
    ELECTRODES.append({'NAME': name,
                       'X': electrode.x,
                       'Y': electrode.y,
                       'Z': electrode.z})
ELECTRODES = pd.DataFrame(ELECTRODES)

In [None]:
plt.scatter(ELECTRODES.X, ELECTRODES.Z)
plt.xlim(-H / 2, H / 2)
plt.ylim(0, H)

In [None]:
set(np.diff(sorted(set(ELECTRODES.Z * 1000_000))))

In [None]:
XX = electrode._X
YY = electrode._Y
ZZ = electrode._Z

In [None]:
assert electrode.base_conductivity == BRAIN_CONDUCTIVITY

In [None]:
ROMBERG_K = 6

In [None]:
dx = (XX[-1] - XX[0]) / (len(XX) - 1)
SRC_R_MAX = (2**(ROMBERG_K - 1)) * dx
ROMBERG_N = 2**ROMBERG_K + 1
print(SRC_R_MAX)

In [None]:
electrodes_kesi = [Electrode_kESI(FILENAME_PATTERN.format(name=name),
                                  decimals_tolerance=16,
                                  dx=dx)
                   for name in electrode_names]
electrodes_kcsd = [Electrode_kCSD(FILENAME_PATTERN.format(name=name)) for name in electrode_names]

In [None]:
H_Y = H/4
X = XX
Y = YY[abs(YY) <= H_Y + SRC_R_MAX + dx]
Z = ZZ

In [None]:
convolver = frr.ckESI_convolver([X,
                                 Y,
                                 Z],
                                [X,
                                 Y,
                                 Z])

In [None]:
sd = SRC_R_MAX / 3
def source(x, y, z):
    return common.SphericalSplineSourceKCSD(x, y, z,
#                                             [SRC_R_MAX],
#                                             [[1]],
                                             [sd, 3 * sd],
                                             [[1],
                                              [0,
                                               2.25 / sd,
                                               -1.5 / sd ** 2,
                                               0.25 / sd ** 3]],
                                             BRAIN_CONDUCTIVITY)

model_src = source(0, 0, 0)

In [None]:
SRC_IDX = ((abs(convolver.SRC_X) < abs(convolver.SRC_X.max()) - SRC_R_MAX)
                 & (abs(convolver.SRC_Y) <= H_Y)
                 & ((convolver.SRC_Z > SRC_R_MAX)
                    & (convolver.SRC_Z < H - SRC_R_MAX)))

In [None]:
SRC_IDX.sum(), SRC_IDX.shape

In [None]:
CSD_IDX = np.zeros(convolver.shape('CSD'),
                   dtype=bool)

# Kernels

Warning: no subtraction of kCSD out of the slice possible

In [None]:
%%time
kcsd_kernels = frr.ckESI_kernel_constructor(model_src,
                                            convolver,
                                            SRC_IDX,
                                            CSD_IDX,
                                            electrodes_kcsd,
                                            weights=ROMBERG_N)

In [None]:
%%time
kesi_kernels = frr.ckESI_kernel_constructor(model_src,
                                            convolver,
                                            SRC_IDX,
                                            CSD_IDX,
                                            electrodes_kesi,
                                            weights=ROMBERG_N)

# kernel analysis

In [None]:
n = len(kesi_kernels._pre_kernel)
FWD = kesi_kernels._pre_kernel * n
PHI = kcsd_kernels._pre_kernel * n
KERNEL = kcsd_kernels.kernel * n

EIGENVALUES, EIGENVECTORS = np.linalg.eigh(KERNEL)
EIGENVALUES, EIGENVECTORS = EIGENVALUES[::-1], EIGENVECTORS[:, ::-1]
LAMBDA = np.sqrt(EIGENVALUES)
EIGENSOURCES = np.matmul(PHI,
                         np.matmul(EIGENVECTORS,
                                   np.diag(1. / LAMBDA)))

In [None]:
KERNEL_KESI = kesi_kernels.kernel * n
EIGENVALUES_KESI, EIGENVECTORS_KESI = np.linalg.eigh(KERNEL_KESI)
EIGENVALUES_KESI, EIGENVECTORS_KESI = EIGENVALUES_KESI[::-1], EIGENVECTORS_KESI[:, ::-1]
LAMBDA_KESI = np.sqrt(EIGENVALUES_KESI)
EIGENSOURCES_KESI = np.matmul(FWD,
                              np.matmul(EIGENVECTORS_KESI,
                                        np.diag(1. / LAMBDA_KESI)))

In [None]:
KCSD_TO_KESI_ES = np.matmul(EIGENSOURCES_KESI.T, EIGENSOURCES)

In [None]:
assert (KCSD_TO_KESI_ES.max(axis=0) > 0).all()

In [None]:
for i, KESI_PROJECTION in enumerate(KCSD_TO_KESI_ES.T):
    _idx = np.argmax(abs(KESI_PROJECTION))
    if KESI_PROJECTION[_idx] > 0:
        EIGENSOURCES[:, i] += EIGENSOURCES_KESI[:, _idx]
    else:
        EIGENSOURCES[:, i] -= EIGENSOURCES_KESI[:, _idx]

In [None]:
del EIGENSOURCES_KESI

## crosskernels

In [None]:
%%time
_SRC = np.zeros(convolver.shape('SRC'))
CROSSKERNEL = np.empty(convolver.shape('CSD') + (len(KERNEL),))

for i, _SRC[SRC_IDX] in enumerate(PHI.T):
    print(i)
    CROSSKERNEL[:, :, :, i] = convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3)
    
del _SRC

In [None]:
%%time
_SRC = np.zeros(convolver.shape('SRC'))
CROSSKERNEL_KESI = np.empty(convolver.shape('CSD') + (len(KERNEL_KESI),))

for i, _SRC[SRC_IDX] in enumerate(FWD.T):
    print(i)
    CROSSKERNEL_KESI[:, :, :, i] = convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3)
    
del _SRC

In [None]:
del PHI, FWD

# IMAGES

In [None]:
%%time
GT_CSD = []

_SRC = np.zeros(convolver.shape('SRC'))
for i, _SRC[SRC_IDX] in enumerate(EIGENSOURCES.T):
    print(i)
    GT_CSD.append(convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3))
    
del _SRC

In [None]:
%%time
IMAGE = []

for i, _CSD in enumerate(GT_CSD):
    print(i)
    _csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  _CSD,
                                  bounds_error=False,
                                  fill_value=0)
    _v = fem_gt(_csd)
    IMAGE.append(np.array(list(map(_v, ELECTRODES.X, ELECTRODES.Y, ELECTRODES.Z))))


In [None]:
del _CSD, _csd, _v

In [None]:
for i, V in enumerate(IMAGE):
    plt.figure()
    plt.title(i)
    plt.axhline(1, ls=':', color=cbf.BLACK)
    plt.axvline(i, ls=':', color=cbf.BLACK)
    plt.stem(np.matmul(V, EIGENVECTORS) / LAMBDA)
    plt.plot(np.matmul(V, EIGENVECTORS_KESI) / LAMBDA)

## Images error (reconstructed)

In [None]:
%%time
IMAGE_ERRORS = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS.append(row)
    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL, CROSSKERNEL),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL, V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    print(row)
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS = pd.DataFrame(IMAGE_ERRORS)

In [None]:
IMAGE_ERRORS

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## Images error (reconstructed regularized)

In [None]:
plt.plot(EIGENVALUES)
plt.plot(EIGENVALUES_KESI)
plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(11, 17, 5 * 6 + 1)

In [None]:
es_reconstructors = {method: verbose.VerboseFFR._CrossKernelReconstructor(
                                                     kesi._engine._LinearKernelSolver(
                                                         _KERNEL),
                                                     np.matmul(np.diag(_LAMBDA),
                                                               _EIGENVECTORS.T))
                     for method, _KERNEL, _EIGENVECTORS, _LAMBDA
                     in [('kCSD', KERNEL, EIGENVECTORS, LAMBDA),
                         ('kESI', KERNEL_KESI, EIGENVECTORS_KESI, LAMBDA_KESI),
                         ]
                     }

In [None]:
%%time
IMAGE_ERRORS_CV = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS_CV.append(row)

    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL, CROSSKERNEL),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _ERRORS = common.cv(es_reconstructors[method],
                            V,
                            REGULARIZATION_PARAMETERS)
        regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]

        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL
                                          + regularization_parameter * np.identity(len(_KERNEL)),
                                          V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS_CV = pd.DataFrame(IMAGE_ERRORS_CV)

In [None]:
IMAGE_ERRORS_CV

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## compare approaches

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    assert abs(IMAGE_ERRORS[f'GT_{norm}'] - IMAGE_ERRORS_CV[f'GT_{norm}']).max() == 0

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls='-',
                 label=method)
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=':',
                 label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')