In [2]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
import plotly.express as ex
import bz2
import torch
import _pickle as pickle
import scipy.stats as stats
import scipy
from plotly.subplots import make_subplots
import plotly.io as pio
import sys, os
import matplotlib.pyplot as plt
import seaborn as sb
import kaleido

# Global definitions

In [3]:
RESULT_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results"
AE_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/AE"
REF_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/REF"
TRANS_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/TRANS"

# Function definitions

In [4]:
def load(file_path:str):
    with bz2.BZ2File(file_path, "rb") as f:
        obj = pickle.load(f)
    return obj
def save(file:object, file_path:str):
    with bz2.BZ2File(file_path, "w") as f:
        pickle.dump(file, f)
def mean_confidence_interval(data, confidence=0.95):
    data = np.nan_to_num(data, 0)
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

# MoE vs LSTM
## AE  

In [5]:
def plot_test(results:dict):
    """
    @param:
        results : dict (
            x : array_like
            y : [y1, y2, y3 ...]
        )
    """
    fig = go.Figure()
    swatches = ex.colors.qualitative.Vivid
    for i,y in enumerate(results["y"]):
        fig.add_trace(go.Bar(x=results["x"], y=[res[0] for res in y],
            marker_color=swatches[i],
            error_y=dict(type="data", array=[res[2] for res in y], arrayminus=[res[1] for res in y])))
    return fig

def plot_stapled_bar(results:dict, title="", x_axis_name=None, y_axis_name=None, 
                    width=500, height=300, annotation_text="baseline", annotation_position="bottom right"):

    fig = ex.bar(results, x="Models", y="Percentage", color="Metrics", opacity=0.8, text="Percentage", template="seaborn")

    fig.update_layout(
        title_text=title,
        width = width, height=height, 
        font_family="Serif", font_size=14, 
        margin_l=5, margin_t=40, margin_b=5, margin_r=5)

    if x_axis_name is not None:
        fig.update_xaxes(
            title_text=x_axis_name
        )
    if y_axis_name is not None:
        fig.update_yaxes(
            title_text=y_axis_name
        )
    return fig

def save_plot(plot:object, file_path="figure.svg"):
    plot.write_image(file_path)

In [6]:
ref_results = load(os.path.join(RESULT_PATH,"reference_results.pbz2"))

In [7]:
ref_ae_moe = [rig["AE+MoE"] for rig in ref_results]
ref_ae_lstm = [rig["AE+LSTM"] for rig in ref_results]
ref_rbf_moe = [rig["RBF-CAT+MoE"] for rig in ref_results]
ref_rbf_lstm = [rig["RBF-CAT+LSTM"] for rig in ref_results]

In [8]:
ref_moe_res = dict(recon=[res["recon_error"] for res in ref_ae_moe], adv=[res["adv_error"] for res in ref_ae_moe], 
rot=[res["rot_error"] for res in ref_ae_moe], deltaRot=[res["delta_rot"] for res in ref_ae_moe])
ref_lstm_res = dict(recon=[res["recon_error"] for res in ref_ae_lstm], adv=[res["adv_error"] for res in ref_ae_lstm], 
rot=[res["rot_error"] for res in ref_ae_lstm], deltaRot=[res["delta_rot"] for res in ref_ae_lstm])

rbf_moe_res = dict(recon=[res["recon_error"] for res in ref_rbf_moe], adv=[res["adv_error"] for res in ref_rbf_moe], 
rot=[res["rot_error"] for res in ref_rbf_moe], deltaRot=[res["delta_rot"] for res in ref_rbf_moe])
rbf_lstm_res = dict(recon=[res["recon_error"] for res in ref_rbf_lstm], adv=[res["adv_error"] for res in ref_rbf_lstm], 
rot=[res["rot_error"] for res in ref_rbf_lstm], deltaRot=[res["delta_rot"] for res in ref_rbf_lstm])

# ref_ae_moe_recon = [res["recon_error"] for res in ref_ae_moe]
# ref_ae_moe_recon_mean = [mean_confidence_interval(rig_res) for rig_res in ref_ae_moe_recon]

# ref_ae_lstm_recon = [res["recon_error"] for res in ref_ae_lstm]
# ref_ae_lstm_recon_mean = [mean_confidence_interval(rig_res) for rig_res in ref_ae_lstm_recon]

In [9]:
def compute_RMS(delta_rot:list):
    rms = lambda x: np.sqrt(np.mean(x**2))
    samples = []
    for (x,y) in delta_rot:
        rms_x = rms(np.asarray(x, np.float64))
        rms_y = rms(np.asarray(y, np.float64))
        samples.append(rms_x + rms_y)
    return samples 

In [10]:
for i in range(len(ref_moe_res["deltaRot"])):
    ref_moe_res["deltaRot"][i] = compute_RMS(ref_moe_res["deltaRot"][i])

for i in range(len(ref_lstm_res["deltaRot"])):
    ref_lstm_res["deltaRot"][i] = compute_RMS(ref_lstm_res["deltaRot"][i])

for i in range(len(rbf_moe_res["deltaRot"])):
    rbf_moe_res["deltaRot"][i] = compute_RMS(rbf_moe_res["deltaRot"][i])

for i in range(len(rbf_lstm_res["deltaRot"])):
    rbf_lstm_res["deltaRot"][i] = compute_RMS(rbf_lstm_res["deltaRot"][i])

In [11]:
comp = lambda x,y : np.asarray([1 / (yi / xi) * 100 for xi,yi in zip(x,y)], np.float64)
l1 = lambda key: np.round(np.mean([comp(m1,m2) for m1,m2 in zip(ref_moe_res[key],ref_lstm_res[key])]), decimals=2)
l2 = lambda key: np.round(np.mean([comp(m1,m2) for m1,m2 in zip(rbf_moe_res[key],rbf_lstm_res[key])]), decimals=2)

y1 = [l1("recon"), l1("adv") , l1("rot"), l1("deltaRot")]
y2 = [l2("recon"), l2("adv") , l2("rot"), l2("deltaRot")]




In [12]:
swatches = ex.colors.qualitative.Vivid
# res = {"Models":["AE"]*4 + ["RBF"]*4, "%":y1+y2, "Metrics":["Recon", "ADV", "Rot", "DeltaRot"]*2}
res = {"Models":["AE"]*4 + ["RBF"]*4, "":y1+y2, "Metrics":["Recon", "ADV", "Rot", "DeltaRot"]*2}

# fig = ex.bar(res, x="Models", y="%", color="Metrics")

# res = {"Models":["AE"]*4 + ["RBF"]*4, "%":y1+y2, "Metrics":["Recon", "ADV", "Rot", "DeltaRot"]*2}
# res = {"Models":["AE"]*4 , "%":y1, "Metrics":["Recon", "ADV", "Rot", "DeltaRot"]}
# res2 = {"Models":["RBF"]*4 , "%":y2, "Metrics":["Recon", "ADV", "Rot", "DeltaRot"]}

fig = plot_stapled_bar(results=res, title="Generation perf of LSTM compared to MoE", annotation_text="", annotation_position="bottom") 
# fig = sb.histplot(x="Models", data=res, hue="Metrics" )
# fig2 = sb.histplot(x="Models", data=res2, hue="Metrics")
fig.show()

In [13]:
save_plot(fig)

In [17]:
# Approach 1 - flattened

# Approach 2 - compute per rig then average

comp_pval = lambda x,y : stats.kruskal(x,y)[1]
l1 = lambda key: np.mean([comp_pval(m1,m2) for m1,m2 in zip(ref_moe_res[key],ref_lstm_res[key])])
l2 = lambda key: np.mean([comp_pval(m1,m2) for m1,m2 in zip(rbf_moe_res[key],rbf_lstm_res[key])])

y1 = [l1("recon"), l1("adv") , l1("rot"), l1("deltaRot")]
y2 = [l2("recon"), l2("adv") , l2("rot"), l2("deltaRot")]


In [18]:
print(y1)

[5.455470728901324e-42, 0.09270252142344164, 1.5319303418694827e-32, 1.2241370411395197e-06]


In [35]:
fig = go.Figure()
swatches = ex.colors.qualitative.Vivid
x_vals = ["R"+str(i) for i in range(1,6)]
y1 = ref_ae_moe_recon
y2 = ref_ae_lstm_recon
y_vals_raw = [[(i22-i11)/i11 * 100 for i11, i22 in zip(i1,i2)] for i1,i2 in zip(y1,y2)]
y_vals_pval = [stats.kruskal(row, np.random.normal(0,5,len(row)))[1] for row in y_vals_raw]
y_vals = [mean_confidence_interval(row) for row in y_vals_raw]


In [127]:
params = ["recon_error", "adv_error", "pos_error", "rot_error", "delta_pos", "delta_rot"]
delta_pos = ref_ae_moe[0]["delta_pos"]
delta_rot = ref_ae_moe[0]["delta_rot"]
delta_rot2 = ref_ae_lstm[0]["delta_rot"]

In [161]:
y1 = [x+y for x,y in zip(delta_rot[100][0],delta_rot[100][1])]
y2 = [x+y for x,y in zip(delta_rot2[100][0],delta_rot2[100][1])]
yy1 = [y if y >= 0.01 else 0 for y in y1]
yy2 = [y if y >= 0.01 else 0 for y in y2]

fig = ex.line(x=np.arange(298), y=y1)
fig.add_trace(go.Scatter(x=np.arange(298), y=y2))
fig.add_trace(go.Scatter(x=np.arange(298), y=yy1))
fig.add_trace(go.Scatter(x=np.arange(298), y=yy2))
fig.add_hline(y=l1(yyy1), x0=0, x1=298)
fig.add_hline(y=l1(yyy2), x0=0, x1=298)
fig.show()

In [144]:
grad = np.gradient(np.gradient(yy1))
grad2 = np.gradient(np.gradient(yy2))
inflec = np.where(np.diff(np.sign(grad)))[0]
inflec2 = np.where(np.diff(np.sign(grad2)))[0]

fig=ex.line(y=grad)
fig.add_trace(go.Scatter(y=grad2))
print(len(inflec))
print(len(inflec2))
print((len(inflec2) - len(inflec)) / len(inflec))

51
73
0.43137254901960786


In [160]:
m1 = np.mean(y1)
m2 = np.mean(y2)
yyy1 = [y-m1 for y in yy1]
yyy2 = [y-m2 for y in yy2]
print(np.sum(np.abs(yyy1)), np.sum(np.abs(yyy2)))
print(np.sum(np.abs(y1)), np.sum(np.abs(y2)))
l1 = lambda x: np.sqrt(np.mean([i**2 for i in x]))
print(l1(yyy1), l1(yyy2))

11.427347792312503 24.84663025289774
11.232755 23.36922
0.04538840771473632 0.10520759829718838


135
135


In [115]:
N = 298
beta = 13
fft = np.fft.fft(y1)
fft2 = np.fft.fft(y2)

psd = 1/297 * np.abs(np.sum([(v * np.exp(-2j*np.pi*beta) / N) for v in fft]))
psd2 = 1/297 * np.abs(np.sum([(v * np.exp(-2j*np.pi*beta) / N) for v in fft2]))
print(psd, psd2)

0.00018197285506861777 6.811912814374731e-05


# Test Range

In [9]:
ae_results = load(os.path.join(RESULT_PATH,"ae_results.pbz2"))
ae_dfs = load(os.path.join(RESULT_PATH,"ae_dfs.pbz2"))


In [16]:
aeRes = ae_results[0]
print(len(ae_results))
print(aeRes.keys())
print((aeRes["name"]))
print((aeRes["params"]))
print((aeRes["mem"]))
print(len(aeRes["elapsed_times"]))
print(len(aeRes["recon_error"]))
print(len(aeRes["rot_error"]))
print(len(aeRes["delta_pos"]))

5
dict_keys(['name', 'params', 'mem', 'elapsed_times', 'recon_error', 'adv_error', 'pos_error', 'rot_error', 'delta_pos', 'delta_rot'])
AE_R1
455565
1.82226
192
192
192
192


In [27]:
ae_dfs[0]


Unnamed: 0,name,params,mem,time,recon_err,adv_err,rot_err,sum_delta_rot_x,sum_delta_rot_y
0,AE_R1,455565,1.82226,0.010796,0.264227,22.558304,8.531455,6.408974,4.966479


In [28]:
ref_results = load(os.path.join(RESULT_PATH,"reference_results.pbz2"))

In [60]:
aeRes = ref_results[0]
print(len(ref_results))
print(len(aeRes.keys()))
print(aeRes.keys())
aeMoe = aeRes["AE+MoE"]
print((aeMoe["name"]))
print((aeMoe["params"]))
print((aeMoe["mem"]))
print(len(aeMoe["elapsed_times"]))
print(len(aeMoe["recon_error"]))
print(len(aeMoe["rot_error"]))
print(len(aeMoe["delta_pos"]))
print(len(ref_dfs))
print(aeMoe["recon_error"][:10])

5
12
dict_keys(['AE+MoE', 'AE+LSTM', 'RBF-CAT+LSTM', 'RBF-IN+LSTM', 'RBF-CAT+MoE', 'RBF-IN+MoE', 'VAE-CAT+LSTM', 'VAE-IN+LSTM', 'VAE-CAT+MoE', 'VAE-IN+MoE', 'DEC-CAT+MoE', 'DEC-IN+MoE'])
AE_MoE_256_ZINF_R1
2946193
11.784772
192
192
192
192
5
[tensor(0.0111), tensor(0.0199), tensor(0.0137), tensor(0.0185), tensor(0.0084), tensor(0.0111), tensor(0.0202), tensor(0.0166), tensor(0.0112), tensor(0.0189)]


In [74]:
ref_mean = np.mean(ae_results[0]["recon_error"])
aeMoe = aeRes["AE+MoE"]
aeLstm = aeRes["AE+LSTM"]
errMoe = torch.stack(aeMoe["recon_error"]).numpy() 
errLstm = torch.stack(aeLstm["recon_error"]).numpy()

In [75]:
fig = ex.bar(x=["AE+MoE", "AE+LSTM"], y=[(np.mean(errMoe)/ref_mean)*100, (np.mean(errLstm)/ref_mean)*100])
fig.show()

In [79]:
stats.kruskal(errMoe, errLstm)

KruskalResult(statistic=277.0583096590908, pvalue=3.2858923007027716e-62)

In [76]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(192), y=ae_results[0]["recon_error"], name="AE"))
fig.add_trace(go.Scatter(x=np.arange(192), y=errMoe, name="MoE"))
fig.add_trace(go.Scatter(x=np.arange(192), y=errLstm, name="LSTM"))
fig.show()

In [32]:
trans_results = load(os.path.join(RESULT_PATH,"transfer_results.pbz2"))
trans_dfs = load(os.path.join(RESULT_PATH,"transfer_dfs.pbz2"))


In [56]:
aeRes = trans_results[0]
print(len(trans_results))
print(len(aeRes.keys()))
print(aeRes.keys())
aeMoe = aeRes["AE+MoE_RAW"]
print((aeMoe["name"]))
print((aeMoe["params"]))
print((aeMoe["mem"]))
print(len(aeMoe["elapsed_times"]))
print(len(aeMoe["recon_error"]))
print(len(aeMoe["rot_error"]))
print(len(aeMoe["delta_pos"]))
print(len(trans_dfs))

4
36
dict_keys(['AE+LSTM_RAW', 'AE+MoE_RAW', 'DEC-CAT+MoE_RAW', 'DEC-IN+MoE_RAW', 'RBF-IN+LSTM_RAW', 'RBF-CAT+LSTM_RAW', 'RBF-CAT+MoE_RAW', 'RBF-IN+MoE_RAW', 'VAE-IN+LSTM_RAW', 'VAE-CAT+LSTM_RAW', 'VAE-CAT+MoE_RAW', 'VAE-IN+MoE_RAW', 'AE+LSTM_F', 'AE+LSTM_T', 'AE+MoE_F', 'AE+MoE_T', 'DEC-CAT+MoE_F', 'DEC-CAT+MoE_T', 'DEC-IN+MoE_F', 'DEC-IN+MoE_T', 'RBF-CAT+LSTM_T', 'RBF-CAT+LSTM_F', 'RBF-IN+LSTM_F', 'RBF-IN+LSTM_T', 'RBF-CAT+MoE_F', 'RBF-CAT+MoE_T', 'RBF-IN+MoE_F', 'RBF-IN+MoE_T', 'VAE-CAT+LSTM_T', 'VAE-CAT+LSTM_F', 'VAE-IN+LSTM_F', 'VAE-IN+LSTM_T', 'VAE-CAT+MoE_F', 'VAE-CAT+MoE_T', 'VAE-IN+MoE_F', 'VAE-IN+MoE_T'])
AE_MoE_256_AE_0.10_RAW_F_R2_ZIN
2946193
11.784772
807
807
807
807
4


In [51]:
print(aeMoe["recon_error"][:10])

[tensor(0.0635), tensor(0.1379), tensor(0.0331), tensor(0.0451), tensor(0.0306), tensor(0.0651), tensor(0.1002), tensor(0.0375), tensor(0.0402), tensor(0.0471)]


In [52]:
reduc_results = load(os.path.join(RESULT_PATH,"transfer_results_reduc.pbz2"))


In [57]:
aeRes = reduc_results[0]
print(len(reduc_results))
print(len(aeRes.keys()))
print(aeRes.keys())
aeMoe = aeRes["AE+MoE_F"]
print((aeMoe["name"]))
print((aeMoe["params"]))
print((aeMoe["mem"]))
print(len(aeMoe["elapsed_times"]))
print(len(aeMoe["recon_error"]))
print(len(aeMoe["rot_error"]))
print(len(aeMoe["delta_pos"]))

4
8
dict_keys(['AE+MoE_T', 'AE+MoE_F', 'DEC-CAT+MoE_F', 'DEC-CAT+MoE_T', 'RBF-CAT+MoE_F', 'RBF-CAT+MoE_T', 'VAE-CAT+MoE_F', 'VAE-CAT+MoE_T'])
AE_MoE_256_AE_0.12_R1_to_F_R2_ZIN_reduced
2869393
11.477572
807
807
807
807
