In [1]:
import argparse
import pickle
import gzip
import os
import torch

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.offsetbox as offsetbox


from src.utils import LABEL_MAPPING
from src.setup import setup
from src.utils import preprocess, get_beat_spans

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ATTR_STR_DICT = {
    "random_baseline": "Random (baseline)",
    "saliency": "Saliency",
    "input_gradient": "Input × Gradient",
    "guided_backprop": "Guided Backprop",
    "integrated_gradients": "Integrated Gradients",
    "deep_lift": "DeepLIFT",
    "deep_shap": "DeepSHAP",
    "lrp": "LRP",
    "lime": "LIME",
    "kernel_shap": "KernelSHAP",
    "gradcam": "Grad-CAM",
    "guided_gradcam": "Guided Grad-CAM",
}

ATTR_FIGSIZE = (20, 5)

FILE_FORMAT = "svg"

In [3]:
def get_plot_range(min_value, max_value, coff=1):
    baseline_value = (min_value + max_value) / 2
    amplitude = max_value - baseline_value
    plot_range = (baseline_value - amplitude * coff, baseline_value + amplitude * coff)
    return plot_range

In [4]:
ECG_COLOR = "darkblue"
ECG_LW = 5

def plot_ecg(x, y, dataset, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    label = label_index[y]

    # ax1.set_frame_on(False)
    ax1.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()

In [5]:
ECG_ALPHA = 0.4
ATTR_COLOR = "crimson"
# ATTR_ALPHA = 0.55
ATTR_LW = 5.5

def plot_attribution(x, y, beat_spans, prob, attr_x, dataset, attr_method=None, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)
    ax2 = ax1.twinx()

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW, alpha=ECG_ALPHA)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    ###
    # Beat spans
    if beat_spans is not None:
        for class_idx, class_span in beat_spans.items():
            for span in class_span:
                if class_idx == 1:
                    ax2.axvspan(span[0], span[1], color=(221/256,238/256,254/256))
                elif class_idx == 2:
                    ax2.axvspan(span[0], span[1], color=(233/256,249/256,220/256))
                # ax1.axvline(span[0], alpha=0.5, c="grey", linestyle="--")
    ###
    

    # Attribution
    if attr_x is not None:
        max_abs_attr = np.max(np.abs(attr_x))
        attr_yrange = (-max_abs_attr * 1.55, max_abs_attr * 1.55)
        ax2.set_ylim(*attr_yrange)
        ax2.plot(attr_x.squeeze(), c=ATTR_COLOR, linewidth=ATTR_LW)
        # ax2.set_ylabel("Attribution value", color=ATTR_COLOR)
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)

    label = label_index[y]

    ax1.set_zorder(ax2.get_zorder() + 1)
    ax1.set_frame_on(False)
    ax1.margins(x=0)
    ax2.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()

In [16]:
def plot_ecg_with_beat_spans(x, y, beat_spans, prob, dataset, attr_method=None, path=None):
    label_index = LABEL_MAPPING[dataset]["LABEL_INDEX"]
    fig, ax1 = plt.subplots(figsize=ATTR_FIGSIZE)

    # ECG
    ecg_yrange = get_plot_range(np.min(x), np.max(x), 1.55)
    ax1.set_ylim(*ecg_yrange)
    ax1.plot(x.squeeze(), c=ECG_COLOR, linewidth=ECG_LW)
    # ax1.set_ylabel("ECG signal", color=ECG_COLOR)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)

    # Beat spans
    if beat_spans is not None:
        for class_idx, class_span in beat_spans.items():
            for span in class_span:
                if class_idx == 1:
                    ax1.axvspan(span[0], span[1], color=(221/256,238/256,254/256))
                elif class_idx == 2:
                    ax1.axvspan(span[0], span[1], color=(233/256,249/256,220/256))
                # ax1.axvline(span[0], alpha=0.5, c="grey", linestyle="--")

    label = label_index[y]

    # ax1.set_frame_on(False)
    ax1.margins(x=0)
    
    plt.tight_layout()
    if path is not None:
        plt.savefig(path)
    else:
        plt.show()
    plt.close()

### 1. 특정 sample에 attribution 적용한 결과 그리기

In [29]:
idx = 405

DATASET_NAME = "mitdb" # mitdb, svdb, incartdb만 가능
attr_list = ["saliency", "input_gradient", "guided_backprop", "integrated_gradients", "deep_lift", "deep_shap", "lrp", "lime", "kernel_shap", "gradcam", "guided_gradcam"]
absolute = False

for attr_method in attr_list:
    args = argparse.Namespace()
    args.attr_dir = f"results_final_231123/results_attribution/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/{attr_method}"
    args.model_path = "results_final_231123/results_training/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20_seed1/model_last.pt"
    args.absolute = absolute
    args.gpu_num = 1
    args.seed = 1
    args.result_dir = f"./figure1/v3/plots_{idx}"

    os.makedirs(args.result_dir, exist_ok=True)
    device = setup(args)
    model = torch.load(args.model_path, map_location=device)

    # load eval_attr_data & feature attribution
    eval_attr_data = pickle.load(gzip.GzipFile(f"{args.attr_dir}/eval_attr_data.pkl", "rb"))
    attr_list = pickle.load(gzip.GzipFile(f"{args.attr_dir}/attr_list.pkl", "rb"))

    x, y, beat_spans, prob = (
        eval_attr_data["x"][idx],
        eval_attr_data["y"][idx],
        eval_attr_data["beat_spans"][idx],
        eval_attr_data["prob"][idx],
    )
    attr_x = attr_list[idx]
    if args.absolute:
        attr_x = np.absolute(attr_x)

    if args.absolute:
        save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}_abs.{FILE_FORMAT}"
    else:
        save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}_{attr_method}.{FILE_FORMAT}"
        
    plot_attribution(x, y, beat_spans, prob, attr_x, DATASET_NAME, attr_method, save_filename)

save_filename = f"{args.result_dir}/mitdb_eval_attr_{idx}.{FILE_FORMAT}"
plot_ecg(x, y, DATASET_NAME, save_filename)
save_filename_with_beatspans = f"{args.result_dir}/mitdb_eval_attr_{idx}_beatspans.{FILE_FORMAT}"
plot_ecg_with_beat_spans(x, y, beat_spans, prob, DATASET_NAME, path=save_filename_with_beatspans)
print(f"label: {y}")
print(f"prob: {prob}")

-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
-- Set random seed: 1
-- Use gpu: 1
label: 2
prob: 0.9320333003997803
