### Imports

In [None]:
# !pip install miditok
# !pip install symusic
# !pip install glob
# !pip install torch
# !pip install pretty_midi
# !pip install midi2audio

import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from miditok.pytorch_data import DatasetMIDI, DataCollator
import glob
from miditok import REMI, TokenizerConfig
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast


[31mERROR: Could not find a version that satisfies the requirement glob (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for glob[0m[31m


  from .autonotebook import tqdm as notebook_tqdm




### Setup NES-MDB

In [None]:
NESMDB_PATH = "./nesmdb_midi/"
midi_data = pretty_midi.PrettyMIDI(NESMDB_PATH + 'train/297_SkyKid_00_01StartMusicBGMIntroBGM.mid')

for instrument in midi_data.instruments:
  print('-' * 80)
  print(instrument.name.upper())
  print('# note events: {}'.format(len(instrument.notes)))
  print('# control change events: {}'.format(len(instrument.control_changes)))

Path to dataset files: /home/josh/.cache/kagglehub/datasets/imsparsh/lakh-midi-clean/versions/1
--------------------------------------------------------------------------------
P1
# note events: 158
# control change events: 221
--------------------------------------------------------------------------------
P2
# note events: 197
# control change events: 73
--------------------------------------------------------------------------------
TR
# note events: 123
# control change events: 0
--------------------------------------------------------------------------------
NO
# note events: 6
# control change events: 164


### Tokenizer and Datasets

In [None]:
train_files = glob.glob(NESMDB_PATH + "train/*.mid")
test_files = glob.glob(NESMDB_PATH + "test/*.mid")

config = TokenizerConfig(
    use_time_signatures=True,
    use_tempos=True,
    use_programs=True,
    num_velocities=127,
    ac_polyphony_track = True,
    ac_polyphony_bar = True,
)

tokenizer = REMI(config)

train_dataset = DatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    )
test_dataset = DatasetMIDI(
    files_paths=test_files,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

  super().__init__(tokenizer_config, params)


In [5]:
input_dir = "./nesmdb_midi/train/"
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collator, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collator, num_workers=4)

In [6]:
len(train_loader), len(test_loader)

(1126, 94)

### The Model: GRU

In [None]:
class MusicGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(MusicGRU, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
        self.gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.norm(self.embedding(x))  
        out, hidden = self.gru(x, hidden)
        out = self.fc(out)
        return out, hidden

#### Training

In [None]:
from torch.amp import GradScaler, autocast

def train(model, train_loader, val_loader, vocab_size, num_epochs=10, lr=0.001, device='cuda'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), 3e-4, weight_decay=1e-2)
    scaler = GradScaler('cuda')

    for epoch in range(num_epochs):
        # Training
        model.train()
        total_train_loss = 0

        for batch in train_loader:
            batch = batch['input_ids'].to(device)  # (batch_size, seq_length)

            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            optimizer.zero_grad()

            with autocast('cuda'):
                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch['input_ids'].to(device)

                inputs = batch[:, :-1]
                targets = batch[:, 1:]

                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


if __name__ == "__main__":
    vocab_size = tokenizer.vocab_size
    embedding_dim = 128
    hidden_dim = 512
    num_layers = 2

    model = MusicGRU(vocab_size, embedding_dim, hidden_dim, num_layers)    
    train(model, train_loader, test_loader, vocab_size)

Epoch 1/10 | Train Loss: 2.4395 | Val Loss: 1.9650
Epoch 2/10 | Train Loss: 2.0247 | Val Loss: 1.8584
Epoch 3/10 | Train Loss: 1.9369 | Val Loss: 1.7829
Epoch 4/10 | Train Loss: 1.8575 | Val Loss: 1.7195
Epoch 5/10 | Train Loss: 1.7932 | Val Loss: 1.6797
Epoch 6/10 | Train Loss: 1.7512 | Val Loss: 1.6502
Epoch 7/10 | Train Loss: 1.7142 | Val Loss: 1.6344
Epoch 8/10 | Train Loss: 1.6886 | Val Loss: 1.6180
Epoch 9/10 | Train Loss: 1.6708 | Val Loss: 1.6053
Epoch 10/10 | Train Loss: 1.6510 | Val Loss: 1.5989


#### Sampling

In [76]:
def sample(model, start_token, max_length=100, temperature=0.8, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == 2 or next_token == 0: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated

In [80]:
start_token = tokenizer.special_tokens_ids[1]
generated_sequence = sample(model, start_token, max_length=2048)

print("Generated token sequence:")
print(generated_sequence)

import midi2audio
from midi2audio import FluidSynth # Import library
from IPython.display import Audio, display
fs = FluidSynth("FluidR3Mono_GM.sf3") # Initialize FluidSynth

output_score = tokenizer.tokens_to_midi(generated_sequence)
print(type(output_score))
# boost all note velocities
for track in output_score.tracks:
    for note in track.notes:
        note.velocity = min(127, max(60, int(note.velocity * 2)))

output_score.dump_midi(f"rnn.mid")
fs.midi_to_audio("rnn.mid", "rnn.wav")
display(Audio("rnn.wav"))

Generated token sequence:
[1, 4, 610, 284, 453, 553, 53, 100, 221, 554, 60, 100, 221, 511, 41, 93, 220, 286, 553, 55, 100, 221, 554, 60, 100, 221, 511, 41, 93, 220, 288, 553, 249, 100, 220, 553, 29, 100, 222, 554, 41, 100, 220, 554, 143, 100, 221, 511, 41, 93, 220, 289, 511, 41, 93, 220, 290, 553, 29, 100, 221, 554, 77, 100, 221, 511, 41, 93, 220, 292, 553, 29, 100, 221, 554, 60, 100, 221, 511, 41, 93, 220, 293, 553, 29, 100, 221, 554, 70, 100, 221, 511, 36, 93, 220, 294, 553, 32, 100, 221, 554, 72, 100, 221, 511, 41, 93, 220, 295, 553, 30, 100, 221, 554, 78, 100, 221, 511, 29, 93, 220, 297, 553, 29, 100, 221, 554, 73, 100, 221, 511, 29, 93, 220, 299, 553, 29, 100, 221, 554, 78, 100, 221, 511, 41, 93, 220, 301, 553, 29, 100, 221, 554, 60, 100, 221, 511, 41, 93, 220, 303, 553, 29, 100, 221, 554, 77, 100, 220, 554, 53, 100, 221, 511, 41, 93, 220, 305, 553, 31, 100, 220, 553, 29, 100, 221, 554, 44, 100, 221, 511, 29, 93, 220, 307, 553, 36, 100, 221, 554, 60, 100, 221, 511, 36, 93, 220, 30

  output_score = tokenizer.tokens_to_midi(generated_sequence)
fluidsynth: error: fluid_is_soundfont(): fopen() failed: 'File does not exist.'
Parameter 'FluidR3Mono_GM.sf3' not a SoundFont or MIDI file or error occurred identifying it.
