In [None]:
import tensorflow as tf
import numpy as np
import typing
import pandas as pd
import seaborn as sns
from tensorflow import keras
from hydra import compose, initialize
from omegaconf import OmegaConf
from musicnet.config.Config import Config
from musicnet.preprocessing.utils import get_datasets_info, load_vocabs
from musicnet.preprocessing.wav_chunks_tfrecord.utils import create_tf_record_ds
from musicnet.models.utils import MODEL_PATH
from musicnet.preprocessing.dataset.base import DsName
from matplotlib import pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from musicnet.preprocessing.utils import load_source_dataset, get_datasets_info, load_vocabs
import mido

In [None]:
initialize(version_base=None, config_path="../scripts")
cfg = compose(config_name="defaults", overrides=["stages=[]", "exp=False"])

In [None]:
config = typing.cast(Config, OmegaConf.to_object(cfg))
ds_infos = get_datasets_info(config)

In [None]:
ds_infos

In [None]:
_, notes_vocab = load_vocabs(config)
model = keras.models.load_model(MODEL_PATH)

In [None]:
def tp_fp_fn_by_note_plot(y_true, y_pred):
    counts = pd.DataFrame({
        "note": list(notes_vocab.keys()) * 3,
        "metric": np.repeat(["tp", "fp", "fn"], len(notes_vocab)),
        "value": np.concatenate([
            ((y_true == 1) & (y_pred >= 0)).sum(axis=0),
            ((y_true == 0) & (y_pred >= 0)).sum(axis=0),
            ((y_true == 1) & (y_pred < 0)).sum(axis=0)
        ], axis=0)
    })

    sns.barplot(counts, x="value", y="note", hue="metric", orient="y", ax=plt.gca())

In [None]:
plt.figure(figsize=(20, 30))
for i, ds_info in enumerate(ds_infos[1:]):
    # src_ds = load_source_dataset(ds_info.config, ds_info.src_name)
    ds =  create_tf_record_ds(ds_info.config, ds_info.name, shuffle=False)
    y_true = np.concatenate([y_batch for _, y_batch in ds.as_numpy_iterator()])
    y_pred = model.predict(ds)
    print(y_true.shape, y_pred.shape)
    y_true = y_true.reshape(-1, y_true.shape[-1])
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    print(y_true.shape, y_pred.shape)
    plt.subplot(1, 2, i+1)
    tp_fp_fn_by_note_plot(y_true, y_pred)

plt.tight_layout()
plt.show()