### Midi generation using the PyTorch

In [2]:
import os
import glob
import random
import pickle
import json
import numpy as np
import music21 as m21

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

### For reproducibility

In [3]:
RANDOM_SEED = 42

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Configutation

In [5]:
SEQ_LEN = 32
BATCH_SIZE = 64
EMBEDDING_DIM = 256
HIDDEN_DIM = 256
LEARNING_RATE = 0.001
EPOCHS = 10 # Reduced for testing, can be increased
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Notebook is using device: {DEVICE}")

Notebook is using device: cpu


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
# there are about 5000 midi files but just use 50 music for now
DATA_PATH = "/content/drive/MyDrive/essen-dataset/*"
midi_file_paths = glob.glob(DATA_PATH)[:50]
print(f'Total files in dataset: {len(midi_file_paths)}')

Total files in dataset: 50


### Data Processing

In [8]:
def get_musical_data(song):
    pitches = []
    durations = []
    for element in song.flat.notesAndRests:
        if isinstance(element, m21.chord.Chord):
            if element.duration.quarterLength > 4:
                continue
            sorted_pitches = sorted([n.nameWithOctave for n in element.pitches])
            pitches.append('.'.join(sorted_pitches))
            durations.append(element.duration.quarterLength)
        elif isinstance(element, m21.note.Rest):
            if element.duration.quarterLength > 4:
                continue
            pitches.append('rest')
            durations.append(element.duration.quarterLength)
        elif isinstance(element, m21.note.Note):
            if element.duration.quarterLength > 4:
                continue
            pitches.append(str(element.nameWithOctave))
            durations.append(element.duration.quarterLength)
    return pitches, durations

def normalize_key(song):
    interval = m21.interval.Interval(0)
    try:
        key = song.analyze("key")
        if key.mode == "major":
            interval = m21.interval.Interval(key.tonic, m21.pitch.Pitch("C"))
        elif key.mode == "minor":
            interval = m21.interval.Interval(key.tonic, m21.pitch.Pitch("A"))
    except:
        pass
    return song.transpose(interval)

In [9]:
dataset_pitches = []
dataset_durations = []

for file_path in tqdm(midi_file_paths):
    try:
        song = m21.converter.parse(file_path).chordify()
        song = normalize_key(song)
        pitches, durations = get_musical_data(song)
        if pitches:
            dataset_pitches.append(pitches)
            dataset_durations.append(durations)
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")

  0%|          | 0/50 [00:00<?, ?it/s]

  return self.iter().getElementsByClass(classFilterList)


In [10]:
def get_unique_tokens(elements):
    token_names = sorted(set(elements))
    return token_names, len(token_names)

def create_vocabulary_mappings(token_names):
    token_to_int = {str(token): number for number, token in enumerate(token_names)}
    int_to_token = {number: str(token) for number, token in enumerate(token_names)}
    return token_to_int, int_to_token

In [11]:
# Build Vocabulary
flat_pitches = [p for song in dataset_pitches for p in song]
flat_durations = [d for song in dataset_durations for d in song]

unique_pitches, n_unique_pitches = get_unique_tokens(flat_pitches)
unique_durations, n_unique_durations = get_unique_tokens(flat_durations)

print(f"Unique Pitches: {n_unique_pitches}, Unique Durations: {n_unique_durations}")

pitch_to_int, int_to_pitch = create_vocabulary_mappings(unique_pitches)
duration_to_int, int_to_duration = create_vocabulary_mappings(unique_durations)


Unique Pitches: 33, Unique Durations: 8


In [12]:
# Save vocab mappings
with open('pt_int_to_pitch.json', 'w') as f: json.dump(int_to_pitch, f)
with open('pt_int_to_duration.json', 'w') as f: json.dump(int_to_duration, f)

In [13]:


def generate_midi_stream(composition_data):
    midi_stream = m21.stream.Stream()
    midi_stream.append(m21.instrument.Piano())

    for element in composition_data:
        pitch_pattern, duration_val = element

        if '.' in pitch_pattern:
            notes_in_chord = pitch_pattern.split('.')
            chord_notes = []
            for note_name in notes_in_chord:
                new_note = m21.note.Note(note_name)
                new_note.duration = m21.duration.Duration(quarterLength=duration_val)
                chord_notes.append(new_note)
            new_chord = m21.chord.Chord(chord_notes)
            midi_stream.append(new_chord)
        elif pitch_pattern == 'rest':
            new_rest = m21.note.Rest()
            new_rest.duration = m21.duration.Duration(quarterLength=duration_val)
            midi_stream.append(new_rest)
        else:
            new_note = m21.note.Note(pitch_pattern)
            new_note.duration = m21.duration.Duration(quarterLength=duration_val)
            midi_stream.append(new_note)

    return midi_stream

def round_duration(duration):
    standard_durations = [0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
    return min(standard_durations, key=lambda x: abs(x - duration))

def fraction(duration_str):
    if '/' in duration_str:
        lst = duration_str.split('/')
        return int(lst[0])/int(lst[1])
    else:
        try:
            return float(duration_str)
        except:
            return 0.25

def sample_with_temp(logits, temperature):
    if temperature == 0:
        return torch.argmax(logits).item()
    logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, 1).item()

### Dataset

In [14]:
class MidiDataset(Dataset):
    def __init__(self, pitches, durations, pitch_to_int, duration_to_int, seq_len=32):
        self.pitch_to_int = pitch_to_int
        self.duration_to_int = duration_to_int
        self.seq_len = seq_len
        self.pitch_seqs = []
        self.duration_seqs = []
        self.pitch_targets = []
        self.duration_targets = []

        for p_list, d_list in zip(pitches, durations):
            for i in range(len(p_list) - seq_len):
                p_in = [self.pitch_to_int[p] for p in p_list[i:i + seq_len]]
                p_out = self.pitch_to_int[p_list[i + seq_len]]

                d_in = [self.duration_to_int[str(d)] for d in d_list[i:i + seq_len]]
                d_out = self.duration_to_int[str(d_list[i + seq_len])]

                self.pitch_seqs.append(p_in)
                self.duration_seqs.append(d_in)
                self.pitch_targets.append(p_out)
                self.duration_targets.append(d_out)

        self.pitch_seqs = torch.tensor(self.pitch_seqs, dtype=torch.long)
        self.duration_seqs = torch.tensor(self.duration_seqs, dtype=torch.long)
        self.pitch_targets = torch.tensor(self.pitch_targets, dtype=torch.long)
        self.duration_targets = torch.tensor(self.duration_targets, dtype=torch.long)

    def __len__(self):
        return len(self.pitch_seqs)

    def __getitem__(self, idx):
        return (self.pitch_seqs[idx], self.duration_seqs[idx]), (self.pitch_targets[idx], self.duration_targets[idx])

In [15]:
# Prepare Dataset
dataset = MidiDataset(dataset_pitches, dataset_durations,
                      pitch_to_int, duration_to_int, seq_len=SEQ_LEN)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

### Model

In [16]:
class MidiGenerationModel(nn.Module):
    def __init__(self, n_pitches, n_durations, pitch_embed_dim=128, duration_embed_dim=32):
        super(MidiGenerationModel, self).__init__()

        self.pitch_embedding = nn.Embedding(n_pitches, pitch_embed_dim)
        self.duration_embedding = nn.Embedding(n_durations, duration_embed_dim)
        # Using Dropout1d for SpatialDropout1D behavior
        self.dropout_emb = nn.Dropout1d(0.2)

        input_dim = pitch_embed_dim + duration_embed_dim
        self.ln_emb = nn.LayerNorm(input_dim)

        self.lstm1 = nn.LSTM(input_dim, 256, batch_first=True)
        self.dropout1 = nn.Dropout(0.2)

        self.lstm2 = nn.LSTM(256, 128, batch_first=True)
        self.dropout2 = nn.Dropout(0.2)

        self.lstm3 = nn.LSTM(128, 128, batch_first=True)
        self.ln_lstm3 = nn.LayerNorm(128)

        self.dense = nn.Linear(128, 128)
        self.dropout_dense = nn.Dropout(0.2)

        self.pitch_head = nn.Linear(128, n_pitches)
        self.duration_head = nn.Linear(128, n_durations)

    def forward(self, pitch_seq, duration_seq):
        p_emb = self.pitch_embedding(pitch_seq)
        d_emb = self.duration_embedding(duration_seq)

        x = torch.cat([p_emb, d_emb], dim=2)

        # SpatialDropout1D equivalent: Drop entire channels across the sequence
        # Input to Dropout1d should be (batch, channel, seq_len)
        x = x.transpose(1, 2)
        x = self.dropout_emb(x)
        x = x.transpose(1, 2)

        x = self.ln_emb(x)

        x, _ = self.lstm1(x)
        x = self.dropout1(x)

        x, _ = self.lstm2(x)
        x = self.dropout2(x)

        _, (h_n, _) = self.lstm3(x)
        c = h_n[-1]
        c = self.ln_lstm3(c)

        c = torch.relu(self.dense(c))
        c = self.dropout_dense(c)

        pitch_out = self.pitch_head(c)
        duration_out = self.duration_head(c)

        return pitch_out, duration_out

In [17]:
# Initialize Model
model = MidiGenerationModel(n_unique_pitches, n_unique_durations).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [22]:
# Training Loop
print("Starting Training...")
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for (pitch_in, duration_in), (pitch_target, duration_target) in dataloader:
        pitch_in, duration_in = pitch_in.to(DEVICE), duration_in.to(DEVICE)
        pitch_target, duration_target = pitch_target.to(DEVICE), duration_target.to(DEVICE)

        optimizer.zero_grad()
        pitch_out, duration_out = model(pitch_in, duration_in)

        loss_pitch = criterion(pitch_out, pitch_target)
        loss_duration = criterion(duration_out, duration_target)
        loss = loss_pitch + loss_duration

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader):.4f}")

# Save Model
torch.save(model.state_dict(), 'midi_model.pth')
print("Model saved to midi_model.pth")

Starting Training...
Epoch 1/10, Loss: 0.3905
Epoch 2/10, Loss: 0.3414
Epoch 3/10, Loss: 0.3584
Epoch 4/10, Loss: 0.3275
Epoch 5/10, Loss: 0.3315
Epoch 6/10, Loss: 0.3274
Epoch 7/10, Loss: 0.3746
Epoch 8/10, Loss: 0.3503
Epoch 9/10, Loss: 0.3778
Epoch 10/10, Loss: 0.2474
Model saved to midi_model.pth


### Inference

In [23]:
# Generation
print("Generating Music...")
model.eval()

# Pick random seed from dataset
if len(dataset_pitches) > 0 and len(dataset_pitches[0]) > SEQ_LEN:
    song_idx = np.random.randint(len(dataset_pitches))
    start_idx = np.random.randint(len(dataset_pitches[song_idx]) - SEQ_LEN)

    seed_pitch = [pitch_to_int[p] for p in dataset_pitches[song_idx][start_idx:start_idx+SEQ_LEN]]
    seed_duration = [duration_to_int[str(d)] for d in dataset_durations[song_idx][start_idx:start_idx+SEQ_LEN]]
else:
    # Fallback random seed
    seed_pitch = [np.random.randint(n_unique_pitches) for _ in range(SEQ_LEN)]
    seed_duration = [np.random.randint(n_unique_durations) for _ in range(SEQ_LEN)]

curr_pitch_seq = seed_pitch
curr_duration_seq = seed_duration

generated_composition = []

with torch.no_grad():
    for _ in range(100): # Generate 100 notes
        p_in = torch.tensor([curr_pitch_seq], dtype=torch.long).to(DEVICE)
        d_in = torch.tensor([curr_duration_seq], dtype=torch.long).to(DEVICE)

        p_logits, d_logits = model(p_in, d_in)

        p_next = sample_with_temp(p_logits[0], 0.5)
        d_next_idx = sample_with_temp(d_logits[0], 0.5)

        # Decode
        p_str = int_to_pitch[p_next]
        d_str = int_to_duration[d_next_idx]
        d_val = round_duration(fraction(d_str))

        generated_composition.append([p_str, d_val])

        # Update sequences
        curr_pitch_seq = curr_pitch_seq[1:] + [p_next]
        curr_duration_seq = curr_duration_seq[1:] + [d_next_idx]

Generating Music...


In [24]:
# Save to MIDI
midi_stream = generate_midi_stream(generated_composition)
output_file = 'generated_pytorch.mid'
midi_stream.write('midi', fp=output_file)
print(f"Generated MIDI saved to {output_file}")

Generated MIDI saved to generated_pytorch.mid


### Download the generated music file

In [25]:
from google.colab import files
files.download(output_file)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>