# IMPORTS

In [None]:
import re
import torch
import warnings
import numpy as np
import pandas as pd
from pathlib import Path

from torch import nn
from torch.utils.data import DataLoader

from model import PaperModel
from data import SalienceDataset

import librosa
import mirdata
import mir_eval

# CONSTANTS

In [None]:
DATA_DIR = Path("/Users/alexandre/mir_datasets/medleydb_pitch/")
EXP_DIR = Path("./EXPERIMENTS")
RESULT_DIR = Path("./RESULTS")

from utils import (
    TARGET_SR,
    BINS_PER_SEMITONE,
    N_OCTAVES,
    FMIN,
    BINS_PER_OCTAVE,
    N_BINS,
    HOP_LENGTH,
    N_TIME_FRAMES,
    H_RANGE,
    CQT_FREQUENCIES,
    get_cqt_times,
    compute_hcqt,
    load_audio,
    visualize,
    evaluate
)

SLICE_SIZE = 200

# DATA

In [None]:
val_data = SalienceDataset(DATA_DIR/"validation", ratio=0.1)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False)

# LOAD MODEL

In [None]:
MODEL_DIR = EXP_DIR/"first_attempt_19082022_143140"
CKP_PATH = MODEL_DIR/"ckp.pt"
MODEL_NAME = re.findall(r"(.*)_\d{8}_\d{6}$", MODEL_DIR.stem)[0]

In [None]:
model = PaperModel()
model.load_state_dict(torch.load(CKP_PATH))

# EVALUATE

In [None]:
results = evaluate(model, val_loader)

In [None]:
results_df = pd.DataFrame(
    {
        k: [np.mean(v),] for k, v in results.items()
    }
)
results_df.index = [MODEL_NAME,]

In [None]:
results_df

In [None]:
results_df.to_csv(RESULT_DIR/f"{MODEL_NAME}.csv")