# MIDI music generation with GAN

In [None]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
!unzip 'maestro-v3.0.0-midi.zip'
!rm 'maestro-v3.0.0-midi.zip'
dataset_path = "maestro-v3.0.0"

--2025-08-11 08:54:29--  https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.130.207, 74.125.68.207, 142.250.4.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.130.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58416533 (56M) [application/octet-stream]
Saving to: ‘maestro-v3.0.0-midi.zip’


2025-08-11 08:54:33 (15.5 MB/s) - ‘maestro-v3.0.0-midi.zip’ saved [58416533/58416533]

Archive:  maestro-v3.0.0-midi.zip
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_09_R1_2004_05_ORIG_MID--AUDIO_09_R1_2004_06_Track06_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-03_ORIG_MID--AUDIO_14_R1_2004_01_Track01_wav.midi  
  inflating: maestro-v3.0.0/2004/MIDI-Unprocessed_XP_01_R1_2004_01

In [1]:
!pip install miditok

Collecting miditok
  Downloading miditok-3.0.6.post1-py3-none-any.whl.metadata (10 kB)
Collecting symusic>=0.5.0 (from miditok)
  Downloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)
Collecting pySmartDL (from symusic>=0.5.0->miditok)
  Downloading pySmartDL-1.3.4-py3-none-any.whl.metadata (2.8 kB)
Downloading miditok-3.0.6.post1-py3-none-any.whl (159 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m159.0/159.0 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading symusic-0.5.8-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m50.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pySmartDL-1.3.4-py3-none-any.whl (20 kB)
Installing collected packages: pySmartDL, symusic, miditok
Successfully installed miditok-3.0.6.post1 pySmartDL-1.3.4 symusic-0.5.8


In [2]:
!pip install -q gdown

In [3]:
import os
import shutil
from pathlib import Path
import os
import numpy as np
from IPython.display import Audio, display
from miditok import REMI, TokenizerConfig
from pathlib import Path
from miditok.data_augmentation import augment_dataset
from miditok.utils import split_files_for_training
from random import shuffle
import json
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

In [None]:
destination_folder = "dataset_midi"
os.makedirs(destination_folder, exist_ok=True)

midi_paths = list(Path(dataset_path).resolve().glob("**/*.mid")) + \
             list(Path(dataset_path).resolve().glob("**/*.midi"))

print(f"Trovati {len(midi_paths)} file MIDI.")

for midi_path in midi_paths:
    filename = os.path.basename(midi_path)
    destination_path = os.path.join(destination_folder, filename)
    shutil.copy2(midi_path, destination_path)

print("Copia completata in 'dataset_midi'.")

# Convert to relative paths
midi_paths = [Path(os.path.join(destination_folder, os.path.basename(p))) for p in midi_paths]

Trovati 1276 file MIDI.
Copia completata in 'dataset_midi'.


## Tokenizzazione

Setup iniziale

In [None]:
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": BEAT_RES,
    "num_velocities": 3,
    "special_tokens": ["BOS", "EOS"],
    "use_chords": True,
    "use_rests": True,
    "use_tempos": True,
    "num_tempos": 8,
    "tempo_range": (50, 200),
}

Carica file MIDI

In [None]:
midi_dir = Path(destination_folder)
midis = list(midi_dir.glob("**/*.mid")) + list(midi_dir.glob("**/*.midi"))

if not midis:
    raise FileNotFoundError("Nessun file MIDI trovato in 'dataset_midi'.")

Train o carica il tokenizer

In [None]:
config = TokenizerConfig(**TOKENIZER_PARAMS)

tokenizer = REMI(config)

vocab_size = 5000
tokenizer.train(vocab_size=vocab_size, files_paths=midis)
processed = [Path(f"{s}") for s in midis]
print(len(processed))

1276


In [None]:
tokenizer.save("tokenizerMIDI")

Tokenizza i file MIDI

dataset preparation

In [None]:
# Assicurati che la directory base sia un Path assoluto
base_dir = Path("/content/dataset_midi").resolve()

# Ottieni tutti i file MIDI come path assoluti
midis = list(base_dir.glob("**/*.mid")) + list(base_dir.glob("**/*.midi"))
midis = [midi.resolve() for midi in midis]

# Suddivisione in train/valid/test
total_num_files = len(midis)
num_files_valid = round(total_num_files * 0.15)
num_files_test = round(total_num_files * 0.15)

shuffle(midis)

midi_paths_valid = midis[:num_files_valid]
midi_paths_test = midis[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midis[num_files_valid + num_files_test:]

# Chunk e augment
for files_paths, subset_name in (
    (midi_paths_train, "train"),
    (midi_paths_valid, "valid")
):
    subset_chunks_dir = Path(f"Maestro_{subset_name}")

    split_files_for_training(
        files_paths=files_paths,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )

    augment_dataset(
        subset_chunks_dir,
        pitch_offsets=[-12, 12],
        velocity_offsets=[-3, 3],
        duration_offsets=[-0.5, 0.5],
    )

# Rileggi i file MIDI già chunkati e augmentati per ciascun subset
midi_paths_train = list(Path("Maestro_train").glob("**/*.mid")) + list(Path("Maestro_train").glob("**/*.midi"))
midi_paths_valid = list(Path("Maestro_valid").glob("**/*.mid")) + list(Path("Maestro_valid").glob("**/*.midi"))


Splitting music files (Maestro_train): 100%|██████████| 894/894 [00:05<00:00, 161.22it/s]
Performing data augmentation: 100%|██████████| 8549/8549 [00:21<00:00, 400.39it/s]
Splitting music files (Maestro_valid): 100%|██████████| 191/191 [00:02<00:00, 91.50it/s]
Performing data augmentation: 100%|██████████| 1852/1852 [00:04<00:00, 406.03it/s]


tokenization

In [None]:
def midi_valid(midi) -> bool:
    if any(ts.numerator != 4 for ts in midi.time_signature_changes):
        return False  # time signature different from 4/*, 4 beats per bar
    return True

if os.path.exists("tokenized"):
  shutil.rmtree("tokenized")

for dir in ("train", "valid"):
    tokenizer.tokenize_dataset(
        Path(f"Maestro_{dir}").resolve(),
        Path(f"tokenized_{dir}").resolve(),
        midi_valid,
    )

Tokenizing music files (content/tokenized_train): 100%|██████████| 51073/51073 [21:40<00:00, 39.28it/s]
Tokenizing music files (content/tokenized_valid): 100%|██████████| 11162/11162 [04:41<00:00, 39.64it/s]


In [None]:
!zip -r tokenized_train.zip tokenized_train
!zip -r tokenized_valid.zip tokenized_valid

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
  adding: tokenized_valid/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav_7#d240.json (deflated 70%)
  adding: tokenized_valid/MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_04_Track04_wav_0#v3.json (deflated 63%)
  adding: tokenized_valid/MIDI-Unprocessed_Schubert1-3_MID--AUDIO_02_R2_2018_wav_2#p12.json (deflated 65%)
  adding: tokenized_valid/ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--4_5#p12.json (deflated 65%)
  adding: tokenized_valid/MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_01_WAV_4#v-3.json (deflated 64%)
  adding: tokenized_valid/MIDI-Unprocessed_XP_17_R2_2004_01_ORIG_MID--AUDIO_17_R2_2004_04_Track04_wav_8#d-240.json (deflated 65%)
  adding: tokenized_valid/MIDI-Unprocessed_046_PIANO046_MID--AUDIO-split_07-06-17_Piano-e_2-02_wav--1_3#v-3.json (deflated 62%)
  adding: tokenized_valid/ORIG-MIDI_01_7_8_13_Group__MID--AUDIO_03_R2_2013_

Scaricare tokenizer addestrato e midi

In [4]:
!gdown --id 1VfmfvXp8QiH10Q5VFTc_TCyKvGSthD3-
!gdown --id 1Cnj-YhsNcj_GTGTAwjgRsBwK8fXXKZax

!unzip tokenized_train.zip
!unzip tokenized_valid.zip

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
  inflating: tokenized_valid/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav_7#d240.json  
  inflating: tokenized_valid/MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_04_Track04_wav_0#v3.json  
  inflating: tokenized_valid/MIDI-Unprocessed_Schubert1-3_MID--AUDIO_02_R2_2018_wav_2#p12.json  
  inflating: tokenized_valid/ORIG-MIDI_02_7_7_13_Group__MID--AUDIO_15_R1_2013_wav--4_5#p12.json  
  inflating: tokenized_valid/MIDI-Unprocessed_16_R2_2009_01_ORIG_MID--AUDIO_16_R2_2009_16_R2_2009_01_WAV_4#v-3.json  
  inflating: tokenized_valid/MIDI-Unprocessed_XP_17_R2_2004_01_ORIG_MID--AUDIO_17_R2_2004_04_Track04_wav_8#d-240.json  
  inflating: tokenized_valid/MIDI-Unprocessed_046_PIANO046_MID--AUDIO-split_07-06-17_Piano-e_2-02_wav--1_3#v-3.json  
  inflating: tokenized_valid/ORIG-MIDI_01_7_8_13_Group__MID--AUDIO_03_R2_2013_wav--1_4#p-12.json  
  inflating: tokenized_valid/MIDI-Unprocessed_

In [5]:
!gdown --id 12der1lMHfMyMbGWImONehL_43AoECyzB

Downloading...
From: https://drive.google.com/uc?id=12der1lMHfMyMbGWImONehL_43AoECyzB
To: /content/tokenizer (2).json
100% 271k/271k [00:00<00:00, 88.5MB/s]


In [6]:
tokenizer = REMI(params="tokenizer (2).json")

In [7]:
def read_json(path: str) -> dict:
  with open(path, "r") as f:
    return json.load(f)

def read_json_files(json_file_paths):
    objects = []

    for file_path in tqdm(json_file_paths):
        try:
            objects.append(read_json(file_path))
        except FileNotFoundError:
            print(f"Error: File not found - {file_path}")
            return []
        except json.JSONDecodeError:
            print(f"Error decoding JSON in file: {file_path}")
            return []
    return objects

tokenized_train = list(Path("tokenized_train").resolve().glob("**/*.json"))
data_objects_train = read_json_files(tokenized_train)

tokenized_valid = list(Path("tokenized_valid").resolve().glob("**/*.json"))
data_objects_valid = read_json_files(tokenized_valid)



if data_objects_train and data_objects_valid:
    print(f"\nSuccessfully read {len(data_objects_train)} training JSON files.")
    print(f"Successfully read {len(data_objects_valid)} validation JSON files.")
else:
    print("Error reading JSON files.")

100%|██████████| 51073/51073 [00:08<00:00, 5899.61it/s]
100%|██████████| 11162/11162 [00:01<00:00, 6895.77it/s]


Successfully read 51073 training JSON files.
Successfully read 11162 validation JSON files.





In [8]:
encoded_train = [np.array(song["ids"][0]) for song in data_objects_train]
encoded_valid = [np.array(song["ids"][0]) for song in data_objects_valid]
all_ids_train = np.concatenate(encoded_train)
all_ids_valid = np.concatenate(encoded_valid)

all_ids_train = all_ids_train.astype(dtype=np.int32)
all_ids_valid = all_ids_valid.astype(dtype=np.int32)

## GAN model

In [9]:
class Generator(nn.Module):
    def __init__(self, noise_dim: int, hidden_dim: int, output_length: int):
        super().__init__()

        # Livelli fully-connected
        self.input_layer = nn.Linear(noise_dim, hidden_dim * 16)
        self.hidden_layer1 = nn.Linear(hidden_dim * 16, hidden_dim * 8)
        self.hidden_layer2 = nn.Linear(hidden_dim * 8, hidden_dim * 4)
        self.hidden_layer3 = nn.Linear(hidden_dim * 4, hidden_dim * 2)
        self.hidden_layer4 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_length)

        # Attivazioni e dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, noise_vector: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.input_layer(noise_vector))
        x = self.dropout(self.relu(self.hidden_layer1(x)))
        x = self.dropout(self.relu(self.hidden_layer2(x)))
        x = self.dropout(self.relu(self.hidden_layer3(x)))
        x = self.dropout(self.relu(self.hidden_layer4(x)))
        x = self.output_layer(x)
        return torch.tanh(x)

Discriminator

In [10]:
class Discriminator(nn.Module):
    def __init__(self, input_length: int, hidden_dim: int):
        super().__init__()

        # Livelli fully-connected
        self.input_layer = nn.Linear(input_length, hidden_dim * 2)
        self.hidden_layer1 = nn.Linear(hidden_dim * 2, hidden_dim * 4)
        self.hidden_layer2 = nn.Linear(hidden_dim * 4, hidden_dim * 8)
        self.hidden_layer3 = nn.Linear(hidden_dim * 8, hidden_dim * 16)
        self.output_layer = nn.Linear(hidden_dim * 16, 1)

        # Attivazioni e dropout
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.dropout = nn.Dropout(0.5)

    def forward(self, sequence: torch.Tensor) -> torch.Tensor:
        x = self.leaky_relu(self.input_layer(sequence))
        x = self.dropout(self.leaky_relu(self.hidden_layer1(x)))
        x = self.dropout(self.leaky_relu(self.hidden_layer2(x)))
        x = self.dropout(self.leaky_relu(self.hidden_layer3(x)))
        x = self.output_layer(x)
        return x

In [12]:
vocab_size = len(tokenizer)
seq_length = 512
noise_dim = 512
hidden_dim = 1536
batch_size = 128

generator_config = {
    "noise_dim"  : noise_dim,
    "hidden_dim" : hidden_dim,
    "seq_length" : seq_length
}


discriminator_config = {
    "seq_length" : seq_length,
    "hidden_dim" : hidden_dim
}

generator = Generator(**generator_config)
discriminator = Discriminator(**discriminator_config)

gan_params = {
    "generator"                : generator,
    "discriminator"            : discriminator,
    "noise_dim"                : noise_dim,
    "seq_length"               : seq_length,
    "tokenizer"                : tokenizer,
}

GAN

In [13]:
class GAN:
    def __init__(
        self,
        generator: nn.Module,
        discriminator: nn.Module,
        noise_dim: int,
        sequence_length: int,
        tokenizer,
        loss_fn: nn.Module = nn.BCEWithLogitsLoss(),
    ):
        """
        Classe GAN per addestrare un generatore e discriminatore.
        """
        self.generator = generator
        self.discriminator = discriminator
        self.noise_dim = noise_dim
        self.sequence_length = sequence_length
        self.loss_fn = loss_fn
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer)

        # Ottimizzatori
        self.generator_optimizer = torch.optim.Adam(
            self.generator.parameters(), lr=0.00005, betas=(0.5, 0.999)
        )
        self.discriminator_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999)
        )

        # Inizializzazione pesi
        self.generator.apply(self._init_weights)
        self.discriminator.apply(self._init_weights)

    def _init_weights(self, layer):
        """Inizializza i pesi dei layer lineari e conv1d."""
        if isinstance(layer, (nn.Linear, nn.Conv1d)):
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)

    def _sample_noise(self, batch_size: int, device: torch.device) -> torch.Tensor:
        """Genera rumore casuale come input per il generatore."""
        return torch.randn(batch_size, self.noise_dim, device=device)

    def _generator_train_step(self, batch_size: int, device: torch.device) -> float:
        """Esegue un passo di training per il generatore."""
        self.generator.zero_grad()

        noise = self._sample_noise(batch_size, device)
        fake_data = self.generator(noise)
        fake_predictions = self.discriminator(fake_data)

        loss_g = -torch.mean(fake_predictions)

        if not (torch.isnan(loss_g) or torch.isinf(loss_g)):
            loss_g.backward()
            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
            self.generator_optimizer.step()
        else:
            print("⚠ NaN o Inf rilevato nella loss del generatore.")

        return loss_g.item()

    def _discriminator_train_step(
        self, batch_size: int, real_data: torch.Tensor, device: torch.device
    ) -> float:
        """Esegue un passo di training per il discriminatore."""
        self.discriminator.zero_grad()

        noise = self._sample_noise(batch_size, device)
        fake_data = self.generator(noise)

        real_predictions = self.discriminator(real_data)
        fake_predictions = self.discriminator(fake_data)

        real_labels = torch.ones_like(real_predictions)
        fake_labels = torch.zeros_like(fake_predictions)

        loss_real = self.loss_fn(real_predictions, real_labels)
        loss_fake = self.loss_fn(fake_predictions, fake_labels)

        loss_d = loss_real + loss_fake

        if not (torch.isnan(loss_d) or torch.isinf(loss_d)):
            loss_d.backward()
            self.discriminator_optimizer.step()
        else:
            print("⚠ NaN o Inf rilevato nella loss del discriminatore.")

        return loss_d.item()

    def _train_batch(
        self, batch_size: int, real_data: torch.Tensor, device: torch.device,
        generator_steps: int, discriminator_steps: int
    ):
        """Esegue un batch di training alternando generator e discriminator."""
        discriminator_losses = [
            self._discriminator_train_step(batch_size, real_data, device)
            for _ in range(discriminator_steps)
        ]
        generator_losses = [
            self._generator_train_step(batch_size, device)
            for _ in range(generator_steps)
        ]
        return np.mean(discriminator_losses), np.mean(generator_losses)

    def train(
        self,
        dataloader,
        epochs: int,
        device: torch.device,
        loss_delta: float = 0.7,
        steps_each_print: int = 5,
        advantage_steps: int = 2,
        alternate_training: bool = False,
        generator_steps: int = 1,
        discriminator_steps: int = 1,
    ):
        """
        Addestra la GAN su un dataloader.
        """
        print(f"🚀 Avvio training per {epochs} epoche")
        self.generator.to(device).train()
        self.discriminator.to(device).train()

        all_d_losses = []
        all_g_losses = []

        for epoch in range(epochs):
            pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{epochs}")
            for step, real_batch in enumerate(dataloader):
                batch_size = real_batch.size(0)
                real_batch = real_batch.to(device)

                d_loss, g_loss = self._train_batch(
                    batch_size, real_batch, device, generator_steps, discriminator_steps
                )
                all_d_losses.append(d_loss)
                all_g_losses.append(g_loss)

                if alternate_training:
                    if g_loss - d_loss > loss_delta:
                        generator_steps, discriminator_steps = advantage_steps, 1
                    elif d_loss - g_loss > loss_delta:
                        discriminator_steps, generator_steps = advantage_steps, 1
                    else:
                        generator_steps, discriminator_steps = 1, 1

                if (step + 1) % steps_each_print == 0:
                    desc = f"D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}"
                    if generator_steps != 1 or discriminator_steps != 1:
                        desc += f" | D Steps: {discriminator_steps}, G Steps: {generator_steps}"
                    pbar.set_description(desc)
                pbar.update(1)

            pbar.close()
            print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}")

        return all_d_losses, all_g_losses

    def save_generator(self, filepath: str):
        """Salva il generatore su disco."""
        torch.save(self.generator.state_dict(), filepath)

    def save_discriminator(self, filepath: str):
        """Salva i pesi del discriminatore su disco."""
        torch.save(self.discriminator.state_dict(), filepath)

    def predict(self, n_samples: int, device: torch.device):
        """Genera output dal generatore sotto forma di array NumPy."""
        noise = self._sample_noise(n_samples, device)
        return self.generator(noise).cpu().detach().numpy()

    def generate_midi(self, n_samples: int, output_prefix: str, device: torch.device):
        """Genera e salva file MIDI a partire dal generatore."""
        predictions = self.predict(n_samples, device)
        midpoint = self.vocab_size // 2
        predictions = [x * midpoint + midpoint for x in predictions]

        for i, pred in enumerate(predictions):
            pred_tokens = pred.astype(np.int32)
            pred_tokens[pred_tokens == self.vocab_size] = 0
            try:
                decoded = self.tokenizer.decode([pred_tokens])
                decoded.dump_midi(f"{output_prefix}_{i}.mid")
            except Exception as e:
                print(f"Errore nella generazione MIDI: {e}")


In [14]:
class TokenDataset(Dataset):
    """Dataset di sequenze di token normalizzate per addestramento GAN."""
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.float32)


def normalize_data(token_array, vocab_size: int, interval=(-1, 1)):
    """
    Normalizza i dati di input in un intervallo specificato.
    Default: (-1, 1)
    """
    min_val, max_val = interval
    normalized = (token_array - vocab_size / 2) / (vocab_size / 2)
    scaled = normalized * (max_val - min_val) / 2 + (max_val + min_val) / 2
    return scaled


# ===============================
# PARAMETRI
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0) if torch.cuda.is_available() else None

sequence_length = seq_length  # preso dalla configurazione generale
batch_size = batch_size       # manteniamo il valore esistente

# ===============================
# PREPARAZIONE DATI
# ===============================
vocab_size = len(tokenizer)
normalized_tokens = normalize_data(all_ids_train, vocab_size)

print(f"Normalizzazione completata | Max: {normalized_tokens.max():.4f}, Min: {normalized_tokens.min():.4f}")

# Creiamo sequenze della lunghezza richiesta
sequences = [
    normalized_tokens[i : i + sequence_length]
    for i in range(0, len(normalized_tokens) - sequence_length, sequence_length)
]

# Dataset e DataLoader
dataset = TokenDataset(sequences)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=4
)

# Test rapido GPU
_ = torch.ones(1, device=device)
print(f"Dataset pronto. Batch size: {batch_size}, Numero batch: {len(dataloader)}")


Max is 0.9996, Min is -0.9988


In [16]:
gan = GAN(**gan_params)

training_arguments = {
    "dataloader"         : dataloader,
    "epochs"             : 45,
    "device"             : device,
    "steps_each_print"   : 10,
    "gen_steps"          : 1,
    "disc_steps"         : 1,
}

losses_discriminator, losses_generator = gan.train(**training_arguments)

Starting training with 30 epochs


D Loss: 1.0919, G Loss: 1.0477: 100%|██████████| 786/786 [06:52<00:00,  1.90it/s]


Epoch [1/30] Loss D: 1.0450, Loss G: 1.8628


D Loss: 0.6945, G Loss: 0.9538:  27%|██▋       | 210/786 [01:50<05:02,  1.91it/s]

KeyboardInterrupt: 

In [None]:
iterations = list(range(1, len(losses_discriminator) + 1))

# Plot
plt.figure(figsize=(8, 5))
plt.plot(iterations, losses_generator, label="Generatore", color="blue", linewidth=2)
plt.plot(iterations, losses_discriminator, label="Discriminatore", color="red", linewidth=2)

plt.xlabel("Iterazioni")
plt.ylabel("Loss")
plt.title("Andamento della Loss GAN")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

salva il modello

In [None]:
# Salva solo i pesi del generator e del discriminator
torch.save(generator.state_dict(), "generator_weights.pth")
torch.save(discriminator.state_dict(), "discriminator_weights.pth")


per caricare

In [None]:
generator = Generator(**generator_config)
generator.load_state_dict(torch.load("generator_weights.pth"))
generator.eval()

discriminator = Discriminator(**discriminator_config)
discriminator.load_state_dict(torch.load("discriminator_weights.pth"))
discriminator.eval()


generate music

In [None]:
noise_dim = 512  # Dimensione del vettore di rumore

def generate_noise(batch_size: int, device: torch.device) -> torch.Tensor:
    """Genera un batch di rumore casuale per il generatore."""
    return torch.randn(batch_size, noise_dim, device=device)

def predict_sequences(generator: torch.nn.Module, n_samples: int, device: torch.device):
    """Genera sequenze normalizzate dal generatore."""
    noise_batch = generate_noise(n_samples, device)
    return generator(noise_batch).cpu().detach().numpy()

def generate_midi_files(generator: torch.nn.Module, n_samples: int, output_prefix: str, device: torch.device):
    """
    Genera file MIDI a partire da output del generatore.
    """
    # Genera sequenze normalizzate
    normalized_outputs = predict_sequences(generator, n_samples, device)

    # Calcolo del boundary per mappare valori normalizzati ai token originali
    boundary = len(tokenizer) // 2
    token_sequences = [output * boundary + boundary for output in normalized_outputs]

    for i, sequence in enumerate(token_sequences):
        tokens = sequence.astype(np.int32)

        # Rimpiazza eventuali valori fuori vocabolario
        tokens[tokens == len(tokenizer)] = 0

        try:
            decoded_midi = tokenizer.decode([tokens])
            file_path = f"{output_prefix}_{i}.mid"
            decoded_midi.dump_midi(file_path)
            print(f"MIDI generato: {file_path}")
        except Exception as e:
            print(f"Errore nella generazione MIDI per il sample {i}: {e}")


In [None]:

num_midi_files = 100

output_folder = "generated"
output_prefix = f"{output_folder}/generated_midi"

os.makedirs(output_folder, exist_ok=True)

sample_normalized_outputs = predict_sequences(generator, 2, device)
token_boundary = len(tokenizer) // 2
sample_token_sequences = [seq * token_boundary + token_boundary for seq in sample_normalized_outputs]

generate_midi_files(
    generator=generator,
    n_samples=num_midi_files,
    output_prefix=output_prefix,
    device=device,
)


In [None]:
!zip -r generetedMIDI.zip generated

  adding: generated/ (stored 0%)
  adding: generated/generated_midi_53.mid (deflated 45%)
  adding: generated/generated_midi_58.mid (deflated 45%)
  adding: generated/generated_midi_44.mid (deflated 44%)
  adding: generated/generated_midi_25.mid (deflated 46%)
  adding: generated/generated_midi_50.mid (deflated 45%)
  adding: generated/generated_midi_49.mid (deflated 45%)
  adding: generated/generated_midi_83.mid (deflated 44%)
  adding: generated/generated_midi_72.mid (deflated 44%)
  adding: generated/generated_midi_75.mid (deflated 44%)
  adding: generated/generated_midi_97.mid (deflated 45%)
  adding: generated/generated_midi_42.mid (deflated 44%)
  adding: generated/generated_midi_81.mid (deflated 45%)
  adding: generated/generated_midi_13.mid (deflated 44%)
  adding: generated/generated_midi_78.mid (deflated 46%)
  adding: generated/generated_midi_6.mid (deflated 46%)
  adding: generated/generated_midi_18.mid (deflated 44%)
  adding: generated/generated_midi_80.mid (deflated 45%)

VALUTARE prestazioni

In [None]:
generated_midis = list(glob.glob("generated/*.mid"))

In [None]:
def pitch_count(notes):
    """
    notes: list of int (pitch values)
    """
    return len(notes)


def note_density(notes, total_duration_sec):
    """
    notes: list of notes
    total_duration_sec: duration of piece in seconds
    """
    return len(notes) / total_duration_sec if total_duration_sec > 0 else 0

from collections import Counter

def rote_memorization_frequency(notes, pattern_length=3):
    """
    Return average number of repetitions for note patterns of length `pattern_length`
    """
    patterns = [tuple(notes[i:i+pattern_length]) for i in range(len(notes)-pattern_length+1)]
    counts = Counter(patterns)
    repeated = [count for pattern, count in counts.items() if count > 1]
    return np.mean(repeated) if repeated else 0

def pitch_intervals(notes):
    """
    Returns average absolute pitch interval between consecutive notes
    """
    if len(notes) < 2:
        return 0
    intervals = [abs(notes[i] - notes[i-1]) for i in range(1, len(notes))]
    return np.mean(intervals)

def pitch_class_histogram(notes):
    """
    Returns normalized histogram over 12 pitch classes
    """
    hist = np.zeros(12)
    for pitch in notes:
        hist[pitch % 12] += 1
    total = np.sum(hist)
    return hist / total if total > 0 else hist

from scipy.special import rel_entr

def kld(p, q):
    """
    Kullback-Leibler Divergence between distributions p and q
    p, q: numpy arrays representing probability distributions
    """
    p = np.asarray(p, dtype=np.float64) + 1e-12  # avoid zeros
    q = np.asarray(q, dtype=np.float64) + 1e-12
    p /= p.sum()
    q /= q.sum()
    return np.sum(rel_entr(p, q))


In [None]:
def compute_all_metrics(generated_notes, real_notes=None, total_duration_sec=30.0):
    metrics = {
        "Pitch Count": pitch_count(generated_notes),
        "Note Density": note_density(generated_notes, total_duration_sec),
        "Rote Memorization": rote_memorization_frequency(generated_notes),
        "Pitch Interval": pitch_intervals(generated_notes),
        "Pitch Class Histogram": pitch_class_histogram(generated_notes),
    }

    if real_notes is not None:
        metrics["PCH KLD"] = kld(
            metrics["Pitch Class Histogram"],
            pitch_class_histogram(real_notes)
        )

    return metrics


In [None]:
metrics = compute_all_metrics(generated_notes, real_notes)

for name, value in metrics.items():
    if isinstance(value, np.ndarray):
        print(f"{name}: {value.round(3)}")
    else:
        print(f"{name}: {value:.3f}")
