## Install MAESTRO

In [None]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip
!unzip maestro-v3.0.0.zip

## Install dependencies 

In [10]:
!pip install torch transformers pretty_midi numpy tqdm




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Import modules

In [16]:
import os
import pandas as pd
import pretty_midi
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, GPT2Model, GPT2Config
from torch.utils.data import Dataset
from transformers import BertTokenizer

## Hypothetical model architecture

In [17]:
class ConditionalMusicTransformer(nn.Module):
    def __init__(self, vocab_size, text_embed_dim=256, music_embed_dim=512, n_head=8):
        super().__init__()
        self.text_encoder = GPT2Model(GPT2Config(
            vocab_size=30522,
            n_embd=text_embed_dim,
            n_head=n_head,
            n_layer=4
        ))
        self.music_embed = nn.Embedding(vocab_size, music_embed_dim)
        self.transformer = nn.Transformer(
            d_model=music_embed_dim,
            nhead=n_head,
            num_encoder_layers=4,
            num_decoder_layers=4
        )
        self.text_proj = nn.Linear(text_embed_dim, music_embed_dim)
        self.note_predictor = nn.Linear(music_embed_dim, vocab_size)

    def forward(self, music_seq, text_input, mask=None):
        text_features = self.text_encoder(**text_input).last_hidden_state.mean(1)
        text_features = self.text_proj(text_features)
        music_emb = self.music_embed(music_seq)
        music_emb = music_emb + text_features.unsqueeze(1)
        output = self.transformer(music_emb, music_emb, tgt_mask=mask)
        return self.note_predictor(output)

## Dataset

In [18]:
MAESTRO_PATH = "maestro-v3.0.0"
BATCH_SIZE = 16
SEQ_LENGTH = 512

metadata = pd.read_csv(os.path.join(MAESTRO_PATH, "maestro-v3.0.0.csv"))
midi_files = [os.path.join(MAESTRO_PATH, row.midi_filename) for _, row in metadata.iterrows()]

In [19]:
class MaestroDataset(Dataset):
    def __init__(self, midi_files, metadata, seq_length=512):
        self.midi_files = midi_files
        self.metadata = metadata
        self.seq_length = seq_length
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Musical note vocabulary
        self.vocab = {
            **{i: i for i in range(128)},  # Notes
            **{128+i: i for i in range(4)} # Durations (0.25, 0.5, 1.0, 2.0 sec)
        }

    def __len__(self):
        return len(self.midi_files)

    def _quantize_duration(self, duration):
        quantized = round(duration * 4) / 4
        return min(quantized, 2.0)

    def __getitem__(self, idx):
        midi_data = pretty_midi.PrettyMIDI(self.midi_files[idx])

        notes = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                duration = self._quantize_duration(note.end - note.start)
                notes.append((note.pitch, duration))

        bpm = int(self.metadata.iloc[idx].tempo)
        tempo_desc = "fast" if bpm > 120 else "moderate" if bpm > 80 else "slow"
        text = f"classical piano {tempo_desc} tempo"
        text_tokens = self.tokenizer(
            text,
            max_length=32,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        seq = []
        for pitch, duration in notes[:self.seq_length]:
            seq.append(pitch)
            seq.append(128 + int(duration * 4))

        if len(seq) < self.seq_length:
            seq += [0] * (self.seq_length - len(seq))

        return {
            'input_seq': torch.LongTensor(seq[:-1]),
            'target_seq': torch.LongTensor(seq[1:]),
            'text_input': text_tokens
        }


## Model training

In [None]:
dataset = MaestroDataset(midi_files, metadata, SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = ConditionalMusicTransformer(vocab_size=132)  # The size: 128 notes + 4 durations
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(5):
    for batch in dataloader:
        outputs = model(
            batch['input_seq'],
            batch['text_input']
        )
        loss = torch.nn.functional.cross_entropy(
            outputs.view(-1, 132),
            batch['target_seq'].view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

## Music generation code with example

In [None]:
def generate_from_prompt(model, prompt, max_length=512):
    model.eval()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    text_input = tokenizer(
        prompt,
        return_tensors='pt',
        max_length=32,
        truncation=True,
        padding='max_length'
    )

    current_seq = torch.LongTensor([[60, 130]])

    for _ in range(max_length//2):
        with torch.no_grad():
            output = model(current_seq, text_input)
            next_pitch = torch.argmax(output[0, -2:-1])
            next_duration = torch.argmax(output[0, -1:])
            current_seq = torch.cat([
                current_seq,
                torch.LongTensor([[next_pitch, next_duration]])
            ], dim=1)

    return current_seq.squeeze().tolist()

# Example
generated = generate_from_prompt(model, "classical piano fast tempo")

## Save to MIDI

In [None]:
def seq_to_midi(seq, filename):
    midi = pretty_midi.PrettyMIDI()
    piano = pretty_midi.Instrument(program=0)

    time = 0.0
    for i in range(0, len(seq)-1, 2):
        pitch = seq[i]
        duration = (seq[i+1] - 128) * 0.25 if seq[i+1] >= 128 else 0.25

        note = pretty_midi.Note(
            velocity=100,
            pitch=pitch,
            start=time,
            end=time + duration
        )
        piano.notes.append(note)
        time += duration

    midi.instruments.append(piano)
    midi.write(filename)

seq_to_midi(generated, "generated.mid")