In [None]:
import pandas as pd
from pathlib import Path
import tqdm
import sys
import matplotlib.pyplot as plt
import re

# ^^^ pyforest auto-imports - don't write above this line
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")

%load_ext autoreload
%autoreload 2

## plots of the databases

In [None]:
from data_reader import (
    CompositeReader,
    CINC2016Reader,
    CINC2022Reader,
    EPHNOGRAMReader,
    PCGDataBase,
)

In [None]:
dr = CINC2022Reader("/home/wenhao/Jupyter/wenhao/data/CinC2022/")

In [None]:
?dr.plot_outcome_correlation

In [None]:
ax = dr.plot_outcome_correlation(col="Murmur")

In [None]:
# ax.figure.savefig("./images/outcome_murmur_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_murmur_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Age")

In [None]:
# ax.figure.savefig("./images/outcome_age_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_age_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Sex")

In [None]:
# ax.figure.savefig("./images/outcome_sex_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_sex_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

In [None]:
ax = dr.plot_outcome_correlation(col="Pregnancy status")

In [None]:
# ax.figure.savefig("./images/outcome_pregnancy_status_corr.pdf", dpi=1200, bbox_inches="tight", transparent=False);
# ax.figure.savefig("./images/outcome_pregnancy_status_corr.svg", dpi=1200, bbox_inches="tight", transparent=False);

## Plots of the models

In [None]:
import seaborn as sns
from matplotlib.pyplot import cm
import matplotlib.patches as patches
from tqdm.auto import tqdm

sns.set()

plt.rcParams["xtick.labelsize"] = 28
plt.rcParams["ytick.labelsize"] = 28
plt.rcParams["axes.labelsize"] = 40
plt.rcParams["legend.fontsize"] = 24

colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

markers = ["p", "v", "s", "d", "x", "*", "+", "$\heartsuit$"]
marker_size = 12

%load_ext autoreload
%autoreload 2

In [None]:
try:
    import bib_lookup
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/")
try:
    from torch_ecg.utils.misc import MovingAverage, list_sum
except ModuleNotFoundError:
    sys.path.insert(0, "/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/")
    from torch_ecg.utils.misc import MovingAverage, list_sum

In [None]:
# ma = MovingAverage()
ma_ea = MovingAverage()

ma = lambda x: x

In [None]:
from models import CRNN_CINC2022, Wav2Vec2_CINC2022, HFWav2Vec2_CINC2022

In [None]:
results_dir = Path("./results/").resolve()
results_dir

In [None]:
l_csv = list(results_dir.rglob("*.csv"))
l_csv

In [None]:
res = {}
with tqdm(l_csv) as t:
    for fp in t:
        if "OutcomeGridSearch" in str(fp):
            continue
        lines = (fp.with_suffix(".txt")).read_text().splitlines()
        model_fp = None
        for l in lines:
            tmp = re.findall("/.*BestModel.*\\.pth\\.tar", l)
            if len(tmp) > 0:
                model_fp = tmp[0]
                epoch = int(re.findall("epoch([\\d]+)_", lines[-1])[0])
        if "CRNN_CINC2022" in model_fp:
            model, train_cfg = CRNN_CINC2022.from_checkpoint(model_fp)
            task = train_cfg.task
            cnn_name = train_cfg[task].cnn_name
            tag = f"crnn-{cnn_name}"
        elif "HFWav2Vec2_CINC2022" in model_fp:
            model, train_cfg = HFWav2Vec2_CINC2022.from_checkpoint(model_fp)
            task = train_cfg.task
            # encoder_name = train_cfg[task].encoder
            tag = f"hf-wav2vec2"
        else:
            model, train_cfg = Wav2Vec2_CINC2022.from_checkpoint(model_fp)
            task = train_cfg.task
            cnn_name = train_cfg[task].cnn_name
            # encoder_name = train_cfg[task].encoder
            tag = f"ta-wav2vec2"
        df_fp = pd.read_csv(fp)
        df_fp.step = df_fp.step.fillna(method="ffill")
        train_loss = df_fp[df_fp.part == "train"][
            ["epoch", "step", "loss", "time", "lr"]
        ].dropna()
        val_metrics = (
            df_fp[df_fp.part == "val"].drop(columns=["loss", "time", "lr"]).dropna()
        )

        res[str(fp)] = dict(
            tag=tag,
            task=task,
            train_loss=train_loss,
            val_metrics=val_metrics,
        )

In [None]:
res

In [None]:
# NOT finished yet
fig, ax = plt.subplots(figsize=(20, 12))
# ax2 = ax.twinx()
for k, v in res.items():
    df_val_metrics = v["val_metrics"]
    ax.plot(
        df_val_metrics.step,
        df_val_metrics.murmur_weighted_accuracy,
        label=f"{v['task']}-{v['tag']}",
    )
    ax.plot(df_val_metrics.step, df_val_metrics.outcome_weighted_accuracy)
#     ax.set_ylim(0.4,1.0)
ax.legend()