# IMPORTS

In [None]:
import json
import torch
import pickle
import warnings
import numpy as np
from pathlib import Path
from datetime import datetime

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import PaperModel
from training import Trainer
from data import SalienceDataset

import librosa
import mirdata
import mir_eval
import soundfile

# CONSTANTS

In [None]:
DATA_DIR = Path("/Users/alexandre/mir_datasets/medleydb_pitch/")

from utils import (
    TARGET_SR,
    BINS_PER_SEMITONE,
    N_OCTAVES,
    FMIN,
    BINS_PER_OCTAVE,
    N_BINS,
    HOP_LENGTH,
    N_TIME_FRAMES,
    CQT_FREQUENCIES,
    get_cqt_times,
    compute_hcqt,
    load_audio,
    visualize
)

In [None]:
EXP_NAME = "first_attempt"
EXP_DIR = Path("./EXPERIMENTS/")

TIMESTAMP = datetime.now().strftime("%d%m%Y_%H%M%S")

EXP_FOLDER = EXP_DIR/f"{EXP_NAME}_{TIMESTAMP}"
EXP_FOLDER.mkdir(parents=True, exist_ok=True)
SUMMARY_WRITER = SummaryWriter(str(EXP_FOLDER/EXP_NAME))

# HYPER PARAMETERS

In [None]:
LR = 1e-2
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 32
INPUT_DIM = 5
DEVICE = "cpu"
N_EPOCHS = 100

HP = {
    "LR": 1e-3,
    "WEIGHT_DECAY": 1e-4,
    "BATCH_SIZE": 32,
    "INPUT_DIM": 5,
    "DEVICE": "cpu",
    "N_EPOCHS": 100,
}
with open(EXP_FOLDER/"hyper_parameters.json", "w") as f:
    json.dump(HP, f)

In [None]:
model = PaperModel()
loss = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(lr=LR, params=model.parameters(), weight_decay=WEIGHT_DECAY)

train_data = SalienceDataset(DATA_DIR/"train", ratio=0.1)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_data = SalienceDataset(DATA_DIR/"validation", ratio=0.01)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
with open(EXP_FOLDER/"model.p", 'wb') as f: pickle.dump(model, f)

# TRAIN

In [None]:
trainer = Trainer(
    model=model,
    train_data=train_loader,
    val_data=val_loader,
    loss_cls=loss,
    optimizer=optim,
    device=DEVICE,
    summary_writer=SUMMARY_WRITER,
    ckp_path=EXP_FOLDER
)

In [None]:
warnings.simplefilter('ignore')
trainer.train(N_EPOCHS)

# LOAD MODEL

In [None]:
model = PaperModel()
model.load_state_dict(torch.load("./EXPERIMENTS/first_attempt_19082022_143140/ckp.pt"))

In [None]:
cqt, salience = train_data[30]
cqt = cqt[np.newaxis, :, :, :]
salience = salience[np.newaxis, :, :, :]

In [None]:
visualize(model, cqt, salience, n=1);

In [None]:
output = model.predict(cqt)[0, :, :, 0].T.detach().numpy().astype(float)

In [None]:
salience = salience[0, :, :, 0].T.detach().numpy().astype(float)

In [None]:
print(output.shape)
print(salience.shape)

In [None]:
times = get_cqt_times(50)
mirdata.initialize("medleydb_pitch");

In [None]:
transition_matrix = librosa.sequence.transition_local(216, 5)
predicted_pitch_idx = librosa.sequence.viterbi(prob=output, transition=transition_matrix)
target_pitch_idx = librosa.sequence.viterbi(prob=salience, transition=transition_matrix)

# compute f0 and amplitudes using predicted indexes
predicted_pitch = np.array([CQT_FREQUENCIES[f] for f in predicted_pitch_idx])
predicted_salience = np.array(
    [output[f, i] for i, f in enumerate(predicted_pitch_idx)]
)
target_pitch = np.array([CQT_FREQUENCIES[f] for f in target_pitch_idx])
target_salience = np.array(
    [salience[f, i] for i, f in enumerate(target_pitch_idx)]
)

In [None]:
estimated_f0 = mirdata.annotations.F0Data(
    times, "s", predicted_pitch, "hz", predicted_salience, "likelihood"
)

est_times, est_freqs, est_voicing = estimated_f0.to_mir_eval()

In [None]:
target_f0 = mirdata.annotations.F0Data(
    times, "s", target_pitch, "hz", target_salience, "likelihood"
)

target_times, target_freqs, target_voicing = target_f0.to_mir_eval()

In [None]:
mir_eval.melody.evaluate(
    target_times, target_freqs, est_times, est_freqs, est_voicing=est_voicing
)