In [None]:
import datetime
import configparser
import os

import numpy as np
import pandas as pd

In [None]:
DEGREE = 1
CONFIG_NAME = 'comb'
MESH_NAME = 'composite_fine'
MODEL_NAME = '4SM'
K = 9
ROMBERG_K = 5
GROUNDED_PLATE_AT = -0.088

FILENAME_PATTERN = f'FEM/solutions/paper/{MODEL_NAME}/{MESH_NAME}_{DEGREE}/{CONFIG_NAME}_sampled/{K}/{{name}}.npz'
ES_PREFIX = f'ES/paper/{MODEL_NAME}/{MESH_NAME}/{DEGREE}/{CONFIG_NAME}/'

In [None]:
ELECTRODES = []
for col, x in zip('ABCDE', np.linspace(-6e-3, 6e-3, 5)):
    for row, z in enumerate(np.linspace(0.046, 0.076, 13)):
        ELECTRODES.append({'NAME': f'{col}_{row:02d}',
                           'X': x,
                           'Y': 0.0,
                           'Z': z})
ELECTRODES = pd.DataFrame(ELECTRODES)

In [None]:
class Electrode_kESI(object):
    def __init__(self, filename, decimals_tolerance=None, dx=0):
        self.filename = filename
        self.decimals_tolerance = decimals_tolerance
        self.dx = dx
        with np.load(filename) as fh:
            self._X = self.round(fh['X'])
            self._Y = self.round(fh['Y'])
            self._Z = self.round(fh['Z'])
            self.x, self.y, self.z = fh['LOCATION']
#             try:
            self.base_conductivity = fh['BASE_CONDUCTIVITY']
#             except KeyError:
#                 pass
    
    def round(self, A):
        if self.decimals_tolerance is None:
            return A
        return np.round(A, decimals=self.decimals_tolerance)

    def correction_potential(self, X, Y, Z):
        _X, IDX_X, _ = np.intersect1d(self._X, self.round(X[:, 0, 0]), return_indices=True)
        assert len(_X) == np.shape(X)[0]
        _Y, IDX_Y, _ = np.intersect1d(self._Y, self.round(Y[0, :, 0]), return_indices=True)
        assert len(_Y) == np.shape(Y)[1]
        _Z, IDX_Z, _ = np.intersect1d(self._Z, self.round(Z[0, 0, :]), return_indices=True)
        assert len(_Z) == np.shape(Z)[2]

        with np.load(self.filename) as fh:
            return fh['CORRECTION_POTENTIAL'][np.ix_(IDX_X, IDX_Y, IDX_Z)]

    def base_potential(self, X, Y, Z):
        return (0.25 / (np.pi * self.base_conductivity)
                / (self.dx * 0.15
                   + np.sqrt(np.square(X - self.x)
                             + np.square(Y - self.y)
                             + np.square(Z - self.z))))

        
class Electrode_kCSD(object):
    def __init__(self, filename):
        with np.load(filename) as fh:
            self.x, self.y, self.z = fh['LOCATION']

# FEM images

In [None]:
FEM_MESH_NAME = MESH_NAME
FEM_DEGREE = DEGREE
FEM_CONFIG_GT = f'FEM/fem_configs/paper/{MODEL_NAME}/{FEM_MESH_NAME}_{FEM_DEGREE}/{CONFIG_NAME}.ini'

In [None]:
import dolfin
import FEM.fem_common as fc
import scipy.interpolate as si

In [None]:
class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    def __init__(self, config):
        self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
        mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
            
            
        self.config = configparser.ConfigParser()
        self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] < GROUNDED_PLATE_AT))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
try:
    if FEM_CONFIG_GT != _LAST_FEM_CONFIG_GT:
        raise NameError
except NameError:
    %time fem_gt = ForwardModel(FEM_CONFIG_GT)
    _LAST_FEM_CONFIG_GT = FEM_CONFIG_GT

In [None]:
def forward_eigensources_to_images(fem, ELECTRODES, eigensource_filenames):
    potentials = {ROW.NAME: [] for _, ROW in ELECTRODES.iterrows()}
    
    for filename in eigensource_filenames:
        with np.load(filename) as fh:
            XYZ = [fh[c] for c in 'XYZ']
            CSD = fh['CSD']
        csd_interpolator = si.RegularGridInterpolator(
                                  XYZ,
                                  CSD,
                                  bounds_error=False,
                                  fill_value=0)
        try:
            v = fem(csd_interpolator)
        except RuntimeError:
            v = lambda x, y, z: np.nan
            
        for _, ROW in ELECTRODES.iterrows():
            potentials[ROW.NAME].append(v(ROW.X, ROW.Y, ROW.Z))

    return {name: np.array(a) for name, a in potentials.items()}

In [None]:
for method in ['kesi', 'kcsd']:
    for electrodes, n_ele in [('wide', 65),
                              ('narrow', 20)]:
        for sources in ['wide', 'narrow']:
            eigenimages = forward_eigensources_to_images(fem_gt,
                                               ELECTRODES, 
                                               [os.path.join(ES_PREFIX,
                                                             f'{method}_{electrodes}__SRCS_{sources}_ES{i:02d}.npz')
                                                for i in range(n_ele)])
            np.savez_compressed(os.path.join(ES_PREFIX,
                                             f'{method}_{electrodes}__SRCS_{sources}_ES_IMAGES_AT_ELECTRODES_{FEM_MESH_NAME}_{FEM_DEGREE}.npz'),
                                **eigenimages)