In [None]:
import os
import random
import itertools

import numpy as np
import pandas as pd

import FEM.fem_sphere_point_new as fspn
import FEM.fem_common as fc

In [None]:
N = 100

In [None]:
R = 89e-3
GEOMETRY = 'single_sphere_composite'
PROPERTIES = 'FEM/model_properties/single_sphere.ini'

In [None]:
MESHES = ['coarse',
          'normal',
          'fine',
          'finer',
          'finest',
          'superfine',
          ]

DEGREES = [1, 2, 3]

CONFIGS = list(itertools.product(MESHES, DEGREES))

In [None]:
MESH_DIR = 'FEM/meshes/meshes/'

In [None]:
RESULT_DIR = os.path.join('test_FEM_parameters',
                          GEOMETRY)
os.makedirs(RESULT_DIR, exist_ok=True)

In [None]:
POINTS_FILE = os.path.join(RESULT_DIR,
                           'point.csv')

if os.path.exists(POINTS_FILE):
    print(POINTS_FILE, 'found')
    POINTS = pd.read_csv(POINTS_FILE,
                         index_col=0)

else:
    random.seed(42)
    POINTS = np.full((N, 3), np.nan)

    i = 0

    while i < N:
        x = random.uniform(-R, R)
        y = random.uniform(-R, R)
        z = random.uniform(-R, R)
        if x ** 2 + y ** 2 + z ** 2 < R ** 2:
            POINTS[i, :] = x, y, z
            i += 1
            
    POINTS = pd.DataFrame(POINTS, columns=['X', 'Y', 'Z'])
    POINTS.to_csv(POINTS_FILE,
                  index_label='ID')

In [None]:
setup_time = fc.fc.Stopwatch()
total_solving_time = fc.fc.Stopwatch()

In [None]:
for mesh, degree in CONFIGS:
    print(mesh, degree)
    result_file = os.path.join(RESULT_DIR,
                               f'{mesh}_{degree}.csv')
    
    if os.path.exists(result_file):
        print(' ', result_file, 'found')
        continue
    
    DF = []
    with setup_time:
        function_manager = fc.FunctionManager(os.path.join(MESH_DIR,
                                                           GEOMETRY,
                                                           f'{mesh}.xdmf'),
                                              degree,
                                              'CG')
        fem = fspn.SphereOnGroundedPlatePointSourcePotentialFEM(function_manager,
                                                                PROPERTIES)

    for src, SRC in POINTS.iterrows():
        print(mesh, degree, src)
        with total_solving_time:
            potential_corr = fem.correction_potential(SRC.X, SRC.Y, SRC.Z)
            
        for dst, DST in POINTS.iterrows():
            DF.append({
                'SRC': src,
                'DST': dst,
                'CORR': potential_corr(DST.X, DST.Y, DST.Z),
                'SOLVING_TIME': float(total_solving_time),
                'SETUP_TIME': float(setup_time)
            })
            
    DF = pd.DataFrame(DF)
    DF.to_csv(result_file, index=False)

# Analysis

In [None]:
import matplotlib.pyplot as plt

from local import cbf

In [None]:
import configparser
from kesi import common

In [None]:
config = configparser.ConfigParser()
config.read(PROPERTIES)

In [None]:
BRAIN_CONDUCTIVITY = config.getfloat('brain', 'conductivity')

## Reading correction matrices

In [None]:
labels = []
corrections = []

for mesh, degree in CONFIGS:
    print(mesh, degree)
    result_file = os.path.join(RESULT_DIR,
                               f'{mesh}_{degree}.csv')
    if not os.path.exists(result_file):
        print(' not found, skipping')
        continue

    labels.append(f'{mesh} {degree}')

    DF = pd.read_csv(result_file)
    CORR = np.full((N, N), np.nan)
    
    for row in DF.itertuples():
        CORR[row.SRC, row.DST] = row.CORR
        
    corrections.append(CORR)

## Base potential

In [None]:
BASE_POTENTIAL = np.full((N, N), np.nan)

for src, SRC in POINTS.iterrows():
    _src = common.PointSource(SRC.X,
                              SRC.Y,
                              SRC.Z,
                              conductivity=BRAIN_CONDUCTIVITY)
    for dst, DST in POINTS.iterrows():
        BASE_POTENTIAL[src, dst] = _src.potential(DST.X, DST.Y, DST.Z)

In [None]:
OFF_DIAGONAL_IDX = ~np.eye(N, dtype=bool)

print('Maximal reciprocity error:', abs(BASE_POTENTIAL - BASE_POTENTIAL.T)[OFF_DIAGONAL_IDX].max())
_OFF_DIAGONAL = BASE_POTENTIAL[OFF_DIAGONAL_IDX]
print('Linf:', abs(_OFF_DIAGONAL).max())
print('L2:', np.sqrt(np.square(_OFF_DIAGONAL).mean()))
print('L1:', abs(_OFF_DIAGONAL).mean())
print('Median absolute value:', np.median(abs(_OFF_DIAGONAL)))
print('min, med, max:', _OFF_DIAGONAL.min(), np.median(_OFF_DIAGONAL), _OFF_DIAGONAL.max())

## Exact solution approximation

In [None]:
# We approximate exact solution as average of the most advanced FEM configurations
# (in terms of mesh density and element degree)

_corrections = dict(zip(labels, corrections))

# AVG = 0.5 * (_corrections['finest 3'] + _corrections['finest 2'])
AVG = _corrections['finest 3']

In [None]:
print('Maximal reciprocity error:', abs(AVG - AVG.T)[OFF_DIAGONAL_IDX].max())
_OFF_DIAGONAL = AVG[OFF_DIAGONAL_IDX]
print('Linf:', abs(_OFF_DIAGONAL).max())
print('L2:', np.sqrt(np.square(_OFF_DIAGONAL).mean()))
print('L1:', abs(_OFF_DIAGONAL).mean())
print('Median absolute value:', np.median(abs(_OFF_DIAGONAL)))
print('min, med, max:', _OFF_DIAGONAL.min(), np.median(_OFF_DIAGONAL), _OFF_DIAGONAL.max())

## Reciprocity validation

In [None]:
reciprocity_errors = [A - A.T for A in corrections]

In [None]:
reciprocity_relative_errors = [(A / (AVG + BASE_POTENTIAL))[OFF_DIAGONAL_IDX] for A in reciprocity_errors]

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Reciprocity errors [V]')
plt.yscale('symlog')
plt.grid()

_ = plt.boxplot([np.ravel(A) for A in reciprocity_errors],
                labels=labels)

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Modulus of reciprocity errors [V]')
plt.yscale('log')
plt.yticks(np.logspace(-8, 2, 11))
plt.grid()
_ = plt.violinplot([A[A > 0] for A in reciprocity_errors])
_ = plt.boxplot([A[A > 0] for A in reciprocity_errors],
                labels=labels)
# _ = plt.violinplot([A[A > 0] for A in reciprocity_errors])

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Modulus of reciprocity errors [%]')
plt.yscale('log')
plt.yticks(np.logspace(-8, 2, 11))
plt.grid()
_ = plt.violinplot([A[A > 0] * 100 for A in reciprocity_relative_errors])
_ = plt.boxplot([A[A > 0] * 100 for A in reciprocity_relative_errors],
                labels=labels)

## Convergence validation

In [None]:
# _corrections = dict(zip(labels, corrections))

# removed = []
# removal_score = []

# _score = 1
# while len(_corrections) > 2 and _score:
#     _AVG = sum(_corrections.values()) / len(_corrections)
#     _score = 0
#     for _k, _CORR in _corrections.items():
#         _s = np.sqrt(np.square(_AVG - _CORR).mean())
#         if _s > _score:
#             _score = _s
#             _key = _k
            
#     removed.append(_key)
#     removal_score.append(_score)
#     del _corrections[_key]

# AVG = sum(_corrections.values()) / len(_corrections)

In [None]:
# for k, v in zip(removed, removal_score):
#     print(f'  {k}\t{v}')
    
# for k, v in zip(labels, corrections):
#     if k in removed:
#         continue
    
#     minimal_score = np.sqrt(np.square(v - AVG).mean())
#     print(f'> {k}\t{minimal_score}')

In [None]:
# plt.plot(removal_score)
# plt.axhline(minimal_score)
# plt.yscale('log')
# plt.grid()

In [None]:
diffs = [_CORR - AVG for _CORR in corrections]
error_L1 = np.array([abs(_DIFF).mean() for _DIFF in diffs])
error_L2 = np.array([np.sqrt(np.square(_DIFF).mean()) for _DIFF in diffs])
error_Linf = np.array([abs(_DIFF).max() for _DIFF in diffs])
error_bias = np.array([_DIFF.mean() for _DIFF in diffs])

In [None]:
diffs_relative = [_DIFF / (AVG + BASE_POTENTIAL) for _DIFF in diffs]
error_relative_L1 = np.array([abs(_DIFF[OFF_DIAGONAL_IDX]).mean() for _DIFF in diffs_relative])
error_relative_L2 = np.array([np.sqrt(np.square(_DIFF[OFF_DIAGONAL_IDX]).mean()) for _DIFF in diffs_relative])
error_relative_Linf = np.array([abs(_DIFF[OFF_DIAGONAL_IDX]).max() for _DIFF in diffs_relative])
error_relative_bias = np.array([_DIFF[OFF_DIAGONAL_IDX].mean() for _DIFF in diffs_relative])

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors [V]')
plt.yscale('log')
plt.grid()
_ = plt.boxplot([abs(np.ravel(A)) for A in diffs],
                labels=labels)

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors (diagonal excluded) [V]')
plt.yscale('log')
plt.grid()
_ = plt.boxplot([abs(A[OFF_DIAGONAL_IDX]) for A in diffs],
                labels=labels)

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors [%]')
plt.yscale('log')
plt.grid()
_ = plt.boxplot([100 * abs(np.ravel(A)) for A in diffs_relative],
                labels=labels)

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors (diagonal excluded) [%]')
plt.yscale('log')
plt.grid()
_ = plt.boxplot([100 * abs(A[OFF_DIAGONAL_IDX]) for A in diffs_relative],
                labels=labels)

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors [V]')

plt.plot(error_L1, label='L1', marker='o')
plt.plot(error_L2, label='L2', marker='+')
plt.plot(error_Linf, label='L\u221e')
plt.plot(error_bias, label='bias')
plt.yscale('symlog', linthresh=0.1)
plt.xticks(range(len(labels)), labels)
plt.grid()
plt.legend(loc='best')

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Convergence errors [%]')

plt.plot(100 * error_relative_L1, label='L1', marker='o')
plt.plot(100 * error_relative_L2, label='L2', marker='+')
plt.plot(100 * error_relative_Linf, label='L\u221e')
plt.plot(100 * error_relative_bias, label='bias')
plt.yscale('symlog', linthresh=0.1)
plt.xticks(range(len(labels)), labels)
plt.grid()
plt.legend(loc='best')