In [None]:
import os
import argparse
import numpy as np
from tqdm import tqdm
import time
import torch
import pretty_midi

from lib import constants
from lib import midi_processing
from lib import generation
from lib.midi_processing import PIANO_RANGE
from lib.model.transformer import MusicTransformer

In [None]:
def decode_and_write(generated, primer, genre, out_dir):
    for i, (gen, g) in enumerate(zip(generated, genre)):
        midi = midi_processing.decode(gen)
        midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')

In [None]:
id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'}
genre2id = dict([[x[1],x[0]] for x in id2genre.items()])
tuned_params = {
    0: 1.1,
    1: 0.95,
    2: 0.9,
    3: 1.0
}

In [None]:
load_path = '../checkpoints/model_big_v3_378k.pt'
out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S')
genre_to_generate = 'calm'  # Use one of ['classic', 'jazz', 'calm', 'pop']
batch_size = 8
device = torch.device('cuda:0')
remove_bad_generations = True

In [None]:
params = dict(
    target_seq_length = 512,
    temperature = tuned_params[genre2id[genre_to_generate]],
    topk = 40,
    topp = 0.99,
    topp_temperature = 1.0,
    at_least_k = 1,
    use_rp = False,
    rp_penalty = 0.05,
    rp_restore_speed = 0.7,
    seed = None,
)

In [None]:
# START GENERATION

os.makedirs(out_dir, exist_ok=True)
genre_id = genre2id[genre_to_generate]

# init model
print('loading model...')
model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()
model.load_state_dict(torch.load(load_path, map_location=device))

# add information about genre (first token)
primer_genre = np.repeat([genre_id], batch_size)
primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4

print('generating to:', os.path.abspath(out_dir))
generated = generation.generate(model, primer, **params)
generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations)

decode_and_write(generated, primer, primer_genre, out_dir)