## Music Tokenization with MuseTok

This file provides the example codes for encoding a given MIDI file into MuseTok codes. 

### Prepare the environment

In [None]:
!pip install -r requirements.txt

### Download checkpoints and unzip

1. download checkpoints from link: https://drive.google.com/file/d/1HK534lEVdHYl3HMRkKvz8CWYliXRmOq_/view?usp=sharing

In [None]:
# 2. unzip the file
!unzip ckpt.zip

### Represent the MIDI file with REMI+ events

We currently support the encoding of piano music with constant time signature of 2/4, 3/4, 4/4, 2/2, 3/8 and 6/8, or set to 4/4 by default if time signature is not provided in the MIDI file.

In [5]:
from data_processing.midi2events import midi_analyzer, midi2corpus_strict, corpus2events

file = 'test/Classic_358.mid'
midi_obj, bar_resol = midi_analyzer(file)
full_data = midi2corpus_strict(midi_obj, bar_resol, remove_overlap=True)
pos, events = corpus2events(full_data, bar_resol, time_first=False, has_velocity=False, repeat_beat=True, remove_short=False)
print('[info] got {} bars, {} events'.format(len(pos), len(events)))

[info] Represent MIDI file test/Classic_358.mid ...
[info] time signature = 4/4
[info] got 16 bars, 999 events


### Encoding with MuseTok

In [3]:
# load pre-trained tokenizer
import torch
from model.musetok import TransformerResidualVQ

import warnings
warnings.simplefilter(action='ignore')

device = 'cuda'
ckpt_path = 'ckpt/best_tokenizer/model.pt'
model = TransformerResidualVQ(
    enc_n_layer=12, enc_n_head=8, enc_d_model=512, enc_d_ff=2048,
    dec_n_layer=12, dec_n_head=8, dec_d_model=512, dec_d_ff=2048,
    d_vae_latent=128, d_embed=512, n_token=168,
    num_quantizers=16, codebook_size=2048,
    rotation_trick=True, rvq_type='SimVQ'
).to(device)
model.eval()
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
print('[info] successfully load tokenizer')


[info] successfully load tokenizer


In [4]:
# encode music events
from encoding import MuseTokEncoder
encoder = MuseTokEncoder(model, device=device)

music_data = encoder.get_segments(events, pos)
indices, latents = encoder.encoding(music_data, return_latents=True)
print('[info] Shape of MuseTok codes: {}'.format(indices.shape))
print('[info] Shape of corresponding embeddings: {}'.format(latents.shape))

[info] Shape of MuseTok codes: (73, 16)
[info] Shape of corresponding embeddings: (73, 128)
