In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
import plotly.express as ex
import bz2
import torch
import _pickle as pickle
import scipy.stats as stats
import scipy
from plotly.subplots import make_subplots
import plotly.io as pio
import sys, os

# Global Variables

In [2]:
RESULT_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results"
REF_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/REF"
TRANS_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/TRANS"
REDUC_SAVE_PATH = "F:/Documents/MEX/Deep-learning-based-rig-agnostic-encoding/results/organized/REDUC"
METRICS = ["recon_error","adv_error", "rot_error", "delta_rot"]


# Functions

In [32]:
def pJoin(p1, p2):
    return os.path.join(p1,p2)

def load(file_path:str):
    with bz2.BZ2File(file_path, "rb") as f:
        obj = pickle.load(f)
    return obj
def save(file:object, file_path:str):
    with bz2.BZ2File(file_path, "w") as f:
        pickle.dump(file, f)
def mean_confidence_interval(data, confidence=0.95):
    data = np.nan_to_num(data, 0)
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def plot_bar(results:dict, title="", x_axis_name=None, y_axis_name=None, 
                    width=500, height=300):

    fig = ex.bar(results, x="Models", y="Percentage", color="Metrics", opacity=0.8, text="Percentage", template="seaborn")

    fig.update_layout(
        title_text=title,
        width = width, height=height, 
        font_family="Serif", font_size=14, 
        margin_l=5, margin_t=40, margin_b=5, margin_r=5)

    if x_axis_name is not None:
        fig.update_xaxes(
            title_text=x_axis_name
        )
    if y_axis_name is not None:
        fig.update_yaxes(
            title_text=y_axis_name
        )
    return fig

def save_plot(plot:object, file_path="figure.svg"):
    plot.write_image(file_path)

def calc_diff(*results, multiplier=100):
    if len(results) < 2:
        raise ValueError("Need at least 2 arrays for comparison")
    if (len(results[0])) < 2:
        raise ValueError("The result array should contain results from each rig")

    comp_fn = lambda base,target: np.asarray([1 / (y / x) * multiplier for x,y in zip(base, target)], np.float64) 
    base = results[0]
    comparison = {"comp":[], "pval":[]}
    for target in results[1:]:
        comp_avg_over_rig = np.round(np.mean([comp_fn(b,t) for b,t in zip(base, target)]), decimals=2)
        pval_avg_over_rig = np.mean([stats.kruskal(t,b)[1] for b,t in zip(base, target)])
        comparison["comp"].append(comp_avg_over_rig)
        comparison["pval"].append(pval_avg_over_rig)
    return comparison




# Generation performance

## MoE vs LSTM

In [36]:
file_names = [] 
for _, _, names in os.walk(REF_SAVE_PATH):
    file_names = names

ae_moe = load(pJoin(REF_SAVE_PATH, file_names[1]))
ae_lstm = load(pJoin(REF_SAVE_PATH, file_names[0]))

rbf_moe =  load(pJoin(REF_SAVE_PATH, file_names[2]))
rbf_lstm =  load(pJoin(REF_SAVE_PATH, file_names[3]))

In [39]:
print(file_names)

['AE+LSTM.pbz2', 'AE+MoE.pbz2', 'DEC-CAT+MoE.pbz2', 'DEC-IN+MoE.pbz2', 'RBF-CAT+LSTM.pbz2', 'RBF-CAT+MoE.pbz2', 'RBF-IN+LSTM.pbz2', 'RBF-IN+MoE.pbz2', 'VAE-CAT+LSTM.pbz2', 'VAE-CAT+MoE.pbz2', 'VAE-IN+LSTM.pbz2', 'VAE-IN+MoE.pbz2']


In [37]:
df,df2 = {}, {}
for metric in METRICS:
    df[metric] = calc_diff(ae_moe[metric], ae_lstm[metric])
    df2[metric] = calc_diff(rbf_moe[metric], rbf_lstm[metric])


In [38]:

df_formatted = dict(
    Models=["AE+LSTM"]*4 + ["RBF+LSTM"]*4, 
    Percentage=[df[metric]["comp"][0] for metric in METRICS] + [df2[metric]["comp"][0] for metric in METRICS],
     Metrics=METRICS*2)
fig = plot_bar(results=df_formatted, title="Generation perf of LSTM compared to MoE")
fig.show()