# MIDI music generation with GAN

In [1]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
!unzip 'maestro-v3.0.0-midi.zip'
!rm 'maestro-v3.0.0-midi.zip'
dataset_path = "maestro-v3.0.0"

--2025-07-27 13:45:29--  https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.117.207, 142.250.99.207, 142.250.107.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.117.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58416533 (56M) [application/octet-stream]
Saving to: ‘maestro-v3.0.0-midi.zip’


2025-07-27 13:45:30 (90.9 MB/s) - ‘maestro-v3.0.0-midi.zip’ saved [58416533/58416533]

Archive:  maestro-v3.0.0-midi.zip
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_09_R1_2004_05_ORIG_MID--AUDIO_09_R1_2004_06_Track06_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-03_ORIG_MID--AUDIO_14_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_01_R1_20

In [2]:
import os
import shutil
from pathlib import Path

In [3]:
destination_folder = "dataset_midi"
os.makedirs(destination_folder, exist_ok=True)

midi_paths = list(Path(dataset_path).resolve().glob("**/*.mid")) + \
             list(Path(dataset_path).resolve().glob("**/*.midi"))

print(f"Trovati {len(midi_paths)} file MIDI.")

for midi_path in midi_paths:
    filename = os.path.basename(midi_path)
    destination_path = os.path.join(destination_folder, filename)
    shutil.copy2(midi_path, destination_path)

print("Copia completata in 'dataset_midi'.")

# Convert to relative paths
midi_paths = [Path(os.path.join(destination_folder, os.path.basename(p))) for p in midi_paths]

Trovati 1276 file MIDI.
Copia completata in 'dataset_midi'.


## Tokenizzazione

In [4]:
!pip install miditok

Collecting miditok
  Downloading miditok-3.0.6.post1-py3-none-any.whl.metadata (10 kB)
Collecting symusic>=0.5.0 (from miditok)
  Downloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)
Collecting pySmartDL (from symusic>=0.5.0->miditok)
  Downloading pySmartDL-1.3.4-py3-none-any.whl.metadata (2.8 kB)
Downloading miditok-3.0.6.post1-py3-none-any.whl (159 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.0/159.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pySmartDL-1.3.4-py3-none-any.whl (20 kB)
Installing collected packages: pySmartDL, symusic, miditok
Successfully installed miditok-3.0.6.post1 pySmartDL-1.3.4 symusic-0.5.8


In [5]:
from miditok import REMI, TokenizerConfig
from pathlib import Path
import os

Setup iniziale

In [6]:
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": BEAT_RES,
    "num_velocities": 3,
    "special_tokens": ["BOS", "EOS"],
    "use_chords": True,
    "use_rests": True,
    "use_tempos": True,
    "num_tempos": 8,
    "tempo_range": (50, 200),
}

Carica file MIDI

In [7]:
midi_dir = Path(destination_folder)
midis = list(midi_dir.glob("**/*.mid")) + list(midi_dir.glob("**/*.midi"))

if not midis:
    raise FileNotFoundError("Nessun file MIDI trovato in 'dataset_midi'.")

Train o carica il tokenizer

In [8]:
config = TokenizerConfig(**TOKENIZER_PARAMS)

tokenizer = REMI(config)

vocab_size = 5000
tokenizer.train(vocab_size=vocab_size, files_paths=midis)
processed = [Path(f"{s}") for s in midis]
print(len(processed))

1276


Tokenizza i file MIDI

In [9]:
from miditok.data_augmentation import augment_dataset
from miditok.utils import split_files_for_training
from random import shuffle

dataset preparation

In [10]:
from pathlib import Path
from random import shuffle

# Assicurati che la directory base sia un Path assoluto
base_dir = Path("/content/dataset_midi").resolve()

# Ottieni tutti i file MIDI come path assoluti
midis = list(base_dir.glob("**/*.mid")) + list(base_dir.glob("**/*.midi"))
midis = [midi.resolve() for midi in midis]

# Suddivisione in train/valid/test
total_num_files = len(midis)
num_files_valid = round(total_num_files * 0.15)
num_files_test = round(total_num_files * 0.15)

shuffle(midis)

midi_paths_valid = midis[:num_files_valid]
midi_paths_test = midis[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midis[num_files_valid + num_files_test:]

# Chunk e augment
for files_paths, subset_name in (
    (midi_paths_train, "train"),
    (midi_paths_valid, "valid"),
    (midi_paths_test, "test")
):
    subset_chunks_dir = Path(f"Maestro_{subset_name}")

    split_files_for_training(
        files_paths=files_paths,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )

    augment_dataset(
        subset_chunks_dir,
        pitch_offsets=[-12, 12],
        velocity_offsets=[-4, 4],
        duration_offsets=[-0.5, 0.5],
    )

# Rileggi i file MIDI già chunkati e augmentati per ciascun subset
midi_paths_train = list(Path("Maestro_train").glob("**/*.mid")) + list(Path("Maestro_train").glob("**/*.midi"))
midi_paths_valid = list(Path("Maestro_valid").glob("**/*.mid")) + list(Path("Maestro_valid").glob("**/*.midi"))
midi_paths_test = list(Path("Maestro_test").glob("**/*.mid")) + list(Path("Maestro_test").glob("**/*.midi"))


Splitting music files (Maestro_train): 100%|██████████| 894/894 [00:06<00:00, 148.31it/s]
Performing data augmentation: 100%|██████████| 8733/8733 [00:18<00:00, 460.91it/s]
Splitting music files (Maestro_valid): 100%|██████████| 191/191 [00:01<00:00, 167.06it/s]
Performing data augmentation: 100%|██████████| 1805/1805 [00:03<00:00, 480.80it/s]
Splitting music files (Maestro_test): 100%|██████████| 191/191 [00:01<00:00, 177.85it/s]
Performing data augmentation: 100%|██████████| 1743/1743 [00:04<00:00, 383.54it/s]


tokenization

In [11]:
def midi_valid(midi) -> bool:
    if any(ts.numerator != 4 for ts in midi.time_signature_changes):
        return False  # time signature different from 4/*, 4 beats per bar
    return True

if os.path.exists("tokenized"):
  shutil.rmtree("tokenized")

for dir in ("train", "valid", "test"):
    tokenizer.tokenize_dataset(
        Path(f"Maestro_{dir}").resolve(),
        Path(f"tokenized_{dir}").resolve(),
        midi_valid,
    )

Tokenizing music files (content/tokenized_train): 100%|██████████| 52340/52340 [20:33<00:00, 42.45it/s]
Tokenizing music files (content/tokenized_valid): 100%|██████████| 10805/10805 [04:15<00:00, 42.36it/s]
Tokenizing music files (content/tokenized_test): 100%|██████████| 10405/10405 [04:10<00:00, 41.48it/s]


In [12]:
tokenizer.save("tokenizerMIDI")

In [13]:
import json
from tqdm import tqdm

In [14]:
def read_json(path: str) -> dict:
  with open(path, "r") as f:
    return json.load(f)

def read_json_files(json_file_paths):
    objects = []

    for file_path in tqdm(json_file_paths):
        try:
            objects.append(read_json(file_path))
        except FileNotFoundError:
            print(f"Error: File not found - {file_path}")
            return []
        except json.JSONDecodeError:
            print(f"Error decoding JSON in file: {file_path}")
            return []
    return objects

tokenized_train = list(Path("tokenized_train").resolve().glob("**/*.json"))
data_objects_train = read_json_files(tokenized_train)

tokenized_valid = list(Path("tokenized_valid").resolve().glob("**/*.json"))
data_objects_valid = read_json_files(tokenized_valid)

tokenized_test = list(Path("tokenized_test").resolve().glob("**/*.json"))
data_objects_test = read_json_files(tokenized_test)


if data_objects_train and data_objects_valid and data_objects_test:
    print(f"\nSuccessfully read {len(data_objects_train)} training JSON files.")
    print(f"Successfully read {len(data_objects_valid)} validation JSON files.")
    print(f"Successfully read {len(data_objects_test)} test JSON files.")
else:
    print("Error reading JSON files.")

100%|██████████| 52340/52340 [00:10<00:00, 5201.88it/s]
100%|██████████| 10805/10805 [00:01<00:00, 7023.60it/s]
100%|██████████| 10405/10405 [00:01<00:00, 6568.69it/s]


Successfully read 52340 training JSON files.
Successfully read 10805 validation JSON files.
Successfully read 10405 test JSON files.





In [None]:
import numpy as np

In [None]:
encoded_train = [np.array(song["ids"][0]) for song in data_objects_train]
encoded_valid = [np.array(song["ids"][0]) for song in data_objects_valid]
all_ids_train = np.concatenate(encoded_train)
all_ids_valid = np.concatenate(encoded_valid)

all_ids_train = all_ids_train.astype(dtype=np.int32)
all_ids_valid = all_ids_valid.astype(dtype=np.int32)

## GAN model

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim, hidden_dim, seq_length):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim * 32),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 32, hidden_dim * 16),
            nn.LayerNorm(hidden_dim * 16),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 16, hidden_dim * 8),
            nn.LayerNorm(hidden_dim * 8),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 8, hidden_dim * 4),
            nn.LayerNorm(hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.ReLU(),

            nn.Linear(hidden_dim * 2, seq_length),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, seq_length, hidden_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length, hidden_dim * 2),
            nn.LeakyReLU(0.2),

            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 4, hidden_dim * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 8, hidden_dim * 16),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 16, hidden_dim * 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
vocab_size = len(tokenizer)
seq_length = 512
noise_dim = 512
hidden_dim = 1536
batch_size = 128

generator_config = {
    "noise_dim"  : noise_dim,
    "hidden_dim" : hidden_dim,
    "seq_length" : seq_length
}


discriminator_config = {
    "seq_length" : seq_length,
    "hidden_dim" : hidden_dim
}

generator = Generator(**generator_config)
discriminator = Discriminator(**discriminator_config)

gan_params = {
    "generator"                : generator,
    "discriminator"            : discriminator,
    "noise_dim"                : noise_dim,
    "seq_length"               : seq_length,
    "tokenizer"                : tokenizer,
}

NameError: name 'tokenizer' is not defined

In [15]:
!zip -r tokenized_train.zip tokenized_train
!zip -r tokenized_valid.zip tokenized_valid
!zip -r tokenized_test.zip tokenized_test

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
  adding: tokenized_test/MIDI-Unprocessed_08_R2_2009_01_ORIG_MID--AUDIO_08_R2_2009_08_R2_2009_02_WAV_4#d192.json (deflated 65%)
  adding: tokenized_test/MIDI-Unprocessed_24_R1_2006_01-05_ORIG_MID--AUDIO_24_R1_2006_05_Track05_wav_0.json (deflated 63%)
  adding: tokenized_test/MIDI-UNPROCESSED_04-07-08-10-12-15-17_R2_2014_MID--AUDIO_08_R2_2014_wav_24#v4.json (deflated 63%)
  adding: tokenized_test/MIDI-Unprocessed_057_PIANO057_MID--AUDIO-split_07-07-17_Piano-e_1-07_wav--1_2#d-240.json (deflated 65%)
  adding: tokenized_test/MIDI-Unprocessed_10_R1_2006_01-04_ORIG_MID--AUDIO_10_R1_2006_03_Track03_wav_15#v4.json (deflated 63%)
  adding: tokenized_test/MIDI-Unprocessed_09_R3_2008_01-07_ORIG_MID--AUDIO_09_R3_2008_wav--3_0#d192.json (deflated 68%)
  adding: tokenized_test/MIDI-Unprocessed_10_R1_2011_MID--AUDIO_R1-D4_05_Track05_wav_8.json (deflated 60%)
  adding: tokenized_test/MIDI-UNPROCESSED_04-05_R1_2014_MID--AUDIO_05_R1_2014_w

GAN

In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

In [None]:

class GAN:
    def __init__(
        self,
        generator,
        discriminator,
        noise_dim,
        seq_length,
        tokenizer,
        loss_fn=None,
        lr=5e-5,
        betas=(0.5, 0.999)
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.noise_dim = noise_dim
        self.seq_length = seq_length
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer)

        self.loss_fn = loss_fn or nn.BCEWithLogitsLoss()

        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas)

        self._initialize_weights()

    @staticmethod
    def _init_weights(m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def _initialize_weights(self):
        self.generator.apply(self._init_weights)
        self.discriminator.apply(self._init_weights)

    def noise(self, batch_size, device):
        return torch.randn(batch_size, self.noise_dim, device=device)

    def _generator_step(self, batch_size, device):
        self.generator.zero_grad()

        fake_noise = self.noise(batch_size, device)
        fake_data = self.generator(fake_noise)
        predictions = self.discriminator(fake_data)

        generator_loss = -torch.mean(predictions)

        if not torch.isnan(generator_loss) and not torch.isinf(generator_loss):
            generator_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
            self.generator_optimizer.step()
        else:
            print("[!] NaN or Inf detected in generator loss")

        return generator_loss

    def _discriminator_step(self, batch_size, real_data, device):
        self.discriminator.zero_grad()

        fake_noise = self.noise(batch_size, device)
        fake_data = self.generator(fake_noise).detach()

        real_preds = self.discriminator(real_data)
        fake_preds = self.discriminator(fake_data)

        real_labels = torch.ones_like(real_preds)
        fake_labels = torch.zeros_like(fake_preds)

        loss_real = self.loss_fn(real_preds, real_labels)
        loss_fake = self.loss_fn(fake_preds, fake_labels)

        discriminator_loss = loss_real + loss_fake

        if not torch.isnan(discriminator_loss) and not torch.isinf(discriminator_loss):
            discriminator_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0)
            self.discriminator_optimizer.step()
        else:
            print("[!] NaN or Inf detected in discriminator loss")

        return discriminator_loss

    def _train_step(self, batch_size, real_data, device, gen_steps=1, disc_steps=1):
        disc_losses = [self._discriminator_step(batch_size, real_data, device) for _ in range(disc_steps)]
        gen_losses = [self._generator_step(batch_size, device) for _ in range(gen_steps)]

        avg_disc_loss = torch.stack(disc_losses).mean().item()
        avg_gen_loss = torch.stack(gen_losses).mean().item()
        return avg_disc_loss, avg_gen_loss

    def train(
        self,
        dataloader,
        epochs,
        device,
        loss_delta=0.7,
        steps_each_print=5,
        advantage_steps=2,
        alternate_training=False,
        gen_steps=1,
        disc_steps=1
    ):
        print(f"Starting training for {epochs} epochs...")

        self.generator.to(device).train()
        self.discriminator.to(device).train()

        discriminator_loss_history = []
        generator_loss_history = []

        for epoch in range(epochs):
            pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
            for real_data in pbar:
                real_data = real_data.to(device)
                batch_size = real_data.size(0)

                disc_loss, gen_loss = self._train_step(batch_size, real_data, device, gen_steps, disc_steps)

                discriminator_loss_history.append(disc_loss)
                generator_loss_history.append(gen_loss)

                if alternate_training:
                    if gen_loss - disc_loss > loss_delta:
                        gen_steps, disc_steps = advantage_steps, 1
                    elif disc_loss - gen_loss > loss_delta:
                        disc_steps, gen_steps = advantage_steps, 1
                    else:
                        gen_steps = disc_steps = 1

                if len(discriminator_loss_history) % steps_each_print == 0:
                    pbar.set_description(f"D Loss: {disc_loss:.4f}, G Loss: {gen_loss:.4f}")

            print(f"Epoch {epoch+1} - D Loss: {disc_loss:.4f}, G Loss: {gen_loss:.4f}")

        return discriminator_loss_history, generator_loss_history

    def save_generator(self, path):
        torch.save(self.generator.state_dict(), path)
        print(f"Generator saved to {path}")

    def save_discriminator(self, path):
        torch.save(self.discriminator.state_dict(), path)
        print(f"Discriminator saved to {path}")

    def predict(self, n_samples, device):
        z = self.noise(n_samples, device)
        return self.generator(z).cpu().detach().numpy()

    def generate(self, n_samples, base_output_name, device):
        predictions = self.predict(n_samples, device)
        boundary = int(self.vocab_size / 2)
        predictions = [x * boundary + boundary for x in predictions]

        for i, pred in enumerate(predictions):
            tokens = np.clip(np.round(pred).astype(int), 0, self.vocab_size - 1)
            try:
                decoded = self.tokenizer.decode([tokens])
                if hasattr(decoded, 'dump_midi'):
                    decoded.dump_midi(f"{base_output_name}_{i}.mid")
                else:
                    print(f"[!] Output {i} is not a MIDI-compatible object.")
            except Exception as e:
                print(f"[!] Error generating MIDI {i}: {e}")


In [None]:
class TokenDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32)

vocab_size = len(tokenizer)

def normalize_data(data, vocab_size, interval=(-1, 1)):
    min_val, max_val = interval
    normalized_data = (data - vocab_size / 2) / (vocab_size / 2)
    scaled_data = normalized_data * (max_val - min_val) / 2 + (max_val + min_val) / 2

    return scaled_data

normalized_seq = normalize_data(all_ids_train, vocab_size)
print(f"Max is {normalized_seq.max()}, Min is {normalized_seq.min()}")

all_ids_train_seq = [normalized_seq[i:i + seq_length]
                 for i in range(0, len(normalized_seq) - seq_length, seq_length)]

dataset = TokenDataset(all_ids_train_seq)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)
_ = torch.ones(1, device='cuda')

In [None]:
training_arguments = {
    "dataloader"         : dataloader,
    "epochs"             : 50,
    "device"             : device,
    "steps_each_print"   : 10,
    "gen_steps"          : 1,
    "disc_steps"         : 1,
}
if TRAIN:
  losses_discriminator, losses_generator = gan.train(**training_arguments)