In [1]:
import sys
from pathlib import Path
sys.path.append("..")

import ipywidgets as widgets
import numpy as np
import plotly.io as pio
import plotly.graph_objects as go
from plotly.offline import init_notebook_mode
from tqdm import tqdm
init_notebook_mode(connected=True)

from scripts.utils import load_finetuned_model_data, get_device
from src.linear_models import LinearClassifier
from src.masking import evaluate_masking_strategies, shapley_order
from src.utils.visualization import plot_shap_values, plot_masking_result

In [2]:
MODELS_DIR = Path("../models/run1/")
results = np.load("../data/shap_values_probs.npy", allow_pickle=True)[()]
results = {
    dataset: {"shap_values": [shap_result for shap_result in results[dataset]["shap_values"]]}
    for dataset in results
}
results.keys()

dict_keys(['Epilepsy', 'FingerMovements', 'WISDM', 'Synthetic', 'PenDigits', 'HAR'])

In [3]:
accuracy_results = {}
for dataset, data in tqdm(results.items(), desc="Computing evaluation", total=len(results)):
    model_data = load_finetuned_model_data(MODELS_DIR, dataset, get_device())
    data.update(model_data)

    classifier: LinearClassifier = model_data["classifier"]
    accuracy_results[dataset] = evaluate_masking_strategies(
        classifier,
        # data["train"]["labels"],
        data["test"]["timesteps"],
        data["test"]["labels"],
        data["shap_values"],
        metric="accuracy",
    )

Computing evaluation: 100%|██████████| 6/6 [00:59<00:00,  9.99s/it]


In [4]:
f1_results = {}
for dataset, data in tqdm(results.items(), desc="Computing evaluation", total=len(results)):
    f1_results[dataset] = evaluate_masking_strategies(
        data["classifier"],
        # data["train"]["labels"],
        data["test"]["timesteps"],
        data["test"]["labels"],
        data["shap_values"],
        metric="f1",
    )

Computing evaluation: 100%|██████████| 6/6 [00:54<00:00,  9.16s/it]


In [5]:
auc_ovo_results = {}
for dataset, data in tqdm(results.items(), desc="Computing evaluation", total=len(results)):
    auc_ovo_results[dataset] = evaluate_masking_strategies(
        data["classifier"],
        # data["train"]["labels"],
        data["test"]["timesteps"],
        data["test"]["labels"],
        data["shap_values"],
        metric="auc_ovo",
    )

Computing evaluation: 100%|██████████| 6/6 [01:07<00:00, 11.22s/it]


In [6]:
masked_percentages = np.linspace(0, 0.9, 10)

fig1 = plot_masking_result(
    masked_percentages,
    *[
        (f"{dataset},{method}", accuracy)
        for dataset, results in accuracy_results.items()
        for method, accuracy in results.items()
    ],
)
for trace in fig1.data:
    if "FingerMovements" in trace.name:
        trace.visible = False
fig1.update_layout(showlegend=False, yaxis_title="Accuracy")

fig2 = plot_masking_result(
    masked_percentages,
    *[
        (f"{dataset},{method}", metric)
        for dataset, results in f1_results.items()
        for method, metric in results.items()
    ],
)
for trace in fig2.data:
    if "FingerMovements" in trace.name:
        trace.visible = False
fig2.update_layout(showlegend=False, yaxis_title="F1-Score")

import plotly.graph_objects as go
import ipywidgets as widgets

widgets.HBox(
    layout=widgets.Layout(width="100%", justify_content="space-between"),
    children=[
        go.FigureWidget(fig1).update_layout(width=560),
        go.FigureWidget(fig2).update_layout(width=560),
    ]
)

HBox(children=(FigureWidget({
    'data': [{'hovertemplate': '%{y:.4f}',
              'legendgroup': 'Epileps…

In [7]:
fig3 = plot_masking_result(
    masked_percentages,
    *[
        (f"{dataset},{method}", metric)
        for dataset, results in auc_ovo_results.items()
        for method, metric in results.items()
    ],
)
for trace in fig3.data:
    if "FingerMovements" in trace.name:
        trace.visible = False
fig3.update_layout(showlegend=False, yaxis_title="AUC OVO")

In [None]:
from src.masking import masking_impact

def auc_ovr(classifier, test_timesteps, test_labels, shap_values):
    return np.array([
        masking_impact(
            classifier, test_timesteps, test_labels, 
            shapley_order(shap_values)[test_labels],
            masked_percentages, metric="auc_ovr"
        )
        for shap_values in shap_values
    ])

auc_ovr_results = {}
for dataset, data in tqdm(results.items(), desc="Computing evaluation", total=len(results)):
    auc_ovr_results[dataset] = auc_ovr(
        data["classifier"],
        data["test"]["timesteps"],
        data["test"]["labels"],
        data["shap_values"],
    )

Computing evaluation: 100%|██████████| 6/6 [00:12<00:00,  2.02s/it]


In [9]:
import plotly.graph_objects as go

fig = plot_masking_result(
    masked_percentages,
    *[
        (f"{dataset} - Class {i}", metric)
        for dataset, results in auc_ovr_results.items()
        for i, metric in enumerate(np.transpose(results, (2, 0, 1)))
    ],
)
fig = go.FigureWidget(fig)


import ipywidgets as widgets

@widgets.interact(dataset=list(auc_ovr_results.keys()))
def plot_dataset(dataset):
    with fig.batch_update():
        for trace in fig.data:
            trace.visible = dataset in trace.name

fig

interactive(children=(Dropdown(description='dataset', options=('Epilepsy', 'FingerMovements', 'WISDM', 'Synthe…

FigureWidget({
    'data': [{'hovertemplate': '%{y:.4f}',
              'legendgroup': 'Epilepsy - Class 0',
              'line': {'color': '#636EFA', 'dash': 'solid'},
              'marker': {'color': '#636EFA', 'size': 8},
              'mode': 'lines+markers',
              'name': 'Epilepsy - Class 0',
              'type': 'scatter',
              'uid': 'ff17f331-c357-4307-8118-57d271836c2a',
              'visible': True,
              'x': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
              'y': array([0.99629779, 0.99616771, 0.99607981, 0.99581964, 0.99565205, 0.99503443,
                          0.99484574, 0.99271278, 0.98819021, 0.98856055])},
             {'fill': 'toself',
              'fillcolor': 'rgba(99,110,250,0.15)',
              'hoverinfo': 'skip',
              'hovertemplate': '%{y:.4f}',
              'legendgroup': 'Epilepsy - Class 0',
              'line': {'color': 'rgba(255,255,255,0)'},
              'name': 'Epilepsy - Class 0',