
# MAESTRO Transformer Continuation

This notebook trains a compact Transformer model on the [MAESTRO](https://magenta.tensorflow.org/datasets/maestro) piano performance dataset and demonstrates how to continue a piano roll sequence after seeding it with two seconds of music. The workflow focuses on simplicity and reproducibility so that the complete pipeline can run on a single GPU machine.



## 1. Environment

Install the minimal set of packages required for MIDI handling and model training. All packages are available on PyPI. Execute this cell once per environment.


In [None]:

%pip install --quiet pretty_midi tqdm



## 2. Configuration

Set up paths and hyperparameters. By default the notebook expects the MAESTRO dataset to be extracted in a local `maestro/` folder. You can override this location by setting the `MAESTRO_ROOT` environment variable before launching the notebook. Because the full dataset is large (~100 hours of audio), you can limit the number of files processed during experimentation with `MAX_FILES`.


In [None]:

import os
from pathlib import Path

DATA_ROOT = Path(os.getenv("MAESTRO_ROOT", "maestro"))
MAX_FILES = int(os.getenv("MAESTRO_MAX_FILES", 24))  # Keep the training set manageable.
SAMPLE_RATE = 100  # Time steps per second for piano-roll sampling.
SEQ_LEN = 200  # Number of frames used as model context (~2 seconds at 100 fps).
LOWEST_MIDI = 21
N_PITCHES = 88
BATCH_SIZE = 8
EPOCHS = 5
DEVICE = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") not in (None, "") else "cpu"
print(f"Using dataset root: {DATA_ROOT}")
print(f"Accelerator: {DEVICE}")



> **Dataset download**: Download and extract the MAESTRO v3.0.0 MIDI archive from Google Cloud Storage:
>
> ```bash
> wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
> unzip maestro-v3.0.0-midi.zip -d maestro
> ```
>
> The extracted folder should include subdirectories such as `2004`, `2006`, ..., `2018` along with a `maestro-v3.0.0.csv` metadata file.



## 3. Data pipeline

The MAESTRO dataset contains expressive piano performances. We convert each MIDI file into a piano-roll representation at 100 frames per second, covering the 88 keys of the piano (A0 to C8). Each frame stores whether a given key is active (`1.0`) or not (`0.0`).

To keep preprocessing light, we load a configurable number of files and sample a fixed number of segments from each file for training. The result is a PyTorch `Dataset` that returns `(input, target)` pairs, where the target is the input shifted by one frame.


In [None]:

import math
import random
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)


def list_maestro_files(root: Path, max_files: int | None = None) -> list[Path]:
    midi_paths = sorted(root.glob("**/*.mid")) + sorted(root.glob("**/*.midi"))
    if not midi_paths:
        raise FileNotFoundError(
            f"No MIDI files were found in {root}. Check that the dataset is downloaded and extracted."
        )
    if max_files is not None:
        midi_paths = midi_paths[:max_files]
    return midi_paths


def midi_to_roll(
    path: Path,
    fs: int = SAMPLE_RATE,
    lowest: int = LOWEST_MIDI,
    n_pitches: int = N_PITCHES,
) -> np.ndarray | None:
    try:
        midi = pretty_midi.PrettyMIDI(str(path))
    except Exception as exc:  # pragma: no cover - defensive against malformed files
        print(f"Skipping {path}: {exc}")
        return None

    piano_roll = None
    for inst in midi.instruments:
        if inst.is_drum:
            continue
        roll = inst.get_piano_roll(fs=fs)
        if roll.size == 0:
            continue
        roll = roll[lowest : lowest + n_pitches]
        roll = (roll > 0).astype(np.float32)
        roll = roll.T  # Time first
        if piano_roll is None:
            piano_roll = roll
        else:
            # Combine instruments by taking the maximum activation
            if roll.shape[0] < piano_roll.shape[0]:
                pad = np.zeros((piano_roll.shape[0] - roll.shape[0], roll.shape[1]), dtype=np.float32)
                roll = np.vstack([roll, pad])
            elif piano_roll.shape[0] < roll.shape[0]:
                pad = np.zeros((roll.shape[0] - piano_roll.shape[0], piano_roll.shape[1]), dtype=np.float32)
                piano_roll = np.vstack([piano_roll, pad])
            piano_roll = np.maximum(piano_roll, roll)

    if piano_roll is None:
        return None

    # Trim leading/trailing silence
    active = np.where(piano_roll.sum(axis=1) > 0)[0]
    if active.size == 0:
        return None
    start, end = active[0], active[-1] + 1
    piano_roll = piano_roll[start:end]
    return piano_roll.astype(np.float32)


class MaestroRollDataset(Dataset):
    def __init__(
        self,
        midi_paths: list[Path],
        fs: int = SAMPLE_RATE,
        seq_len: int = SEQ_LEN,
        max_segments_per_file: int = 32,
    ) -> None:
        self.seq_len = seq_len
        self.rolls: list[np.ndarray] = []
        self.index: list[tuple[int, int]] = []

        for path in tqdm(midi_paths, desc="Preparing piano rolls"):
            roll = midi_to_roll(path, fs=fs)
            if roll is None or roll.shape[0] <= seq_len + 1:
                continue
            self.rolls.append(roll)
            total = roll.shape[0] - (seq_len + 1)
            segments = min(max_segments_per_file, total)
            for _ in range(segments):
                start = random.randint(0, total)
                self.index.append((len(self.rolls) - 1, start))

        if not self.index:
            raise RuntimeError("Dataset preparation failed: no usable segments were extracted.")

    def __len__(self) -> int:
        return len(self.index)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        file_idx, start = self.index[idx]
        roll = self.rolls[file_idx]
        window = roll[start : start + self.seq_len + 1]
        x = torch.from_numpy(window[:-1])
        y = torch.from_numpy(window[1:])
        return x, y


midi_files = list_maestro_files(DATA_ROOT, MAX_FILES)
dataset = MaestroRollDataset(midi_files)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
len(dataset), len(dataloader)



## 4. Model

The model is a lightweight Transformer encoder. Each frame (an 88-dimensional binary vector) is projected into the model dimension, combined with sinusoidal positional encodings, and processed by a stack of Transformer encoder layers. The output layer projects back into the piano-roll space. We use a sigmoid activation combined with a binary cross-entropy loss to model active notes independently for each pitch and frame.


In [None]:

import torch.nn as nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class MaestroTransformer(nn.Module):
    def __init__(
        self,
        input_dim: int = N_PITCHES,
        d_model: int = 256,
        nhead: int = 8,
        num_layers: int = 4,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.input = nn.Linear(input_dim, d_model)
        self.positional = PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output = nn.Linear(d_model, input_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input(x)
        x = self.positional(x)
        x = self.encoder(x)
        return self.output(x)


model = MaestroTransformer().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
model



## 5. Training

The training loop optimizes the binary cross-entropy loss over the piano-roll predictions. Training logs provide per-epoch loss averages.


In [None]:

from statistics import mean


def train_epoch(model: nn.Module, loader: DataLoader) -> float:
    model.train()
    losses = []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        losses.append(loss.item())
    return mean(losses)


def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    losses = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            logits = model(xb)
            loss = criterion(logits, yb)
            losses.append(loss.item())
    return mean(losses)


for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, dataloader)
    val_loss = evaluate(model, dataloader)
    print(f"Epoch {epoch:02d} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")



## 6. Autoregressive continuation

We generate new music by seeding the model with two seconds of recorded piano-roll activations (200 frames at 100 fps) and iteratively predicting one frame at a time. The generated piano roll is converted back into a MIDI file for inspection.


In [None]:

from datetime import datetime


def sample_sequence(
    model: nn.Module,
    seed: torch.Tensor,
    steps: int,
    temperature: float = 1.0,
    threshold: float = 0.5,
    device: str = DEVICE,
) -> torch.Tensor:
    model.eval()
    generated = [seed.to(device)]
    context = seed.to(device)
    for _ in range(steps):
        with torch.no_grad():
            logits = model(context.unsqueeze(0))[:, -1]
            if temperature != 1.0:
                logits = logits / temperature
            probs = torch.sigmoid(logits)
            next_frame = (probs > threshold).float()
        generated.append(next_frame.cpu())
        context = torch.cat([context, next_frame], dim=0)
        if context.shape[0] > SEQ_LEN:
            context = context[-SEQ_LEN:]
    return torch.stack(generated)


def roll_to_pretty_midi(roll: np.ndarray, fs: int = SAMPLE_RATE) -> pretty_midi.PrettyMIDI:
    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    padded_roll = np.pad(roll, ((0, 1), (0, 0)), mode="constant")
    changes = np.diff(padded_roll, axis=0)
    for pitch in range(roll.shape[1]):
        on_indices = np.where(changes[:, pitch] == 1)[0]
        off_indices = np.where(changes[:, pitch] == -1)[0]
        for start, end in zip(on_indices, off_indices):
            note = pretty_midi.Note(
                velocity=80,
                pitch=pitch + LOWEST_MIDI,
                start=float(start) / fs,
                end=float(end) / fs,
            )
            instrument.notes.append(note)
    midi.instruments.append(instrument)
    return midi


# Choose a random seed segment from the dataset
seed_x, _ = dataset[random.randrange(len(dataset))]
seed_frames = int(2 * SAMPLE_RATE)
seed = seed_x[:seed_frames]

# Generate 8 seconds of continuation (~800 frames)
continuation = sample_sequence(model, seed, steps=800)
continuation_roll = continuation.numpy()

midi_out = roll_to_pretty_midi(continuation_roll)
out_path = Path("generated_" + datetime.now().strftime("%Y%m%d_%H%M%S") + ".mid")
midi_out.write(str(out_path))
print(f"Generated continuation saved to {out_path}")



## 7. Next steps

- Increase `MAX_FILES`, `BATCH_SIZE`, or model width/depth for higher fidelity.
- Train for more epochs or introduce a validation split with early stopping.
- Replace the binary piano-roll representation with an event-based tokenizer for nuanced dynamics and pedaling control.
- Experiment with sampling strategies (top-k, nucleus sampling) for more expressive continuations.
