In [None]:
from interpretability_utilities import load_workspace_file

import numpy as np
import librosa

import librosa.display
import matplotlib.pyplot as plt

import torch
import mlflow

from captum.attr import Lime
from captum.attr import visualization as vis
from captum.attr._core.lime import get_exp_kernel_similarity_function
from captum._utils.models.linear_model import SkLearnLinearRegression

## Settings and utils

In [None]:
# Adjust according to your experiment
ref_fold = "8"
run_id = ""
tracking_server = ""
workspace_file = ""
dataset_dir = ""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

mlflow.set_tracking_uri(f"{tracking_server}:5000")
logged_model = mlflow.pytorch.load_model(f"runs:/{run_id}/models")
logged_model = logged_model.eval()

client = mlflow.MlflowClient()
run = client.get_run(run_id)
run_data = run.data
tags = run_data.tags

sr = int(tags["sample_rate"])
n_fft = int(tags["window_size"])
hop_size = int(tags["hop_size"])
window_size = int(tags["window_size"])

In [None]:
inp_data, indexes, labels, _, lb_to_idx, _ = load_workspace_file(workspace_file, ref_fold,
                                dataset_dir, device)

idx_to_label = {idx: label for label, idx in lb_to_idx.items()}
target = [idx for label, idx in lb_to_idx.items() if label.startswith("albilora")]

inp_data.requires_grad_()

In [None]:
rng = np.random.default_rng(135)

## LIME attribution

Select input

In [None]:
index = rng.integers(200, 350) # albilora
inp = inp_data[index]
index, idx_to_label[np.argmax(labels.cpu().detach().numpy()[index])]

Configure LIME

In [None]:
exp_eucl_distance = get_exp_kernel_similarity_function('euclidean', kernel_width=1000)

lr_lime = Lime(
    logged_model,
    interpretable_model=SkLearnLinearRegression(),
    similarity_func=exp_eucl_distance
)

In [None]:
lr_attrs = lr_lime.attribute(
    inp_data[200:350],
    target=int(np.argmax(labels.cpu().detach().numpy()[200:350])),
    perturbations_per_eval=4,
    show_progress=False
)

In [None]:
lr_attrs, lr_attrs.size()

In [None]:
librosa.display.waveplot(inp.cpu().detach().numpy(), sr=sr);

In [None]:
vis.visualize_timeseries_attr(lr_attrs[0].reshape(1, inp.size()[0]).cpu().detach().numpy(),
                              inp.reshape(1, inp.size()[0]).cpu().detach().numpy(),
                              channels_last=False,
                              method="colored_graph",
                              sign="all");

In [None]:
lr_attrs[0], lr_attrs[0].size()

In [None]:
_, ax = plt.subplots(nrows=2, ncols=1, sharex=True)

spec = np.abs(
    librosa.stft(inp.cpu().detach().numpy(),
    n_fft=window_size, win_length=window_size, hop_length=hop_size, center=True)
)

normalized_attr = vis._normalize_attr(lr_attrs[21], "all", 2, reduction_axis=None)

attr_spec = np.abs(
    librosa.stft(
        normalized_attr.cpu().detach().numpy(),
        n_fft=window_size, win_length=window_size, hop_length=hop_size, center=True)
)

librosa.display.specshow(
    spec, x_axis='time', y_axis='linear', sr=sr, hop_length=hop_size, ax=ax[0]
);

librosa.display.specshow(
    attr_spec, x_axis='time', y_axis='linear', sr=sr, hop_length=hop_size, ax=ax[1]
);