In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import pandas as pd

# ^^^ pyforest auto-imports - don't write above this line
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")

%load_ext autoreload
%autoreload 2

## plots of the databases

In [None]:
from data_reader import (
    CompositeReader,
    CINC2016Reader,
    CINC2022Reader,
    EPHNOGRAMReader,
    PCGDataBase,
)

In [None]:
dr = CINC2022Reader("/home/wenhao/Jupyter/wenhao/data/CinC2022/")

In [None]:
# ?dr.plot_outcome_correlation

In [None]:
ax = dr.plot_outcome_correlation(col="Murmur")

In [None]:
# ax.figure.savefig("./images/outcome_murmur_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_murmur_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Age")

In [None]:
# ax.figure.savefig("./images/outcome_age_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_age_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Sex")

In [None]:
# ax.figure.savefig("./images/outcome_sex_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_sex_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Pregnancy status")

In [None]:
# ax.figure.savefig("./images/outcome_pregnancy_status_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_pregnancy_status_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

## Plots of the models

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
import matplotlib.patches as patches
from tqdm.auto import tqdm

sns.set()

plt.rcParams["xtick.labelsize"] = 28
plt.rcParams["ytick.labelsize"] = 28
plt.rcParams["axes.labelsize"] = 40
plt.rcParams["legend.fontsize"] = 24

colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

markers = ["p", "v", "s", "d", r"$\heartsuit$", "*", "X", "P", "x"]
marker_size = 14

%load_ext autoreload
%autoreload 2

In [None]:
try:
    import bib_lookup
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")
try:
    from torch_ecg.utils.misc import MovingAverage, list_sum
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
    from torch_ecg.utils.misc import MovingAverage, list_sum

In [None]:
ma = MovingAverage()
# ma_ea = MovingAverage()

# ma = lambda x: x

In [None]:
from models import CRNN_CINC2022, Wav2Vec2_CINC2022, HFWav2Vec2_CINC2022

In [None]:
CRNN_CINC2022.__DEBUG__ = False
Wav2Vec2_CINC2022.__DEBUG__ = False
HFWav2Vec2_CINC2022.__DEBUG__ = False

In [None]:
results_dir = Path("./results/").resolve()
results_dir

In [None]:
l_csv = list(results_dir.rglob("*.csv"))
l_csv

In [None]:
loss_map = {"AsymmetricLoss": "Loss-A", "BCEWithLogitsWithClassWeightLoss": "Loss-B"}

In [None]:
cnn_name_map = {
    "multi_scopic": "MultiBranch",
    "tresnetF": "TResNetF",
    "tresnetS": "TResNetS",
    "tresnetP": "TResNetP",
    "resnet_nature_comm_bottle_neck_se": "SE-ResNet",
    "resnet_nature_comm": "ResNet-NC",
}

In [None]:
task_map = {
    "classification": "MTL2",
    "multi_task": "MTL3",
}

In [None]:
res = {}
with tqdm(l_csv) as t:
    for fp in t:
        if "OutcomeGridSearch" in str(fp):
            continue
        # if "classification" in str(fp):
        #     continue
        lines = (fp.with_suffix(".txt")).read_text().splitlines()
        model_fp = None
        epoch = None
        for l in lines:
            tmp = re.findall("/.*BestModel.*\\.pth\\.tar", l)
            if len(tmp) > 0:
                model_fp = tmp[0]
                epoch = int(re.findall("epoch([\\d]+)_", lines[-1])[0])
        if "CRNN_CINC2022" in model_fp:
            # model, train_cfg = CRNN_CINC2022.from_checkpoint(model_fp)
            train_cfg = torch.load(model_fp)["train_config"]
            task = train_cfg.task
            model_name = "CRNN"
            cnn_name = train_cfg[task].cnn_name
            cnn_name = cnn_name_map.get(cnn_name, cnn_name)
            tag = f"CRNN-{cnn_name}"
            # if cnn_name != "tresnetS":
            #     continue
        elif "HFWav2Vec2_CINC2022" in model_fp:
            # model, train_cfg = HFWav2Vec2_CINC2022.from_checkpoint(model_fp)
            train_cfg = torch.load(model_fp)["train_config"]
            task = train_cfg.task
            model_name = "wav2vec2"
            cnn_name = None
            # encoder_name = train_cfg[task].encoder
            tag = "wav2vec2"
            # continue
        else:
            # model, train_cfg = Wav2Vec2_CINC2022.from_checkpoint(model_fp)
            train_cfg = torch.load(model_fp)["train_config"]
            task = train_cfg.task
            model_name = "ta-wav2vec2"
            cnn_name = train_cfg[task].cnn_name
            cnn_name = cnn_name_map.get(cnn_name, cnn_name)
            # encoder_name = train_cfg[task].encoder
            tag = "ta-wav2vec2"
            # continue
        loss = train_cfg[task].loss
        loss = loss_map[loss["murmur"]]
        tag = tag + "-" + loss
        df_fp = pd.read_csv(fp)
        df_fp.step = df_fp.step.fillna(method="ffill")
        train_loss = df_fp[df_fp.part == "train"][
            ["epoch", "step", "loss", "time", "lr"]
        ].dropna()
        val_metrics = (
            df_fp[df_fp.part == "val"].drop(columns=["loss", "time", "lr"]).dropna()
        )
        train_metrics = (
            df_fp[df_fp.part == "train"].drop(columns=["loss", "time", "lr"]).dropna()
        )

        res[str(fp)] = dict(
            model_name=model_name,
            cnn_name=cnn_name,
            loss=loss,
            tag=tag,
            task=task_map[task],
            train_loss=train_loss,
            val_metrics=val_metrics,
            train_metrics=train_metrics,
            best_epoch=epoch,
        )

In [None]:
len(res)

In [None]:
train_murmur_weighted_accuracy = []
train_outcome_cost = []
train_outcome_weighted_accuracy = []
val_murmur_weighted_accuracy = []
val_outcome_cost = []
val_outcome_weighted_accuracy = []

for item in res.values():
    train_murmur_weighted_accuracy.append(
        item["train_metrics"]["murmur_weighted_accuracy"].max()
    )
    train_outcome_cost.append(item["train_metrics"]["outcome_cost"].min())
    train_outcome_weighted_accuracy.append(
        item["train_metrics"]["outcome_weighted_accuracy"].max()
    )
    val_murmur_weighted_accuracy.append(
        item["val_metrics"]["murmur_weighted_accuracy"].max()
    )
    val_outcome_cost.append(item["val_metrics"]["outcome_cost"].min())
    val_outcome_weighted_accuracy.append(
        item["val_metrics"]["outcome_weighted_accuracy"].max()
    )

train_murmur_weighted_accuracy = np.array(train_murmur_weighted_accuracy)
train_outcome_cost = np.array(train_outcome_cost)
train_outcome_weighted_accuracy = np.array(train_outcome_weighted_accuracy)
val_murmur_weighted_accuracy = np.array(val_murmur_weighted_accuracy)
val_outcome_cost = np.array(val_outcome_cost)
val_outcome_weighted_accuracy = np.array(val_outcome_weighted_accuracy)

In [None]:
filtered_res = {}
seen_tag = []
for k, v in list(res.items())[::-1]:
    if v["tag"] in seen_tag:
        continue
    if v["loss"] != "Loss-A":
        continue
    if v["task"] != "MTL2":
        continue
    if v["model_name"] == "ta-wav2vec2":
        continue
    if (
        v["cnn_name"] is not None
        and "TResNet" in v["cnn_name"]
        and v["cnn_name"] not in ["TResNetS", "TResNetF"]
    ):
        continue
        #     if (
        #         v["cnn_name"] is not None
        #         and "ResNet-NC" in v["cnn_name"]
        #     ):
        continue
    filtered_res[k] = v
    seen_tag.append(v["tag"])
len(filtered_res)

In [None]:
fig, ax = plt.subplots(figsize=(20, 12))

line_width = 4
spacing = 2

# ax2 = ax.twinx()

ordering = np.argsort(
    [
        v["val_metrics"]["murmur_weighted_accuracy"].max()
        for k, v in filtered_res.items()
    ]
)[::-1].tolist()

for idx in ordering:
    # for idx, (k, v) in enumerate(list(filtered_res.items())):
    k, v = list(filtered_res.items())[idx]
    df_val_metrics = v["val_metrics"]
    if v["model_name"] == "CRNN":
        label = f"{v['cnn_name']}"
    else:
        label = f"{v['model_name']}"
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        label=label,
    )
#     ax.plot(
#         df_val_metrics.step.values[::spacing],
#         ma(df_val_metrics.outcome_accuracy.values)[::spacing],
#         marker=markers[idx],
#         markersize=marker_size,
#         linewidth=line_width,
#         color=colors[idx],
#         linestyle="dashed",
#         label=label,
#     )
ax.set_xlabel("Step (n.u.)")
ax.set_ylabel("Murmur Weighted Accuracy (n.u.)")
# ax.legend(bbox_to_anchor=(1.0, 0.53));
ax.legend(loc="lower right")

fig.savefig("./images/compare_nn.pdf", dpi=1200, bbox_inches="tight", transparent=False)
fig.savefig("./images/compare_nn.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
filtered_res = {
    k: v
    for k, v in res.items()
    if v["cnn_name"] == "SE-ResNet" and v["model_name"] == "CRNN"
    # and v["loss"] == "Loss-A"
    # and v["task"] == "multi_task"
    and v["task"] == "MTL2"
}
len(filtered_res)

In [None]:
fig, ax = plt.subplots(figsize=(20, 12))

line_width = 4
spacing = 2

# ax2 = ax.twinx()

# for idx, (k, v) in enumerate(list(filtered_res.items())[1:]):
for idx, (k, v) in enumerate(list(filtered_res.items())):
    df_val_metrics = v["val_metrics"]
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        label=f"{v['loss']}-{v['task']}-{v['cnn_name']}-Murmur",
    )
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.outcome_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        linestyle="dashed",
        label=f"{v['loss']}-{v['task']}-{v['cnn_name']}-Outcome",
    )
ax.set_xlabel("Step (n.u.)")
ax.set_ylabel("Weighted Accuracy (n.u.)")
ax.legend(loc="lower right", bbox_to_anchor=(1.0, 0.43))


fig.savefig(
    "./images/clf-se-resnet-lossA-vs-lossB.pdf",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
)
fig.savefig(
    "./images/clf-se-resnet-lossA-vs-lossB.svg",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
);

In [None]:
filtered_res = {
    k: v
    for k, v in res.items()
    if v["cnn_name"] == "SE-ResNet" and v["model_name"] == "CRNN"
    # and v["loss"] == "Loss-A"
    and v["task"] == "MTL3"
    #     and v["task"] == "classification"
}
len(filtered_res)

In [None]:
fig, ax = plt.subplots(figsize=(20, 12))

line_width = 4
spacing = 2

# ax2 = ax.twinx()

for idx, (k, v) in enumerate(list(filtered_res.items())[1:]):
    # for idx, (k, v) in enumerate(list(filtered_res.items())):
    df_val_metrics = v["val_metrics"]
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        label=f"{v['loss']}-{v['task']}-{v['cnn_name']}-Murmur",
    )
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.outcome_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        linestyle="dashed",
        label=f"{v['loss']}-{v['task']}-{v['cnn_name']}-Outcome",
    )
ax.set_xlabel("Step (n.u.)")
ax.set_ylabel("Weighted Accuracy (n.u.)")
ax.legend(loc="lower right", bbox_to_anchor=(1.0, 0.43))

fig.savefig(
    "./images/mtl-se-resnet-lossA-vs-lossB.pdf",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
)
fig.savefig(
    "./images/mtl-se-resnet-lossA-vs-lossB.svg",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
);

In [None]:
filtered_res = {
    k: v
    for k, v in res.items()
    if v["cnn_name"] == "TResNetS"
    and v["model_name"] == "CRNN"
    and v["loss"] == "Loss-A"
}
len(filtered_res)

In [None]:
fig, ax = plt.subplots(figsize=(20, 12))

line_width = 4
spacing = 2

# ax2 = ax.twinx()

for idx, (k, v) in enumerate(filtered_res.items()):
    df_val_metrics = v["val_metrics"]
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        label=f"{v['task']}-{v['cnn_name']}-Murmur",
    )
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.outcome_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        linestyle="dashed",
        label=f"{v['task']}-{v['cnn_name']}-Outcome",
    )
ax.set_xlabel("Step (n.u.)")
ax.set_ylabel("Weighted Accuracy (n.u.)")
ax.legend(loc="lower right", bbox_to_anchor=(1.0, 0.43))


fig.savefig(
    "./images/tresnets-clf-vs-mtl.pdf", dpi=1200, bbox_inches="tight", transparent=False
)
fig.savefig(
    "./images/tresnets-clf-vs-mtl.svg", dpi=1200, bbox_inches="tight", transparent=False
);

In [None]:
filtered_res = {
    k: v
    for k, v in res.items()
    if v["cnn_name"] == "SE-ResNet"
    and v["model_name"] == "CRNN"
    and v["loss"] == "Loss-A"
}
len(filtered_res)

In [None]:
fig, ax = plt.subplots(figsize=(20, 12))

line_width = 4
spacing = 2

# ax2 = ax.twinx()

for idx, (k, v) in enumerate(filtered_res.items()):
    df_val_metrics = v["val_metrics"]
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        label=f"{v['task']}-{v['cnn_name']}-Murmur",
    )
    ax.plot(
        df_val_metrics.step.values[::spacing],
        ma(df_val_metrics.outcome_accuracy.values)[::spacing],
        marker=markers[idx],
        markersize=marker_size,
        linewidth=line_width,
        color=colors[idx],
        linestyle="dashed",
        label=f"{v['task']}-{v['cnn_name']}-Outcome",
    )
ax.set_xlabel("Step (n.u.)")
ax.set_ylabel("Weighted Accuracy (n.u.)")
ax.legend(loc="lower right", bbox_to_anchor=(1.0, 0.43))

fig.savefig(
    "./images/se-resnet-clf-vs-mtl.pdf",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
)
fig.savefig(
    "./images/se-resnet-clf-vs-mtl.svg",
    dpi=1200,
    bbox_inches="tight",
    transparent=False,
);

In [None]:
filtered_res = {
    k: v for k, v in res.items() if v["cnn_name"] == "MB" and v["model_name"] == "CRNN"
}
len(filtered_res)

In [None]:
# fig, ax = plt.subplots(figsize=(20, 12))

# line_width = 4
# spacing = 2

# # ax2 = ax.twinx()

# for idx, (k, v) in enumerate(list(filtered_res.items())):
#     df_val_metrics = v["val_metrics"]
#     ax.plot(
#         df_val_metrics.step.values[::spacing],
#         ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
#         marker=markers[idx],
#         markersize=marker_size,
#         linewidth=line_width,
#         color=colors[idx],
#         label=f"{v['task']}-{v['cnn_name']}-Murmur",
#     )
#     ax.plot(
#         df_val_metrics.step.values[::spacing],
#         ma(df_val_metrics.outcome_accuracy.values)[::spacing],
#         marker=markers[idx],
#         markersize=marker_size,
#         linewidth=line_width,
#         color=colors[idx],
#         linestyle="dashed",
#         label=f"{v['task']}-{v['cnn_name']}-Outcome",
#     )
# #     ax.axvline(v["best_epoch"] * 108, linewidth=10, color=colors[idx], alpha=0.2)
# #     ax.set_ylim(0.4,1.0)
# ax.set_xlabel("Step (n.u.)")
# ax.set_ylabel("Weighted Accuracy (n.u.)")
# ax.legend(loc="lower right", bbox_to_anchor=(1.,0.43))

In [None]:
# filtered_res = {k:v for k,v in res.items() if v["task"] == "classification"}
filtered_res = {
    k: v
    for k, v in res.items()
    if v["cnn_name"] == "TResNetF" and v["model_name"] == "CRNN"
}
len(filtered_res)

In [None]:
# fig, ax = plt.subplots(figsize=(20, 12))

# line_width = 4
# spacing = 2

# # ax2 = ax.twinx()

# for idx, (k, v) in enumerate(filtered_res.items()):
#     df_val_metrics = v["val_metrics"]
#     ax.plot(
#         df_val_metrics.step.values[::spacing],
#         ma(df_val_metrics.murmur_weighted_accuracy.values)[::spacing],
#         marker=markers[idx],
#         markersize=marker_size,
#         linewidth=line_width,
#         color=colors[idx],
#         label=f"{v['task']}-{v['cnn_name']}-Murmur",
#     )
#     ax.plot(
#         df_val_metrics.step.values[::spacing],
#         ma(df_val_metrics.outcome_accuracy.values)[::spacing],
#         marker=markers[idx],
#         markersize=marker_size,
#         linewidth=line_width,
#         color=colors[idx],
#         linestyle="dashed",
#         label=f"{v['task']}-{v['cnn_name']}-Outcome",
#     )
# #     ax.axvline(v["best_epoch"] * 108, linewidth=10, color=colors[idx], alpha=0.2)
# #     ax.set_ylim(0.4,1.0)
# ax.set_xlabel("Step (n.u.)")
# ax.set_ylabel("Weighted Accuracy (n.u.)")
# ax.legend(loc="lower right", bbox_to_anchor=(1., -0.3), ncol=2)