In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import logging
import os
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import music21
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from joblib import Parallel, delayed
from tqdm import tqdm

from double_jig_gen.data import ABCDataset, get_oneills_dataloaders, get_folkrnn_dataloaders
from double_jig_gen.tokenizers import Tokenizer, ABCTune, ABCTuneError

logging.basicConfig()
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel("DEBUG")

In [None]:
DEVICE_ID = 7
# DATA_HOME = "/disk/scratch_fast/s0816700/data"
DATA_HOME = "data"
DATA_PATH = f"{DATA_HOME}/folk-rnn/data_v1"

In [None]:
# https://github.com/IraKorshunova/folk-rnn/blob/master/configurations/config5.py
ONE_HOT = True
EMBEDDING_SIZE = 256  # is ignored if one_hot=True
NUM_LAYERS = 3
RNN_SIZE = 512
DROPOUT = 0.5

LEARNING_RATE = 0.003
LEARNING_RATE_DECAY_AFTER = 20
LEARNING_RATE_DECAY = 0.97

BATCH_SIZE = 64
MAX_EPOCH = 100
GRAD_CLIPPING = 5
VALIDATION_FRACTION = 0.05
VALIDATE_EVERY = 1000  # iterations

SAVE_EVERY = 10  # epochs

In [None]:
LOGGER.info(f"Changing to device {DEVICE_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{DEVICE_ID}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f"device = {device}")

In [None]:
with open(DATA_PATH, 'r') as fh:
    raw_folkrnn_data = fh.read()

In [None]:
raw_folkrnn_data[:1000]

In [None]:
abc_data = raw_folkrnn_data.split("\n\n")[0]
abc_tune = ABCTune(
    abc_data,
    pianoroll_divisions_per_quarternote=2,
    min_pitch=None,
    min_time=None,
    transpose_to_pitchclass="C",
)
abc_tune

In [None]:
abc_tune.show()

In [None]:
# import asyncio

# def background(f):
#     def wrapped(*args, **kwargs):
#         return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)
#     return wrapped
# @background

def get_abc_tune(abc_data):
#     with warnings.catch_warnings():
#         warnings.filterwarnings("error")
    try:
        abc_tune = ABCTune(
            abc_data,
            pianoroll_divisions_per_quarternote=12,
            min_pitch=0,
            min_time=0,
            transpose_to_pitchclass="C",
        )
#         tokens = [str(e) for e in abc_tune.abc_music21.flat]
#     except Warning as w:
#         print(f"Not including {abc_data}.\nRaised warning {w}")
    except ABCTuneError as e:
        msg = f"Raised error: {repr(e)}\nNot including {abc_data}."
        LOGGER.warning(msg)
        abc_tune = msg
    except Exception as e:
        msg = f"unexpected error: {repr(e)}\nNot including {abc_data}."
        LOGGER.warning(msg)
        abc_tune = msg
    return abc_tune

In [None]:
tqdm._instances.clear()

In [None]:
abc_data_list = raw_folkrnn_data.split("\n\n")
len(abc_data_list)

In [None]:
nr_tunes = 1_000

In [None]:
# TODO: extract tokens
# TODO: clean tunes with warnings e.g remove all text in quotes before extraction?
tunes = [get_abc_tune(abc_data) for abc_data in tqdm(abc_data_list[:nr_tunes])]

In [None]:
clean_tunes = [tune for tune in tunes if not isinstance(tune, str)]

In [None]:
len(clean_tunes)

In [None]:
tune = clean_tunes[0]
tune

In [None]:
tune.tokens

In [None]:
abc_tune = ABCTune("T: maitune\nM:3/4\nL:1/8\n|: [ACE]2B2D2 | [ceg]6 | [1 (3ABC (3DEF (3GAB :| [2 (3ABC (3DEF (3GAe |]")
[tok for tok in abc_tune.abc_music21.flat]

In [None]:
abc_tune.show()

In [None]:
abc_tune.play()

In [None]:
abc_tune.events

In [None]:
tok = copy.deepcopy(tune.tokens[5])

In [None]:
vars(tune._abc_handler)

In [None]:
handler = copy.deepcopy(tune._abc_handler)

In [None]:
handler.tokenProcess()

In [None]:
handler.tokens

In [None]:
import inspect
elem_list = [elem for elem in tune.abc_music21.recurse()]
elem_list

In [None]:
elem_list[3]

In [None]:
import music21
score = music21.stream.Score()
part = music21.stream.Part()
nr_measures = 4
for _ in range(nr_measures):
    measure = music21.stream.Measure()
    notes = (("A", 1), ("B-", 0.5), ("C#", 1.5))
    for pitch_name, quarter_length in notes:
        measure.append(music21.note.Note(pitch_name, quarterLength=quarter_length))
    part.append(measure)
score.insert(0, part)
score.show("text")

In [None]:
tune.abc_music21.show("text")

In [None]:
set(tok.src for tok in tune._abc_handler.tokens)

In [None]:
handler.tokens

In [None]:
tok.preParse()

In [None]:
tok.quarterLength

In [None]:
tune.tokens

In [None]:
[tok for tok in tune.abc_music21.flat]

In [None]:
tune.tokens

In [None]:
len(tunes)

In [None]:
tunes

In [None]:
bum_tunes = [(idx, tune) for idx, tune in enumerate(tunes) if isinstance(tune, str)]

In [None]:
len(bum_tunes)

In [None]:
", ".join([str(idx) for idx, msg in bum_tunes])

In [None]:
len(bum_tunes)

In [None]:
for idx, msg in bum_tunes:
    print(idx)
    print(msg)
    print()

In [None]:
bum_tune_abc_reencoded = [
    bytes(
        abc_data_list[idx], "utf-8"
    ).replace(
        b"\xc3\xa2\xc2\x80\xc2\x99",
        "'".encode("utf-8")
    ).decode("utf-8")
    for idx, _ in bum_tunes
]
bum_tune_abc_reencoded[:3]

In [None]:
extra_tunes = [get_abc_tune(abc_data) for abc_data in tqdm(bum_tune_abc_reencoded)]

In [None]:
bum_bum_tunes = [(idx, tune) for idx, tune in enumerate(extra_tunes) if isinstance(tune, str)]

In [None]:
len(bum_bum_tunes)

In [None]:
for idx, msg in bum_bum_tunes:
    print(idx)
    print(msg)
    print()

In [None]:
tunes[:3]

In [None]:
tqdm._instances.clear()

In [None]:
# tunes = Parallel(n_jobs=-1)(delayed(get_abc_tune)(abc_data) for abc_data in raw_folkrnn_data.split("\n\n"))

In [None]:
len(tunes)

data are tokens separated by spaces, and pieces separated by to `\n` characters. Pieces begin with a meter, then a new line with a key, then a new line with the piece.

In [None]:
len(clean_tunes)

In [None]:
clean_tunes[0]

In [None]:
tokenized_str_tunes = [[tok.src for tok in tune._abc_handler.tokens] for tune in tqdm(clean_tunes)]

In [None]:
tokens_set = set(tok for tune in tokenized_str_tunes for tok in tune if not tok.startswith("T:"))
vocab_size = len(tokens_set)
print(f"vocabulary size: {vocab_size}")
print(f"vocabulary (each token separated by a space): \n{' '.join(sorted(tokens_set))}")

In [None]:
# TODO: get frequency of each and exclude infrequent

In [None]:
tokenizer = Tokenizer(tokens=tokens_set)

In [None]:
tokenized_tunes = [tokenizer.tokenize(tune) for tune in tqdm(tokenized_str_tunes)]

In [None]:
# dataset = ABCDataset(filepath=DATA_PATH)
dataset = ABCDataset(
    tunes=[" ".join(token_list) for token_list in tokenized_str_tunes],
    tokens=tokens_set,
)

In [None]:
print(dataset)

# Exclude long tunes
We will need to batch all the data. Exclude very long examples for efficiency.

In [None]:
tune_lens = np.array([len(t) for t in tokenized_str_tunes])

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
sns.histplot(tune_lens)
plt.subplot(132)
sns.kdeplot(tune_lens)
plt.subplot(133)
sns.ecdfplot(tune_lens)
plt.suptitle(f"Number of tokens for all {len(tokenized_str_tunes)} tunes")
plt.tight_layout()

In [None]:
pct = .99
nr_kept = int(np.rint(len(tune_lens)*pct))
val_pct = sorted(tune_lens)[nr_kept - 1]
val_pct

In [None]:
short_tunes = [tune for tune in tokenized_str_tunes if len(tune) <= val_pct]

In [None]:
short_tune_lens = np.array([len(t) for t in short_tunes])

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
sns.histplot(short_tune_lens)
plt.subplot(132)
sns.kdeplot(short_tune_lens)
plt.subplot(133)
sns.ecdfplot(short_tune_lens)
plt.suptitle(f"Tunes shorter than or equal to {val_pct} tokens")
plt.tight_layout()

# Train valid split

In [None]:
ntunes = len(short_tunes)

In [None]:
nvalid_tunes = ntunes * VALIDATION_FRACTION
# round to a multiple of batch_size
nvalid_tunes = BATCH_SIZE * max(
    1,
    int(np.rint(nvalid_tunes / BATCH_SIZE))
)
nvalid_tunes

In [None]:
rng = np.random.RandomState(42)
valid_idxs = rng.choice(np.arange(ntunes), nvalid_tunes, replace=False)

In [None]:
ntrain_tunes = ntunes - nvalid_tunes
train_idxs = np.delete(np.arange(ntunes), valid_idxs)

In [None]:
valid_tunes = [tune for idx, tune in enumerate(short_tunes) if idx in valid_idxs]
train_tunes = [tune for idx, tune in enumerate(short_tunes) if idx in train_idxs]

In [None]:
valid_dataset = ABCDataset(tunes=valid_tunes, tokens=dataset.tokens)
train_dataset = ABCDataset(tunes=train_tunes, tokens=dataset.tokens)

In [None]:
print(train_dataset)

In [None]:
print(valid_dataset)

In [None]:
train_dataset[0][:10]

In [None]:
train_dataset.tokenizer.untokenize(train_dataset[0])[:10]

In [None]:
valid_dataset[0][:10]

In [None]:
train_dataset.tokenizer.untokenize(valid_dataset[0])[:10]

# Batching in the dataloader

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

PAD_IDX = dataset.tokenizer.pad_token_index

def rpad_batch(batch):
    # Subtract 1 from the sequence length - never want to predict on </s> token
    lengths = [seq.shape[0] for seq in batch]
    data = pad_sequence(batch, batch_first=False, padding_value=PAD_IDX)
    return data, lengths

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=rpad_batch,
    pin_memory=True,
#     num_workers=8,
    num_workers=0
)

In [None]:
for batch in train_dataloader:
    print(batch)
    print(batch[0].size())
    print(max(batch[1]))
    break

In [None]:
val_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=rpad_batch,
    pin_memory=True,
    num_workers=0,
)

# Model

In [None]:
import pytorch_lightning as pl

from double_jig_gen.models import SimpleRNN


In [None]:
# lightning_trainer = pl.Trainer(gpus='1,')
lightning_trainer = pl.Trainer()

In [None]:
model = SimpleRNN(
    rnn_type="LSTM",
    ntoken=dataset.vocabulary_size,
    ninp=EMBEDDING_SIZE,
    nhid=RNN_SIZE,
    nlayers=NUM_LAYERS,
    model_batch_size=BATCH_SIZE,
    dropout=DROPOUT,
    embedding_padding_idx=0,
)

In [None]:
lightning_trainer.fit(
    model,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
print("hi")

In [None]:
model

In [None]:
len(list(model.parameters()))

In [None]:
len(list(model.named_parameters()))

In [None]:
param_dict = dict(model.named_parameters())

In [None]:
param_dict.keys()

In [None]:
param_dict['encoder_layer.weight'].shape

Here we show that something has been learned! The first four tokens are:
* 0: `<pad>` - padding token
* 1: `<unk>` - unknown/rare token
* 2: `<s>` - start sequence
* 3: `</s>` - end sequence

The encoder weights show that: nothing is learned for `<pad>`, `<unk>`, and `</s>` as they have their initialised weights near zero; something is learned for `<s>` as these have weights. This is as expected because nothing should follow pad and end seq, and there are no unk tokens in this dataset!

The decoder weights show the same, except nothing is learned for `<s>`, and something for `</s>`. Again, this is expected since the start sequence token should never be predicted, and the end sequence token should be predicted a lot.

In [None]:
W_enc = param_dict['encoder_layer.weight'].detach()

In [None]:
plt.matshow(W_enc)
plt.colorbar();

In [None]:
plt.matshow(W_enc[:4], aspect='auto', interpolation='none')
plt.colorbar();

In [None]:
W_dec = param_dict['decoder_layer.weight'].detach()

In [None]:
plt.matshow(W_dec)
plt.colorbar();

In [None]:
plt.matshow(W_dec[:4], aspect='auto', interpolation='none')
plt.colorbar();

In [None]:
model.eval()

In [None]:
model.training

In [None]:
tokenizer = val_dataloader.dataset.tokenizer
# token_sequences = [
#     ["<s>"],
#     ["<s>", "M:6/8"],
#     ["<s>", "M:6/8", "K:mix"],
#     [""]
# ]
token_sequences = (
    [["blarg"]] * 1 +
    [["<s>"]] * 1 +
    [["<s>", "T: My title"]] * 1 +
    [["<s>", "T: My title", "M: 6/8"]] * 1 +
    [["<s>", "T: My title", "M: 6/8", "L: 1/8"]] * 1 +
    [["<s>", "T: My title", "M: 6/8", "L: 1/8", "K: Cmaj"]] * 1
)
priming_dataset = ABCDataset(
    tunes=token_sequences,
    tokens=val_dataloader.dataset.tokens,
    wrap_tunes=False,
)
pad_token_idx = val_dataloader.dataset.tokenizer.pad_token_index
pad_priming_batch = lambda batch: pad_batch(batch, pad_token_idx)
priming_loader = DataLoader(
    priming_dataset,
    batch_size=len(priming_dataset),
    shuffle=False,
    num_workers=0,  # TODO: fix why we can't use workers...
    pin_memory=False,
    collate_fn=pad_priming_batch,
)

In [None]:
for batch_item in tqdm(priming_loader):
    print(batch_item)

In [None]:
from double_jig_gen.data import pad_batch
import torch.nn.functional as F

end_token_idx = tokenizer.end_token_index
max_seq_len = 1000
for batch_item in tqdm(priming_loader):
    padded_data, seq_lens = batch_item
    seq_lens = np.array(seq_lens)
    padded_data = padded_data.to(device)
    still_generating = np.array([True] * padded_data.shape[1])
    for ii in tqdm(range(max_seq_len), leave=False):
        next_tokens = model.generate_next_token(
            padded_data[:, still_generating], 
            seq_lens[still_generating],
            topk=5
        )
        padded_data = F.pad(
            input=padded_data,
            pad=(0, 0, 0, 1),  # Pad bottom
            mode="constant",
            value=pad_token_idx,
        )
        padded_data[seq_lens[still_generating], still_generating] = next_tokens
        if all(padded_data[-1] == 0):
            padded_data = padded_data[:-1]
        seq_lens[still_generating] += 1
        last_tokens = padded_data[seq_lens - 1, range(padded_data.shape[1])]
        still_generating = np.array((last_tokens != end_token_idx).tolist())
        if still_generating.sum() == 0:
            break
        
    generations = [tokenizer.untokenize(seq.cpu()) for seq in padded_data.T]

In [None]:
[
#     " ".join(gen[:10]) + " ... " + " ".join(gen[(gen.index("</s>")-10):gen.index("</s>") + 1]) 
    " ".join(gen[:10]) + " ... " + " ".join(gen[(gen_len-10):gen_len]) 
    for gen, gen_len in zip(generations, seq_lens)
]

In [None]:
def clean_gen(gen_list):
    start_token, title, meter, note_len, key, *tune, end_token = gen_list
    return f"T: {title.strip()}\n{meter.strip()}\n{note_len.strip()}\n{key.strip()}\n{' '.join(tune)}"

In [None]:
for idx, gen in enumerate(generations):
    tune_str = clean_gen(gen[:seq_lens[idx]])
    print(tune_str)
    tune = ABCTune(tune_str)
    tune.show()

# Train on oneills

In [None]:
on_dataset = ABCDataset(
    filepath='data/oneills_reformat.abc',
    tokens=dataset.tokens
)

In [None]:
print(on_dataset)

In [None]:
nr_unk_toks = 0
for idx in range(len(on_dataset)):
    nr_unk_toks += (
        np.array(on_dataset[idx]) == on_dataset.tokenizer.unk_token_index
    ).sum()
nr_unk_toks

In [None]:
on_dataloader = DataLoader(
    on_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=rpad_batch,
    pin_memory=True,
    num_workers=8,
)

In [None]:
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=100,
    verbose=True,
    mode="min",
)

In [None]:
lightning_trainer = pl.Trainer(gpus='0,', deterministic=True, early_stop_callback=early_stop_callback)

In [None]:
lightning_trainer.fit(
    model,
    train_dataloader=on_dataloader,
    val_dataloaders=on_dataloader,
)

In [None]:
trn, vld, tst = get_oneills_dataloaders(
    "/disk/scratch_fast/s0816700/data/oneills/oneills_reformat.abc",
    "/disk/scratch_fast/s0816700/data/folk-rnn/data_v3_vocabulary.txt",
    batch_size=16,
    num_workers=1,
    pin_memory=True,
)

In [None]:
for ii in range(len(tst.dataset)):
    print(tst.dataset[ii])

In [None]:
model

In [None]:
print(model)
lightning_trainer.test(
    model,
    test_dataloaders=tst,
#             ckpt_path=str(args.model_load_from_checkpoint),
    ckpt_path=None,
)