In [1]:
import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import accuracy_score
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=True)

In [2]:
results = np.load("../data/masking_results_pred.npy", allow_pickle=True)[()]
results.keys()

dict_keys(['Synthetic', 'FingerMovements', 'Epilepsy', 'WISDM', 'HAR', 'PenDigits'])

In [3]:
data = []
for dataset, experiment in results.items():
    for i, masked_percent in enumerate(experiment["masked_percentage"]):
        shap_acc = np.array([
            accuracy_score(experiment["labels"], shap_preds)
            for shap_preds in experiment["shap_preds"][i]
        ])
        # print(dataset, masked_percent, shap_acc)

        random_acc = np.array([
            accuracy_score(experiment["labels"], random_preds)
            for random_preds in experiment["random_preds"][i]
        ])

        data.append((
            dataset, masked_percent,
            shap_acc.mean(), shap_acc.std(), 
            random_acc.mean(), random_acc.std()
        ))

df = pd.DataFrame(
    data, columns=[
        "Dataset", "MaskedPercentage",
        "ShapAccMean", "ShapAccStd",
        "RandomAccMean", "RandomAccStd",
    ],
)
df.sample(n=5)

Unnamed: 0,Dataset,MaskedPercentage,ShapAccMean,ShapAccStd,RandomAccMean,RandomAccStd
20,Epilepsy,0.0,0.976522,1.110223e-16,0.976522,1.110223e-16
33,WISDM,0.3,0.845747,0.004917805,0.854294,0.008168448
46,HAR,0.6,0.866192,0.002719334,0.854519,0.003549112
37,WISDM,0.7,0.792837,0.002508919,0.78429,0.01308745
41,HAR,0.1,0.88395,1.110223e-16,0.881258,0.002352458


In [4]:
# First melt the accuracy means
melted_acc = df.melt(
    id_vars=["Dataset", "MaskedPercentage"],
    value_vars=["ShapAccMean", "RandomAccMean"],
    var_name="Method",
    value_name="Accuracy"
)

# Then melt the standard deviations
melted_std = df.melt(
    id_vars=["Dataset", "MaskedPercentage"],
    value_vars=["ShapAccStd", "RandomAccStd"],
    var_name="Method",
    value_name="Std"
)

# Clean up method names to match before merging
melted_acc["Method"] = melted_acc["Method"].str.replace("AccMean", "")
melted_std["Method"] = melted_std["Method"].str.replace("AccStd", "")

# Merge the two melted DataFrames
melted_df = pd.merge(
    melted_acc,
    melted_std,
    on=["Dataset", "MaskedPercentage", "Method"]
)
melted_df.sample(n=5)

Unnamed: 0,Dataset,MaskedPercentage,Method,Accuracy,Std
72,FingerMovements,0.2,Random,0.492667,0.02048306
103,HAR,0.3,Random,0.87558,0.001961722
78,FingerMovements,0.8,Random,0.501333,0.03667273
40,HAR,0.0,Shap,0.880556,1.110223e-16
91,WISDM,0.1,Random,0.865608,0.00433336


In [5]:

fig = px.line(
    melted_df,
    x="MaskedPercentage",
    y="Accuracy",
    color="Dataset",
    line_dash="Method",
    error_y="Std",
    labels={
        "MaskedPercentage": "Masked Percentage",
        "Accuracy": "Accuracy",
        "Method": "Masking Method",
    },
    markers=True,
)
fig.update_traces(hovertemplate="%{y:.4f}")
fig.update_layout(hovermode="x")

fig.show()

In [6]:
from sklearn.metrics import auc

# Calculate AUC for each (Dataset, Method) pair
auc_results = (
    melted_df
    .sort_values(["Dataset", "Method", "MaskedPercentage"])
    .groupby(["Dataset", "Method"])
    .apply(lambda g: auc(g["MaskedPercentage"], g["Accuracy"]), include_groups=False)
    .reset_index(name="AUC")
)

auc_results

Unnamed: 0,Dataset,Method,AUC
0,Epilepsy,Random,0.873817
1,Epilepsy,Shap,0.877826
2,FingerMovements,Random,0.4526
3,FingerMovements,Shap,0.445
4,HAR,Random,0.771099
5,HAR,Shap,0.781461
6,PenDigits,Random,0.864714
7,PenDigits,Shap,0.872632
8,Synthetic,Random,0.695533
9,Synthetic,Shap,0.768667


In [7]:
data = []
for dataset, experiment in results.items():
    for i, masked_percent in enumerate(experiment["masked_percentage"]):
        diversity = np.array(experiment["random_diversity"][i]).T
        preds = np.array(experiment["random_preds"][i]).T
        sorted_diversity = np.argsort(diversity, axis=1)
        for j in range(sorted_diversity.shape[1]):
            data.append((
                dataset,
                masked_percent,
                diversity[sorted_diversity == j].mean(),
                diversity[sorted_diversity == j].std(),
                accuracy_score(experiment["labels"], preds[sorted_diversity == j]),
            ))

df = pd.DataFrame(data, columns=["Dataset", "MaskedPercent", "DiversityMean", "DiversityStd", "Acc"])
df["MaskedPercent"] = df["MaskedPercent"].round(2)
df.sample(n=5)

Unnamed: 0,Dataset,MaskedPercent,DiversityMean,DiversityStd,Acc
286,FingerMovements,0.9,1.0,0.0,0.57
100,Synthetic,0.6,0.856007,0.011186,0.715
483,WISDM,0.2,0.933341,0.04647,0.863248
762,PenDigits,0.0,0.745605,0.026693,0.974557
509,WISDM,0.3,0.933734,0.046432,0.846154


In [8]:
fig = go.FigureWidget()
for (dataset, masked), rest in df.groupby(["Dataset", "MaskedPercent"]):
    fig.add_scatter(
        y=rest["Acc"],
        x=rest["DiversityMean"],
        error_x=dict(array=rest["DiversityStd"], type="data"),
        name=f"{dataset}:{masked:.02f}",
        mode="markers",
        visible=False,
    )
fig.update_layout({
    "margin": dict(l=0, r=0, t=0, b=0),
    "xaxis_title": "Diversity (Mean)",
    "yaxis_title": "Accuracy"
})

@widgets.interact(dataset=df["Dataset"].unique(), masked=df["MaskedPercent"].unique())
def update_graph(dataset, masked):
    with fig.batch_update():
        for i in range(len(fig.data)):
            fig.data[i].visible = fig.data[i].name == f"{dataset}:{masked:.02f}"

fig

interactive(children=(Dropdown(description='dataset', options=('Synthetic', 'FingerMovements', 'Epilepsy', 'WI…

FigureWidget({
    'data': [{'error_x': {'array': array([0.09269391, 0.09269391, 0.09269391, 0.09269391, 0.09269391, 0.09269391,
                                          0.09269391, 0.09269391, 0.09269391, 0.09269391, 0.09269391, 0.09269391,
                                          0.09269391, 0.09269391, 0.09269391]),
                          'type': 'data'},
              'mode': 'markers',
              'name': 'Epilepsy:0.00',
              'type': 'scatter',
              'uid': '1fa926c6-0b48-47cb-a308-97bb4dce1235',
              'visible': False,
              'x': array([0.92510167, 0.92510167, 0.92510167, 0.92510167, 0.92510167, 0.92510167,
                          0.92510167, 0.92510167, 0.92510167, 0.92510167, 0.92510167, 0.92510167,
                          0.92510167, 0.92510167, 0.92510167]),
              'y': array([0.97652174, 0.97652174, 0.97652174, 0.97652174, 0.97652174, 0.97652174,
                          0.97652174, 0.97652174, 0.97652174, 0.97652174, 0.97

In [9]:
corr_df = df.groupby(["Dataset", "MaskedPercent"])[["DiversityMean", "Acc"]].corr()

fig = go.FigureWidget(
    layout=dict(
        margin=dict(l=0, r=0, t=20, b=10),
        xaxis=dict(title="Variables", tickvals=[0, 1], ticktext=["Diversity", "Accuracy"]),
        yaxis=dict(title="Variables", tickvals=[1, 0], ticktext=["Diversity", "Accuracy"]),
    )
)

fig.add_heatmap(
    z=np.eye(2),
    colorscale="Viridis",
    showscale=True,
    colorbar=dict(title="Correlation"),
    texttemplate="%{z:.4f}",
    textfont={"size":12}
)

@widgets.interact(dataset=df["Dataset"].unique(), masked=np.arange(0.1, 1.0, step=0.1).round(2))
def update_graph(dataset, masked):
    with fig.batch_update():
        z = (corr_df.loc[(dataset, masked)].to_numpy())[::-1]
        fig.data[0].z = z

fig

interactive(children=(Dropdown(description='dataset', options=('Synthetic', 'FingerMovements', 'Epilepsy', 'WI…

FigureWidget({
    'data': [{'colorbar': {'title': {'text': 'Correlation'}},
              'colorscale': [[0.0, '#440154'], [0.1111111111111111, '#482878'],
                             [0.2222222222222222, '#3e4989'], [0.3333333333333333,
                             '#31688e'], [0.4444444444444444, '#26828e'],
                             [0.5555555555555556, '#1f9e89'], [0.6666666666666666,
                             '#35b779'], [0.7777777777777778, '#6ece58'],
                             [0.8888888888888888, '#b5de2b'], [1.0, '#fde725']],
              'showscale': True,
              'textfont': {'size': 12},
              'texttemplate': '%{z:.4f}',
              'type': 'heatmap',
              'uid': '6517e8ee-50b8-4196-b27f-e2bf9be1d4f0',
              'z': array([[0.5469662, 1.       ],
                          [1.       , 0.5469662]])}],
    'layout': {'margin': {'b': 10, 'l': 0, 'r': 0, 't': 20},
               'template': '...',
               'xaxis': {'ticktext': ['D