# Composing Music with xLSTM.

This notebook shows you how to compose music with xLSTM

On top of the installation instructions from the README, you might want to install the following packages:

In [None]:
# Please uncomments. This is for the first time setup
#!pip install note-seq
#!pip install --upgrade bokeh==2.4.3

## Load the language model.

We will use this model: https://huggingface.co/TristanBehrens/jsfakes-music-xlstm/tree/main

In [None]:
import sys
sys.path.append("..")
from source.languagemodel import LanguageModel
import note_seq

# Select a model to load.
#model_path_or_repo = "TristanBehrens/jsfakes-music-transformer"
#model_path_or_repo = "TristanBehrens/jsfakes-music-pharia"
#model_path_or_repo = "TristanBehrens/jsfakes-music-mamba"
#model_path_or_repo = "TristanBehrens/jsfakes-music-xlstm"
#model_path_or_repo = "TristanBehrens/bach-garland-transformer"
model_path_or_repo = "TristanBehrens/bach-garland-pharia"
#model_path_or_repo = "TristanBehrens/bach-garland-mamba"
#model_path_or_repo = "TristanBehrens/bach-garland-xlstm"
#model_path_or_repo = "TristanBehrens/bach-garland-mambaplus"


# Load the model.
model = LanguageModel(
    model_path_or_repo,
    config_overrides={"context_length": 16_384},
)
model.summary()

## Utility code for music

Here is some code that will map anything that the model generates into a music representation that is close to GM MIDI.

In [2]:
import copy

def handle_garland_tokens(tokens):
    song_data = {}

    song_data["tracks"] = []

    current_track_index = 0
    current_timestep = 0
    for token in tokens:
        if token == "GARLAND_START":
            pass
        elif token == "BAR_START":
            if current_track_index == len(song_data["tracks"]):
                song_data["tracks"] += [{"bars": [], "instrument": "0"}]
            bar_data = {"notes": []}
            song_data["tracks"][current_track_index]["bars"] += [bar_data]
            current_timestep = 0
        elif token.startswith("INST="):
            instrument = token.split("=")[1]
            song_data["tracks"][current_track_index]["instrument"] = instrument
        elif token.startswith("DENSITY="):
            pass
        elif token.startswith("NOTE_ON="):
            note_pitch = int(token.split("=")[1])
            note_data = {
                "note": note_pitch,
                "start": current_timestep,
                "end": current_timestep,
                "veloctiy": 80
            }
            song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data]
            pass
        elif token.startswith("TIME_DELTA="):
            current_timestep += int(token.split("=")[1])
        elif token.startswith("NOTE_OFF="):
            note_pitch = int(token.split("=")[1])
            for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]:
                if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]:
                    note_data["end"] = current_timestep
                    break
            pass
        elif token == "BAR_END":
            current_track_index += 1
        elif token == "NEXT":
            current_track_index = 0
        elif token == "GARLAND_END":
            pass
        elif token == "[PAD]":
            pass
        elif token == "[EOS]":
            pass
        else:
            raise Exception(f"Unknown token: {token}")
    
    assert isinstance(song_data, dict)
    return song_data

def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True):

    assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}"

    # Clone the song data.
    song_data = copy.deepcopy(song_data)

    # Sort the tracks by instrument.
    assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}"
    tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"])
    song_data["tracks"] = tracks

    # Remove tracks that are not enabled.
    if remove_disabled_tracks:
        song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)]

    # Create an empy note sequence.
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()

    # Add the tempo.
    bpm = song_data["bpm"] if "bpm" in song_data else 120
    note_sequence.tempos.add().qpm = bpm

    # Compute some lengths.
    step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter
    bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter

    # Get the instruments.
    instruments = list(set([t["instrument"] for t in song_data["tracks"]]))

    # Add the tracks.
    for track_index, track_data in enumerate(song_data["tracks"]):
        instrument = track_data["instrument"]
        for bar_index, bar_data in enumerate(track_data["bars"]):
            bar_start_time = bar_index * bar_length_seconds
            for note_data in bar_data["notes"]:
                assert "note" in note_data
                assert "start" in note_data
                assert "end" in note_data
                note = note_sequence.notes.add()
                #note.instrument = instrument TODO
                note.pitch = note_data["note"]
                note.start_time = note_data["start"] * step_length_seconds + bar_start_time
                note.end_time = note_data["end"] * step_length_seconds + bar_start_time
                if "velocity" in note_data:
                    note.velocity = note_data["velocity"]
                else:
                    note.velocity = 80
                note.instrument = track_index
                if instrument == "drums":
                    note.is_drum = True
                else:
                    note.is_drum = False
                    note.program = int(instrument)

    return note_sequence

## Generate music

In [None]:
# The temperature of the generation. The higher the temperature, the more random the output.
temperature = 0.5

# The maximum length of the generated music.
max_length = 16_384

# When to stop the generation.
end_tokens = ["NEXT"]

# Compose the music iterativelybar by bar.
output = "GARLAND_START"
for iteration in range(10):
    output_dict = model.generate(
        prompt=output,
        temperature=temperature,
        max_length=max_length,
        end_tokens=end_tokens,
        forbidden_tokens=["[PAD]", "[EOS]", "GARLAND_END"],
        return_structured_output=True
    )
    for key, value in output_dict.items():
        print(key, value)
    output = output_dict["output"]
    length = len(output.split(" "))
    print(f"Iteration {iteration}, length {length:_}")
    if length >= max_length:
        break

In [None]:
tokens = output.split()
song_data = handle_garland_tokens(tokens)
note_sequence = convert_songdata_to_notesequence(song_data)

note_seq.plot_sequence(note_sequence)
note_seq.play_sequence(note_sequence)

# Thank you!