In [1]:
import argparse
import pickle
import gzip
import os
import torch

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.offsetbox as offsetbox


from src.utils import LABEL_MAPPING
from src.setup import setup
from src.utils import preprocess, get_beat_spans

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ATTR_STR_DICT = {
    "random_baseline": "Random (baseline)",
    "saliency": "Saliency",
    "input_gradient": "Input × Gradient",
    "guided_backprop": "Guided Backprop",
    "integrated_gradients": "Integrated Gradients",
    "deep_lift": "DeepLIFT",
    "deep_shap": "DeepSHAP",
    "lrp": "LRP",
    "lime": "LIME",
    "kernel_shap": "KernelSHAP",
    "gradcam": "Grad-CAM",
    "guided_gradcam": "Guided Grad-CAM",
}

ATTR_FIGSIZE = (20, 5)

FILE_FORMAT = "svg"

In [3]:
def get_plot_range(min_value, max_value, coff=1):
    baseline_value = (min_value + max_value) / 2
    amplitude = max_value - baseline_value
    plot_range = (baseline_value - amplitude * coff, baseline_value + amplitude * coff)
    return plot_range

ECG_COLOR = "darkblue"
ECG_LW = 5

def plot_ecg(x, y, dataset, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    label = label_index[y]

    # ax1.set_frame_on(False)
    ax1.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()
    
ECG_ALPHA = 0.4
ATTR_COLOR = "crimson"
# ATTR_ALPHA = 0.55
ATTR_LW = 5.5

def plot_attribution(x, y, beat_spans, prob, attr_x, dataset, attr_method=None, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)
    ax2 = ax1.twinx()

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW, alpha=ECG_ALPHA)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    # Attribution
    if attr_x is not None:
        max_abs_attr = np.max(np.abs(attr_x))
        attr_yrange = (-max_abs_attr * 1.55, max_abs_attr * 1.55)
        ax2.set_ylim(*attr_yrange)
        ax2.plot(attr_x.squeeze(), c=ATTR_COLOR, linewidth=ATTR_LW)
        # ax2.set_ylabel("Attribution value", color=ATTR_COLOR)
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)

    # Beat spans
    if beat_spans is not None:
        for class_idx, class_span in beat_spans.items():
            for span in class_span:
                if class_idx == 1:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(221/256,238/256,254/256))
                elif class_idx == 2:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(233/256,249/256,220/256))
                ax1.axvline(span[0], alpha=0.5, c="grey", linestyle="--")

    label = label_index[y]

    ax1.set_zorder(ax2.get_zorder() + 1)
    ax1.set_frame_on(False)
    ax1.margins(x=0)
    ax2.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()
    


In [9]:
idx = 7

DATASET_NAME = "mitdb" # mitdb, svdb, incartdb만 가능
# attr_list = ["saliency", "input_gradient", "guided_backprop", "integrated_gradients", "deep_lift", "deep_shap", "lrp", "lime", "kernel_shap", "gradcam", "guided_gradcam"]
attr_method = "guided_gradcam"
absolute = True

args = argparse.Namespace()
args.attr_dir = f"results_final_231123/results_attribution/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/{attr_method}"
args.model_path = "results_final_231123/results_training/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/model_last.pt"
args.absolute = absolute
args.gpu_num = 1
args.seed = 1
args.result_dir = "./figure2"

os.makedirs(args.result_dir, exist_ok=True)
device = setup(args)
model = torch.load(args.model_path, map_location=device)

# load eval_attr_data & feature attribution
eval_attr_data = pickle.load(gzip.GzipFile(f"{args.attr_dir}/eval_attr_data.pkl", "rb"))
attr_list = pickle.load(gzip.GzipFile(f"{args.attr_dir}/attr_list.pkl", "rb"))

x, y, beat_spans, prob = (
    eval_attr_data["x"][idx],
    eval_attr_data["y"][idx],
    eval_attr_data["beat_spans"][idx],
    eval_attr_data["prob"][idx],
)
attr_x = attr_list[idx]
if args.absolute:
    attr_x = np.absolute(attr_x)

if args.absolute:
    save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_abs.{FILE_FORMAT}"
else:
    save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}.{FILE_FORMAT}"

if args.absolute:
    save_filename_beats = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_beats_abs.{FILE_FORMAT}"
else:
    save_filename_beats = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_beats.{FILE_FORMAT}"


-- Set random seed: 1
-- Use gpu: 1


In [None]:
plot_attribution(x, y, None, prob, attr_x, DATASET_NAME, attr_method, save_filename)
plot_attribution(x, y, beat_spans, prob, attr_x, DATASET_NAME, attr_method, save_filename_beats)

save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}.{FILE_FORMAT}"
plot_ecg(x, y, DATASET_NAME, save_filename)
print(f"label: {y}")
print(f"prob: {prob}")

### Perturb and visualize

patch size: 24

In [146]:

# ECG_ALPHA = 0.4
ATTR_COLOR = "crimson"
ATTR_ALPHA = 0.4
ATTR_LW = 5.5
PATCH_SIZE = 24

def plot_attribution_ecg_focus(x, y, beat_spans, prob, attr_x, dataset, perturb_point_indices, attr_method=None, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)
    ax2 = ax1.twinx()

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ####
    x_squeezed = x.squeeze()
    
    for point_idx in perturb_point_indices:
        x_squeezed[point_idx] = 0
    
    plot1_x = np.arange(len(x_squeezed))
    plot1_y = x_squeezed
    points = np.array([plot1_x,plot1_y]).T
    for start, stop in zip(points[:-1], points[1:]):
        _x, _y = zip(start, stop)
        if _x[0] in perturb_point_indices:
            ax1.plot(_x, _y, c='gold', linewidth=ECG_LW)
        else:
            ax1.plot(_x, _y, c=ECG_COLOR, linewidth=ECG_LW)
            
    
    # ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW)
    # ax1.plot(np.where(x_squeezed == 0, x_squeezed, None), c='gray', linewidth=ECG_LW)
    # ax1.plot(np.where(x_squeezed != 0, x_squeezed, None), c=ECG_COLOR, linewidth=ECG_LW)
    
    ####
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    # Attribution
    if attr_x is not None:
        max_abs_attr = np.max(np.abs(attr_x))
        attr_yrange = (-max_abs_attr * 1.55, max_abs_attr * 1.55)
        ax2.set_ylim(*attr_yrange)
        ax2.plot(attr_x.squeeze(), c=ATTR_COLOR, linewidth=ATTR_LW, alpha=ATTR_ALPHA)
        # ax2.set_ylabel("Attribution value", color=ATTR_COLOR)
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)

    # Beat spans
    if beat_spans is not None:
        for class_idx, class_span in beat_spans.items():
            for span in class_span:
                if class_idx == 1:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(221/256,238/256,254/256))
                elif class_idx == 2:
                    ax2.fill_between(np.arange(*span), attr_yrange[0], attr_yrange[1], color=(233/256,249/256,220/256))
                ax1.axvline(span[0], alpha=0.5, c="grey", linestyle="--")

    label = label_index[y]

    ax1.set_zorder(ax2.get_zorder() + 1)
    ax1.set_frame_on(False)
    ax1.margins(x=0)
    ax2.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()
    


In [147]:
x_squeezed = x.squeeze()
attr_x_squeezed = attr_x.squeeze()

attr_x_chunked = attr_x_squeezed.reshape(-1, 24)
attr_x_chunked_sum = attr_x_chunked.sum(axis=1)
patch_order_increasing = np.argsort(attr_x_chunked_sum)

num_patches_perturb = 30


In [148]:
# 1) region perturbation MoRF
x_1 = x.copy()
perturb_patch_indices = patch_order_increasing[::-1][:num_patches_perturb]
perturb_point_indices = []
for patch_idx in perturb_patch_indices:
    perturb_point_indices.extend(np.arange(patch_idx*PATCH_SIZE, (patch_idx+1)*PATCH_SIZE))
    
save_filename_p1 = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_perturb_morf.{FILE_FORMAT}"

plot_attribution_ecg_focus(x_1, y, None, prob, attr_x, DATASET_NAME, perturb_point_indices, attr_method, path=save_filename_p1)


In [149]:
# 2) region perturbation LeRF
x_2 = x.copy()
perturb_patch_indices = patch_order_increasing[:num_patches_perturb]
perturb_point_indices = []
for patch_idx in perturb_patch_indices:
    perturb_point_indices.extend(np.arange(patch_idx*PATCH_SIZE, (patch_idx+1)*PATCH_SIZE))

save_filename_p2 = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_perturb_lerf.{FILE_FORMAT}"

plot_attribution_ecg_focus(x_2, y, None, prob, attr_x, DATASET_NAME, perturb_point_indices, attr_method, path=save_filename_p2)


In [150]:
# 3) Faithfulness correlation (visualize를 위해 patch 단위로 perturb)
x_3 = x.copy()
perturb_patch_indices = np.random.choice(120, num_patches_perturb, replace=False)
perturb_point_indices = []
for patch_idx in perturb_patch_indices:
    perturb_point_indices.extend(np.arange(patch_idx*PATCH_SIZE, (patch_idx+1)*PATCH_SIZE))


# Original은 아래와 같음
# num_points_perturb = num_patches_perturb * patch_size
# perturb_indices = np.random.choice(len(x_squeezed), num_points_perturb, replace=False)
# x_3 = x.copy()
# for perturb_idx in perturb_indices:
#     x_3[:,:,perturb_idx] = 0

save_filename_p3 = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_perturb_fc.{FILE_FORMAT}"

plot_attribution_ecg_focus(x_3, y, None, prob, attr_x, DATASET_NAME, perturb_point_indices, attr_method, path=save_filename_p3)

In [151]:
save_filename_ecg_focus = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_ecg_focus.{FILE_FORMAT}"
plot_attribution_ecg_focus(x, y, None, prob, attr_x, DATASET_NAME, [], attr_method, path=save_filename_ecg_focus)