In [None]:
import os
import logging, os

logging.disable(logging.WARNING)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

In [None]:
import tensorflow as tf
import keras
from musicnet.preprocessing.wav_specs_and_notes.utils import create_tf_record_ds
from musicnet.models.transformer.Transformer import F1FromSeqLogits, WeightedBinaryCrossentropy, WarmupLRSchedule
from musicnet.utils import load_params, PROJECT_ROOT_DIR, notes_vocab, instruments_vocab, note_frequency
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
from utils import y_vs_y_pred_vis, spectogram_vis
from ipywidgets import interact
import librosa
import pandas as pd
from tqdm import tqdm

In [None]:
tf.config.list_physical_devices("GPU")

In [None]:
model = keras.models.load_model(
    os.path.join(PROJECT_ROOT_DIR, "musicnet", "models", "cnn", "model.keras"),
    {
        "WeightedBinaryCrossentropy": WeightedBinaryCrossentropy,
        "F1FromSeqLogits": F1FromSeqLogits,
        "WarmupLRSchedule": WarmupLRSchedule
    }
)

In [None]:
params = load_params([
    "cnn.*",
    "wav_specs_and_notes.preprocessor.target_sr",
    "wav_specs_and_notes.preprocessor.spectogram.*",
    "wav_specs_and_notes.use_converted_midis",
    "midi_to_wav.programs_whitelist"
])

In [None]:
if params["programs_whitelist"]:
    target_classes = len(notes_vocab) * len(params["programs_whitelist"])
else:
    target_classes = len(notes_vocab) * len(instruments_vocab)

In [None]:
ds_params = {
    "architecture": "cnn",
    "n_filters": params["n_filters"],
    "target_classes": target_classes,
    "batch_size": params["batch_size"],
    "dataset_size": params["dataset_size"],
    "use_converted_midis": params["use_converted_midis"]
}

In [None]:
train_ds = create_tf_record_ds("train", **ds_params)
val_ds = create_tf_record_ds("val", **ds_params, shuffle=False)

In [None]:
model.loss

In [None]:
model.compile(
    loss=model.loss,
    metrics=[
        F1FromSeqLogits(threshold=0.5, average="weighted", name="f1_weighted"),
        F1FromSeqLogits(threshold=0.5, average="micro", name="f1_global"),
        keras.metrics.Precision(0, name="precision"),
        keras.metrics.Precision(0, name="precision_33", class_id=33),
        keras.metrics.Recall(0, name="recall")
    ]
)

In [None]:
model.evaluate(val_ds)

In [None]:
y_pred = model.predict(val_ds)
y_pred = tf.sigmoid(y_pred)

In [None]:
y_true = None
for x_batch, y_batch in val_ds:
    y_true = tf.concat((y_true, y_batch), axis=0) if y_true is not None else y_batch

In [None]:
print(y_pred.shape)
print(y_true.shape)

In [None]:
def calc_per_note_stats(y_true, y_pred, threshold=0.5):
    stats = []
    for i in tqdm(list(range(0, len(notes_vocab)))):
        precision = keras.metrics.Precision(threshold, class_id=i)(y_true, y_pred)
        recall = keras.metrics.Recall(threshold, class_id=i)(y_true, y_pred)
        true_count = tf.reduce_sum(y_true[:, :, i])
        pred_count = tf.reduce_sum(tf.cast(y_pred[:, :, i] > threshold, tf.float32))
        tp = tf.reduce_sum(
            tf.cast(
                ((y_pred[:, :, i] > threshold) & tf.cast(y_true[:, :, i], tf.bool)),
                tf.float32
            ))
        stats.append({
            "note_idx": i,
            "note_freq": note_frequency(i),
            "true_count": int(true_count.numpy()),
            "pred_count": int(pred_count.numpy()),
            "tp": int(tp.numpy()),
            "fp": int(pred_count.numpy()) - int(tp.numpy()),
            "fn": int(true_count.numpy()) - int(tp.numpy()),
            "precision": round(precision.numpy(), 3),
            "recall": round(recall.numpy(), 3)
        })
    return pd.DataFrame(stats)

def calc_per_batch_stats(y_true, y_pred, threshold=0.5, batch_size=128):
    stats = []
    for x in tqdm(list(range(batch_size, y_true.shape[0] + 1, batch_size))):
        y_true_batch = y_true[x-batch_size:x, :, :]
        y_pred_batch = y_pred[x-batch_size:x, :, :]
        precision = keras.metrics.Precision(threshold)(y_true_batch, y_pred_batch)
        recall = keras.metrics.Recall(threshold)(y_true_batch, y_pred_batch)
        true_count = tf.reduce_sum(y_true_batch)
        pred_count = tf.reduce_sum(tf.cast(y_pred_batch > threshold, tf.float32))
        tp = tf.reduce_sum(
            tf.cast(
                ((y_pred_batch > threshold) & tf.cast(y_true_batch, tf.bool)),
                tf.float32
            ))
        stats.append({
            "batch": x // batch_size,
            "true_count": int(true_count.numpy()),
            "pred_count": int(pred_count.numpy()),
            "tp": int(tp.numpy()),
            "fp": int(pred_count.numpy()) - int(tp.numpy()),
            "fn": int(true_count.numpy()) - int(tp.numpy()),
            "precision": round(precision.numpy(), 3),
            "recall": round(recall.numpy(), 3)
        })
    return pd.DataFrame(stats)

In [None]:
per_note_stats = calc_per_note_stats(y_true, y_pred)
pd.set_option("display.max_rows", 100)
per_note_stats

In [None]:
per_batch_stats = calc_per_batch_stats(y_true, y_pred, batch_size=128)
pd.set_option("display.max_rows", 100)
per_batch_stats

In [None]:
plt.plot(per_note_stats["precision"])
plt.plot(per_note_stats["recall"])

In [None]:
for x, y in val_ds:
    break

In [None]:
print(x.shape, y.shape)

In [None]:
y_pred = model(x)
y_pred = tf.sigmoid(y_pred)

In [None]:
spectogram_params = { k: v for k, v in params.items() if k in ["n_fft", "target_sr", "min_hz"] }

@interact(i=(0, params["batch_size"] - 1), t=(0.5, 1.0))
def show_sample(i, t=0.5):
    y_vs_y_pred_vis(y[i].numpy(), (y_pred > t)[i].numpy())
    spectogram_vis(x[i].numpy(), **spectogram_params)

In [None]:
# @tf.function
# def run_inference(context, max_len=999, head_start=0):
#     y_pred = tf.TensorArray(dtype=tf.bool, size=0, dynamic_size=True)
#     for i in tf.range(head_start):
#         y_pred = y_pred.write(i, y[-1, i, :])
#     for i in tf.range(head_start, max_len):
#         next_pred = model([context, [y_pred.stack()]], training=False)[-1, -1, :]
#         next_pred = tf.sigmoid(next_pred)
#         y_pred = y_pred.write(i, next_pred > 0.5)
#     return y_pred.stack()

In [None]:
# y_pred_infer = run_inference(context, y.shape[1], head_start=100)
# y_pred_context = model([context, x])[0]
# y_pred_context = tf.sigmoid(y_pred_context).numpy() > 0.5