# IMPORTS

In [1]:
import json
import torch
import pickle
import warnings
import numpy as np
from pathlib import Path
from datetime import datetime

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from models import *
from training import Trainer
from data import SalienceDataset

import librosa
import mirdata
import mir_eval
import soundfile

# CONSTANTS

In [None]:
DATA_DIR = Path("/Users/alexandre/mir_datasets/medleydb_pitch/")

from utils import (
    TARGET_SR,
    BINS_PER_SEMITONE,
    N_OCTAVES,
    FMIN,
    BINS_PER_OCTAVE,
    N_BINS,
    HOP_LENGTH,
    N_TIME_FRAMES,
    CQT_FREQUENCIES,
    get_cqt_times,
    compute_hcqt,
    load_audio,
    visualize
)

In [None]:
EXP_NAME = "first_attempt"
EXP_DIR = Path("./EXPERIMENTS/")
TIMESTAMP = datetime.now().strftime("%d%m%Y_%H%M%S")

# HYPER PARAMETERS

In [None]:
LR = 1e-2
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 64
INPUT_DIM = train_data[0][0].size(0)
DEVICE = "cpu"
N_EPOCHS = 8
RESIDUAL = False
BILINEAR_INTERP = True

HP = {
    "LR": 1e-3,
    "WEIGHT_DECAY": 1e-4,
    "BATCH_SIZE": 32,
    "INPUT_DIM": 5,
    "DEVICE": "cpu",
    "N_EPOCHS": 100,
}

In [None]:
model = SalienceNetwork()
loss = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(lr=LR, params=model.parameters(), weight_decay=WEIGHT_DECAY)

train_data = SalienceDataset(DATA_DIR/"train")
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_data = SalienceDataset(DATA_DIR/"validation")
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
EXP_FOLDER = EXP_DIR/f"{EXP_NAME}_{TIMESTAMP}"
EXP_FOLDER.mkdir(parents=True, exist_ok=True)
SUMMARY_WRITER = SummaryWriter(str(EXP_FOLDER/EXP_NAME))

with open(EXP_FOLDER/"hyper_parameters.json", "w") as f:
    json.dump(HP, f)
with open(EXP_FOLDER/"model.p", 'wb') as f:
    pickle.dump(model, f)

# TRAIN

In [None]:
trainer = Trainer(
    model=model,
    train_data=train_loader,
    val_data=val_loader,
    loss_cls=loss,
    optimizer=optim,
    device=DEVICE,
    summary_writer=SUMMARY_WRITER,
    ckp_path=EXP_FOLDER
)

In [None]:
warnings.simplefilter('ignore')
trainer.train(N_EPOCHS)