In [131]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [132]:
import warnings
warnings.filterwarnings("ignore")

In [133]:
ALE_COLOR = 'rgba(30,144,255,1)'
PD_COLOR = 'rgba(220,20,60,1)'
ALE_COLOR_LIGHT = 'rgba(30,144,255,0.2)'
PD_COLOR_LIGHT = 'rgba(220,20,60,0.2)'

In [134]:
CSV_PATH = "results_final/all_rows.csv"
METRIC_NAMES = {
    "rho": "Spearman rho",
    "l1": "L1",
    "l2": "L2",
    "max_diff": "maximum distance (Linf)"    
}

In [135]:
df = pd.read_csv(CSV_PATH).drop(columns=['Unnamed: 0']).drop_duplicates()
df.head()

Unnamed: 0,name,path,variable,size,seed,lr,iter,constrain,dist_weight,ale_l2,ale_l1,ale_max_diff,ale_rho,pd_l2,pd_l1,pd_max_diff,pd_rho
0,heart,results_final/heart/age_512_0_gradient_0.1_50_...,age,512,0,0.1,50,False,0.01,0.0256,0.0256,0.067292,-0.859558,0.038559,0.038559,0.078188,-0.994805
1,heart,results_final/heart/age_256_0_gradient_0.1_50_...,age,256,0,0.1,50,False,0.01,0.020159,0.020159,0.043379,-0.842653,0.041103,0.041103,0.104379,-0.885714
2,heart,results_final/heart/age_128_0_gradient_0.1_50_...,age,128,0,0.1,50,False,0.01,0.018468,0.018468,0.044442,-0.43433,0.025585,0.025585,0.052349,-0.984416
3,heart,results_final/heart/age_64_0_gradient_0.1_50_0...,age,64,0,0.1,50,False,0.01,0.011359,0.011359,0.023311,-0.304291,0.011301,0.011301,0.028792,-0.976623
5,heart,results_final/heart/age_32_0_gradient_0.1_50_0...,age,32,0,0.1,50,False,0.01,0.024567,0.024567,0.087358,-0.890767,0.033288,0.033288,0.084878,-0.942857


In [136]:
def _base_plot(name, variable, metric="rho", y="size", y_desc="network size", logarithmic=True, conditions=[(df.lr == 0.1), (df.iter == 50), (df.dist_weight == 0.01)], title_suffix=""):
    conds = True
    for c in conditions:
        conds = c & conds

    selected_df = df[(df.name == name) & (df.variable == variable) & conds]
    means = selected_df.groupby(y).mean()
    means = means.reset_index()
    std = selected_df.groupby(y).std()
    std = std.reset_index()
    x = means[y]
    ale_upper = means[f"ale_{metric}"] + std[f"ale_{metric}"]
    ale_lower = means[f"ale_{metric}"] - std[f"ale_{metric}"]
    pd_upper = means[f"pd_{metric}"] + std[f"pd_{metric}"]
    pd_lower = means[f"pd_{metric}"] - std[f"pd_{metric}"]
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=means[f"ale_{metric}"],
                        mode='lines',
                        name=f'ALE {METRIC_NAMES[metric]}',
                        line=dict(color=ALE_COLOR)))

    fig.add_trace(go.Scatter(x=x, y=means[f"pd_{metric}"],
                        mode='lines',
                        name=f'PD {METRIC_NAMES[metric]}',
                        line=dict(color=PD_COLOR)))

    fig.add_trace(go.Scatter(
            x=pd.concat([x, x[::-1]]), # x, then x reversed
            y=pd.concat([ale_upper, ale_lower[::-1]]), # upper, then lower reversed
            fill='toself',
            fillcolor=ALE_COLOR_LIGHT,
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.add_trace(go.Scatter(
            x=pd.concat([x, x[::-1]]), # x, then x reversed
            y=pd.concat([pd_upper, pd_lower[::-1]]), # upper, then lower reversed
            fill='toself',
            fillcolor=PD_COLOR_LIGHT,
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))
    if logarithmic:
        fig.update_xaxes(type="log")

    fig.update_layout(
        title=f"{METRIC_NAMES[metric]} for {name}, variable {variable}{title_suffix}",
        xaxis_title=f"{y_desc}{' (logarithmic scale)' if logarithmic else ''}",
        yaxis_title=f"{metric} values",
        font=dict(
            family="Courier New, monospace",
            size=18,
            color="black"
        )
    )

    fig.show()
    

In [137]:
def size_plot(name, variable, dist_weight=0.01, metric="rho", logarithmic=True):
    conditions=[(df.lr == 0.1), (df.iter == 50), (df.dist_weight == dist_weight)]
    _base_plot(
        name=name,
        variable=variable,
        metric=metric,
        y="size",
        y_desc="network size",
        logarithmic=logarithmic,
        conditions=conditions,
        title_suffix=f", dist_weight={dist_weight}",
        )


In [138]:
def weight_plot(name, variable, sizes, metric="rho", logarithmic=True):
    conditions=[(df.lr == 0.1), (df.iter == 50), (df["size"].isin(sizes))]
    _base_plot(
        name=name,
        variable=variable,
        metric=metric,
        y="dist_weight",
        y_desc="weight of distribution distance loss",
        logarithmic=logarithmic,
        conditions=conditions,
        title_suffix=f"",
        )


In [139]:
for _, row in df[["name", "variable"]].drop_duplicates().iterrows():
    size_plot(row["name"], row.variable)

In [140]:
for _, row in df[["name", "variable"]].drop_duplicates().iterrows():
    size_plot(row["name"], row.variable, dist_weight=0)

In [141]:
for _, row in df[["name", "variable"]].drop_duplicates().iterrows():
    weight_plot(row["name"], row.variable, sizes=[32, 64, 128])