In [None]:
import argparse
import pickle
import gzip
import os
import torch

import numpy as np
import matplotlib.pyplot as plt

from src.utils import LABEL_MAPPING
from src.setup import setup
from src.utils import preprocess, get_beat_spans

In [None]:
ATTR_FIGSIZE = (20, 5)
ECG_COLOR = "darkblue"
ECG_LW = 3
ATTR_COLOR = "crimson"
ATTR_ALPHA = 0.55
ATTR_LW = 4


def get_plot_range(min_value, max_value, coff=1):
    baseline_value = (min_value + max_value) / 2
    amplitude = max_value - baseline_value
    plot_range = (baseline_value - amplitude * coff, baseline_value + amplitude * coff)
    return plot_range
    

def plot_attribution(x, y, beat_spans, prob, attr_x, dataset, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)
    ax2 = ax1.twinx()

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    # Attribution
    if attr_x is not None:
        max_abs_attr = np.max(np.abs(attr_x))
        attr_yrange = (-max_abs_attr * 1.55, max_abs_attr * 1.55)
        ax2.set_ylim(*attr_yrange)
        ax2.plot(attr_x.squeeze(), c=ATTR_COLOR, alpha=ATTR_ALPHA, linewidth=ATTR_LW)
        # ax2.set_ylabel("Attribution value", color=ATTR_COLOR)
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)

    # Beat spans
    if beat_spans is not None:
        for class_idx, class_span in beat_spans.items():
            for span in class_span:
                if class_idx == 1:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(221/256,238/256,254/256))
                elif class_idx == 2:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(233/256,249/256,220/256))
                ax1.axvline(span[0], alpha=0.5, c="grey", linestyle="--")
        
    label = label_index[y]
    plt.title(f"Label: {label}, Prob: {prob:.6f}")

    ax1.set_zorder(ax2.get_zorder() + 1)
    ax1.set_frame_on(False)
    ax1.margins(x=0)
    ax2.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()

### 1. plot ECG recording from raw data (including normal ecg)

In [None]:
DATASET_NAME = "mitdb" # mitdb, svdb, incartdb만 가능
DATA_PATH = f"dataset/data/{DATASET_NAME}.pkl"

In [None]:
data_dict = pickle.load(gzip.GzipFile(DATA_PATH, "rb"))

In [None]:
train_set, test_set = data_dict["train"], data_dict["test"]
x_test = np.expand_dims(preprocess(test_set["X"]), axis=(1, 2))
y_test, y_raw_test = np.array([Y["y"] for Y in test_set["Y"]]), [
    Y["y_raw"] for Y in test_set["Y"]
]

In [None]:
for idx in range(len(y_raw_test)):
    y_raw = y_raw_test[idx]
    if len(y_raw[1]) > 0 and len(y_raw[2]) > 0:
        # if len(y_raw[1]) > 0:
        print(idx)

In [None]:
idx = 2011

In [None]:
x, y = x_test[idx], y_test[idx]
beat_spans = get_beat_spans(y_raw_test[idx], x.shape[-1], DATASET_NAME)

In [None]:
plot_attribution(x, y, beat_spans, 0, None, DATASET_NAME, f"{DATASET_NAME}_test_{idx}.png")

### 2. plot attribution using pre-processed dataset

In [None]:
DATASET_NAME = "mitdb" # mitdb, svdb, incartdb만 가능
attr_list = ["random_baseline", "saliency", "input_gradient", "guided_backprop", "integrated_gradients", "deep_lift", "deep_shap", "lrp", "lime", "kernel_shap", "gradcam", "guided_gradcam"]
attr_method = "integrated_gradients"
absolute = True

In [None]:
args = argparse.Namespace()
args.attr_dir = f"results_final_231123/results_attribution/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/{attr_method}"
args.model_path = "results_final_231123/results_training/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/model_last.pt"
args.absolute = absolute
args.gpu_num = 1
args.seed = 1
args.result_dir = "./figures_plot_attribution"

In [None]:
os.makedirs(args.result_dir, exist_ok=True)
device = setup(args)
model = torch.load(args.model_path, map_location=device)

# load eval_attr_data & feature attribution
eval_attr_data = pickle.load(gzip.GzipFile(f"{args.attr_dir}/eval_attr_data.pkl", "rb"))
attr_list = pickle.load(gzip.GzipFile(f"{args.attr_dir}/attr_list.pkl", "rb"))

In [None]:
idx = 126

In [None]:
x, y, beat_spans, prob = (
    eval_attr_data["x"][idx],
    eval_attr_data["y"][idx],
    eval_attr_data["beat_spans"][idx],
    eval_attr_data["prob"][idx],
)
attr_x = attr_list[idx]
if args.absolute:
    attr_x = np.absolute(attr_x)

In [None]:
if args.absolute:
    save_filename = f"{args.result_dir}/mitdb_eval_attr_{attr_method}_{idx}_abs.png"
else:
    save_filename = f"{args.result_dir}/mitdb_eval_attr_{attr_method}_{idx}.png"
    
plot_attribution(x, y, beat_spans, prob, attr_x, DATASET_NAME, save_filename)
# plot_attribution(x, y, beat_spans, prob, None, DATASET_NAME, f"{args.result_dir}/mitdb_eval_attr_{idx}.png")

In [None]:
save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}.png"
plot_attribution(x, y, None, 0, None, DATASET_NAME, save_filename)