<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>About</h1>
  Try music transformer. Make sure to use the `classical_music_generator_env` environment. 
</div>

<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>Load Stuff</h1>
</div>

In [1]:
from miditok import REMI
from miditok import REMI, TokSequence
from miditoolkit import MidiFile

from pathlib import Path
import pickle

from datasets import Dataset
import torch
from torch.nn.utils.rnn import pad_sequence

from transformers import GPT2Config, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments

In [2]:
midi_dir = Path("sample_midi")
token_output_dir = Path("tokenized_midi")
token_output_dir.mkdir(exist_ok=True)

<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>Prepare for Training</h1>
</div>

In [3]:
label_map = {
    "air.mid": {"composer": "<composer:Bach>", "era": "<era:Baroque>"},
    "furelise.mid": {"composer": "<composer:Beethoven>", "era": "<era:Classical>"}
}

tokenizer = REMI()
midi_paths = list(midi_dir.glob("*.mid"))

for midi_path in midi_paths:
    filename = midi_path.name
    if filename not in label_map:
        print(f"Skipping unlabeled file: {filename}")
        continue

    labels = label_map[filename]
    midi = MidiFile(str(midi_path))
    token_seq = tokenizer(midi)

    # Assume 1 segment per file
    styled_seq = [labels["composer"], labels["era"]] + token_seq[0].tokens

    with open(token_output_dir / f"{filename}.pkl", "wb") as f:
        pickle.dump(styled_seq, f)

  token_seq = tokenizer(midi)


In [4]:
##########################################################################################
# Setup & Token ID Conversion
##########################################################################################

# Load tokenized sequences
input_dir = Path("tokenized_midi")
files = list(input_dir.glob("*.pkl"))

styled_tokens = []
for f in files:
    with open(f, "rb") as infile:
        styled_tokens.append(pickle.load(infile))

# Build vocab
all_tokens = set(token for seq in styled_tokens for token in seq)
token2id = {tok: i for i, tok in enumerate(sorted(all_tokens))}
id2token = {i: tok for tok, i in token2id.items()}

# Add pad token
pad_token = "<PAD>"
pad_token_id = len(token2id)
token2id[pad_token] = pad_token_id
id2token[pad_token_id] = pad_token

# Convert tokens to IDs
token_ids = [[token2id[tok] for tok in seq] for seq in styled_tokens]

# Wrap in HF Dataset
examples = [{"input_ids": ids} for ids in token_ids]
hf_dataset = Dataset.from_list(examples)

<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>Model</h1>
</div>

In [5]:
##########################################################################################
# Define Model
##########################################################################################
config = GPT2Config(
    vocab_size=len(token2id),
    n_positions=1024,
    n_layer=4,
    n_head=4,
    pad_token_id=pad_token_id
)
model = GPT2LMHeadModel(config)

In [6]:
##########################################################################################
# Fine-tuning
##########################################################################################
class SimpleDataCollator:
    def __call__(self, examples):
        input_ids = [torch.tensor(e["input_ids"], dtype=torch.long) for e in examples]
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
        labels = input_ids.clone()
        return {"input_ids": input_ids, "labels": labels}

training_args = TrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=2,
    num_train_epochs=10,
    logging_dir="./logs",
    save_steps=100,
    logging_steps=10
)

collator = SimpleDataCollator()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_dataset,
    data_collator=collator
)

trainer.train()

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
10,3.1271
20,2.2377


TrainOutput(global_step=20, training_loss=2.6823679924011232, metrics={'train_runtime': 25.7607, 'train_samples_per_second': 1.553, 'train_steps_per_second': 0.776, 'total_flos': 16977790771200.0, 'train_loss': 2.6823679924011232, 'epoch': 10.0})

<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>Music Generation</h1>
</div>

In [7]:
########################################################################################## 
# Generate Music
##########################################################################################
start = ["<composer:Bach>", "<era:Baroque>"]
start_ids = [token2id[tok] for tok in start]
input_ids = torch.tensor(start_ids).unsqueeze(0)

# Move model & input to CPU
device = torch.device("cpu")
model = model.to(device)
input_ids = input_ids.to(device)

output = model.generate(
    input_ids=input_ids,
    max_length=512,
    do_sample=True,
    temperature=1.0
)

generated_ids = output[0].tolist()
generated_tokens = [id2token[i] for i in generated_ids]

In [8]:
########################################################################################## 
# Convert to MIDI
##########################################################################################

# Re-init tokenizer
tokenizer = REMI()

# Remove style tokens like <composer:*>, <era:*>
music_tokens = [tok for tok in generated_tokens if "_" in tok]

# Wrap in TokSequence
tok_seq = TokSequence(tokens=music_tokens)

# Decode back to a symusic.Score/ScoreTick object
score = tokenizer.decode([tok_seq])

# Save to .mid using dump_midi()
score.dump_midi("generations/generated_classical.mid")

<div class="alert alert-block alert-warning"  style="background-color: #c78cf5; color: black;">  
  <h1>Try it Out</h1>
</div>

In [9]:
def generate_classical_midi(composer: str, era: str, filename="generated_piece.mid"):
    from miditok import REMI, TokSequence
    import torch

    # Initialize tokenizer + model device
    tokenizer = REMI()
    device = torch.device("cpu")
    model.to(device)

    # Prepare input tokens
    prompt_tokens = [f"<composer:{composer}>", f"<era:{era}>"]
    prompt_ids = [token2id[t] for t in prompt_tokens]
    input_ids = torch.tensor(prompt_ids).unsqueeze(0).to(device)

    # Generate
    output = model.generate(
        input_ids=input_ids,
        max_length=512,
        do_sample=True,
        temperature=1.0
    )

    # Convert IDs to tokens
    generated_ids = output[0].tolist()
    generated_tokens = [id2token[i] for i in generated_ids]

    # Filter out label tokens
    music_tokens = [tok for tok in generated_tokens if "_" in tok]

    # Decode to MIDI
    tok_seq = TokSequence(tokens=music_tokens)
    score = tokenizer.decode([tok_seq])
    score.dump_midi(filename)

    print(f"🎵 Saved generated piece to {filename}")


In [10]:
generate_classical_midi("Bach", "Baroque", filename="generations/bach_baroque.mid")
generate_classical_midi("Beethoven", "Classical", filename="generations/beethoven_classical.mid")

🎵 Saved generated piece to generations/bach_baroque.mid
🎵 Saved generated piece to generations/beethoven_classical.mid
