# IMPORTS

In [None]:
import re
import torch
import warnings
import numpy as np
import pandas as pd
from pathlib import Path

from torch import nn
from torch.utils.data import DataLoader

from model import PaperModel
from data import SalienceDataset

import librosa
import mirdata
import mir_eval

# CONSTANTS

In [None]:
DATA_DIR = Path("/Users/alexandre/mir_datasets/medleydb_pitch/")
EXP_DIR = Path("./EXPERIMENTS")
RESULT_DIR = Path("./RESULTS")

from utils import (
    TARGET_SR,
    BINS_PER_SEMITONE,
    N_OCTAVES,
    FMIN,
    BINS_PER_OCTAVE,
    N_BINS,
    HOP_LENGTH,
    N_TIME_FRAMES,
    H_RANGE,
    CQT_FREQUENCIES,
    get_cqt_times,
    compute_hcqt,
    load_audio,
    visualize
)

mirdata.initialize("medleydb_pitch")

SLICE_SIZE = 200

# DATA

In [None]:
val_data = SalienceDataset(DATA_DIR/"validation", ratio=0.1)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False)

# LOAD MODEL

In [None]:
MODEL_DIR = EXP_DIR/"first_attempt_19082022_143140"
CKP_PATH = MODEL_DIR/"ckp.pt"
MODEL_NAME = re.findall(r"(.*)_\d{8}_\d{6}$", MODEL_DIR.stem)[0]

In [None]:
model = PaperModel()
model.load_state_dict(torch.load(CKP_PATH))

# EVALUATE

In [None]:
def extract_freqs(transition_matrix, times, salience_2D):
    pitch = librosa.sequence.viterbi(salience_2D.T, transition_matrix)
    
    pitch_hz = np.array([CQT_FREQUENCIES[f] for f in pitch])
    salience_1D = np.array(
        [salience_2D[i, f] for i, f in enumerate(pitch)]
    )
    
    estimated_f0 = mirdata.annotations.F0Data(
        times, "s", pitch_hz, "hz", salience_1D, "likelihood"
    )
    return estimated_f0.to_mir_eval()


def evaluate(model, data):
    model.eval()
    
    times_s = get_cqt_times(N_TIME_FRAMES)
    transition_matrix = librosa.sequence.transition_local(N_BINS, len(H_RANGE))

    results = {}

    for batch in val_loader:
        hcqt, target_saliences = batch
        target_saliences = torch.transpose(target_saliences, 0, 2)
        target_saliences = target_saliences[:, :, :, 0].T.detach().numpy().astype(float)
        
        predicted_saliences = model.predict(hcqt).detach().numpy().astype(float)[:, :, :, 0]
        
        for predicted_salience, target_salience in zip(
            predicted_saliences, target_saliences
        ):
            est_times, est_freqs, est_voicing = extract_freqs(
                transition_matrix,
                times_s,
                predicted_salience
            )
            target_times, target_freqs, target_voicing = extract_freqs(
                transition_matrix,
                times_s,
                target_salience
            )
            
            res = mir_eval.melody.evaluate(
                    target_times, target_freqs, est_times, est_freqs, est_voicing=est_voicing
                )
            
            for k, v in res.items():
                if k in results:
                    results[k].append(v)
                else:
                    results[k] = [v,]
    
    return results

In [None]:
results = evaluate(model, val_loader)

In [None]:
results_df = pd.DataFrame(
    {
        k: [np.mean(v),] for k, v in results.items()
    }
)
results_df.index = [MODEL_NAME,]

In [None]:
results_df

In [None]:
results_df.to_csv(RESULT_DIR/f"{MODEL_NAME}.csv")