In [None]:
import os
from glob import glob
import logging
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from tqdm import tqdm
import h5py
import pickle

from harmonic_inference.data.piece import Piece, ScorePiece
import harmonic_inference.models.initial_chord_models as icm
from harmonic_inference.data.corpus_reading import load_clean_corpus_dfs
from harmonic_inference.models.joint_model import MODEL_CLASSES, HarmonicInferenceModel
import harmonic_inference.data.datasets as ds

logging.basicConfig(level=logging.DEBUG)

In [None]:
files_df, measures_df, chords_df, notes_df = load_clean_corpus_dfs('corpus_data')

In [None]:
files_df

In [None]:
measures_df

In [None]:
chords_df

In [None]:
notes_df

In [None]:
# Get model paths
model_paths = {}

for model in MODEL_CLASSES.keys():
    if model == "icm":
        continue

    model_paths[model] = os.path.join(
        "checkpoints", model, "lightning_logs", "version_*", "checkpoints", "*.ckpt"
    )

model_paths["icm"] = os.path.join("checkpoints", "icm", "initial_chord_prior.json")

In [None]:
# Load models
models = {}
for model_name, model_class in MODEL_CLASSES.items():
    if model_name == "icm":
        continue

    possible_checkpoints = sorted(glob(model_paths[model_name]))
    if len(possible_checkpoints) == 0:
        logging.error(f"No checkpoints found for {model_name} in {model_paths[model_name]}")
        sys.exit(2)

    if len(possible_checkpoints) == 1:
        checkpoint = possible_checkpoints[0]
        logging.info(f"Loading checkpoint {checkpoint} for {model_name}.")

    else:
        checkpoint = possible_checkpoints[0]
        logging.info(f"Multiple checkpoints found for {model_name}. Loading {checkpoint}.")

    models[model_name] = model_class.load_from_checkpoint(checkpoint)
    models[model_name].freeze()

# Load icm json differently
logging.info(f"Loading checkpoint {model_paths['icm']} for icm.")
models["icm"] = icm.SimpleInitialChordModel(model_paths["icm"])

In [None]:
h5_dir = Path("h5_data")
seed = 0

# Load validation data for ctm
h5_path = Path(h5_dir / f"ChordTransitionDataset_valid_seed_{seed}.h5")
with h5py.File(h5_path, "r") as h5_file:
    if "file_ids" not in h5_file:
        logging.error(f"file_ids not found in {h5_path}. Re-create with create_h5_data.py")
        sys.exit(1)

    file_ids = list(h5_file["file_ids"])

# Load from pkl if available
pkl_path = Path(h5_dir / f"pieces_valid_seed_{seed}.pkl")
if pkl_path.exists():
    with open(pkl_path, "rb") as pkl_file:
        piece_dicts = pickle.load(pkl_file)
else:
    piece_dicts = [None] * len(file_ids)

pieces = [
    ScorePiece(
        notes_df.loc[file_id],
        chords_df.loc[file_id],
        measures_df.loc[file_id],
        piece_dict=piece_dict,
        name=(
            f"{file_id}: {files_df.loc[file_id, 'corpus_name']}/"
            f"{files_df.loc[file_id, 'file_name']}"
        ),
    ) for file_id, piece_dict in tqdm(
        zip(file_ids, piece_dicts),
        total=len(file_ids),
        desc="Loading pieces",
    )
]

In [None]:
dataset_classes = {
    "ccm": ds.ChordClassificationDataset,
    "ctm": ds.ChordTransitionDataset,
    "csm": ds.ChordSequenceDataset,
    "ktm": ds.KeyTransitionDataset,
    "ksm": ds.KeySequenceDataset,
}

# Load and evaluate validation data
for model_name, model in models.items():
    logging.info(f"Evaluating {model_name}")
    if model_name == "icm":
        logging.info(model.evaluate(pieces))
        continue

    dataset = dataset_classes[model_name]
    h5_path_valid = Path(h5_dir / f"{dataset.__name__}_valid_seed_{seed}.h5")
    dataset_valid = ds.h5_to_dataset(h5_path_valid, dataset, transform=torch.from_numpy)

    logging.info(model.evaluate(dataset_valid))