In [None]:
import logging
logging.basicConfig(level=logging.DEBUG)

import os
from glob import glob
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from tqdm import tqdm
import h5py
import pickle
import numpy as np

from harmonic_inference.data.piece import Piece, ScorePiece, get_range_start
import harmonic_inference.models.initial_chord_models as icm
from harmonic_inference.data.corpus_reading import load_clean_corpus_dfs
from harmonic_inference.models.joint_model import MODEL_CLASSES, HarmonicInferenceModel, DebugLogger
import harmonic_inference.data.datasets as ds
from harmonic_inference.data.data_types import KeyMode
from harmonic_inference.utils.beam_search_utils import State

In [None]:
files_df, measures_df, chords_df, notes_df = load_clean_corpus_dfs('corpus_data')

In [None]:
files_df

In [None]:
measures_df

In [None]:
chords_df

In [None]:
notes_df

In [None]:
# Get model paths
model_paths = {}

for model in MODEL_CLASSES.keys():
    if model == "icm":
        continue

    model_paths[model] = os.path.join(
        "checkpoints", model, "lightning_logs", "version_*", "checkpoints", "*.ckpt"
    )

model_paths["icm"] = os.path.join("checkpoints", "icm", "initial_chord_prior.json")

In [None]:
# Load models
models = {}
for model_name, model_class in MODEL_CLASSES.items():
    if model_name == "icm":
        continue

    possible_checkpoints = sorted(glob(model_paths[model_name]))
    if len(possible_checkpoints) == 0:
        logging.error(f"No checkpoints found for {model_name} in {model_paths[model_name]}")
        sys.exit(2)

    if len(possible_checkpoints) == 1:
        checkpoint = possible_checkpoints[0]
        logging.info(f"Loading checkpoint {checkpoint} for {model_name}.")

    else:
        checkpoint = possible_checkpoints[-1]
        logging.info(f"Multiple checkpoints found for {model_name}. Loading {checkpoint}.")

    models[model_name] = model_class.load_from_checkpoint(checkpoint)
    models[model_name].freeze()

# Load icm json differently
logging.info(f"Loading checkpoint {model_paths['icm']} for icm.")
models["icm"] = icm.SimpleInitialChordModel(model_paths["icm"])

In [None]:
h5_dir = Path("h5_data")
seed = 0

# Load validation data for ctm
h5_path = Path(h5_dir / f"ChordTransitionDataset_valid_seed_{seed}.h5")
with h5py.File(h5_path, "r") as h5_file:
    if "file_ids" not in h5_file:
        logging.error(f"file_ids not found in {h5_path}. Re-create with create_h5_data.py")
        sys.exit(1)

    file_ids = list(h5_file["file_ids"])

# Load from pkl if available
pkl_path = Path(h5_dir / f"pieces_valid_seed_{seed}.pkl")
if pkl_path.exists():
    with open(pkl_path, "rb") as pkl_file:
        piece_dicts = pickle.load(pkl_file)
else:
    piece_dicts = [None] * len(file_ids)

pieces = [
    ScorePiece(
        notes_df.loc[file_id],
        chords_df.loc[file_id],
        measures_df.loc[file_id],
        piece_dict=piece_dict,
        name=(
            f"{file_id}: {files_df.loc[file_id, 'corpus_name']}/"
            f"{files_df.loc[file_id, 'file_name']}"
        ),
    ) for file_id, piece_dict in tqdm(
        zip(file_ids, piece_dicts),
        total=len(file_ids),
        desc="Loading pieces",
    )
]

In [None]:
dataset_classes = {
    "ccm": ds.ChordClassificationDataset,
    "ctm": ds.ChordTransitionDataset,
    "csm": ds.ChordSequenceDataset,
    "ktm": ds.KeyTransitionDataset,
    "ksm": ds.KeySequenceDataset,
}

# Load and evaluate validation data
for model_name, model in models.items():
    logging.info(f"Evaluating {model_name}")
    if model_name == "icm":
        logging.info(model.evaluate(pieces))
        continue

    dataset = dataset_classes[model_name]
    h5_path_valid = Path(h5_dir / f"{dataset.__name__}_valid_seed_{seed}.h5")
    dataset_valid = ds.h5_to_dataset(h5_path_valid, dataset, transform=torch.from_numpy)

    logging.info(model.evaluate(dataset_valid))

In [None]:
[f"{i}: {piece.name}" for i, piece in enumerate(pieces)]

In [None]:
# Choose random piece
np.random.seed()

piece_index = np.random.choice(len(pieces))
piece = pieces[piece_index]

file_id, piece_name = [string.strip() for string in piece.name.split(':')]
file_id = int(file_id)

logging.info(f"Evaluating piece file_id={file_id}: {piece_name}")

joint_model = HarmonicInferenceModel(
    models,
    min_chord_change_prob=0.15,
    max_no_chord_change_prob=0.85,
)

# Normally, just run joint_model.get_harmony(piece)
# Here, we are running it manually
joint_model.current_piece = piece

# Save caches from piece
joint_model.duration_cache = piece.get_duration_cache()
joint_model.onset_cache = [vec.onset for vec in piece.get_inputs()] + [
    piece.get_inputs()[-1].offset
]
joint_model.onset_level_cache = [vec.onset_level for vec in piece.get_inputs()] + [
    piece.get_inputs()[-1].offset_level
]

joint_model.debugger = DebugLogger(
    joint_model.current_piece,
    joint_model.CHORD_OUTPUT_TYPE,
    joint_model.KEY_OUTPUT_TYPE,
    joint_model.max_chord_branching_factor,
    joint_model.max_key_branching_factor,
)


In [None]:
# Evaluate Chord Transition Model
chord_change_probs = joint_model.get_chord_change_probs()
chord_ranges, chord_range_log_probs = joint_model.get_chord_ranges(chord_change_probs)

# Convert range starting points to new starts based on the note offsets
chord_change_indices = [start for start, _ in chord_ranges]
chord_windows = [
    (get_range_start(piece.get_inputs()[start].onset, piece.get_inputs()), end)
    for start, end in chord_ranges
]

joint_model.debugger.debug_chord_change_probs(chord_change_probs)

In [None]:
# How many ranges are missed?
correct_chord_change_indices = piece.get_chord_change_indices()
correct_chord_ranges = list(zip(correct_chord_change_indices, list(correct_chord_change_indices[1:]) + [None]))
correct_chord_windows = piece.get_chord_ranges()

missed = 0
for i, correct_range in enumerate(correct_chord_ranges):
    if correct_range not in chord_ranges:
        print(f"Chord range {correct_range} not found (chord {i})")
        missed += 1

print(f"Missed {missed} out of {len(correct_chord_change_indices)} ranges")
print(f"Accuracy (found ranges) = {1 - missed / len(correct_chord_change_indices)}")

In [None]:
# Evaluate Chord Classification Model on noisy chord ranges
chord_classifications = joint_model.get_chord_classifications(chord_windows, chord_change_indices)
joint_model.debugger.debug_chord_classifications(chord_ranges, chord_classifications)

In [None]:
# Evaluate Chord Classification Model on true chord ranges
true_chord_classifications = joint_model.get_chord_classifications(correct_chord_windows, correct_chord_change_indices)
joint_model.debugger.debug_chord_classifications(correct_chord_ranges, true_chord_classifications)

In [None]:
# To investigate a specific range
range = (1875, 1886)

if range in correct_chord_ranges:
    ranges = correct_chord_ranges
    log_probs = true_chord_classifications
elif range in chord_ranges:
    ranges = chord_ranges
    log_probs = chord_classifications
else:
    ranges = [range, (0, 2)]
    log_probs = joint_model.get_chord_classifications(ranges)

index = ranges.index(range)
log_prior = log_probs[index]

correct_chords = piece.get_chords_within_range(range[0], range[1])
correct_notes = piece.get_inputs()[max(range[0] - 10, 0) : range[1]]

print("Correct chords:")
for chord in correct_chords:
    print(f"    {str(chord)}")

file_chords_df = chords_df.loc[file_id]
correct_mcs = set([chord.onset[0] for chord in correct_chords])

cols = ["mc", "onset", "label", "globalkey", "localkey", "globalkey_is_minor", "localkey_is_minor", "chord", "numeral", "figbass", "relativeroot", "changes", "root", "bass_note", "duration"]
print(file_chords_df.loc[file_chords_df["mc"].isin(correct_mcs), cols])

print("Notes:")
for note in correct_notes:
    print(f"    {str(note)}")

print(notes_df.loc[file_id].iloc[max(range[0] - 10, 0) : range[1]])

In [None]:
# Evaluate other models in beam search
state = State(
    key=piece.get_keys()[0].get_one_hot_index(),
    csm_log_prior=joint_model.initial_chord_model.get_prior(
        piece.get_keys()[0].relative_mode == KeyMode.MINOR,
        log=True,
    )
)
joint_model.debugger.debug_initial_chord_prior(state.csm_log_prior)

state = state.chord_transition(
    piece.get_chords()[0].get_one_hot_index(
        relative=False,
        use_inversion=True,
        pad=False,
    ),
    piece.get_chord_change_indices()[1],
    0,
    joint_model.CHORD_OUTPUT_TYPE,
    joint_model.LABELS,
)

In [None]:
pause_on_key_changes = False

key_changes_input_indices = piece.get_key_change_input_indices()

for chord_index, (start, end) in enumerate(correct_chord_ranges[1:], start=1):
    chord = piece.get_chords()[chord_index]
    chord_one_hot = chord.get_one_hot_index(relative=False, use_inversion=True, pad=False)
    chord_relative_one_hot = chord.get_one_hot_index(relative=True, use_inversion=True, pad=False)

    joint_model.run_csm_batched([state])
    state = state.chord_transition(
        chord_one_hot,
        end,
        0,
        joint_model.CHORD_OUTPUT_TYPE,
        joint_model.LABELS,
    )

    if start in key_changes_input_indices:
        if pause_on_key_changes and input() == 'q':
            break

        # Key change
        key_index = key_changes_input_indices.index(start)
        prev_key = piece.get_keys()[key_index - 1]
        next_key = piece.get_keys()[key_index]

        # KTM
        key_change_probs = joint_model.get_key_change_probs([state])

        # KSM
        joint_model.get_key_change_states([state])
        state = state.key_transition(
            next_key.get_one_hot_index(),
            0,
            joint_model.KEY_OUTPUT_TYPE,
            joint_model.LABELS,
        )

    else:
        joint_model.debugger.debug_chord_sequence_priors([state])