# MIDI music generation with GAN

In [1]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
!unzip 'maestro-v3.0.0-midi.zip'
!rm 'maestro-v3.0.0-midi.zip'
dataset_path = "maestro-v3.0.0"

--2025-07-25 13:16:02--  https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.2.207, 142.250.141.207, 74.125.137.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.2.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58416533 (56M) [application/octet-stream]
Saving to: ‘maestro-v3.0.0-midi.zip’


2025-07-25 13:16:03 (121 MB/s) - ‘maestro-v3.0.0-midi.zip’ saved [58416533/58416533]

Archive:  maestro-v3.0.0-midi.zip
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_09_R1_2004_05_ORIG_MID--AUDIO_09_R1_2004_06_Track06_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-03_ORIG_MID--AUDIO_14_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_01_R1_2004_01

In [2]:
import os
import shutil
from pathlib import Path

In [3]:
destination_folder = "dataset_midi"
os.makedirs(destination_folder, exist_ok=True)

midi_paths = list(Path(dataset_path).resolve().glob("**/*.mid")) + \
             list(Path(dataset_path).resolve().glob("**/*.midi"))

print(f"Trovati {len(midi_paths)} file MIDI.")

for midi_path in midi_paths:
    filename = os.path.basename(midi_path)
    destination_path = os.path.join(destination_folder, filename)
    shutil.copy2(midi_path, destination_path)

print("Copia completata in 'dataset_midi'.")


Trovati 1276 file MIDI.
Copia completata in 'dataset_midi'.


## Tokenizzazione

In [5]:
!pip install miditok

Collecting miditok
  Downloading miditok-3.0.6.post1-py3-none-any.whl.metadata (10 kB)
Collecting symusic>=0.5.0 (from miditok)
  Downloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)
Collecting pySmartDL (from symusic>=0.5.0->miditok)
  Downloading pySmartDL-1.3.4-py3-none-any.whl.metadata (2.8 kB)
Downloading miditok-3.0.6.post1-py3-none-any.whl (159 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.0/159.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m61.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pySmartDL-1.3.4-py3-none-any.whl (20 kB)
Installing collected packages: pySmartDL, symusic, miditok
Successfully installed miditok-3.0.6.post1 pySmartDL-1.3.4 symusic-0.5.8


In [6]:
from miditok import REMI, TokenizerConfig
from pathlib import Path
import os

Setup iniziale

In [7]:
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": BEAT_RES,
    "num_velocities": 3,
    "special_tokens": ["BOS", "EOS"],
    "use_chords": True,
    "use_rests": True,
    "use_tempos": True,
    "num_tempos": 8,
    "tempo_range": (50, 200),
}

Carica file MIDI

In [8]:
midi_dir = Path(destination_folder)
midis = list(midi_dir.glob("**/*.mid")) + list(midi_dir.glob("**/*.midi"))

if not midis:
    raise FileNotFoundError("Nessun file MIDI trovato in 'dataset_midi'.")

Train o carica il tokenizer

In [10]:
!gdown 1Uf734gntq6RLpAvcruWuKcHqPmcZ_4O8

Downloading...
From: https://drive.google.com/uc?id=1Uf734gntq6RLpAvcruWuKcHqPmcZ_4O8
To: /content/gan_tokenizer.json
  0% 0.00/217k [00:00<?, ?B/s]100% 217k/217k [00:00<00:00, 4.30MB/s]


In [13]:
TRAIN_TOKENIZER = True
VOCAB_SIZE = 5000

if TRAIN_TOKENIZER:
    print("Training del tokenizer su MIDI...")
    config = TokenizerConfig(**TOKENIZER_PARAMS)
    tokenizer = REMI(config)
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=midis)
    tokenizer.save_params("gan_tokenizer.json")
else:
    print("Caricamento tokenizer già addestrato...")
    tokenizer = REMI(params="gan_tokenizer.json")

print(f"Dimensione del vocabolario: {len(tokenizer)} token")

Training del tokenizer su MIDI...
Dimensione del vocabolario: 5000 token


  tokenizer.save_params("gan_tokenizer.json")


Tokenizza i file MIDI

In [14]:
output_dir = Path("tokenized_midis")
output_dir.mkdir(exist_ok=True)

for midi_path in midis:
    try:
        tokens = tokenizer(midi_path)
        tokens.save(output_dir / f"{midi_path.stem}.json")
        print(f"Tokenizzato: {midi_path.name}")
    except Exception as e:
        print(f"Errore su {midi_path.name}: {e}")

Errore su MIDI-Unprocessed_SMF_07_R1_2004_01_ORIG_MID--AUDIO_07_R1_2004_12_Track12_wav.midi: 'list' object has no attribute 'save'
Errore su ORIG-MIDI_01_7_7_13_Group__MID--AUDIO_14_R1_2013_wav--2.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_09_R1_2006_01-04_ORIG_MID--AUDIO_09_R1_2006_02_Track02_wav.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_R2_D2-12-13-15_mid--AUDIO-from_mp3_15_R2_2015_wav--1.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_Recital1-3_MID--AUDIO_02_R1_2018_wav--1.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_R1_D2-13-20_mid--AUDIO-from_mp3_17_R1_2015_wav--2.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_15_R1_2006_01-05_ORIG_MID--AUDIO_15_R1_2006_02_Track02_wav.midi: 'list' object has no attribute 'save'
Errore su MIDI-Unprocessed_10_R1_2011_MID--AUDIO_R1-D4_05_Track05_wav.midi: 'list' object has no attribute 'save'
Errore su ORIG-MIDI_02_7_7_13

KeyboardInterrupt: 

## GAN model

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim, hidden_dim, seq_length):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim * 32),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 32, hidden_dim * 16),
            nn.LayerNorm(hidden_dim * 16),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 16, hidden_dim * 8),
            nn.LayerNorm(hidden_dim * 8),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 8, hidden_dim * 4),
            nn.LayerNorm(hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.ReLU(),

            nn.Linear(hidden_dim * 2, seq_length),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, seq_length, hidden_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length, hidden_dim * 2),
            nn.LeakyReLU(0.2),

            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 4, hidden_dim * 8),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 8, hidden_dim * 16),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 16, hidden_dim * 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim * 32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)
