# ðŸŽ¶ Music Generation with Transformers

In this notebook, we'll walk through the steps required to train your own Transformer model to generate music in the style of the Bach cello suites

In [None]:
import os
import glob
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import music21
import matplotlib.pyplot as plt
import random

## 0. Parameters <a name="parameters"></a>

In [None]:
PARSE_MIDI_FILES = True
PARSED_DATA_PATH = "./parsed_data/"
DATASET_REPETITIONS = 1

SEQ_LEN = 50
EMBEDDING_DIM = 256
KEY_DIM = 256
N_HEADS = 5
DROPOUT_RATE = 0.3
FEED_FORWARD_DIM = 256
LOAD_MODEL = False

EPOCHS = 5000
BATCH_SIZE = 256
GENERATE_LEN = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


## 1. Prepare the Data

In [None]:
# %%
file_list = glob.glob("./data/bach-cello/*.mid")
print(f"Found {len(file_list)} MIDI files")

In [None]:
parser = music21.converter
example_score = parser.parse(file_list[0]).chordify()
example_score.show("text")

In [None]:
notes = [["C4","E4","G4","C5"]*20]*1000  # list of list of note strings
durations = [["0.25","0.25","0.25","0.25"]*20]*1000  # list of list of duration strings

In [None]:
example_notes = notes[658]
example_durations = durations[658]
print("\nNotes string\n", example_notes, "...")
print("\nDuration string\n", example_durations, "...")

## 2. Tokenize the data <a name="tokenize"></a>

In [None]:
from collections import Counter

def build_vocab(sequences):
    counter = Counter()
    for seq in sequences:
        counter.update(seq)
    most_common = counter.most_common(5000)
    itos = ["<pad>", "<unk>"] + [w for w,_ in most_common]
    stoi = {w:i for i,w in enumerate(itos)}
    return stoi, itos

notes_stoi, notes_itos = build_vocab(notes)
durations_stoi, durations_itos = build_vocab(durations)

def tokenize_sequence(seq, stoi, max_len):
    tokens = [stoi.get(s, stoi["<unk>"]) for s in seq][:max_len+1]
    x = tokens[:-1]
    y = tokens[1:]
    if len(x) < max_len:
        pad_len = max_len - len(x)
        x += [0]*pad_len
        y += [0]*pad_len
    return x, y

In [None]:
class MusicDataset(Dataset):
    def __init__(self, notes, durations, notes_stoi, durations_stoi, seq_len):
        self.notes = notes
        self.durations = durations
        self.notes_stoi = notes_stoi
        self.durations_stoi = durations_stoi
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.notes)
    
    def __getitem__(self, idx):
        x_notes, y_notes = tokenize_sequence(self.notes[idx], self.notes_stoi, self.seq_len)
        x_dur, y_dur = tokenize_sequence(self.durations[idx], self.durations_stoi, self.seq_len)
        return torch.tensor(x_notes), torch.tensor(x_dur), torch.tensor(y_notes), torch.tensor(y_dur)

dataset = MusicDataset(notes, durations, notes_stoi, durations_stoi, SEQ_LEN)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


## 3. Causal mask <a name="create"></a>

In [None]:
def causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask  # shape: [seq_len, seq_len]

## 6. Create a Transformer Block layer <a name="transformer"></a>

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=DROPOUT_RATE):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.ln2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        seq_len = x.size(1)
        mask = causal_mask(seq_len).to(x.device)
        attn_out, attn_weights = self.attn(x, x, x, attn_mask=mask)
        x = self.ln1(x + self.dropout(attn_out))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x, attn_weights


## 7. Create the Token and Position Embedding <a name="embedder"></a>

In [None]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len=SEQ_LEN):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Parameter(torch.zeros(1, max_len, embed_dim))
    
    def forward(self, x):
        x = self.token_emb(x)
        x = x + self.pos_emb[:, :x.size(1), :]
        return x

## 8. Build the Transformer model <a name="transformer_decoder"></a>

In [None]:
class MusicTransformer(nn.Module):
    def __init__(self, notes_vocab_size, durations_vocab_size, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.note_emb = TokenAndPositionEmbedding(notes_vocab_size, embed_dim//2)
        self.dur_emb = TokenAndPositionEmbedding(durations_vocab_size, embed_dim//2)
        self.transformer = TransformerBlock(embed_dim, num_heads, ff_dim)
        self.fc_notes = nn.Linear(embed_dim, notes_vocab_size)
        self.fc_dur = nn.Linear(embed_dim, durations_vocab_size)
    
    def forward(self, x_notes, x_dur):
        note_e = self.note_emb(x_notes)
        dur_e = self.dur_emb(x_dur)
        x = torch.cat([note_e, dur_e], dim=-1)
        x, attn = self.transformer(x)
        note_out = self.fc_notes(x)
        dur_out = self.fc_dur(x)
        return note_out, dur_out, attn

model = MusicTransformer(len(notes_itos), len(durations_itos), EMBEDDING_DIM, N_HEADS, FEED_FORWARD_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

## 9. Train the Transformer <a name="train"></a>

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x_notes, x_dur, y_notes, y_dur in loader:
        x_notes, x_dur, y_notes, y_dur = x_notes.to(DEVICE), x_dur.to(DEVICE), y_notes.to(DEVICE), y_dur.to(DEVICE)
        optimizer.zero_grad()
        note_logits, dur_logits, _ = model(x_notes, x_dur)
        loss = criterion(note_logits.view(-1, len(notes_itos)), y_notes.view(-1)) + \
               criterion(dur_logits.view(-1, len(durations_itos)), y_dur.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader, optimizer, criterion)
    print(f"Epoch {epoch+1}/{EPOCHS} - loss: {loss:.4f}")

# 3. Generate music using the Transformer

In [None]:
def sample_from_logits(logits, temperature=1.0):
    probs = F.softmax(logits / temperature, dim=-1)
    token = torch.multinomial(probs, num_samples=1)
    return token.item()

def generate_music(model, start_notes, start_durations, max_tokens=50, temperature=0.5):
    model.eval()
    note_tokens = [notes_stoi.get(n, 1) for n in start_notes]
    dur_tokens = [durations_stoi.get(d, 1) for d in start_durations]
    generated_notes = note_tokens.copy()
    generated_durs = dur_tokens.copy()
    midi_stream = music21.stream.Stream()
    midi_stream.append(music21.clef.BassClef())
    for n,d in zip(start_notes, start_durations):
        midi_stream.append(music21.note.Note(n) if n != "START" else music21.note.Rest())
    for _ in range(max_tokens):
        x_notes = torch.tensor([generated_notes[-SEQ_LEN:]], device=DEVICE)
        x_dur = torch.tensor([generated_durs[-SEQ_LEN:]], device=DEVICE)
        note_logits, dur_logits, _ = model(x_notes, x_dur)
        next_note = sample_from_logits(note_logits[0,-1], temperature)
        next_dur = sample_from_logits(dur_logits[0,-1], temperature)
        generated_notes.append(next_note)
        generated_durs.append(next_dur)
        midi_stream.append(music21.note.Note(notes_itos[next_note]))
    return midi_stream

# %%
midi_stream = generate_music(model, ["START"], ["0.0"], max_tokens=50)
midi_stream.show()