In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import logging
import os
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import music21
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from joblib import Parallel, delayed
# from tqdm.notebook import tqdm
from tqdm import tqdm

from double_jig_gen.data import (
    ABCDataset,
    default_pad_batch,
    fix_encoding_errors,
    get_oneills_dataloaders,
    get_folkrnn_dataloaders,
    remove_quoted_strings,
)
from double_jig_gen.tokenizers import Tokenizer, ABCTune, ABCTuneError

logging.basicConfig()
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel("DEBUG")

In [None]:
! nvidia-smi

In [None]:
DEVICE_ID = 7
SCRATCH_NAME = "scratch_ssd"
DATA_HOME = f"/disk/{SCRATCH_NAME}/s0816700/data"
# DATA_HOME = "data"
DATA_PATH = f"{DATA_HOME}/folk-rnn/data_v1"

In [None]:
# ! scripts/dj-gen-get-data {DATA_HOME}

In [None]:
# https://github.com/IraKorshunova/folk-rnn/blob/master/configurations/config5.py
ONE_HOT = True
EMBEDDING_SIZE = 256  # is ignored if one_hot=True
NUM_LAYERS = 3
RNN_SIZE = 512
DROPOUT = 0.5

LEARNING_RATE = 0.003
LEARNING_RATE_DECAY_AFTER = 20
LEARNING_RATE_DECAY = 0.97

BATCH_SIZE = 64
MAX_EPOCH = 100
GRAD_CLIPPING = 5
VALIDATION_FRACTION = 0.05
VALIDATE_EVERY = 1000  # iterations

SAVE_EVERY = 10  # epochs

In [None]:
LOGGER.info(f"Changing to device {DEVICE_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{DEVICE_ID}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f"device = {device}")

# Read FolkRNN data

In [None]:
# NR_TUNES = 1_000
NR_TUNES = 10_000

In [None]:
with open(DATA_PATH, 'r') as fh:
    raw_folkrnn_data = fh.read()

In [None]:
raw_folkrnn_data[:1000]

In [None]:
abc_data = raw_folkrnn_data.split("\n\n")[0]
abc_tune = ABCTune(
    abc_data,
    pianoroll_divisions_per_quarternote=2,
    min_pitch=None,
    min_time=None,
    transpose_to_pitchclass="C",
)
abc_tune

In [None]:
abc_tune.show()

## Cleaning

In [None]:
tqdm._instances.clear()

In [None]:
abc_data_list = raw_folkrnn_data.split("\n\n")
len(abc_data_list)

In [None]:
abc_data_list[0]

In [None]:
def clean_tune(tune_str):
    return fix_encoding_errors(remove_quoted_strings(tune_str))

clean_abc_data = [clean_tune(tune_str) for tune_str in abc_data_list[:NR_TUNES]]

In [None]:
# TODO: handle cleaning fails (normally a singleton double quote in input data)
cleaning_fails = [
    (ii, abc_data_list[ii])
    for ii, tune in enumerate(abc_data_list)
    if '"' in clean_tune(tune)
]
# for idx, tune in cleaning_fails:
#     print(f"tune: {idx=}")
#     print(tune)
#     print("clean_tune=")
#     print(clean_tune(tune))

## Read with Music21

In [None]:
def get_abc_tune(abc_data):
    try:
        abc_tune = ABCTune(
            abc_data,
            pianoroll_divisions_per_quarternote=12,
            min_pitch=0,
            min_time=0,
            transpose_to_pitchclass="C",
        )
    except ABCTuneError as e:
        msg = f"Raised error: {repr(e)}\nNot including {abc_data}."
        LOGGER.warning(msg)
        abc_tune = msg
    except Exception as e:
        msg = f"unexpected error: {repr(e)}\nNot including {abc_data}."
        LOGGER.warning(msg)
        abc_tune = msg
    return abc_tune

In [None]:
# TODO: profile this
tunes = [get_abc_tune(abc_data) for abc_data in tqdm(clean_abc_data)]

## Remove import failures

In [None]:
clean_tunes = [tune for tune in tunes if not isinstance(tune, str)]

In [None]:
len(clean_tunes)

In [None]:
clean_tunes[0]

## Import failures

In [None]:
bum_tunes = [(idx, tune) for idx, tune in enumerate(tunes) if isinstance(tune, str)]

In [None]:
len(bum_tunes)

In [None]:
", ".join([str(idx) for idx, msg in bum_tunes])

In [None]:
# for idx, msg in bum_tunes:
#     print(idx)
#     print(msg)
#     print()

# Exclude long tunes
We will need to batch all the data. Exclude very long examples for efficiency.

In [None]:
tunes_as_token_lists = [
    [tok.src for tok in tune._abc_handler.tokens]
    for tune in tqdm(clean_tunes)
]

In [None]:
idx = 0
tunes_as_token_lists[idx][:5] + ["..."] + tunes_as_token_lists[idx][-5:]

In [None]:
idx = 1
tunes_as_token_lists[idx][:5] + ["..."] + tunes_as_token_lists[idx][-5:]

In [None]:
tune_lens = np.array([len(t) for t in tunes_as_token_lists])

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
sns.histplot(tune_lens)
plt.subplot(132)
sns.kdeplot(tune_lens)
plt.subplot(133)
sns.ecdfplot(tune_lens)
plt.suptitle(f"Number of tokens for all {len(tunes_as_token_lists)} tunes")
plt.tight_layout()

In [None]:
pct = .95
nr_kept = int(np.rint(len(tune_lens)*pct))
val_pct = sorted(tune_lens)[nr_kept - 1]
val_pct

In [None]:
not_long_tunes = [tune for tune in tunes_as_token_lists if len(tune) <= val_pct]

In [None]:
tune_lens = [len(t) for t in not_long_tunes]
plt.figure(figsize=(12,4))
plt.subplot(131)
sns.histplot(tune_lens)
plt.subplot(132)
sns.kdeplot(tune_lens)
plt.subplot(133)
sns.ecdfplot(tune_lens)
plt.suptitle(f"Tunes shorter than or equal to {val_pct} tokens")
plt.tight_layout()

# Exclude very short tunes
They are likely errors / not representative

In [None]:
# min_tune_length = 8*6 + 4  # 8 bars with 6 tokens plus 4 starting tokens
min_tune_length = 60  # a bit harsher
print(f"{min_tune_length=}")
short_tunes = [tune for tune in not_long_tunes if len(tune) < min_tune_length]
print(f"excluding {len(short_tunes)} short tunes (length < {min_tune_length})")

In [None]:
short_tunes[3]

In [None]:
mid_length_tunes = [
    tune
    for tune in not_long_tunes
    if len(tune) >= min_tune_length
]

In [None]:
tune_lens = [len(tune) for tune in mid_length_tunes]
plt.figure(figsize=(12,4))
plt.subplot(131)
sns.histplot(tune_lens)
plt.subplot(132)
sns.kdeplot(tune_lens)
plt.subplot(133)
sns.ecdfplot(tune_lens)
plt.suptitle(f"Tunes longer than or equal to {min_tune_length} tokens")
plt.tight_layout()

# Replace titles!

In [None]:
tunes = mid_length_tunes
title_token = "T: Title"
tunes = [
    [title_token if tok.startswith("T:") else tok for tok in tune]
    for tune in tunes
]

# Train valid split

In [None]:
ntunes = len(tunes)

In [None]:
nvalid_tunes = ntunes * VALIDATION_FRACTION
# round to a multiple of batch_size
nvalid_tunes = BATCH_SIZE * max(
    1,
    int(np.rint(nvalid_tunes / BATCH_SIZE))
)
nvalid_tunes

In [None]:
rng = np.random.RandomState(42)
valid_idxs = rng.choice(np.arange(ntunes), nvalid_tunes, replace=False)

In [None]:
ntrain_tunes = ntunes - nvalid_tunes
train_idxs = np.delete(np.arange(ntunes), valid_idxs)

In [None]:
valid_tunes = [tune for idx, tune in enumerate(tunes) if idx in valid_idxs]
train_tunes = [tune for idx, tune in enumerate(tunes) if idx in train_idxs]

# Illustrate tokenizer

In [None]:
# TODO: handle ornaments and weird metadata
tokens_set = set(
    tok
    for tune in train_tunes  # Note that we're only tokenizing the training data
    for tok in tune
)
vocab_size = len(tokens_set)
print(f"vocabulary size: {vocab_size}")
print(f"vocabulary (each token separated by a space): \n{' '.join(sorted(tokens_set))}")

In [None]:
valid_dataset = ABCDataset(tunes=valid_tunes, tokens=tokens_set)
train_dataset = ABCDataset(tunes=train_tunes, tokens=tokens_set)

In [None]:
print(train_dataset)

In [None]:
print(valid_dataset)

In [None]:
# TODO: get frequency of each and exclude infrequent
# TODO: handle chords!
# TODO: handle timings (e.g. a number after a note)

In [None]:
# tokenizer = Tokenizer(tokens=tokens_set)
tokenizer = train_dataset.tokenizer

In [None]:
tokenized_tunes = [tokenizer.tokenize(tune) for tune in tqdm(tunes_as_token_lists)]

In [None]:
idx = 1
tunes_as_token_lists[idx][:5] + ["..."] + tunes_as_token_lists[idx][-5:]

In [None]:
tokenized_tunes[idx][:5] + ["..."] + tokenized_tunes[idx][-5:]

In [None]:
(
    tokenizer.untokenize(tokenized_tunes[idx][:5]) +
    ["..."] + 
    tokenizer.untokenize(tokenized_tunes[idx][-5:])
)

# Batching in the dataloader

In [None]:
from torch.utils.data import DataLoader


train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=default_pad_batch,
    pin_memory=True,
    num_workers=4,
)

In [None]:
for batch in train_dataloader:
    print(batch)
    print(batch[0].size())
    print(max(batch[1]))
    break

In [None]:
val_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=default_pad_batch,
    pin_memory=True,
    num_workers=4,
)

# Model

In [None]:
import pytorch_lightning as pl

from double_jig_gen.models import SimpleRNN


In [None]:
if device.type == "cuda":
    max_epochs = 200
else:
    max_epochs = 10

In [None]:
# lightning_trainer = pl.Trainer(gpus='1,')

lightning_trainer = pl.Trainer(
    max_epochs=max_epochs,
    gpus=1,
)

In [None]:
model = SimpleRNN(
    rnn_type="LSTM",
    ntoken=train_dataset.vocabulary_size,
    ninp=EMBEDDING_SIZE,
    nhid=RNN_SIZE,
    nlayers=NUM_LAYERS,
    model_batch_size=BATCH_SIZE,
    dropout=DROPOUT,
    embedding_padding_idx=0,
)

In [None]:
lightning_trainer.fit(
    model,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
model

In [None]:
len(list(model.parameters()))

In [None]:
len(list(model.named_parameters()))

In [None]:
param_dict = dict(model.named_parameters())

In [None]:
param_dict.keys()

In [None]:
param_dict['encoder_layer.weight'].shape

Here we show that something has been learned! The first four tokens are:
* 0: `<pad>` - padding token
* 1: `<unk>` - unknown/rare token
* 2: `<s>` - start sequence
* 3: `</s>` - end sequence

The encoder weights show that: nothing is learned for `<pad>`, ~`<unk>`~, and `</s>` as they have their initialised weights near zero; something is learned for `<s>` as these have weights. This is as expected because nothing should follow pad and end seq~, and there are no unk tokens in this dataset~! ~All pieces start with an `<unk>` token - the title, so this will normally predict a time signature next.~

The decoder weights show the same, except nothing is learned for `<s>`, and something for `</s>`. Again, this is expected since the start sequence token should never be predicted, and the end sequence token should be predicted a lot.

In [None]:
W_enc = param_dict['encoder_layer.weight'].detach().cpu()

In [None]:
plt.matshow(W_enc)
plt.colorbar();

In [None]:
plt.matshow(W_enc[:4], aspect='auto', interpolation='none')
plt.colorbar();

In [None]:
W_dec = param_dict['decoder_layer.weight'].detach().cpu()

In [None]:
plt.matshow(W_dec)
plt.colorbar();

In [None]:
plt.matshow(W_dec[:4], aspect='auto', interpolation='none')
plt.colorbar();

In [None]:
model.eval()

In [None]:
model.training

In [None]:
tokenizer = val_dataloader.dataset.tokenizer
# token_sequences = [
#     ["<s>"],
#     ["<s>", "M:6/8"],
#     ["<s>", "M:6/8", "K:mix"],
#     [""]
# ]
nr_of_each = 5
token_sequences = (
    [["blarg"]] * nr_of_each +
    [["<s>"]] * nr_of_each +
    [["<s>", title_token]] * nr_of_each +
    [["<s>", title_token, "M: 6/8"]] * nr_of_each +
    [["<s>", title_token, "M: 6/8", "L: 1/8"]] * nr_of_each +
    [["<s>", title_token, "M: 6/8", "L: 1/8", "K: Cmaj"]] * nr_of_each
)
priming_dataset = ABCDataset(
    tunes=token_sequences,
    tokens=val_dataloader.dataset.tokens,
    wrap_tunes=False,
)
pad_token_idx = val_dataloader.dataset.tokenizer.pad_token_index
priming_loader = DataLoader(
    priming_dataset,
    batch_size=len(priming_dataset),
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    collate_fn=default_pad_batch,
)

In [None]:
for batch_item in tqdm(priming_loader, leave=True, desc="batch item"):
    padded_data, seq_lens = batch_item
    print(type(padded_data), type(seq_lens))

In [None]:
tqdm._instances.clear()

In [None]:
model.device

In [None]:
model.to(device)

In [None]:
model.device

In [None]:
import torch.nn.functional as F

end_token_idx = val_dataloader.dataset.tokenizer.end_token_index
pad_token_idx = val_dataloader.dataset.tokenizer.pad_token_index
assert pad_token_idx == 0
max_seq_len = 1000

# TODO: this seems rediculous...
for batch_item in tqdm(priming_loader, leave=True, desc="batch item"):
    padded_data, seq_lens = batch_item
    seq_lens = np.array(seq_lens)
    padded_data = padded_data.to(device)
    nr_seqs = padded_data.shape[1]
    still_generating = np.array([True] * nr_seqs)
    for ii in tqdm(list(range(max_seq_len)), leave=False, desc="seq position"):
        next_tokens = model.generate_next_token(
            padded_data[:, still_generating], 
            seq_lens[still_generating],
            topk=5
        )
        padded_data = F.pad(
            input=padded_data,
            pad=(0, 0, 0, 1),  # Pad bottom
            mode="constant",
            value=0,
        )
        padded_data[seq_lens[still_generating], still_generating] = next_tokens
        if all(padded_data[-1] == 0):
            padded_data = padded_data[:-1]
        seq_lens[still_generating] += 1
        last_tokens = padded_data[seq_lens - 1, range(padded_data.shape[1])]
        still_generating = np.array((last_tokens != end_token_idx).tolist())
        if still_generating.sum() == 0:
            break
        
generations = [tokenizer.untokenize(seq.cpu()) for seq in padded_data.T]

In [None]:
[
#     " ".join(gen[:10]) + " ... " + " ".join(gen[(gen.index("</s>")-10):gen.index("</s>") + 1]) 
    " ".join(gen[:10]) + " ... " + " ".join(gen[(gen_len-10):gen_len]) 
    for gen, gen_len in zip(generations, seq_lens)
]

In [None]:
def clean_gen(gen_list):
    start_token, title, meter, note_len, key, *tune, end_token = gen_list
#     return f"T: {title.strip()}\n{meter.strip()}\n{note_len.strip()}\n{key.strip()}\n{' '.join(tune)}"
    return f"{title}\n{meter}\n{note_len}\n{key}\n{' '.join(tune)}"

In [None]:
idx=7
trunc_generation = generations[idx][:seq_lens[idx]]
print(trunc_generation)
print(clean_gen(trunc_generation))

In [None]:
# https://www.abcjs.net/abcjs-editor.html
for idx, gen in enumerate(generations):
    tune_str = clean_gen(gen[:seq_lens[idx]])
    print(tune_str)
    try:
        tune = ABCTune(tune_str)
#         tune.show()
    except ABCTuneError:
        LOGGER.warning("Tune does not compile")
        print("DOES NOT COMPILE")
    print()

# Train on oneills

In [None]:
# on_dataset = ABCDataset(
#     filepath='data/oneills_reformat.abc',
#     tokens=dataset.tokens
# )

In [None]:
# print(on_dataset)

In [None]:
# nr_unk_toks = 0
# for idx in range(len(on_dataset)):
#     nr_unk_toks += (
#         np.array(on_dataset[idx]) == on_dataset.tokenizer.unk_token_index
#     ).sum()
# nr_unk_toks

In [None]:
# on_dataloader = DataLoader(
#     on_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=True,
#     collate_fn=rpad_batch,
#     pin_memory=True,
#     num_workers=8,
# )

In [None]:
# early_stop_callback = pl.callbacks.EarlyStopping(
#     monitor="val_loss",
#     min_delta=0.00,
#     patience=100,
#     verbose=True,
#     mode="min",
# )

In [None]:
# lightning_trainer = pl.Trainer(gpus='0,', deterministic=True, early_stop_callback=early_stop_callback)

In [None]:
# lightning_trainer.fit(
#     model,
#     train_dataloader=on_dataloader,
#     val_dataloaders=on_dataloader,
# )

In [None]:
# trn, vld, tst = get_oneills_dataloaders(
#     "/disk/scratch_fast/s0816700/data/oneills/oneills_reformat.abc",
#     "/disk/scratch_fast/s0816700/data/folk-rnn/data_v3_vocabulary.txt",
#     batch_size=16,
#     num_workers=1,
#     pin_memory=True,
# )

In [None]:
# for ii in range(len(tst.dataset)):
#     print(tst.dataset[ii])

In [None]:
# model

In [None]:
# print(model)
# lightning_trainer.test(
#     model,
#     test_dataloaders=tst,
# #             ckpt_path=str(args.model_load_from_checkpoint),
#     ckpt_path=None,
# )