In [5]:
import torch
import mido
import time
import argparse
from typing import Optional, Union, Callable
from IPython.display import Audio

from data_preprocessing import *
from model import *

Using device: cuda


In [7]:
def generate_sequence(
    model: 'MusicTransformer',
    initial_sequence: list,
    sampling_mode: str = "categorical",
    temperature_fn: Union[float, Callable] = 1.0,
    top_k: Optional[Union[int, Callable]] = None
) -> torch.Tensor:
    """
    Generate a sequence of tokens using autoregressive prediction.
    
    Args:
        model: Trained MusicTransformer model
        initial_sequence: Starting sequence of MIDI events
        sampling_mode: Token sampling strategy ('categorical' or 'argmax')
        temperature_fn: Temperature value or function of sequence length
        top_k: Number of top tokens to sample from, or function returning this
        
    Returns:
        torch.Tensor: Generated sequence of token IDs
    """
    # Prepare input sequence
    token_ids = events_to_indices(initial_sequence)
    if token_ids[0] != start_token:
        token_ids = [start_token] + token_ids
        
    sequence = torch.tensor(token_ids, dtype=torch.int64, device=device).unsqueeze(0)
    mask_dims = sequence.dim() + 2
    
    # Convert temperature and top_k to callable if they're constants
    if not callable(temperature_fn):
        temp_value = temperature_fn
        temperature_fn = lambda _: temp_value
        
    if top_k is not None and not callable(top_k):
        k_value = top_k
        top_k = lambda _: k_value
    
    # Generation loop
    torch.set_float32_matmul_precision("high")
    try:
        with torch.no_grad():
            while True:
                # Get model predictions
                logits = model(sequence, mask=create_mask(sequence, mask_dims))
                
                # Apply temperature scaling
                scaled_logits = logits[..., -1, :] / temperature_fn(sequence[-1].shape[-1])
                
                # Sample next token based on specified mode
                if sampling_mode == "argmax":
                    next_token = torch.argmax(scaled_logits, dim=-1)
                elif top_k is not None:
                    # get top k predictions using sequence length
                    k = top_k(sequence[-1].shape[-1])
                    top_logits, top_indices = torch.topk(scaled_logits, k, dim=-1)
                    sampled_idx = torch.distributions.Categorical(logits=top_logits).sample()
                    next_token = top_indices[..., sampled_idx]
                else:  # categorical
                    next_token = torch.distributions.Categorical(logits=scaled_logits).sample()
                
                # Check for end token
                if next_token.item() == end_token:
                    return sequence.squeeze()
                
                # Append prediction and continue
                sequence = torch.cat([sequence, next_token.view(1, 1)], dim=-1)
                
    except (KeyboardInterrupt, RuntimeError):
        # Return partial sequence if interrupted or exceeded position encoding limit
        return sequence.squeeze()

In [8]:
def create_midi(
    token_ids: torch.Tensor,
    output_path: str = "output/output.mid",
    tempo: int = 512820,
    verbose: bool = False
) -> None:
    """
    Convert generated token IDs to a MIDI file.
    
    Args:
        token_ids: Sequence of generated token IDs
        output_path: Path to save the MIDI file
        tempo: Tempo in microseconds per beat
        verbose: Enable verbose output
    """
    # Normalize file extension
    if output_path.endswith(".midi"):
        output_path = f"{output_path[:-1]}"
    elif not output_path.endswith(".mid"):
        output_path = f"{output_path}.mid"
    
    if verbose:
        print(f"Saving MIDI file to {output_path}...")
        
    # Create and save MIDI file
    midi_obj = list_parser(
        index_list=token_ids,
        fname=output_path[:-4],
        tempo=tempo
    )
    midi_obj.save(output_path)
    
    if verbose:
        print("Done")
        

In [9]:
def generate_music(
    model: 'MusicTransformer',
    initial_sequence: list,
    output_path: str = "./output.mid",
    sampling_mode: str = "categorical",
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    tempo: int = 512820,
    verbose: bool = False
) -> None:
    """
    Generate music using a trained MusicTransformer model.
    
    Args:
        model: Trained MusicTransformer model
        initial_sequence: Starting sequence of MIDI events
        output_path: Path to save the generated MIDI file
        sampling_mode: Token sampling strategy ('categorical' or 'argmax')
        temperature: Sampling temperature for token generation
        top_k: Number of top tokens to sample from
        tempo: Tempo in microseconds per beat
        verbose: Enable verbose output
    """
    if verbose:
        print("Generating sequence...")
        start_time = time.time()
        
    token_ids = generate_sequence(
        model=model,
        initial_sequence=initial_sequence,
        sampling_mode=sampling_mode,
        temperature_fn=temperature,
        top_k=top_k
    )
    
    if verbose:
        generation_time = time.time() - start_time
        print(f"Generated {len(token_ids)} tokens in {generation_time:.2f} seconds")
    
    create_midi(
        token_ids=token_ids,
        output_path=output_path,
        tempo=tempo,
        verbose=verbose
    )

In [12]:
# Define device first
device = "cuda" if torch.cuda.is_available() else "cpu"

# First load the checkpoint to get the config
checkpoint_path = "music_transformer_bigmodel.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

# Initialize model with correct config
model = MusicTransformerDecoder(checkpoint["config"]).to(device)

# Fix the state dict by removing "_orig_mod." prefix
fixed_state_dict = {}
for key, value in checkpoint["state_dict"].items():
    # Remove "_orig_mod." prefix if it exists
    new_key = key.replace("_orig_mod.", "")
    fixed_state_dict[new_key] = value

# Load the fixed state dict
model.load_state_dict(fixed_state_dict)

# Optional compilation
enable_compilation = False  # Set this based on your needs
if enable_compilation:
    model = torch.compile(model)

# Set model to evaluation mode
model.eval()

MusicTransformerDecoder(
  (input_embedding): Embedding(416, 256)
  (input_dropout): Dropout(p=0.1, inplace=False)
  (decoder): ModuleList(
    (0-5): 6 x DecoderBlock(
      (self_attention): MultiHeadRelativeAttention(
        (query_proj): Linear(in_features=256, out_features=256, bias=True)
        (key_proj): Linear(in_features=256, out_features=256, bias=True)
        (value_proj): Linear(in_features=256, out_features=256, bias=True)
        (rel_emb): Embedding(1024, 256)
        (output_proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (feed_forward): FeedForward(
        (network): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-06, elementwise_

In [14]:
# Generation parameters
temperature = 1.0  # Controls randomness in generation (higher = more random)
mode = "categorical"  # Sampling mode: "categorical" or "argmax"
top_k = None  # Optional top-k sampling parameter
tempo = 512820  # Default tempo (120 BPM = 512820 microseconds per beat)

# Input sequence setup
use_midi_prompt = False  # Set to True if you want to use a MIDI file as prompt
midi_prompt_path = "path/to/your/midi/prompt.mid"  # Only used if use_midi_prompt is True
midi_prompt_tokens = None  # Number of tokens to use from MIDI prompt (None = all)

# Output settings
save_path = "output/generated_music_big.mid"
verbose = True  # Print generation progress

if use_midi_prompt:
    # Parse MIDI prompt if specified
    midi_parser_output = midi_parser(midi_prompt_path)
    tempo = midi_parser_output[2]  # Use tempo from MIDI file
    if midi_prompt_tokens:
        midi_input = midi_parser_output[1][:midi_prompt_tokens]
    else:
        midi_input = midi_parser_output[1]
else:
    # Start from scratch with just the start token
    midi_input = ["<start>"]

# Generate music using the configured model
generate_music(
    model=model,  # Using the model we loaded earlier
    initial_sequence=midi_input,
    output_path=save_path,
    sampling_mode=mode,
    temperature=temperature,
    top_k=top_k,
    tempo=tempo,
    verbose=verbose
)

Generating sequence...
Generated 556 tokens in 5.71 seconds
Saving MIDI file to output/generated_music_big.mid...
Done


In [15]:
soundfont_path = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
midi_path = "output/generated_music_big.mid"  
wav_path = "output/generated_music.wav"

# Simpler FluidSynth command
command = f"fluidsynth -ni {soundfont_path} {midi_path} -F {wav_path} -r 44100 -g 1.0"
os.system(command)

# If you're in a Jupyter notebook, you can play it:
Audio(wav_path)

FluidSynth runtime version 2.2.5
Copyright (C) 2000-2022 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'output/generated_music.wav'..
