# Playing with Music Generation models
This notebook is a playground for a backbone generation model

In [11]:
import torch

device = "mps" if torch.backends.mps.is_available() else "cpu"

In [22]:
import note_seq
from magenta.models.melody_rnn import melody_rnn_sequence_generator
from magenta.models.shared import sequence_generator_bundle
from note_seq.protobuf import generator_pb2, music_pb2

# Load the pre-trained model bundle
BUNDLE_PATH = note_seq.get_sample_bundle('basic_rnn')  # Downloads automatically

# Initialize the generator
bundle = sequence_generator_bundle.read_bundle_file(BUNDLE_PATH)
generator = melody_rnn_sequence_generator.get_generator_map()['basic_rnn'](checkpoint=None, bundle=bundle)
generator.initialize()

# Create a primer sequence with one note (middle C)
primer_sequence = music_pb2.NoteSequence()
primer_sequence.notes.add(pitch=60, start_time=0, end_time=0.5, velocity=80)
primer_sequence.total_time = 1.0
primer_sequence.tempos.add(qpm=120)

# Set generation options (10 seconds of music)
generator_options = generator_pb2.GeneratorOptions()
generator_options.generate_sections.add(start_time=primer_sequence.total_time, end_time=primer_sequence.total_time + 10)

# Generate the melody
generated_sequence = generator.generate(primer_sequence, generator_options)

# Save as MIDI
note_seq.sequence_proto_to_midi_file(generated_sequence, "generated_melody.mid")
print("🎶 Melody generated and saved as 'generated_melody.mid'")

# (Optional) Play the sequence if in Jupyter or Colab
try:
    from IPython.display import Audio
    note_seq.play_sequence(generated_sequence, synth=note_seq.fluidsynth)
except:
    pass



ModuleNotFoundError: No module named 'magenta'

In [19]:
audio_values = model.generate(**inputs, max_new_tokens=256)

In [20]:
from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)

In [19]:
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

In [18]:
audio_values[0].cpu().numpy()

array([[ 0.06585294,  0.06052003,  0.09031354, ..., -0.3202339 ,
        -0.3255787 , -0.31786978]], shape=(1, 1308800), dtype=float32)