In [None]:
import datetime
import configparser
import os

import numpy as np
import pandas as pd

import kesi
import kesi._verbose as verbose
import _common_new as common
import _fast_reciprocal_reconstructor as frr

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import cbf

In [None]:
MODEL_NAME = 'circular_slice'
MESH_NAME = 'finer'
DEGREE = 3
K = 9

FILENAME_PATTERN = f'FEM/solutions/paper/{MODEL_NAME}/{MESH_NAME}/{DEGREE}/sampled/{K}/{{name}}.npz'

In [None]:
BRAIN_CONDUCTIVITY = 0.3  # S / m
SALINE_CONDUCTIVITY = 1.5  # S / m
H = 3e-4  # m

# SD_GT = H / 16
# RESOLUTION = 32

# FEM

In [None]:
import dolfin
import FEM.fem_common as fc
import scipy.interpolate as si

In [None]:
class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    
    def __init__(self, mesh, degree, config):
        self.fm = fc.FunctionManager(mesh, degree, 'CG')
        self.config = configparser.ConfigParser()
        self.config.read(config)
        
        mesh_filename = mesh[:-5]
#     def __init__(self, config):
#         self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
#         mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
            
            
#         self.config = configparser.ConfigParser()
#         self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] > 0))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
%%time
fem_gt = ForwardModel('FEM/meshes/meshes/circular_slice/fine.xdmf',
                      3,
                      'FEM/model_properties/circular_slice.ini')

# kernel construction

In [None]:
electrode_grid_names = [
#                         'minus24',
                        'minus12',
#                         '0',
                        '12',
#                         '24',
                        ]
electrode_names = [f'{x}_0_{i}' for i in [
                                          6,
                                          12,
                                          18,
                                          24,
                                          30,
                                          36,
                                          42,
                                          48,
                                          54,
                                          60,
                                          ]
                   for x in electrode_grid_names]

In [None]:
class Electrode(object):
    def __init__(self, filename):
        """
        Parameters
        ----------
        
        filename : str
            Path to the sampled correction potential.
        """
        self.filename = filename
        with np.load(filename) as fh:
            self.SAMPLING_GRID = [fh[c] for c in 'XYZ']
            self.x, self.y, self.z = fh['LOCATION']
            self.base_conductivity = fh['BASE_CONDUCTIVITY']

    def correction_leadfield(self, X, Y, Z):
        """
        Correction of the leadfield of the electrode
        for violation of kCSD assumptions
        
        Parameters
        ----------
        X, Y, Z : np.array
            Coordinate matrices of the same shape.
        """
        with np.load(self.filename) as fh:
            return self._correction_leadfield(fh['CORRECTION_POTENTIAL'],
                                              [X, Y, Z])

    def _correction_leadfield(self, SAMPLES, XYZ):
        # if XYZ points are in nodes of the sampling grid,
        # no time-consuming interpolation is necessary
        return SAMPLES[self._sampling_grid_indices(XYZ)]

    def _sampling_grid_indices(self, XYZ):
        return tuple(np.searchsorted(GRID, COORD)
                     for GRID, COORD in zip(self.SAMPLING_GRID, XYZ))

In [None]:
ELECTRODES = []
for name in electrode_names:
    electrode = Electrode(FILENAME_PATTERN.format(name=name))
    ELECTRODES.append({'NAME': name,
                       'X': electrode.x,
                       'Y': electrode.y,
                       'Z': electrode.z})
ELECTRODES = pd.DataFrame(ELECTRODES)

In [None]:
plt.scatter(ELECTRODES.X, ELECTRODES.Z)
plt.xlim(-H / 2, H / 2)
plt.ylim(0, H)

In [None]:
set(np.diff(sorted(set(ELECTRODES.Z * 1000_000))))

In [None]:
XX, YY, ZZ = electrode.SAMPLING_GRID

In [None]:
assert electrode.base_conductivity == BRAIN_CONDUCTIVITY

In [None]:
ROMBERG_K = 6

In [None]:
dx = (XX[-1] - XX[0]) / (len(XX) - 1)
SRC_R_MAX = (2**(ROMBERG_K - 1)) * dx
ROMBERG_N = 2**ROMBERG_K + 1
print(SRC_R_MAX)

In [None]:
electrodes = [Electrode(FILENAME_PATTERN.format(name=name))
              for name in electrode_names]

In [None]:
H_Y = H/4
X = XX
Y = YY[abs(YY) <= H_Y + SRC_R_MAX + dx]
Z = ZZ

In [None]:
convolver = frr.Convolver([X, Y, Z],
                          [X, Y, Z])

In [None]:
sd = SRC_R_MAX / 3
def source(x, y, z):
    return common.SphericalSplineSourceKCSD(x, y, z,
#                                             [SRC_R_MAX],
#                                             [[1]],
                                             [sd, 3 * sd],
                                             [[1],
                                              [0,
                                               2.25 / sd,
                                               -1.5 / sd ** 2,
                                               0.25 / sd ** 3]],
                                             BRAIN_CONDUCTIVITY)

model_src = source(0, 0, 0)

In [None]:
SRC_MASK = ((abs(convolver.SRC_X) < abs(convolver.SRC_X.max()) - SRC_R_MAX)
            & (abs(convolver.SRC_Y) <= H_Y)
            & ((convolver.SRC_Z > SRC_R_MAX)
               & (convolver.SRC_Z < H - SRC_R_MAX)))

In [None]:
SRC_MASK.sum(), SRC_MASK.shape

In [None]:
CSD_MASK = np.ones(convolver.shape('CSD'),
                   dtype=bool)

# Kernels

In [None]:
from scipy.integrate import romb

ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N)) * 2 ** -ROMBERG_K

convolver_interface = frr.ConvolverInterfaceIndexed(convolver,
                                                    model_src.csd,
                                                    ROMBERG_WEIGHTS,
                                                    SRC_MASK)

kernel_constructor = frr.KernelConstructor()
kernel_constructor.create_crosskernel = frr.CrossKernelConstructor(convolver_interface,
                                                                   CSD_MASK)

In [None]:
%%time
pae_kcsd = frr.PAE_Analytical(convolver_interface,
                              potential=model_src.potential)

PHI_KCSD = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                               pae_kcsd)
KERNEL_KCSD = kernel_constructor.create_kernel(PHI_KCSD)
CROSSKERNEL_KCSD = kernel_constructor.create_crosskernel(PHI_KCSD).reshape(convolver.shape('CSD') + (-1,))

In [None]:
%%time
pae_kesi = frr.PAE_AnalyticalCorrectedNumerically(convolver_interface,
                                                  potential=model_src.potential)

PHI_KESI = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                               pae_kesi)
KERNEL_KESI = kernel_constructor.create_kernel(PHI_KESI)
CROSSKERNEL_KESI = kernel_constructor.create_crosskernel(PHI_KESI).reshape(convolver.shape('CSD') + (-1,))

# kernel analysis

In [None]:
%%time

EIGENVALUES_KCSD, EIGENVECTORS_KCSD = np.linalg.eigh(KERNEL_KCSD)
EIGENVALUES_KCSD, EIGENVECTORS_KCSD = EIGENVALUES_KCSD[::-1], EIGENVECTORS_KCSD[:, ::-1]
LAMBDA_KCSD = np.sqrt(EIGENVALUES_KCSD)
EIGENSOURCES_KCSD = np.matmul(PHI_KCSD,
                              np.matmul(EIGENVECTORS_KCSD,
                                        np.diag(1. / LAMBDA_KCSD)))

del PHI_KCSD

In [None]:
%%time

EIGENVALUES_KESI, EIGENVECTORS_KESI = np.linalg.eigh(KERNEL_KESI)
EIGENVALUES_KESI, EIGENVECTORS_KESI = EIGENVALUES_KESI[::-1], EIGENVECTORS_KESI[:, ::-1]
LAMBDA_KESI = np.sqrt(EIGENVALUES_KESI)
EIGENSOURCES_KESI = np.matmul(PHI_KESI,
                              np.matmul(EIGENVECTORS_KESI,
                                        np.diag(1. / LAMBDA_KESI)))

del PHI_KESI

In [None]:
KCSD_TO_KESI_ES = np.matmul(EIGENSOURCES_KESI.T, EIGENSOURCES_KCSD)

In [None]:
assert (KCSD_TO_KESI_ES.max(axis=0) > 0).all()

In [None]:
for i, KESI_PROJECTION in enumerate(KCSD_TO_KESI_ES.T):
    _idx = np.argmax(abs(KESI_PROJECTION))
    if KESI_PROJECTION[_idx] > 0:
        EIGENSOURCES_KCSD[:, i] += EIGENSOURCES_KESI[:, _idx]
    else:
        EIGENSOURCES_KCSD[:, i] -= EIGENSOURCES_KESI[:, _idx]

EIGENSOURCES_MIXED = EIGENSOURCES_KCSD
del EIGENSOURCES_KCSD, EIGENSOURCES_KESI
EIGENSOURCES_MIXED *= 0.5

# IMAGES

In [None]:
%%time
GT_CSD = []

_SRC = np.zeros(convolver.shape('SRC'))
for i, _SRC[SRC_MASK] in enumerate(EIGENSOURCES_MIXED.T):
    print(i)
    GT_CSD.append(convolver.base_weights_to_csd(_SRC, model_src.csd, (ROMBERG_N,) * 3))
    
del _SRC

In [None]:
%%time
IMAGE = []

for i, _CSD in enumerate(GT_CSD):
    print(i)
    _csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  _CSD,
                                  bounds_error=False,
                                  fill_value=0)
    _v = fem_gt(_csd)
    IMAGE.append(np.array(list(map(_v, ELECTRODES.X, ELECTRODES.Y, ELECTRODES.Z))))

del _CSD, _csd, _v

In [None]:
for i, V in enumerate(IMAGE):
    plt.figure()
    plt.title(i)
    plt.axhline(1, ls=':', color=cbf.BLACK)
    plt.axvline(i, ls=':', color=cbf.BLACK)
    plt.stem(np.matmul(V, EIGENVECTORS_KCSD) / LAMBDA_KCSD)
    plt.plot(np.matmul(V, EIGENVECTORS_KESI) / LAMBDA_KESI)

## Images error (reconstructed)

In [None]:
%%time
IMAGE_ERRORS = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS.append(row)
    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL_KCSD, CROSSKERNEL_KCSD),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL, V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    print(row)
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS = pd.DataFrame(IMAGE_ERRORS)

In [None]:
IMAGE_ERRORS

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## Images error (reconstructed regularized)

In [None]:
plt.plot(EIGENVALUES_KCSD)
plt.plot(EIGENVALUES_KESI)
plt.yscale('log')

In [None]:
REGULARIZATION_PARAMETERS = np.logspace(11, 17, 5 * 6 + 1)

In [None]:
es_reconstructors = {method: verbose.VerboseFFR._CrossKernelReconstructor(
                                                     kesi._engine._LinearKernelSolver(
                                                         _KERNEL),
                                                     np.matmul(np.diag(_LAMBDA),
                                                               _EIGENVECTORS.T))
                     for method, _KERNEL, _EIGENVECTORS, _LAMBDA
                     in [('kCSD', KERNEL_KCSD, EIGENVECTORS_KCSD, LAMBDA_KCSD),
                         ('kESI', KERNEL_KESI, EIGENVECTORS_KESI, LAMBDA_KESI),
                         ]
                     }

In [None]:
%%time
IMAGE_ERRORS_CV = []
for i, (_CSD_GT, V) in enumerate(zip(GT_CSD, IMAGE)):
    print(i)
    row = {'ES': i,
           'GT_L1': abs(_CSD_GT).mean(),
           'GT_L2': np.sqrt(np.square(_CSD_GT).mean()),
           'GT_Linf': abs(_CSD_GT).max(),
           }
    IMAGE_ERRORS_CV.append(row)

    for method, _KERNEL, _CROSSKERNEL in [
        ('kCSD', KERNEL_KCSD, CROSSKERNEL_KCSD),
        ('kESI', KERNEL_KESI, CROSSKERNEL_KESI),
        ]:
        _ERRORS = common.cv(es_reconstructors[method],
                            V,
                            REGULARIZATION_PARAMETERS)
        regularization_parameter = REGULARIZATION_PARAMETERS[np.argmin(_ERRORS)]

        _DIFF = np.matmul(_CROSSKERNEL,
                          np.linalg.solve(_KERNEL
                                          + regularization_parameter * np.identity(len(_KERNEL)),
                                          V)) - _CSD_GT
        row[f'ERR_{method}_L1'] = abs(_DIFF).mean()
        row[f'ERR_{method}_L2'] = np.sqrt(np.square(_DIFF).mean())
        row[f'ERR_{method}_Linf'] = abs(_DIFF).max()
    
del _CSD_GT, _DIFF, _CROSSKERNEL, _KERNEL

In [None]:
IMAGE_ERRORS_CV = pd.DataFrame(IMAGE_ERRORS_CV)

In [None]:
IMAGE_ERRORS_CV

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color, ls in [('kCSD', cbf.SKY_BLUE, '-'),
                              ('kESI', cbf.VERMILION, ':'),
                             ]:
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=ls,
                 label=method)
    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')

## compare approaches

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    assert abs(IMAGE_ERRORS[f'GT_{norm}'] - IMAGE_ERRORS_CV[f'GT_{norm}']).max() == 0

In [None]:
for norm in ['L1', 'L2', 'Linf']:
    plt.figure()
    plt.title(norm)
    for method, color in [('kCSD', cbf.SKY_BLUE),
                          ('kESI', cbf.VERMILION),
                         ]:
        plt.plot(IMAGE_ERRORS.ES,
                 IMAGE_ERRORS[f'ERR_{method}_{norm}'] / IMAGE_ERRORS[f'GT_{norm}'],
                 color=color,
                 ls='-',
                 label=method)
        plt.plot(IMAGE_ERRORS_CV.ES,
                 IMAGE_ERRORS_CV[f'ERR_{method}_{norm}'] / IMAGE_ERRORS_CV[f'GT_{norm}'],
                 color=color,
                 ls=':',
                 label=f'{method} (CV)')

    plt.legend(loc='best')
    plt.axhline(1, color=cbf.BLACK, ls=':')
    plt.axhline(0.5, color=cbf.BLACK, ls=':')
    plt.axhline(0.1, color=cbf.BLACK, ls=':')
    plt.axhline(0.05, color=cbf.BLACK, ls=':')
    plt.yscale('log')