# Requirements

## Environment

1. Anaconda Python distribution (tested with _Anaconda3-2021.11-Linux-x86\_64.sh_, _conda v. 4.12.0_).
2. Jupyter server (see _extras/jupyter\_server.sh_ for details).
3. Anaconda environments (run _setup\_conda\_envs.sh_).
4. gmsh (not necessary if you already have meshes in either MSH or XDMF format).


## Setup

### Geometry: mesh

You need to have a mesh in XDMF format.  Try:

    conda activate kesi37
    cd extras/FEM/_meshes_new
    snakemake finite_slice_fine.xdmf -j 1
    
It may take a while.  Sadly, errors are expected.  If the build fails due to an error in the _MSH_ rule, you may need to remove the intermediate _finite\_slice\_fine.msh_ file and remove `-optimize_netgen` _gmsh_ flag from _Snakefile_ to disable mesh optimization which may crash _gmsh_ with a segmentation fault.

Now you have the following files created:
- _finite\_slice\_fine.xdmf_,
- _finite\_slice\_fine.h5_,
- _finite\_slice\_fine_boundaries.xdmf_,
- _finite\_slice\_fine_boundaries.h5_,
- _finite\_slice\_fine_subdomains.xdmf_,
- _finite\_slice\_fine_subdomains.h5_

(_*.xdmf_ files are headers for _*.h5_ files).  The files are derived from _finite\_slice\_fine.msh_ mesh saved in _gmsh_ format, which is created from blueprints in the _finite\_slice\_fine.geo_.  The blueprint itself is derived from a template _finite\_slice\_fine.geo.template_.  

The template may also derive meshes of other resolutions:
- _finite\_slice\_finest.xdmf_,
- _finite\_slice\_finer.xdmf_,
- _finite\_slice.xdmf_,
- _finite\_slice\_coarse.xdmf_,
- _finite\_slice\_coarser.xdmf_, and
- _finite\_slice\_coarsest.xdmf_.

There are also blueprints for _finite_slice_composite.xdmf_.

> There are also templates for single sphere geometries:
> - _one\_sphere\_composite.geo.template_,
> - _one\_sphere\_plain.geo.template_, and
> - _one\_sphere\_uniform\_cortex.geo.template_.
> The recommended meshes derived from them are:
> - _one\_sphere\_composite\_finest.xdmf_,
> - _one\_sphere\_plain\_finer.xdmf_, 
> - _one\_sphere\_plain\_finest.xdmf_, 
> - _one\_sphere\_uniform\_cortex\_fine.xdmf_,
> - _one\_sphere\_uniform\_cortex\_finer.xdmf_, and
> - _one\_sphere\_uniform\_cortex\_finest.xdmf_.
>
> For four sphere geometry there are blueprints from which you can derive:
> - _four\_spheres\_separate\_cortex\_composite.xdmf_,
> - _four\_spheres\_separate\_cortex\_composite\_fine.xdmf_, and
> - _four\_spheres\_plain.xdmf_.

### Geometry: properties

For every mesh additional information is necessary, like conductivity of its compartments.  Such information is stored in the following files:
- _extras/FEM/\_meshes\_new/finite\_slice.ini_,
- _extras/FEM/\_meshes\_new/one\_sphere.ini_, and
- _extras/FEM/\_meshes\_new/four\_spheres.ini_.

Format of such file is:

    [<compartment name>]
    volume = <volume number>
    conductivity = <conductivity in SI units>
    
for a compartment and:

    [<surface name>]
    surface = <surface number>
    
for a boundary.  Additional information may be provided, like radius, thickness, or conductivity associated with external surface for subtraction method.

### FEM

In the _extras/FEM/fem\_configs_ you can find examplary configurations of FEM solver in the _fem_ section of INI files:

    [fem]
    mesh = _meshes_new/finite_slice_composite.xdmf
    config = _meshes_new/finite_slice.ini
    element_type = CG
    degree = 3
    solution_metadata_filename = solutions/paper/finite_slice/composite_3/grid_3d.ini

`mesh` is a path to the mesh header, while `config` - to properties of its compartments and boundaries.  `element_type` (`CG` for _continuous Galerkin_) and `degree` is configuration of elements used by FEM.
`solution_metadata_filename` points to a directory where the solutions are to be stored, and an INI file containing their metadata.


### electrodes

Other sections of the INI files have format:

    [<electrode name>]
    filename = <solution filename.h5>
    x = <X coordinate in meters>
    y = <Y coordinate in meters>
    z = <Z coordinate in meters>
    
and define name and spatial location of electrodes.  They also define the (relative) path to the file where the leadfield correction is to be stored.

# Preprocessing

Lets define a simple setup of three point electrodes:

    [fem]
    mesh = _meshes_new/finite_slice_fine.xdmf
    config = _meshes_new/finite_slice.ini
    element_type = CG
    degree = 3
    solution_metadata_filename = solutions/tutorial/finite_slice/fine_3/demo.ini
    
    [first]
    filename = demo/first.h5
    x = 0
    y = 0
    z = 0.5e-4
    
    [second]
    filename = demo/second.h5
    x = 0.5e-4
    y = 0
    z = 1.5e-4
    
    [third]
    filename = demo/third.h5
    x = 0.5e-4
    y = -0.5e-4
    z = 2.5e-4
    
Write the setup as _extras/FEM/fem\_configs/tutorial/finite\_slice/fine\_3/demo.ini_.
 
For every electrode solve for the leadfield correction:

    cd extras
    python paper_solve_slice_on_plate.py -c FEM/fem_configs/tutorial/finite_slice/fine_3/demo.ini -n first -o FEM/solutions/tutorial/finite_slice/fine_3/demo/first.ini
    python paper_solve_slice_on_plate.py -c FEM/fem_configs/tutorial/finite_slice/fine_3/demo.ini -n second -o FEM/solutions/tutorial/finite_slice/fine_3/demo/second.ini
    python paper_solve_slice_on_plate.py -c FEM/fem_configs/tutorial/finite_slice/fine_3/demo.ini -n third -o FEM/solutions/tutorial/finite_slice/fine_3/demo/third.ini
    
Then join the scattered metadata.

In [None]:
import configparser

metadata = configparser.ConfigParser()
for filename in ['FEM/fem_configs/tutorial/finite_slice/fine_3/demo.ini',
                 'FEM/solutions/tutorial/finite_slice/fine_3/demo/first.ini',
                 'FEM/solutions/tutorial/finite_slice/fine_3/demo/second.ini',
                 'FEM/solutions/tutorial/finite_slice/fine_3/demo/third.ini',
                 ]:
    metadata.read(filename)
    metadata.write(open('FEM/solutions/tutorial/finite_slice/fine_3/demo.ini', 'w'))

Sample the solution on NxNxN grid, where `N = 2**K + 1`:

    mkdir -p FEM/solutions/tutorial/finite_slice/fine_3/demo_sampled/9/
    python paper_sample_slice_solution.py -c FEM/solutions/tutorial/finite_slice/fine_3/demo.ini -n first -k 9 -o FEM/solutions/tutorial/finite_slice/fine_3/demo_sampled/9/first.npz -r 0.0003
    python paper_sample_slice_solution.py -c FEM/solutions/tutorial/finite_slice/fine_3/demo.ini -n second -k 9 -o FEM/solutions/tutorial/finite_slice/fine_3/demo_sampled/9/second.npz -r 0.0003
    python paper_sample_slice_solution.py -c FEM/solutions/tutorial/finite_slice/fine_3/demo.ini -n third -k 9 -o FEM/solutions/tutorial/finite_slice/fine_3/demo_sampled/9/third.npz -r 0.0003
    
It may take several hours.

# Kernel construction

## Electrode object

An electrode object contains information about electrode spatial location (`.x`, `.y` and `.z` attribute), which is an absolute minimum to be used by kESI (in this case: kCSD with known base function in potential space).  It may also provide additional information about:
- base leadfield (`.base_potential()` method) which enable kCSD for arbitrary base function in CSD space, or
- leadfield correction (`.correction_potential()` method) which enable kESI for setups violating kCSD assumptions,
- base conductivity (`.base_conductivity` attribute) assumed when calculating the leadfield correction.

In [None]:
import numpy as np

class Electrode(object):
    def __init__(self, filename, decimals_tolerance=None, dx=0):
        """
        Parameters
        ----------
        
        filename : str
            Path to the sampled correction potential.
            
        decimals_tolerance : int
            Precision of coordinate comparison
            in the `.correction_potential()` method.
            
        dx : float
            Integration step used to calculate a regularization
            parameter of the `.base_potential()` method.
        """
        self.filename = filename
        self.decimals_tolerance = decimals_tolerance
        self.dx = dx
        with np.load(filename) as fh:
            self._X = self.round(fh['X'])
            self._Y = self.round(fh['Y'])
            self._Z = self.round(fh['Z'])
            self.x, self.y, self.z = fh['LOCATION']
            self.base_conductivity = fh['BASE_CONDUCTIVITY']

    @property
    def _epsilon(self):
        """
        Regularization parameter of the `.base_potential()` method.
        
        Note
        ----
        
        The 0.15 factor choice has been based on a toy numerical experiment.
        Further, more rigorous experiments are definitely recommended.
        """
        return 0.15 * self.dx
    
    def round(self, A):
        if self.decimals_tolerance is None:
            return A
        return np.round(A, decimals=self.decimals_tolerance)

    def correction_potential(self, X, Y, Z):
        """
        Parameters
        ----------
        X, Y, Z : np.array
            Coordinate matrices with matrix indexing.
            Coordinates are expected to be - respectively -
            from `._X`, `._Y` and `._Z` attributes.
            May be obtained with
            `X, Y, Z = np.meshgrid(..., indexing='ij')`.
        """
        _X, IDX_X, _ = np.intersect1d(self._X, self.round(X[:, 0, 0]), return_indices=True)
        assert len(_X) == np.shape(X)[0]
        _Y, IDX_Y, _ = np.intersect1d(self._Y, self.round(Y[0, :, 0]), return_indices=True)
        assert len(_Y) == np.shape(Y)[1]
        _Z, IDX_Z, _ = np.intersect1d(self._Z, self.round(Z[0, 0, :]), return_indices=True)
        assert len(_Z) == np.shape(Z)[2]

        with np.load(self.filename) as fh:
            return fh['CORRECTION_POTENTIAL'][np.ix_(IDX_X, IDX_Y, IDX_Z)]

    def base_potential(self, X, Y, Z):
        return (0.25 / (np.pi * self.base_conductivity)
                / (self._epsilon
                   + np.sqrt(np.square(X - self.x)
                             + np.square(Y - self.y)
                             + np.square(Z - self.z))))

In [None]:
electrodes = [Electrode(f'FEM/solutions/tutorial/finite_slice/fine_3/demo_sampled/9/{name}.npz')
              for name in ['first', 'second', 'third']]

## FRR: Fast Reciprocal Reconstructor 

The `_fast_reciprocal_reconstructor` experimental module contains tools which allow for fast construction of (cross)kernels.  They use discrete for high throughput integration.

### Convolver object

The convolver is the engine of the FRR tools.  It is used to:
- integrate leadfields weighted by a CSD profile,
- obtain CSD profile of a mixture of base functions.

The convolver operates on three regular 3D grids of coordinates:
- _POT_ grid used for leadfield (_reciprocal potential_) integration,
- _CSD_ grid used for CSD profile calculation,
- _SRC_ grid used for distributing of base function centroids.

The _SRC_ grid is an intersection (in set arithmetic sense) of the _POT_ and the _CSD_ grids, thus they define the convolver unequivocally.

In [None]:
from _fast_reciprocal_reconstructor import ckESI_convolver

In [None]:
_X = electrodes[0]._X
_Y = electrodes[0]._Y
_Z = electrodes[0]._Z

_pot_mesh = [_X, _Y, _Z]
_csd_mesh = [_X, _Y, _Z]

convolver = ckESI_convolver(_pot_mesh, _csd_mesh)

In [None]:
print(convolver.csd_shape)

In [None]:
for name in ['POT', 'CSD', 'SRC']:
    print(f'{name} mesh')
    print('  shape:', convolver.shape(name))
    print('  spacing:', convolver.ds(name))

Open 3D meshgrids may be accessed as `.{NAME}_MESH` attributes, where `{NAME}` is the name of the mesh.
Components of each meshgrid may be accessed as `.{NAME}_{C}` attributes, where `{C}` is the name of the coordinate.

### Model source

While FRR tools operate on base function profiles defined as callables, we can use convenience kCSD base function objects as model bases (bases which centroid is `(0, 0, 0)`).
As convolver will use the Romberg method for integration, the size of the CSD profile is limited by the K parameter of the method.

In [None]:
from _common_new import SphericalSplineSourceKCSD, GaussianSourceKCSD3D

ROMBERG_K = 6
SRC_R_MAX = 2 ** (ROMBERG_K - 1) * min(convolver.ds('POT'))
BASE_CONDUCTIVITY = electrodes[0].base_conductivity

spline_nodes = [SRC_R_MAX / 3, SRC_R_MAX]
spline_polynomials = [[1],
                      [0,
                       6.75 / SRC_R_MAX,
                       -13.5 / SRC_R_MAX ** 2,
                       6.75 / SRC_R_MAX ** 3]]
model_src = SphericalSplineSourceKCSD(0, 0, 0,
                                      spline_nodes,
                                      spline_polynomials,
                                      BASE_CONDUCTIVITY)
print(SRC_R_MAX)

### Convolver interface

The convolver interface binds the convolver to:
- a CSD profile,
- weights of a quadrature of equally-spaced nodes,
- boolean mask of nodes of the _SRC_ mesh with centroids of the base functions.

When analytical solution of the kCSD forward problem is used coupled with numeric integration of leadfield correction, it is advised not to put centroids near the boundary of the _SRC_ grid.

In [None]:
from _fast_reciprocal_reconstructor import ConvolverInterfaceIndexed
from scipy.integrate import romb

ROMBERG_N = 2 ** ROMBERG_K + 1
ROMBERG_WEIGHTS = romb(np.identity(ROMBERG_N)) * 2 ** -ROMBERG_K

SRC_IDX = ((convolver.SRC_Z > SRC_R_MAX)
           & (convolver.SRC_Z < 3e-4 - SRC_R_MAX)
           & (abs(convolver.SRC_X) < 1.5e-4 - SRC_R_MAX)
           & (abs(convolver.SRC_Y) < 1.5e-4 - SRC_R_MAX))

In [None]:
print(SRC_IDX.sum())

In [None]:
convolver_interface = ConvolverInterfaceIndexed(convolver,
                                                model_src.csd,
                                                ROMBERG_WEIGHTS,
                                                SRC_IDX)

### Kernel constructor and cross-kernel constructor

The kernel constructor is an object which is a collection of callables (methods) facilitating construction of base function images at electrodes (_PHI_ matrix) and the kernel matrix.  The cross-kernel constructor is a callable which - based on the _PHI_ matrix and boolean mask of the _CSD_ grid - constructs the cross-kernel.

In [None]:
from _fast_reciprocal_reconstructor import ckESI_kernel_constructor, ckESI_crosskernel_constructor

In [None]:
kernel_constructor = ckESI_kernel_constructor()

In [None]:
CSD_IDX = np.ones(convolver.shape('CSD'),
                  dtype=bool)

In [None]:
kernel_constructor.create_crosskernel = ckESI_crosskernel_constructor(convolver_interface,
                                                                      CSD_IDX)

### Potential At Electrode: analytical solution of the kCSD forward problem

In [None]:
from _fast_reciprocal_reconstructor import PAE_kCSD_Analytical

In [None]:
pae_kcsd_a = PAE_kCSD_Analytical(convolver_interface,
                                 potential=model_src.potential)

In [None]:
%%time
PHI_KCSD_ANALYTICAL = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                                          pae_kcsd_a)

In [None]:
KERNEL_KCSD_ANALYTICAL = kernel_constructor.create_kernel(PHI_KCSD_ANALYTICAL)

In [None]:
%%time
CROSSKERNEL_KCSD_ANALYTICAL = kernel_constructor.create_crosskernel(PHI_KCSD_ANALYTICAL)

### Potential At Electrode: numerical solution of the kCSD forward problem

In [None]:
from _fast_reciprocal_reconstructor import PAE_kCSD_Numerical

In [None]:
pae_kcsd_n = PAE_kCSD_Numerical(convolver_interface)

In [None]:
%%time
PHI_KCSD_NUMERICAL = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                                         pae_kcsd_n)

In [None]:
KERNEL_KCSD_NUMERICAL = kernel_constructor.create_kernel(PHI_KCSD_NUMERICAL)

In [None]:
%%time
CROSSKERNEL_KCSD_NUMERICAL = kernel_constructor.create_crosskernel(PHI_KCSD_NUMERICAL)

### Potential At Electrode: kESI corrected analytical solution of the kCSD forward problem

In [None]:
from _fast_reciprocal_reconstructor import PAE_kESI_Analytical

In [None]:
pae_kesi_a = PAE_kESI_Analytical(convolver_interface,
                                 potential=model_src.potential)

In [None]:
%%time
PHI_KESI_ANALYTICAL = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                                          pae_kesi_a)

In [None]:
KERNEL_KESI_ANALYTICAL = kernel_constructor.create_kernel(PHI_KESI_ANALYTICAL)

In [None]:
%%time
CROSSKERNEL_KESI_ANALYTICAL = kernel_constructor.create_crosskernel(PHI_KESI_ANALYTICAL)

### Potential At Electrode: kESI corrected numerical solution of the kCSD forward problem

In [None]:
from _fast_reciprocal_reconstructor import PAE_kESI_Numerical

In [None]:
pae_kesi_n = PAE_kESI_Numerical(convolver_interface)

In [None]:
%%time
PHI_KESI_NUMERICAL = kernel_constructor.create_base_images_at_electrodes(electrodes,
                                                                         pae_kesi_n)

In [None]:
KERNEL_KESI_NUMERICAL = kernel_constructor.create_kernel(PHI_KESI_NUMERICAL)

In [None]:
%%time
CROSSKERNEL_KESI_NUMERICAL = kernel_constructor.create_crosskernel(PHI_KESI_NUMERICAL)

# Reconstructor

In [None]:
from kesi._verbose import _CrossKernelReconstructor as Reconstructor
from kesi._engine import _LinearKernelSolver as KernelSolver

In [None]:
reconstructor_kcsd_a = Reconstructor(KernelSolver(KERNEL_KCSD_ANALYTICAL),
                                     CROSSKERNEL_KCSD_ANALYTICAL)

In [None]:
reconstructor_kesi_a = Reconstructor(KernelSolver(KERNEL_KESI_ANALYTICAL),
                                     CROSSKERNEL_KESI_ANALYTICAL)

In [None]:
reconstructor_kcsd_n = Reconstructor(KernelSolver(KERNEL_KCSD_NUMERICAL),
                                     CROSSKERNEL_KCSD_NUMERICAL)

In [None]:
reconstructor_kesi_n = Reconstructor(KernelSolver(KERNEL_KESI_NUMERICAL),
                                     CROSSKERNEL_KESI_NUMERICAL)

# FEM forward modelling

In [None]:
import configparser
import dolfin
import scipy.interpolate as si

import FEM.fem_common as fc

We define a FEM forward model. It is a callable, which accepts CSD profile as a callable compatible with `scipy.interpolate.RegularGridInterpolator`.

In [None]:
class ForwardModel(object):
    # XXX: duplicated code with FEM classes
    def __init__(self, config):
        self.fm = fc.FunctionManagerINI(config)
        
        self.V = self.fm.function_space
        mesh = self.fm.mesh

        n = self.V.dim()
        d = mesh.geometry().dim()

        self.dof_coords = self.V.tabulate_dof_coordinates()
        self.dof_coords.resize((n, d))
        
        self.csd_f = self.fm.function()
        
        
        mesh_filename = self.fm.getpath('fem', 'mesh')[:-5]
        with dolfin.XDMFFile(mesh_filename + '_subdomains.xdmf') as fh:
            mvc = dolfin.MeshValueCollection("size_t", mesh, 3)
            fh.read(mvc, "subdomains")
            self.subdomains = dolfin.cpp.mesh.MeshFunctionSizet(mesh, mvc)
            self.dx = dolfin.Measure("dx")(subdomain_data=self.subdomains)
            
            
        self.config = configparser.ConfigParser()
        self.config.read(self.fm.getpath('fem', 'config'))

    @property
    def CONDUCTIVITY(self):
        for section in self.config.sections():
            if self._is_conductive_volume(section):
                yield (self.config.getint(section, 'volume'),
                       self.config.getfloat(section, 'conductivity'))

    def _is_conductive_volume(self, section):
        return (self.config.has_option(section, 'volume')
                and self.config.has_option(section, 'conductivity')) 
        
    def __call__(self, csd_interpolator):
        self.csd_f.vector()[:] = csd_interpolator(self.dof_coords)
        
        dirichlet_bc_gt = dolfin.DirichletBC(self.V,
                                     dolfin.Constant(0),
                                     (lambda x, on_boundary:
                                      on_boundary and x[2] > 0))
        test = self.fm.test_function()
        trial = self.fm.trial_function()
        potential = self.fm.function()
        
        
        dx = self.dx
        a = sum(dolfin.Constant(c)
                * dolfin.inner(dolfin.grad(trial),
                               dolfin.grad(test))
                * dx(i)
                for i, c
                in self.CONDUCTIVITY)
        L = self.csd_f * test * dx
        
        b = dolfin.assemble(L)
        A = dolfin.assemble(a)
        dirichlet_bc_gt.apply(A, b)
        
        solver = dolfin.KrylovSolver("cg", "ilu")
        solver.parameters["maximum_iterations"] = 10000
        solver.parameters["absolute_tolerance"] = 1E-8
        # solver.parameters["monitor_convergence"] = True  # Goes to Jupyter server output stream
        solver.solve(A, potential.vector(), b)
        
        return potential

In [None]:
%time fem = ForwardModel('FEM/fem_configs/tutorial/finite_slice/fine_3/demo.ini')

In [None]:
%%time
_EIGENVALUES, _EIGENVECTORS = np.linalg.eigh(KERNEL_KESI_ANALYTICAL)
_CSD = reconstructor_kesi_a(_EIGENVECTORS[:, 1]).reshape(convolver.shape('CSD'))
_csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  _CSD,
                                  bounds_error=False,
                                  fill_value=0)
_v = fem(_csd)
_V = np.array([_v(_e.x, _e.y, _e.z) for _e in electrodes])

print(np.matmul(_V, _EIGENVECTORS))

# Reconstruction

## Ground truth CSD and its image at the electrodes

Derive GT CSD as an eigensource of analytical kCSD.

In [None]:
%%time
_EIGENVALUES, _EIGENVECTORS = np.linalg.eigh(KERNEL_KCSD_ANALYTICAL)
GT_CSD = reconstructor_kcsd_a(_EIGENVECTORS[:, 1]).reshape(convolver.shape('CSD'))
_csd = si.RegularGridInterpolator(
                                  [getattr(convolver, f'CSD_{x}').flatten()
                                   for x in 'XYZ'],
                                  GT_CSD,
                                  bounds_error=False,
                                  fill_value=0)
_v = fem(_csd)
GT_V = np.array([_v(_e.x, _e.y, _e.z) for _e in electrodes])

print(np.matmul(GT_V, _EIGENVECTORS))

In [None]:
%%time
V_WHOLE_SPACE = np.vectorize(_v)(convolver.CSD_X[::4, :, :],
                                 convolver.CSD_Y[:, ::4, :],
                                 convolver.CSD_Z[:, :, ::4])

### Visualisation

In [None]:
import matplotlib.pyplot as plt
import cbf

In [None]:
def crude_plot_data(DATA,
                    x=None,
                    y=None,
                    z=None,
                    grid=None,
                    dpi=30,
                    cmap=cbf.bwr,
                    title=None,
                    amp=None):
    wx, wy, wz = DATA.shape
    
    if grid is None:
        ix, iy, iz = [w // 2 if a is None else a
                      for a, w in zip([x, y, z],
                                      [wx, wy, wz])]
        x, y, z = ix, iy, iz
        
    else:
        x, y, z = [g.mean() if a is None else a
                   for a, g in zip([x, y, z],
                                   grid)]
        ix, iy, iz = [np.searchsorted(g, a)
                      for a, g in zip([x, y, z],
                                   grid)]

    fig = plt.figure(figsize=((wx + wy) / dpi,
                              (wz + wy) / dpi))
    if title is not None:
        fig.suptitle(title)
    gs = plt.GridSpec(2, 2,
                      figure=fig,
                      width_ratios=[wx, wy],
                      height_ratios=[wz, wy])

    ax_xz = fig.add_subplot(gs[0, 0])
    ax_xz.set_aspect('equal')
    ax_xz.set_ylabel('Z')

    ax_xy = fig.add_subplot(gs[1, 0],
                            sharex=ax_xz)
    ax_xy.set_aspect('equal')
    ax_xy.set_ylabel('Y')
    ax_xy.set_xlabel('X')

    ax_yz = fig.add_subplot(gs[0, 1],
                            sharey=ax_xz)
    ax_yz.set_aspect('equal')
    ax_yz.set_xlabel('Y')

    cax = fig.add_subplot(gs[1,1])
    cax.set_visible(False)
#     cax.get_xaxis().set_visible(False)
#     cax.get_yaxis().set_visible(False)


    if amp is None:
        amp = abs(DATA).max()

    if grid is None:
        ax_xz.imshow(DATA[:, iy, :].T,
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap)
        ax_xy.imshow(DATA[:, :, iz].T,
             vmin=-amp,
             vmax=amp,
             cmap=cmap)
        im = ax_yz.imshow(DATA[ix, :, :].T,
                          vmin=-amp,
                          vmax=amp,
                          cmap=cmap)
    else:
        ax_xz.imshow(DATA[:, iy, :].T,
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     extent=(grid[0].min(), grid[0].max(),
                             grid[2].min(), grid[2].max()))
        ax_xy.imshow(DATA[:, :, iz].T,
                     vmin=-amp,
                     vmax=amp,
                     cmap=cmap,
                     extent=(grid[0].min(), grid[0].max(),
                             grid[1].min(), grid[1].max()))
        im = ax_yz.imshow(DATA[ix, :, :].T,
                          vmin=-amp,
                          vmax=amp,
                          cmap=cmap,
                          extent=(grid[1].min(), grid[1].max(),
                                  grid[2].min(), grid[2].max()))
        
    ax_xz.axvline(x, ls=':', color=cbf.BLACK)
    ax_xz.axhline(z, ls=':', color=cbf.BLACK)

    ax_xy.axvline(x, ls=':', color=cbf.BLACK)
    ax_xy.axhline(y, ls=':', color=cbf.BLACK)

    ax_yz.axvline(y, ls=':', color=cbf.BLACK)
    ax_yz.axhline(z, ls=':', color=cbf.BLACK)
    fig.colorbar(im, ax=cax)


In [None]:
crude_plot_data(GT_CSD,
                grid=[_x.flatten() for _x in convolver.CSD_MESH],
                x=5e-5)

In [None]:
crude_plot_data(V_WHOLE_SPACE,
                cmap=cbf.PRGn,
                x=90)