In [1]:
%load_ext autoreload
%autoreload 2

In [158]:
import os
import h5py
import random
import numpy as np
import sys
import pickle as pkl

from collections import Counter
from omegaconf import OmegaConf, DictConfig
import torch
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, Subset, random_split
from torchtext.vocab import Vocab

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from sklearn.model_selection import train_test_split

from chord_rec.models.lit_seq2seq import LitSeq2Seq

from chord_rec.datasets.vec_datasets import Vec45Dataset

In [176]:
class CheckpointEveryNEpoch(pl.Callback):
    def __init__(self, start_epoc, ckpt_every_n = 1):
        self.start_epoc = start_epoc
        self.ckpt_every_n = ckpt_every_n

    def on_epoch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train epoch """
        # file_path = f"{trainer.logger.log_dir}/checkpoints/epoch={trainer.current_epoch}.pt"
        epoch = trainer.current_epoch
        if epoch >= self.start_epoc and epoch % self.ckpt_every_n == 0:
            ckpt_path = f"{trainer.logger.log_dir}/checkpoints/epoch={epoch}.ckpt"
            trainer.save_checkpoint(ckpt_path)

In [187]:
ckp_dir = "D:\\Documents\\2021Spring\\ChordSymbolRec\\chord_rec\\logs\\final_run\\version_1"
hparams_path = os.path.join(ckp_dir, "hparams.yaml")
checkpoint_path = os.path.join(ckp_dir, "checkpoints", "epoch=169-step=18699.ckpt")

all_conf = OmegaConf.load(hparams_path)
conf = all_conf.configs
data_conf = conf.dataset
seed = conf.experiment.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if conf.experiment.device == "gpu" and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

data_root = conf.dataset.directory
dataset_name = conf.dataset.name

In [188]:
data = pkl.load(open(data_conf.fpath, "rb"))

note_seq, chord_seq = [],[]
max_seq_len = 0
data_num = 0
for file in data:
    data_num += len(file)
    for window in file:
        note_seq.append(window[0])
        chord_seq.append(window[1])
        max_seq_len = max(max_seq_len, len(window[1]))

note_padding_vec = np.full(len(note_seq[0][0]), -1).reshape(1,-1) # should be 45; not sure if -1 is good
note_ending_vec = np.ones(len(note_seq[0][0])).reshape(1,-1) # should be 45
note_starting_vec = np.zeros(len(note_seq[0][0])).reshape(1,-1) # should be 45

chord_start = "<sos>"
chord_padding = "<pad>"
chord_end = "<eos>"

padded_note_seq = []
padded_chord_seq = []

eval_masks = []

for i in range(len(note_seq)):
    len_diff = max_seq_len - len(note_seq[i])

    temp_note_vec = np.vstack((note_starting_vec, np.array(note_seq[i]), note_ending_vec, np.repeat(note_padding_vec, len_diff , axis = 0)))
    padded_note_seq.append(temp_note_vec)

    eval_masks.append([False] + [True for _ in range(len(note_seq[i]))] + [False for _ in range(len_diff+1)])
    temp_chord_vec = np.hstack((chord_start, np.array(chord_seq[i]), chord_end, np.repeat(chord_padding, len_diff , axis = 0)))
    padded_chord_seq.append(temp_chord_vec)

eval_masks = np.array(eval_masks)
stacked_note_seq = np.stack(padded_note_seq, axis = 0)
stacked_chord_seq = np.vstack(padded_chord_seq)

note_vec = np.asarray(stacked_note_seq, dtype = np.float32)
chord_vocab = Vocab(Counter(list(stacked_chord_seq.flatten())))

vec_size = note_vec.shape[-1]
vocab_size = len(chord_vocab.stoi)

assert data_conf.val_ratio + data_conf.test_ratio <= 0.6, "At least 40 percent of the data needed for training"

dataset = Vec45Dataset(note_vec, stacked_chord_seq, eval_masks, chord_vocab)


train_ratio = 1 - data_conf.val_ratio - data_conf.test_ratio

train_len = int(len(dataset)*train_ratio)
val_len = int(len(dataset)*data_conf.val_ratio)
test_len = len(dataset) - train_len - val_len

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len], 
                                                generator=torch.Generator().manual_seed(seed)
                                               )


train_loader = DataLoader(train_dataset, batch_size =data_conf.batch_size, shuffle = data_conf.shuffle_train, num_workers = data_conf.num_workers, drop_last = True)
val_loader = DataLoader(val_dataset, batch_size = data_conf.batch_size, shuffle = data_conf.shuffle_val, num_workers = data_conf.num_workers, drop_last = True)
test_loader =  DataLoader(test_dataset, batch_size = data_conf.batch_size, shuffle = data_conf.shuffle_val, num_workers = data_conf.num_workers, drop_last = True)

MAX_LEN = max_seq_len + 2

if conf.model.type == "attn_s2s":
    model = LitSeq2Seq(vec_size, MAX_LEN, chord_vocab, conf)
else:
    raise NotImplementedError

In [189]:
model = LitSeq2Seq.load_from_checkpoint(checkpoint_path, chord_vocab = chord_vocab)

In [190]:
epochs = conf.training.warm_up + conf.training.decay_run + conf.training.post_run
tb_logger = pl_loggers.TensorBoardLogger(conf.logging.output_dir, name = conf.experiment.objective)
trainer = pl.Trainer()

GPU available: True, used: False
INFO:lightning:GPU available: True, used: False
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores


In [191]:
trainer.test(model, test_dataloaders = test_loader)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

  9%|██████▌                                                                     | 1712/19715 [00:09<01:38, 182.16it/s]


KeyError: ''

In [119]:
model.eval()
all_pred2 = []
all_label2 = []
eval_masks = []
for idx, (note, chord, mask) in enumerate(tqdm(test_loader)):
        pred = model(note.to(device), chord.long().to(device), teacher_forcing = False, start_idx = chord_vocab.stoi["<sos>"])
        pred = pred.detach().cpu().numpy().argmax(axis = -1)
        
        label = chord.detach().cpu().numpy()
        pred[:,0] = np.full(len(pred), chord_vocab.stoi["<sos>"])
        all_pred2.append(dataset.vec_decode(pred))
        all_label2.append(dataset.vec_decode(label))
        eval_masks.append(mask.detach().cpu().numpy())

HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))




In [120]:
all_pred2 = np.vstack(all_pred2)
all_label2 = np.vstack(all_label2)

In [121]:
all_pred2[0]

array(['<sos>', 'B dominant seventh', 'E major', 'E major',
       'Cx german augmented sixth', 'Cx german augmented sixth',
       'Cx german augmented sixth', 'Cx german augmented sixth',
       'Cx german augmented sixth', 'Cx german augmented sixth',
       'D# major', 'D# major', 'D# major', 'D# major',
       'Cx italian augmented sixth', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
       '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos

In [122]:
all_label2[0]

array(['<sos>', 'B major seventh', 'E major', 'E major', 'E major',
       'Cx german augmented sixth', 'Cx german augmented sixth',
       'Cx german augmented sixth', 'Cx german augmented sixth',
       'Cx german augmented sixth', 'C# minor', 'D# major', 'D# major',
       'D# major', 'Cx italian augmented sixth', '<eos>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
       '<pad>', '<pad>', '<pad>',

In [123]:
decoded_preds = all_pred2
decoded_chords = all_label2

In [124]:
decoded_preds

array([['<sos>', 'B dominant seventh', 'E major', ..., '<eos>', '<eos>',
        '<eos>'],
       ['<sos>', 'A# major', 'A# major', ..., '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'Bb major', 'Bb major', ..., '<eos>', '<eos>', '<eos>'],
       ...,
       ['<sos>', 'G major', 'D major', ..., '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'F major', 'F major', ..., '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'B major', 'B major', ..., '<eos>', '<eos>', '<eos>']],
      dtype='<U27')

In [127]:
eval_masks = np.vstack(eval_masks)

In [128]:
# mask = (decoded_preds != "<sos>") & (decoded_preds != "<eos>") & (decoded_preds != "<pad>")
mask = eval_masks
masked_preds = decoded_preds[mask]
masked_chords = decoded_chords[mask]

print(np.sum(masked_preds == masked_chords) / len(masked_chords))

0.8727689628173785


In [26]:
# import pickle
# pickle.dump({"preds":all_pred2, "labels": all_label2 }, open("examples/output/haydn_red1_preds.pk", "wb"))

In [27]:
masked_preds == masked_chords

array([ True,  True,  True, ...,  True, False,  True])

In [28]:
# SEPERATE EVALUATION OF ROOT AND QUALITY AFTER DECODING
# seperate all pred 
root_preds = decoded_preds.copy()
quality_preds = decoded_preds.copy()
for r_id in range(decoded_preds.shape[0]):
    for c_id in range(decoded_preds.shape[1]):
        sp = decoded_preds[r_id, c_id].split(' ')
        root_preds[r_id, c_id] = sp[0]
        quality_preds[r_id, c_id] = ' '.join(sp[1:])
    
root_labels = decoded_chords.copy()
quality_labels = decoded_chords.copy()
for r_id in range(decoded_chords.shape[0]):
    for c_id in range(decoded_chords.shape[1]):
        sp = decoded_chords[r_id, c_id].split(' ')
        root_labels[r_id, c_id] = sp[0]
        quality_labels[r_id, c_id] = ' '.join(sp[1:])
# # seperate all lable 
# root_labels = []
# quality_labels = []
# for c in decoded_chords:
#     sp = c.split(' ')
#     root_labels.append(sp[0])
#     quality_labels.append(' '.join(sp[1:]))
    
# root_labels = np.asarray(root_labels)
# quality_labels = np.asarray(quality_labels)

In [29]:
mask = (root_preds != "<sos>") & (root_preds != "<eos>") & (root_preds != "<pad>")
root_preds = root_preds[mask]
quality_preds = quality_preds[mask]
root_label = root_labels[mask]
quality_labels = quality_labels[mask]

In [30]:
np.sum(root_preds == root_label) / len(root_preds)

0.8458368083952097

In [31]:
np.sum(quality_preds == quality_labels) / len(quality_preds)

0.8411080251661956

In [150]:
import time

In [157]:
# for i in range(100):

#     start = time.time()
#     avg_similarity = chord_similarity(masked_preds[i], masked_chords[i])
#     end = time.time()
#     print(end-start)

0.0050122737884521484
0.0039865970611572266
0.004013538360595703
0.004997968673706055
0.005002021789550781
0.004987001419067383
0.005013465881347656
0.0049877166748046875
0.004999399185180664
0.003989696502685547
0.004014253616333008
0.004986763000488281
0.0040132999420166016
0.003999948501586914
0.005000114440917969
0.004999637603759766
0.004986763000488281
0.005013465881347656
0.005002498626708984
0.004984617233276367
0.005012989044189453
0.0039861202239990234
0.005013704299926758
0.0049855709075927734
0.005014657974243164
0.00398564338684082
0.0050008296966552734
0.0040132999420166016
0.004000186920166016
0.004999637603759766
0.004999876022338867
0.004000186920166016
0.0039865970611572266
0.0040132999420166016
0.004999876022338867
0.00500035285949707
0.005000114440917969
0.004986286163330078
0.005013465881347656
0.004000186920166016
0.004999637603759766
0.005013942718505859
0.005000114440917969
0.004000186920166016
0.0049855709075927734
0.004014015197753906
0.005000114440917969
0.00

In [149]:
print(len(masked_chords) // 1000 * 5 )

625


In [152]:
end-start

4.800968408584595

# Confusion Matrix

In [None]:
co = Counter(chords)

In [None]:
most_common_chord = list(list(zip(*co.most_common(20)))[0])

In [None]:
for i in range(len(decoded_preds)):
    if decoded_preds[i] not in most_common_chord:
        decoded_preds[i] = "others"
    if decoded_chords[i] not in most_common_chord:
        decoded_chords[i] = "others"

In [None]:
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
import seaborn as sn

In [64]:
cm = confusion_matrix(decoded_chords, decoded_preds, normalize = "true", labels = most_common_chord + ["others"])

NameError: name 'confusion_matrix' is not defined

In [None]:
fig, ax = plt.subplots(figsize=(13,10)) 

sn.heatmap(cm, annot=False)
ax.set_xticklabels(most_common_chord + ["others"])
ax.set_yticklabels(most_common_chord + ["others"])
plt.yticks(rotation=0) 
plt.xticks(rotation="vertical") 
plt.show()
fig.savefig("confusion_bachhaydn_baseline.pdf", format = "pdf")

In [None]:
torch.save(model.state_dict(), "baseline_bach_and_haydn.pt")

In [None]:
symbol, num = list(zip(*co.most_common(50)))
symbol = list(symbol)
num = list(num)

In [None]:
symbol += ['others']
num += [np.sum(list(co.values())) - np.sum(num)]

In [None]:
num/=np.sum(num)

In [None]:
plt.subplots(figsize=(13,10)) 
x_pos = [i for i, _ in enumerate(symbol)]

plt.bar(x_pos, num)
plt.xlabel("Chord Symbol")
plt.ylabel("Occurance")
plt.title("bach chorales Chord Distribution")

plt.xticks(x_pos, symbol, rotation = "vertical")

plt.show()