In [None]:
import base64
import concurrent.futures as cf
from collections import defaultdict, deque
import copy
import io
import math
from pathlib import Path
import pickle
import time
import warnings

from confugue import Configuration
import IPython.display as ipd
import muspy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.interpolate import make_interp_spline
import scipy.spatial.distance
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import yaml

from spe_music.model.music_performer import MusicPerformer
from spe_music.model import fast_transformer_decoder
from spe_music.train_performer_grv2grv import make_representation

In [None]:
DEVICE = 'cuda'
fast_transformer_decoder.device = DEVICE

In [None]:
DATA_PATH = Path('../data/train_split')

In [None]:
dataset = muspy.MusicDataset(DATA_PATH / 'test')

In [None]:
def load_model(model_dir):
    cfg = Configuration.from_yaml_file(model_dir / 'config.yaml')
    cfg['model']['max_len'] = 2048

    representation, start_id, end_id = cfg.configure(make_representation)
    model = cfg['model'].configure(MusicPerformer, n_token=len(representation.vocab))

    params_path = sorted(model_dir.glob('params/*_params.pt'))[-1]
    print(params_path)
    state_dict = torch.load(params_path, map_location=torch.device(DEVICE))
    if 'pe.pe' in state_dict:
        del state_dict['pe.pe']
    print(model.load_state_dict(state_dict, strict=False))
    model.to(DEVICE)
    model.eval()
    return representation, start_id, end_id, model

representation, start_id, end_id, model = load_model(model_dir)
print(model)

In [None]:
def generate(model, prompt, max_len, temperature=0.6, pos_code=None):
    tokens = [start_id] + representation.encode(prompt).tolist()[:-1]  # Get rid of EOS
    prompt_len = len(tokens)
    tokens = tokens + [end_id] * (max_len - len(tokens))
    tokens = torch.tensor(tokens, device=DEVICE)[None, :]

    active_notes = defaultdict(lambda: defaultdict(deque))
    potentially_invalid_ids = set(representation.vocab[('note_off', tr, p)]
                                  for tr in range(representation.num_tracks)
                                  for p in range(128))

    with torch.no_grad():
        t = 0
        for i in range(prompt_len, max_len):
            logits = model(tokens[:, :i], attn_kwargs=dict(
                omit_feature_map_draw=True, pos_code=pos_code))
            logits /= temperature

            # Constrain to predict valid tokens
            invalid_ids = set(potentially_invalid_ids)
            invalid_ids -= set(  # Avoid note-offs for notes that are off
                representation.vocab[('note_off', tr, p)]
                for tr, d in active_notes.items()
                for p in d.keys() if d[p])
            invalid_ids.update(  # Avoid backward time shifts
                representation.vocab[('time_shift', 0, ticks)]
                for ticks in range(1, (t % representation.resolution) + 1)
            )
            invalid_ids = torch.as_tensor(sorted(invalid_ids), device=DEVICE)
            logits.squeeze(0)[-1, invalid_ids] = -1e6

            # Sample
            dist = torch.distributions.Categorical(logits=logits[:, -1])
            tokens[:, i] = dist.sample()
            if tokens[:, i] in invalid_ids:
                print('Invalid token sampled')
            if tokens[:, i] == end_id:
                break

            # Process event
            event, *args = representation.vocab.inv[tokens[:, i].item()]
            if event == 'time_shift':
                t_new = representation.timing.decode(t, (event, *args))
                if t_new <= t:
                    print(f'Invalid time shift: {t} + {args} = {t_new}')
                else:
                    t = t_new
            elif event == 'note_on':
                track_id, pitch = args
                active_notes[track_id][pitch].append(muspy.Note(
                    time=t,
                    pitch=pitch,
                    velocity=-1,
                    duration=-1))
            elif event == 'note_off':
                track_id, pitch = args
                try:
                    note = active_notes[track_id][pitch].popleft()
                except IndexError:
                    print('Invalid note-off')
                note.duration = t - note.time
    
    tokens = tokens.cpu().numpy().squeeze(0)
    
    return representation.decode(tokens), representation.decode(tokens[prompt_len:])

In [None]:
def postprocess(output, prompt):
    m = copy.deepcopy(prompt)
    m.tracks.clear()
    m += output
    for i, track in enumerate(m.tracks):
        track.program, track.is_drum = prompt.tracks[i].program, prompt.tracks[i].is_drum
        if prompt.tracks[i].notes:
            velocity = int(np.mean([n.velocity for n in prompt.tracks[i].notes]))
            for note in track.notes:
                note.velocity = velocity
    return m

In [None]:
max_len = 1024
num_outputs = 3

for model_dir in ['trio_performer_softmax_l512_v01', 'trio_performer_softmax_sinespe_l512_v01', 'trio_performer_softmax_convspe_l512_v01']:
    print(model_dir)
    model_dir = Path(model_dir)
    representation, start_id, end_id, model = load_model(model_dir)

    pos_code = None

    torch.random.manual_seed(0)

    # Pre-generate positional code for speedup
    if model.transformer_decoder._spe and model.transformer_decoder.share_pe:
        with torch.no_grad():
            pos_code = model.transformer_decoder.spe((1, max_len))

    for prompt in tqdm(dataset):
        # Prepare prompt
        prompt.tracks[2:] = sorted(prompt.tracks[2:], key=lambda tr: -len(tr.notes))
        prompt.tracks[:] = prompt.tracks[:3]
        style_ref = copy.deepcopy(prompt)
        for track in prompt.tracks:
            track.notes = [n for n in track.notes if n.end < 2 * 4 * prompt.resolution]

        # Save prompt
        (model_dir / 'outputs').mkdir(exist_ok=True)
        prompt.write_midi(model_dir / 'outputs' / (prompt.metadata.title + '.prompt.mid'))

        outputs = []
        for j in range(num_outputs):
            # Generate
            prompt_and_output, output = generate(model, prompt, max_len=max_len, pos_code=pos_code.detach())

            # Post-process
            prompt_and_output = postprocess(prompt_and_output, prompt)
            output = postprocess(output, prompt)

            # Save outputs
            output.write_midi(model_dir / 'outputs' / (prompt.metadata.title + f'.cont_{j}.mid'))
            prompt_and_output.write_midi(model_dir / 'outputs' / (prompt.metadata.title + f'.prompt_cont_{j}.mid'))
            outputs.append(output)