In [None]:
import torch
import numpy as np
import os
from torch.utils.data import Subset
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
import random
from matplotlib import pyplot as plt
from utils import set_seed

from data import MIDIDataset, graph_from_tensor, graph_from_tensor_torch
from model import VAE
from utils import plot_struct, dense_from_sparse, dense_from_sparse_torch, muspy_from_dense, muspy_from_dense_torch
from utils import plot_pianoroll, midi_from_muspy
from train import VAETrainer

Set global seed:

In [None]:
seed = 42
set_seed(seed)

Load model:

In [None]:
models_dir = 'models/'
model = 'LMD2'
gpu = True
device_idx = 3

checkpoint = torch.load(os.path.join(models_dir, model, 'checkpoint'), map_location='cpu')
state_dict = checkpoint['model_state_dict']
params = torch.load(os.path.join(models_dir, model, 'params'), map_location='cpu')

In [None]:
torch.cuda.set_device(device_idx)

device = torch.device("cuda") if gpu else torch.device("cpu")
print("Device:", device)
print("Device idx:", torch.cuda.current_device())

In [None]:
params

In [None]:
vae = VAE(**params['model'], device=device).to(device)
vae.load_state_dict(state_dict)
vae.eval()

In [None]:
def generate_music(vae, z, s_cond=None, s_tensor_cond=None):
    
    # Get structure and content logits
    with torch.cuda.amp.autocast():
        _, c_logits, s_tensor_out = vae.decoder(z, s_cond)
    
    s_tensor = s_tensor_cond if s_tensor_cond != None else s_tensor_out
    
    # Build (n_batches x n_bars x n_tracks x n_timesteps x Sigma x d_token)
    # multitrack pianoroll tensor containing logits for each activation and
    # hard silences elsewhere
    mtp = dense_from_sparse_torch(c_logits, s_tensor)
    
    # Collapse bars dimension
    mtp = mtp.permute(0, 2, 1, 3, 4, 5)
    size = (mtp.shape[0], mtp.shape[1], -1, mtp.shape[4], mtp.shape[5])
    mtp = mtp.reshape(*size)
    
    return mtp, s_tensor

In [None]:
import matplotlib as mpl


def save(mtp, dir, s_tensor=None, track_data=None, n_loop=1):
    
    track_data = ([('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 83)]
                  if track_data == None else track_data)
    
    # Clear matplotlib cache (this avoids formatting problems with first plot)
    plt.clf()

    # Iterate over the generated n-bar sequences
    for i in range(mtp.size(0)):
        
        # Create the directory if it does not exist
        save_dir = os.path.join(dir, str(i))
        os.makedirs(save_dir, exist_ok=True)

        print("Saving midi sequence " + str(i+1) + "...")
        
        # Generate muspy song from multitrack pianoroll, then midi from muspy
        # and save
        muspy_song = muspy_from_dense_torch(mtp[i], track_data, 
                                            params['model']['resolution'])
        midi_from_muspy(muspy_song, save_dir, name='music')
        
        # Plot the pianoroll associated to the sequence
        preset = 'full'
        with mpl.rc_context({'lines.linewidth': 4, 
                             'axes.linewidth': 4, 'font.size': 34}):
            plot_pianoroll(muspy_song, save_dir, name='pianoroll',
                           figsize=(20, 10), fformat='png', preset=preset)
        
        # Plot structure_tensor if present
        if s_tensor != None:
            s_curr = s_tensor[i]
            s_curr = s_curr.permute(1, 0, 2)
            s_curr = s_curr.reshape(s_curr.shape[0], -1)
            with mpl.rc_context({'lines.linewidth': 1, 
                                 'axes.linewidth': 1, 'font.size': 14}):
                plot_struct(s_curr.cpu(), name='structure', 
                            save_dir=save_dir, figsize=(12, 3))

        if n_loop > 1:
            # Generate extended sequence
            print("Saving extended (looped) midi sequence " + str(i+1) + "...")
            extended = mtp[i].repeat(1, n_loop, 1, 1)
            extended = muspy_from_dense_torch(extended, track_data, 
                                              params['model']['resolution'])
            midi_from_muspy(extended, save_dir, name='extended')
        
        print()

    print("Finished.")

In [None]:
def generate_z(bs, d_model):
    shape = (bs, d_model)

    z_norm = torch.normal(
        torch.zeros(shape, device=device),
        torch.ones(shape, device=device)
    )
    
    return z_norm
    

In [None]:
import json

structure = True
structure_file = 'structure2bars.json'

n_sequences = 5
n_loop = 4
d_model = params['model']['d']
n_bars = params['model']['n_bars']
n_tracks = params['model']['n_tracks']
n_timesteps = 4*params['model']['resolution']
dir = 'tmpmusic/'

bs = n_sequences
track_data = [('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 83)]

s, s_tensor = None, None

if structure:
    
    # Load structure tensor from file
    with open(structure_file, 'r') as f:
        s_tensor = json.load(f)
    
    s_tensor = torch.tensor(s_tensor)
    s_tensor = s_tensor.bool()
    s_tensor = s_tensor.unsqueeze(0).repeat(bs, 1, 1, 1)
    s = vae.decoder._structure_from_binary(s_tensor)

z = generate_z(bs, d_model)
mtp, s_tensor = generate_music(vae, z, s, s_tensor)
save(mtp, dir, s_tensor, track_data, n_loop)

In [None]:
import torch
import json

# 1. Create a Tensor
tensor = torch.zeros(2, 4, 32)  # Replace with your actual tensor.
tensor[:, :, ::8] = 1
tensor = tensor.int()
# 2. Convert Tensor to List
tensor_list = tensor.tolist()

# 3. Write to JSON File
with open('structure2bars.json', 'w') as f:
    json.dump(tensor_list, f, indent=4)
