In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import logging
import os
import re
from argparse import Namespace
from pathlib import Path
from typing import Dict, Type, Union

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm

from double_jig_gen.data import ABCDataset, get_folkrnn_dataloaders, pad_batch
from double_jig_gen.tokenizers import Tokenizer
from double_jig_gen.models import SimpleRNN, Transformer
from double_jig_gen.utils import get_model_from_checkpoint

logging.basicConfig()
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel("DEBUG")

In [2]:
def _get_most_recent_path(paths):
    """Returns the most recently created path from a list of paths.

    Args:
        paths: a list of paths to check.

    Returns:
        the most recently created path.
    """
    return max(paths, key=os.path.getctime)

In [3]:
expt_id = 54
# scratch_path = "/disk/scratch_fast"
# expt_dirpath = Path(f"{scratch_path}/s0816700/logs/lightning_logs/version_{expt_id}")
expt_dirpath = Path(f"lightning_logs/version_{expt_id}")
checkpoint_dirpath = Path(expt_dirpath, "checkpoints")
ckpt_paths = [
    path for path in checkpoint_dirpath.iterdir() if str(path).endswith(".ckpt")
]
latest_ckpt_path = _get_most_recent_path(ckpt_paths)

experiment_args_path = Path(expt_dirpath, "experiment_args.yaml")
# The yaml file has lowcase trainer in tag:
# python/name:pytorch_lightning.trainer.trainer._gpus_arg_default
# so loading fails with SafeLoader, have to use BaseLoader
# args = pl.core.saving.load_hparams_from_yaml(str(experiment_args_path))
with open(str(experiment_args_path), 'r') as fh:
    args_dict = yaml.load(fh, Loader=yaml.BaseLoader)
args_dict['model_load_from_checkpoint'] = latest_ckpt_path
args = Namespace()
vars(args).update(args_dict)
args

FileNotFoundError: [Errno 2] No such file or directory: 'lightning_logs/version_54/experiment_args.yaml'

In [None]:
ckpt_path = Path(args.model_load_from_checkpoint).expanduser().resolve()

In [None]:
ckpt_path

In [None]:
MODELS: Dict[str, Union[Type[SimpleRNN], Type[Transformer]]] = {
    "rnn": SimpleRNN,
    "transformer": Transformer,
}

In [None]:
ModelClass = MODELS[args.model]

In [None]:
model = get_model_from_checkpoint(ckpt_path, ModelClass)

In [None]:
model

In [None]:
len(list(model.parameters()))

In [None]:
len(list(model.named_parameters()))

In [None]:
param_dict = dict(model.named_parameters())

In [None]:
param_dict.keys()

In [None]:
param_dict['encoder_layer.weight'].shape

Here we show that something has been learned! The first four tokens are:
* 0: `<pad>` - padding token
* 1: `<unk>` - unknown/rare token
* 2: `<s>` - start sequence
* 3: `</s>` - end sequence

The encoder weights show that: nothing is learned for `<pad>`, `<unk>`, and `</s>` as they have their initialised weights near zero; something is learned for `<s>` as these have weights. This is as expected because nothing should follow pad and end seq, and there are no unk tokens in this dataset!

The decoder weights show the same, except nothing is learned for `<s>`, and something for `</s>`. Again, this is expected since the start sequence token should never be predicted, and the end sequence token should be predicted a lot.

In [None]:
W_enc = param_dict['encoder_layer.weight'].detach()

In [None]:
plt.matshow(W_enc)
plt.colorbar();

In [None]:
plt.matshow(W_enc[:4], aspect='auto', interpolation='none')
plt.colorbar();

In [None]:
W_dec = param_dict['decoder_layer.weight'].detach()

In [None]:
plt.matshow(W_dec)
plt.colorbar();

In [None]:
plt.matshow(W_dec[:4], aspect='auto', interpolation='none')
plt.colorbar();

# Producing generations

In [None]:
DEVICE_ID = 7
LOGGER.info(f"Changing to device {DEVICE_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{DEVICE_ID}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataloaders = get_folkrnn_dataloaders(
    args.folkrnn_data_path,
    batch_size=64,
    num_workers=2,
    pin_memory=True
)

In [None]:
valid_dataloader = dataloaders[1]

In [None]:
print(valid_dataloader.dataset)

In [None]:
model.train()

In [None]:
model.training

In [None]:
model.eval()

In [None]:
model.training

In [None]:
model = model.to(device)

In [None]:
tokenizer = valid_dataloader.dataset.tokenizer
# token_sequences = [
#     ["<s>"],
#     ["<s>", "M:6/8"],
#     ["<s>", "M:6/8", "K:mix"],
#     [""]
# ]
token_sequences = (
    [["<s>"]] * 20 +
    [["<s>", "M:6/8"]] * 5 + 
    [["<s>", "M:6/8", "K:mix"]] * 5
)
priming_dataset = ABCDataset(
    tunes=token_sequences,
    tokens=valid_dataloader.dataset.tokens,
    wrap_tunes=False,
)

In [None]:
pad_token_idx = valid_dataloader.dataset.tokenizer.pad_token_index
pad_priming_batch = lambda batch: pad_batch(batch, pad_token_idx)

In [None]:
priming_loader = DataLoader(
    priming_dataset,
    batch_size=len(priming_dataset),
    shuffle=False,
    num_workers=4,
    pin_memory=False,
    collate_fn=pad_priming_batch,
)

In [None]:
end_token_idx = tokenizer.end_token_index
max_seq_len = 1000
for batch_item in priming_loader:
    padded_data, seq_lens = batch_item
    seq_lens = np.array(seq_lens)
    padded_data = padded_data.to(device)
    still_generating = np.array([True] * padded_data.shape[1])
    for ii in range(max_seq_len):
        next_tokens = model.generate_next_token(
            padded_data[:, still_generating], 
            seq_lens[still_generating],
            topk=5
        )
        padded_data = F.pad(
            input=padded_data,
            pad=(0, 0, 0, 1),  # Pad bottom
            mode="constant",
            value=pad_token_idx,
        )
        padded_data[seq_lens[still_generating], still_generating] = next_tokens
        if all(padded_data[-1] == 0):
            padded_data = padded_data[:-1]
        seq_lens[still_generating] += 1
        last_tokens = padded_data[seq_lens - 1, range(padded_data.shape[1])]
        still_generating = np.array((last_tokens != end_token_idx).tolist())
        if still_generating.sum() == 0:
            break
        
    generations = [tokenizer.untokenize(seq.cpu()) for seq in padded_data.T]

In [None]:
def clean_gen_str(gen_str):
    gen_str = gen_str[4:-5]
    meter, key, tune = gen_str.split(' ', 2)
    tune = re.sub(r" (/?[0-9])", r"\1", tune)
    return f"{meter.strip()}\n{key.strip()}\n{tune.strip()}"

In [None]:
for idx, gen in enumerate(generations):
    print(clean_gen_str(' '.join(gen[:seq_lens[idx]])))
    print()

In [None]:
tokenizer.end_token_index

# Make 10,000!

In [None]:
token_sequences = (
    [["<s>"]] * 10000
)
priming_dataset = ABCDataset(
    tunes=token_sequences,
    tokens=valid_dataloader.dataset.tokens,
    wrap_tunes=False,
)
priming_loader = DataLoader(
    priming_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=8,
    pin_memory=False,
    collate_fn=pad_priming_batch,
)

In [None]:
all_gens = []
all_lens = []
for batch_item in priming_loader:
    padded_data, seq_lens = batch_item
    padded_data = padded_data.to(device)
    generations, gen_seq_lens = model.generate_tunes(
        padded_data,
        seq_lens,
        max_nr_generation_steps=1000,
        tokenizer=tokenizer,
    )
    all_gens.extend(generations)
    all_lens.extend(gen_seq_lens)

In [None]:
len(all_gens)

In [None]:
out_dir = Path('data', 'output')
Path(out_dir, 'folkrnn').mkdir(parents=True, exist_ok=True)
Path(out_dir, 'abc').mkdir(parents=True, exist_ok=True)

In [None]:
for idx, (gen, gen_len) in tqdm(enumerate(zip(all_gens, all_lens)), total=10000):
    filename = f"tune_{idx+1:05d}"
    gen_str = ' '.join(gen[:gen_len])
    gen_str = gen_str[4:-5]
    meter, key, tune = gen_str.split(' ', 2)
    
    folkrnn_outpath = Path(out_dir, 'folkrnn', filename)
    with open(str(folkrnn_outpath), 'w') as fh:
        fh.write(f"{meter.strip()}\n{key.strip()}\n{tune.strip()}")
    
    tune = re.sub(r" (/?[0-9])", r"\1", tune)
    key = key[:2] + "C" + key[2:]
    abc_outpath = Path(out_dir, 'abc', f"{filename}.abc")
    with open(str(abc_outpath), 'w') as fh:
        fh.write(f"X:{idx+1}\n{meter.strip()}\n{key.strip()}\n{tune.strip()}")