In [1]:
from GPTmuseVAE import GPTmuseVAE

import os
import torch
from GPTmuseVAE import GPTmuseVAE
from miditok.pytorch_data import DatasetTok
from miditok import REMI
from torchtoolkit.data import create_subsets
from pathlib import Path
from utils import *
import pygame
from pygame import mixer

  from .autonotebook import tqdm as notebook_tqdm


pygame 2.5.2 (SDL 2.28.3, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# Model Hyperparameters
n_embd = 32
n_head = 4
n_layer = 3
z_dim = 64
block_size = 254 # what is the maximum context length for predictions?
dropout = 0.2
########################

# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ------------

In [3]:
tokenizer = REMI(params= Path('midi_dataset_tokenizer_bpe.conf'))
vocab_size = len(tokenizer)

tokens_paths = list(Path('midi_dataset_tokens_no_bpe').glob("**/*.json"))

dataset = DatasetTok(
    tokens_paths, 
    max_seq_len=block_size, # to make target and prediction match the song length of block size
    min_seq_len=block_size, 
    one_token_stream= False,
    func_to_get_labels = get_artist_label
)

Loading data: midi_dataset_tokens_no_bpe/midi_metal/Slayer: 100%|██████████| 511/511 [00:01<00:00, 284.76it/s]


In [4]:
model = GPTmuseVAE( vocab_size= len(tokenizer),
                    n_embd = n_embd,
                    n_head = n_head,
                    n_layer = n_layer,
                    z_dim = z_dim,
                    block_size = block_size,
                    dropout = dropout)


m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')



0.187536 M parameters


In [5]:
loaded_state_dict = torch.load('checkpoints/checkpoint_3300.pt')
model.load_state_dict(loaded_state_dict['model_state_dict'])

<All keys matched successfully>

In [6]:
gen_seed = dataset[500]['input_ids'].unsqueeze(0)
print(gen_seed.shape)

torch.Size([1, 254])


In [7]:
generated_sequence = model.generate(gen_seed[:], max_new_tokens=128)
print(generated_sequence[0].shape)
out = generated_sequence[0].cpu().numpy().tolist()
print(len(out))
gen_midi = tokenizer.tokens_to_midi(out)
gen_midi.dump('musicGPT.mid')

torch.Size([1, 382])
1


In [8]:
mixer.init()
mixer.music.load("musicGPT.mid")
mixer.music.play()

In [9]:
mixer.music.stop()

# Z manipulation

In [10]:
_ , small_data = create_subsets(dataset, [0.1])

In [11]:
z, labels = process_dataset_for_z(small_data)

z.shape

torch.Size([2834, 254])

In [12]:
z = model.sample_latent(z)

In [13]:
z.shape

torch.Size([2834, 254, 64])

In [14]:
pointer_dict = calculate_feature_pointers(z,labels)

In [15]:
pointer_dict.keys()

dict_keys(['Dvorak', 'Type O Negative', 'Megadeth', 'Brahms', 'Beethoven', 'Schubert', 'Bach', 'Pantera', 'Mozart', 'Ozzy Osbourne', 'Carcass', 'Ravel', 'Slayer', 'Black_sabath', 'Faure', 'midi_pop_songs', 'Children Of Bodom', 'Judas Priest', 'Sepultura', 'Cambini', 'Haydn'])

In [41]:
pointer = pointer_dict['midi_pop_songs']
song_flag = 10000
input_block_size = block_size
magnitude = 100
pointer = pointer
print(decode_artist_label(dataset[song_flag]['labels'],get_artist_label.artist_id_mapping))

Megadeth


In [42]:
gen_seed = dataset[song_flag]['input_ids'].unsqueeze(0)
print(gen_seed.shape)
generated_sequence = model.generate(gen_seed[:input_block_size] ,max_new_tokens=128, latent_vector=pointer, magnitude= magnitude)
out = generated_sequence[0].cpu().numpy().tolist()
gen_midi = tokenizer.tokens_to_midi(out)
gen_midi.dump('musicGPT_latent.mid')
mixer.init()
mixer.music.load("musicGPT_latent.mid")
mixer.music.play()

torch.Size([1, 254])


In [43]:
mixer.music.stop()