## Symbolic Music Generation with MuseTok

### Prepare the environment

In [None]:
!pip install -r requirements.txt

### Download checkpoints and unzip

1. download checkpoints from link: https://drive.google.com/file/d/1HK534lEVdHYl3HMRkKvz8CWYliXRmOq_/view?usp=sharing

In [None]:
# 2. unzip the file
!unzip ckpt.zip

### Load Models

In [1]:
# load pre-trained tokenizer and generator
import torch
import numpy as np
from model.musetok import TransformerResidualVQ, GPT2TokenGenerator

import warnings
warnings.simplefilter(action='ignore')

device = 'cuda'
num_quantizers = 16
codebook_size = 2048

tokenizer_path = 'ckpt/best_tokenizer/model.pt'
tokenizer = TransformerResidualVQ(
    enc_n_layer=12, enc_n_head=8, enc_d_model=512, enc_d_ff=2048,
    dec_n_layer=12, dec_n_head=8, dec_d_model=512, dec_d_ff=2048,
    d_vae_latent=128, d_embed=512, n_token=168,
    num_quantizers=16, codebook_size=2048,
    rotation_trick=True, rvq_type='SimVQ'
).to(device)
tokenizer.eval()
tokenizer.load_state_dict(torch.load(tokenizer_path, map_location='cpu'))
print('[info] successfully load tokenizer')

generator_path = 'ckpt/best_generator/model.pt'
generator = GPT2TokenGenerator(
    dec_n_layer=12, dec_n_head=16, dec_d_model=1024, dec_d_ff=2048,
    d_embed=512, n_token=32771, n_bar=16
).to(device)
generator.eval()
generator.load_state_dict(torch.load(generator_path, map_location='cpu'))
print('[info] successfully load generator')


  from .autonotebook import tqdm as notebook_tqdm


[info] successfully load tokenizer
[info] successfully load generator


### Music Continuation
Generate music pieces by continuing the prompts (e.g. 4 bars) from the provided MIDI file.

In [27]:
from data_processing.midi2events import midi_analyzer, midi2corpus_strict, corpus2events
from encoding import MuseTokEncoder
from test_generation import generate_tokens, decode_tokens, word2event, TempoEvent
from remi2midi import remi2midi
from utils import numpy_to_tensor

file = 'test/Classic_358.mid'
primer_n_bar = 4

# represent MIDI file into music events
midi_obj, bar_resol = midi_analyzer(file)
full_data = midi2corpus_strict(midi_obj, bar_resol, remove_overlap=True)
pos, events = corpus2events(full_data, bar_resol, time_first=False, has_velocity=False, repeat_beat=True, remove_short=False)
print('[info] got {} bars, {} events'.format(len(pos), len(events)))

# encode music events
encoder = MuseTokEncoder(tokenizer, device=device)
music_data = encoder.get_segments(events, pos)
indices, latents = encoder.encoding(music_data, return_latents=True)
print('[info] Shape of MuseTok codes: {}'.format(indices.shape))
print('[info] Shape of corresponding embeddings: {}'.format(latents.shape))

[info] Represent MIDI file test/Classic_358.mid ...
[info] time signature = 4/4
[info] got 16 bars, 999 events
[info] Shape of MuseTok codes: (16, 16)
[info] Shape of corresponding embeddings: (16, 128)


In [28]:
# generate MuseTok codes by continuation
indices = (indices + np.arange(num_quantizers) * codebook_size).reshape(-1).tolist()
bos_token = len(range(num_quantizers * codebook_size))
prompt = [bos_token] + indices[:primer_n_bar * num_quantizers]

gen_tokens, t_sec = generate_tokens(
    generator, primer=prompt, primer_n_bar=primer_n_bar,
    max_bars=16, num_tokens=num_quantizers, codebook_size=codebook_size,
    temp=1.1, top_p=0.9, top_k=30, eos=bos_token + 1
)

# transfer codes to corresponding embeddings
num_bars = len(gen_tokens) // num_quantizers
gen_tokens = np.array(gen_tokens) - np.tile(np.arange(num_quantizers), num_bars) * codebook_size
gen_indices = numpy_to_tensor(gen_tokens, device=device).view(-1, num_quantizers).long()
gen_latents = tokenizer.residual_sim_vq.get_output_from_indices(gen_indices)

[info] generated 5 bars
[info] generated 6 bars
[info] generated 7 bars
[info] generated 8 bars
[info] generated 9 bars
[info] generated 10 bars
[info] generated 11 bars
[info] generated 12 bars
[info] generated 13 bars
[info] generated 14 bars
[info] generated 15 bars
[info] generated 16 bars
-- time elapsed: 3.18 secs


In [29]:
# decode codes back to music events
song, t_sec = decode_tokens(
    tokenizer, gen_latents,
    encoder.event2idx, encoder.idx2event, device,
    max_events=12800
)
song = word2event(song, encoder.idx2event)

out_file = 'test/continuation'
midi_obj = remi2midi(
    song, 
    out_file + '.mid', 
    enforce_tempo=True, 
    enforce_tempo_val=[TempoEvent(110, 0, 0, 4, 4)]
)

[info] generated 1 bars, #events = 50
[info] generated 2 bars, #events = 112
[info] generated 3 bars, #events = 168
[info] generated 4 bars, #events = 218
[info] generated 5 bars, #events = 274
[info] generated 6 bars, #events = 336
[info] generated 7 bars, #events = 404
[info] generated 8 bars, #events = 463
[info] generated 9 bars, #events = 546
[info] generated 10 bars, #events = 608
[info] generated 11 bars, #events = 688
[info] generated 12 bars, #events = 771
[info] generated 13 bars, #events = 836
[info] generated 14 bars, #events = 901
[info] generated 15 bars, #events = 972
[info] generated 16 bars, #events = 1025
-- generated events: 1026
-- time elapsed: 34.78 secs


## Music Generation from Scratch
Generate music pieces without prompts.

In [14]:
# generate MuseTok codes from scratch
gen_tokens, t_sec = generate_tokens(
    generator, primer=[bos_token], primer_n_bar=primer_n_bar,
    max_bars=16, num_tokens=num_quantizers, codebook_size=codebook_size,
    temp=1.1, top_p=0.9, top_k=30, eos=bos_token + 1
)

# transfer codes to corresponding embeddings
num_bars = len(gen_tokens) // num_quantizers
gen_tokens = np.array(gen_tokens) - np.tile(np.arange(num_quantizers), num_bars) * codebook_size
gen_indices = numpy_to_tensor(gen_tokens, device=device).view(-1, num_quantizers).long()
gen_latents = tokenizer.residual_sim_vq.get_output_from_indices(gen_indices)

[info] generated 1 bars
[info] generated 2 bars
[info] generated 3 bars
[info] generated 4 bars
[info] generated 5 bars
[info] generated 6 bars
[info] generated 7 bars
[info] generated 8 bars
[info] generated 9 bars
[info] generated 10 bars
[info] generated 11 bars
[info] generated 12 bars
-- time elapsed: 3.15 secs


In [15]:
# decode codes back to music events
song, t_sec = decode_tokens(
    tokenizer, gen_latents,
    encoder.event2idx, encoder.idx2event, device,
    max_events=12800
)
song = word2event(song, encoder.idx2event)

out_file = 'test/generation'
midi_obj = remi2midi(
    song, 
    out_file + '.mid', 
    enforce_tempo=True, 
    enforce_tempo_val=[TempoEvent(110, 0, 0, 4, 4)]
)

[info] generated 1 bars, #events = 8
[info] generated 2 bars, #events = 25
[info] generated 3 bars, #events = 42
[info] generated 4 bars, #events = 56
[info] generated 5 bars, #events = 70
[info] generated 6 bars, #events = 87
[info] generated 7 bars, #events = 107
[info] generated 8 bars, #events = 127
[info] generated 9 bars, #events = 135
[info] generated 10 bars, #events = 152
[info] generated 11 bars, #events = 169
[info] generated 12 bars, #events = 186
-- generated events: 187
-- time elapsed: 2.12 secs
