In [None]:
import os
if os.getcwd().split("/")[-1] == "notebooks":
    os.chdir(os.pardir)

import matplotlib.pyplot as plt
import numpy as np

from src.utils.dataloader import load_ear_data
from src.utils.viz import set_axes_equal

In [None]:
def clean_df(df):
    df = df[(df['x [mm]'] != df['x [mm]'].min())
            & (df['x [mm]'] != df['x [mm]'].max())
            & (df['y [mm]'] != df['y [mm]'].min())
            & (df['y [mm]'] != df['y [mm]'].max())
            & (df['z [mm]'] != df['z [mm]'].min())
            & (df['z [mm]'] != df['z [mm]'].max())]
    df.reset_index(drop=True, inplace=True)
    return df


def export_pcd(df, area=False):
    if area:
        pcd = np.c_[df['x [mm]'].to_numpy(),
                    df['y [mm]'].to_numpy(),
                    df['z [mm]'].to_numpy(),
                    df['area [mm^2]'].to_numpy()]
    else:
        pcd = np.c_[df['x [mm]'].to_numpy(),
                    df['y [mm]'].to_numpy(),
                    df['z [mm]'].to_numpy()]
    return pcd


def export_fields(df):
    Ex = df['ExRe [V/m]'].to_numpy() + 1j * df['ExIm [V/m]'].to_numpy()
    Ey = df['EyRe [V/m]'].to_numpy() + 1j * df['EyIm [V/m]'].to_numpy()
    Ez = df['EzRe [V/m]'].to_numpy() + 1j * df['EzIm [V/m]'].to_numpy()
    Hx = df['HxRe [A/m]'].to_numpy() + 1j * df['HxIm [A/m]'].to_numpy()
    Hy = df['HyRe [A/m]'].to_numpy() + 1j * df['HyIm [A/m]'].to_numpy()
    Hz = df['HzRe [A/m]'].to_numpy() + 1j * df['HzIm [A/m]'].to_numpy()
    return (Ex, Ey, Ez), (Hx, Hy, Hz)


def poynting_vector(E, H):
    Sx = 0.5 * (E[1] * H[2].conjugate() - E[2] * H[1].conjugate())
    Sy = 0.5 * (E[2] * H[0].conjugate() - E[0] * H[2].conjugate())
    Sz = 0.5 * (E[0] * H[1].conjugate() - E[1] * H[0].conjugate())
    return Sx, Sy, Sz


def prettify_viz():
    import seaborn as sns
    sns.set(style='ticks', palette='colorblind')
    %config InlineBackend.figure_format = 'retina'


def plot_2d(xy_dict, figsize=plt.rcParams['figure.figsize'], c=None, alpha=1):
    fig = plt.figure(figsize=figsize)
    ax = plt.axes()
    keys = list(xy_dict.keys())
    values = list(xy_dict.values())
    if (len(values) == 3) and not(c):
        cs = ax.scatter(values[0], values[1], c=values[2], cmap='viridis')
        cbar = fig.colorbar(cs)
        cbar.ax.set_ylabel(keys[2])
    else:
        if not(c):
            c = 'k'
        cs = ax.scatter(values[0], values[1], c=c, alpha=alpha)
    ax.set(xlabel=keys[0], ylabel=keys[1])
    
    ax.axis('equal')
    fig.tight_layout()
    return fig, ax


def plot_3d(xyz_dict, figsize=(7, 7), elev=[20], azim=[45], c=None, alpha=1):
    from itertools import product
    num_figs = len(elev) * len(azim)
    if num_figs > 4:
        raise ValueError('The max number of subplots is 4.')
    if num_figs != 1:
        figsize = (figsize[0], figsize[1] * num_figs / 2)
    fig = plt.figure(figsize=figsize)
    keys = list(xyz_dict.keys())
    values = list(xyz_dict.values())
    for i, (e, a) in enumerate(product(elev, azim)):
        ax = fig.add_subplot(num_figs, 1, i+1, projection='3d')
        if (len(values) == 4) and not(c):
            cs = ax.scatter(values[0], values[1], values[2],
                            c=values[3], cmap='viridis')
            cbar = fig.colorbar(cs, shrink=0.5, pad=0.1)
            cbar.ax.set_ylabel(keys[3])
        else:
            if not(c):
                c = 'k'
            cs = ax.plot(values[0], values[1], values[2], '.', c=c, alpha=alpha)
        ax.set(xlabel=keys[0], ylabel=keys[1], zlabel=keys[2])
        ax = set_axes_equal(ax)
        ax.view_init(elev=e, azim=a)
    fig.tight_layout()
    return fig, ax


def minmax_scale(x, range_=(0, 1)):
    scaler = (x - x.min()) / (x.max() - x.min())
    x_scaled = scaler * (range_[1] - range_[0]) + range_([0])


def colormap_from_array(x, cmap='viridis', alpha=None, bytes=False):
    from matplotlib import cm
    x_scaled = minmax_scale(x)
    try:
        cs = eval(f'cm.{cmap}')(x_scaled, alpha, bytes)
    except Exception as e:
        print(e, 'Falling to default colormap')
        cs = cm.viridis(x_scaled, alpha, bytes)
    finally:
        if alpha is None:
            cs = cs[:, :3]
    return cs


def estimate_normals(xyz, take_every=1, knn=30, fast=True):
    import open3d as o3d
    xyz = xyz[::take_every, :]
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(xyz)
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn),
                         fast_normal_computation=fast)
    pcd.normalize_normals()
    n = np.asarray(pcd.normals)
    return n

In [None]:
prettify_viz()

In [None]:
PROJECT_NAME = 'IMBioC2022_paper'

In [None]:
df = load_ear_data('te', 26)
df = clean_df(df)
df

In [None]:
xyz = export_pcd(df)
E, H = export_fields(df)

In [None]:
Sx, Sy, Sz = poynting_vector(E, H)

In [None]:
S_dist = np.sqrt(Sx.real ** 2 + Sy.real ** 2 + Sz.real ** 2)
fig, ax = plot_3d({'z [mm]': xyz[:, 2],
                   'x [mm]': xyz[:, 0],
                   'y [mm]': xyz[:, 1],
                   'Re[S] [W/m2]': S_dist},
                  elev=[15], azim=[120])

In [None]:
n = estimate_normals(xyz, knn=300, fast=True)
APD = Sx.real * n[:, 0] + Sy.real * n[:, 1] + Sz.real * n[:, 2]
fig, ax = plot_3d({'z [mm]': xyz[:, 2],
                   'x [mm]': xyz[:, 0],
                   'y [mm]': xyz[:, 1],
                   'Re[S] [W/m2]': abs(APD)},
                  elev=[15], azim=[200])

In [None]:
import open3d as o3d

In [None]:
# define coordinate frame
cframe = o3d.geometry.TriangleMesh.create_coordinate_frame(size=10, origin=[0] * 3)

In [None]:
# downscale data uniformly
skip = 10
xyz_ds = xyz[::skip, :]
pcd_ds = o3d.geometry.PointCloud()
pcd_ds.points = o3d.utility.Vector3dVector(xyz_ds)
pcd_ds.paint_uniform_color([0.5, 0.5, 0.5]);

In [None]:
o3d.visualization.draw_geometries([pcd_ds, cframe])

In [None]:
# translate data to have a center at (0, 0, 0)
center = pcd_ds.get_center()
xyz_ds_t = np.c_[xyz_ds[:, 0] - center[0],
                 xyz_ds[:, 1] - center[1],
                 xyz_ds[:, 2] - center[2]]
pcd_ds_t = o3d.geometry.PointCloud()
pcd_ds_t.points = o3d.utility.Vector3dVector(xyz_ds_t)
pcd_ds_t.paint_uniform_color([0.5, 0.5, 0.5]);

In [None]:
o3d.visualization.draw_geometries([pcd_ds_t, cframe])

In [None]:
# select x-visible indices
diameter = np.linalg.norm(pcd_ds_t.get_max_bound() - pcd_ds_t.get_min_bound())
radius = 10 ** 3.8
camera = [0, 0, -diameter]

_, pt_map = pcd_ds_t.hidden_point_removal(camera, radius)
pcd_ds_t_visible = pcd_ds_t.select_by_index(pt_map)

In [None]:
o3d.visualization.draw_geometries([pcd_ds_t_visible, cframe])