In [None]:
import pandas as pd
import numpy as np

In [None]:
from corpus_reading import read_dump
import corpus_utils

chords_tsv = 'data/chord_list.tsv'
notes_tsv = 'data/note_list.tsv'
measures_tsv = 'data/measure_list.tsv'
files_tsv = 'data/file_list.tsv'

chords_df = read_dump(chords_tsv)
notes_df = read_dump(notes_tsv, index_col=[0,1,2])
measures_df = read_dump(measures_tsv)
files_df = read_dump(files_tsv, index_col=0)

# Bugfixes
measures_df.loc[(685, 487), 'next'][0] = 488

In [None]:
files_df

In [None]:
chords_df

In [None]:
measures_df

In [None]:
notes_df

In [None]:
import harmonic_inference_data as hid
import importlib
importlib.reload(hid)

datasets = {'global': {},
            'local': {},
            'none': {}}

datasets['global']['train'], datasets['global']['valid'], datasets['global']['test'] = hid.get_train_valid_test_splits(
    chords_df=chords_df, notes_df=notes_df, measures_df=measures_df, files_df=files_df,
    seed=0, h5_directory='data', h5_prefix='globalkey_811split', make_dfs=True, transpose_global=True
)

datasets['local']['train'], datasets['local']['valid'], datasets['local']['test'] = hid.get_train_valid_test_splits(
    chords_df=chords_df, notes_df=notes_df, measures_df=measures_df, files_df=files_df,
    seed=0, h5_directory='data', h5_prefix='localkey_811split', make_dfs=True, transpose_local=True
)

datasets['none']['train'], datasets['none']['valid'], datasets['none']['test'] = hid.get_train_valid_test_splits(
    chords_df=chords_df, notes_df=notes_df, measures_df=measures_df, files_df=files_df,
    seed=0, h5_directory='data', h5_prefix='811split', make_dfs=True
)

data = 'local'

In [None]:
import harmonic_inference_models as him
import harmonic_utils
import importlib
importlib.reload(him)
from ablation import get_masks_and_names
import torch
import os

masks, mask_names = get_masks_and_names()
dir_name = 'results'

index = -1
prefix = (data + '_') if data is not 'none' else ''

mask = torch.tensor(masks[index])
mask_name = mask_names[index]
checkpoint_dir = os.path.join(dir_name, prefix + mask_name)
log_file_name = os.path.join(dir_name, prefix + mask_name + '.log')

model = him.MusicScoreModel(len(train_dataset[0]['notes'][0]), len(harmonic_utils.CHORD_TYPES) * 12,
                            dropout=0.2, input_mask=mask)

print(mask_name)
print(checkpoint_dir)

In [None]:
import os
import model_trainer
import importlib
importlib.reload(model_trainer)

from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

optimizer = Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.001)
criterion = CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
schedule_var = 'valid_loss'

trainer = model_trainer.ModelTrainer(model, train_dataset=train_dataset, valid_dataset=valid_dataset,
                                     test_dataset=test_dataset, seed=0, num_epochs=100, early_stopping=20,
                                     optimizer=optimizer, scheduler=scheduler, schedule_var=schedule_var,
                                     criterion=criterion,
                                     log_every=1, log_file_name=log_file_name,
                                     save_every=10, save_dir=checkpoint_dir, save_prefix='checkpoint',
                                     resume=os.path.join(checkpoint_dir, 'best.pth.tar'))

In [None]:
trainer.train()

In [None]:
trainer = model_trainer.ModelTrainer(model, train_dataset=train_dataset, valid_dataset=valid_dataset,
                                     test_dataset=test_dataset, seed=0, num_epochs=100, early_stopping=20,
                                     optimizer=optimizer, scheduler=scheduler, schedule_var=schedule_var,
                                     criterion=criterion,
                                     log_every=1, log_file_name=log_file_name,
                                     save_every=10, save_dir=checkpoint_dir, save_prefix='checkpoint',
                                     resume=os.path.join(checkpoint_dir, 'best.pth.tar'))

loss, acc, outputs, labels = trainer.evaluate(valid=False)
print(loss, acc)

In [None]:
import eval_utils as eu
import harmonic_utils as hu
import matplotlib.pyplot as plt

label_strings = hu.get_one_hot_labels()
conf_mat = eu.get_conf_mat(labels, outputs)

plt.figure(figsize=(30,30))
plt.imshow(conf_mat, interpolation='none')
plt.colorbar()
plt.xticks(ticks=list(range(len(label_strings))), labels=label_strings, rotation=90, fontsize=10)
plt.yticks(ticks=list(range(len(label_strings))), labels=label_strings, fontsize=10)
plt.show()

In [None]:
import eval_utils as eu

correct, incorrect = eu.get_correct_and_incorrect_indexes(labels, outputs)
print('Correct: ' + str(len(correct)))
print('Incorrect: ' + str(len(incorrect)))

In [None]:
import eval_utils as eu
    
eu.print_result(incorrect[0], labels, outputs, limit=10, prob=False)

In [None]:
import eval_utils as eu

chord, onset_notes, all_notes = eu.get_input_df_rows(incorrect[0], test_dataset)

print(chord)
print("USED NOTES:")
print(onset_notes)
print()
print("ALL NOTES:")
print(all_notes)

In [None]:
import matplotlib.pyplot as plt
import eval_utils as eu

correct_ranks, indexes_by_rank = eu.get_correct_ranks(labels, outputs)
    
plt.figure(figsize=(30,30))
plt.bar(range(len(outputs[0])), [len(indexes) for indexes in indexes_by_rank])

In [None]:
import eval_utils as eu
import importlib
importlib.reload(eu)

eval_df = eu.get_eval_df(labels, outputs, test_dataset)
eval_df

In [None]:
import ablation
import importlib
importlib.reload(ablation)

dfs = ablation.load_all_ablated_dfs(directory='results', prefix=prefix[:-1] if len(prefix) > 0 else None)
_, mask_names = ablation.get_masks_and_names()

In [None]:
import pandas as pd
import os

logs = []
for mask_name in mask_names:
    logs.append(pd.read_csv(os.path.join(os.path.join('results', prefix + mask_name + '.log'))))

In [None]:
for df, log, mask_name in zip(dfs, logs, mask_names):
    print(f"{mask_name} Acc: {100 * df.correct.sum() / len(df)}")
    print(log.iloc[-1])