In [None]:
import os
if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir(os.pardir)

import matplotlib.pyplot as plt
import numpy as np
# import seaborn as sns

from src.utils.dataloader import load_ear_data
from src.utils.viz import set_axes_equal

In [None]:
# sns.set(style='ticks', palette='colorblind')

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
PROJECT_NAME = 'IMBioC2022_paper'

In [None]:
def export_pcd(df, area=False):
    if area:
        pcd = np.c_[df['x [mm]'].to_numpy(),
                 df['y [mm]'].to_numpy(),
                 df['z [mm]'].to_numpy(),
                 df['area [mm^2]'].to_numpy()]
    else:
        pcd = np.c_[df['x [mm]'].to_numpy(),
                    df['y [mm]'].to_numpy(),
                    df['z [mm]'].to_numpy()]
    return pcd


def export_fields(df):
    Ex = df['ExRe [V/m]'].to_numpy() + 1j * df['ExIm [V/m]'].to_numpy()
    Ey = df['EyRe [V/m]'].to_numpy() + 1j * df['EyIm [V/m]'].to_numpy()
    Ez = df['EzRe [V/m]'].to_numpy() + 1j * df['EzIm [V/m]'].to_numpy()
    Hx = df['HxRe [A/m]'].to_numpy() + 1j * df['HxIm [A/m]'].to_numpy()
    Hy = df['HyRe [A/m]'].to_numpy() + 1j * df['HyIm [A/m]'].to_numpy()
    Hz = df['HzRe [A/m]'].to_numpy() + 1j * df['HzIm [A/m]'].to_numpy()
    return ((Ex, Ey, Ez), (Hx, Hy, Hz))


def poynting_vector(E, H):
    return (E[1] * H[2].conjugate() - E[2] * H[1].conjugate(),
            E[2] * H[0].conjugate() - E[0] * H[2].conjugate(),
            E[0] * H[1].conjugate() - E[1] * H[0].conjugate())


def plot_2d(xy_dict, figsize=plt.rcParams['figure.figsize'], c=None, alpha=1):
    fig = plt.figure(figsize=figsize)
    ax = plt.axes()
    keys = list(xy_dict.keys())
    values = list(xy_dict.values())
    if (len(values) == 3) and not(c):
        cs = ax.scatter(values[0], values[1], c=values[2])
        cbar = fig.colorbar(cs)
        cbar.ax.set_ylabel(keys[2])
    else:
        if not(c):
            c = 'k'
        cs = ax.scatter(values[0], values[1], c=c, alpha=alpha)
    ax.set(xlabel=keys[0], ylabel=keys[1])
    
    ax.axis('equal')
    fig.tight_layout()
    return fig, ax


def plot_3d(xyz_dict, figsize=(7, 7), elev=20, azim=45, c=None, alpha=1):
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection ='3d')
    keys = list(xyz_dict.keys())
    values = list(xyz_dict.values())
    if (len(values) == 4) and not(c):
        cs = ax.scatter(values[0], values[1], values[2], c=values[3])
        cbar = fig.colorbar(cs, shrink=0.5)
        cbar.ax.set_ylabel(keys[3])
    else:
        if not(c):
            c = 'k'
        cs = ax.plot(values[0], values[1], values[2], '.', c=c, alpha=alpha)
    ax.set(xlabel=keys[0], ylabel=keys[1], zlabel=keys[2])
    ax = set_axes_equal(ax)
    ax.view_init(elev, azim)
    fig.tight_layout()
    return fig, ax


def estimate_normals(xyz, knn, down_sampling_ratio=1, orient_normals=False):
    import open3d as o3d
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    pcd.paint_uniform_color(np.array([1, 0, 0]))
    pcd = pcd.random_down_sample(down_sampling_ratio)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn))
    if orient_normals is True:
        pcd.orient_normals_consistent_tangent_plane(knn)
    elif orient_normals:
        pcd.orient_normals_consistent_tangent_plane(int(orient_normals))
    return np.asarray(pcd.points), np.asarray(pcd.normals)

In [None]:
df = load_ear_data('te', 60)
df

In [None]:
xyz = export_pcd(df)
xyz.shape

In [None]:
E, H = export_fields(df)
len(E), len(H)

In [None]:
Sx, Sy, Sz = poynting_vector(E, H)
S_dist_cpx = np.sqrt(Sx ** 2 + Sy ** 2 + Sz ** 2)
S_dist_abs = np.abs(S_dist_cpx)

In [None]:
fig, ax = plot_3d({'z [mm]': xyz[:, 2],
                   'x [mm]': xyz[:, 0],
                   'y [mm]': xyz[:, 1],
                   'S [W/m2]': S_dist_abs[:]}, elev=20)

In [None]:
xyz_ds, n = estimate_normals(xyz, knn=10, down_sampling_ratio=0.1, orient_normals=True)

fig, ax = plot_3d(xyz_dict={'x [mm]': xyz_ds[:, 0],
                            'y [mm]': xyz_ds[:, 1],
                            'z [mm]': xyz_ds[:, 2]},
                  c='k')
ax.quiver3D(xyz_ds[:, 0], xyz_ds[:, 1], xyz_ds[:, 2],
            n[:, 0], n[:, 1], n[:, 2],
            length=5, normalize=True, alpha=0.5);