In [None]:
# Set to false if you are not running
# this notebook in Google Colaboratory
run_on_colab = True

In [None]:
if(run_on_colab):
  from google.colab import drive
  # This will prompt for authorization.
  drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install music21;



In [None]:
if(run_on_colab):
  from google.colab import files
  files.upload()


Saving finalzip.zip to finalzip.zip


This script processes MIDI files from a specified folder using the music21 library. For each file, it extracts the average tempo, key, time signature, and note/chord/rest data, then stores the results in a list. Finally, it pickles the collected features for later use and prints out a brief sample for verification

In [1]:
import os
import music21
import pickle
from music21 import converter, note, chord, tempo, meter

# path
midi_path = os.path.join("extracted_files", "finalzip")
features = []

for file in os.listdir(midi_path):
    if file.lower().endswith((".mid", ".midi")):
        try:
            midi_data = converter.parse(os.path.join(midi_path, file))

            # 1) extract temp
            tempos = midi_data.flat.getElementsByClass(tempo.MetronomeMark)
            if tempos:
                avg_tempo = sum(t.number for t in tempos) / len(tempos)
            else:
                avg_tempo = 120.0

            # 2) extrat key
            key_signature = midi_data.analyze('key')
            key_name = key_signature.name

            # 3) extract time signature
            time_sigs = midi_data.flat.getElementsByClass(meter.TimeSignature)
            if time_sigs:
                time_sig = time_sigs[0].ratioString  #for exemple : '4/4'
            else:
                time_sig = "4/4"

            # 4) Extract (Note, Chord, Rest)
            notes_data = []
            for elem in midi_data.flat.notesAndRests:
                offset = elem.offset
                duration = elem.quarterLength

                if isinstance(elem, note.Note):
                    notes_data.append({
                        "type": "note",
                        "pitch": str(elem.pitch),
                        "offset": offset,
                        "duration": duration
                    })
                elif isinstance(elem, chord.Chord):
                    # pitch classes or normalOrder
                    chord_str = '.'.join(str(n) for n in elem.normalOrder)
                    notes_data.append({
                        "type": "chord",
                        "chord": chord_str,
                        "offset": offset,
                        "duration": duration
                    })
                elif isinstance(elem, note.Rest):
                    notes_data.append({
                        "type": "rest",
                        "offset": offset,
                        "duration": duration
                    })
                # If you would like to address other elements, you can add them here.

            # 5) save the data
            features.append({
                "file": file,
                "tempo": avg_tempo,
                "key": key_name,
                "time_signature": time_sig,
                "notes": notes_data
            })

            print(f"Processed {file}")

        except Exception as e:
            print(f"Error processing {file}: {e}")

# save to pkl file
file_path = '/content/drive/My Drive/extracted_features.pkl'
with open(file_path, 'wb') as f:
    pickle.dump(features, f)

#print for check
for feature in features[:5]:
    print(feature)


FileNotFoundError: [Errno 2] No such file or directory: 'extracted_files/finalzip'

This script loads previously extracted MIDI features, then creates a token sequence for each file by:

1. Sorting events by offset.
2. Converting each note, chord, or rest to a discrete token that includes its pitch/chord (if applicable) and quantized duration.
3. Adding context tokens for tempo, key, and time signature.
The final token sequences are saved to a pickle file for later use.


In [None]:
import math
from fractions import Fraction
import pickle
import os

def quantize_to_16th(value):
    """
   Converts value into quarters.
    """
    return int(round(float(value) * 4))

def event_to_token(event):
    """
This function converts a musical event (represented as a dictionary) into a token string. Each event has a "type" (which can be "note", "chord", or "rest") and a "duration". If the type is "note", there's also a "pitch" key; if it's "chord", there's a "chord" key (a series of numbers separated by dots). The output token string is formatted as:

For a note: NOTE_<pitch>_<duration> (e.g., NOTE_C4_4)
For a chord: CHORD_<chord>_<duration> (e.g., CHORD_1.6_6)
For a rest: REST_<duration> (e.g., REST_2)
    """
    event_type = event["type"]
    dur_quant = quantize_to_16th(event["duration"])

    if event_type == "note":
        pitch_str = event["pitch"]
        return f"NOTE_{pitch_str}_{dur_quant}"
    elif event_type == "chord":
        chord_str = event["chord"]
        return f"CHORD_{chord_str}_{dur_quant}"
    elif event_type == "rest":
        return f"REST_{dur_quant}"
    else:
        return "UNK"

def build_token_sequence(notes_data):
    """

    This step sorts the list of events by their "offset" property and then converts each event into a token string using the defined format.
    """
    # sort by offset
    notes_sorted = sorted(notes_data, key=lambda x: float(x["offset"]))
    token_seq = [event_to_token(event) for event in notes_sorted]
    return token_seq

# load the data
features_path = '/content/drive/My Drive/extracted_features.pkl'
with open(features_path, 'rb') as f:
    features = pickle.load(f)

# for each sequence token
all_token_sequences = []
for feat in features:
    notes_data = feat["notes"]
    token_seq = build_token_sequence(notes_data)


    tempo_token = f"TEMPO_{round(feat['tempo'],2)}"
    key_token = f"KEY_{feat['key'].replace(' ', '_')}"
    time_sig_token = f"TIME_{feat['time_signature'].replace('/', '_')}"

    full_seq = [tempo_token, key_token, time_sig_token] + token_seq
    all_token_sequences.append(full_seq)

# statistics & checks
print("Total number of files:", len(all_token_sequences))
print("Sample token sequence (first 20 tokens) from first file:")
print(all_token_sequences[0][:20])

# save
output_path = '/content/drive/My Drive/all_token_sequences.pkl'
with open(output_path, 'wb') as f:
    pickle.dump(all_token_sequences, f)

print("✅ Token sequences saved to:", output_path)


Total number of files: 830
Sample token sequence (first 20 tokens) from first file:
['TEMPO_120.0', 'KEY_F#_major', 'TIME_4_4', 'REST_8', 'REST_8', 'REST_10', 'CHORD_1.6_5', 'NOTE_B-4_8', 'NOTE_B-3_1', 'NOTE_C#4_1', 'CHORD_1.6_1', 'CHORD_10.1_1', 'REST_3', 'REST_0', 'CHORD_10.1_1', 'CHORD_1.6_2', 'REST_4', 'REST_8', 'REST_14', 'NOTE_B-4_1']
✅ Token sequences saved to: /content/drive/My Drive/all_token_sequences.pkl


This script loads preprocessed token sequences, builds a vocabulary (mapping tokens to integers), and generates training data using a sliding window (length 200) to create input-target pairs. Finally, it saves the training arrays and the mapping dictionaries as pickle files for future use.


In [None]:
import pickle
import numpy as np
import random


with open('/content/drive/My Drive/all_token_sequences.pkl', 'rb') as f:
    all_token_sequences = pickle.load(f)

print("Total token sequences (files):", len(all_token_sequences))

# build vocabulary
vocab = set()
for seq in all_token_sequences:
    vocab.update(seq)
vocab = sorted(list(vocab))
note_to_int = {token: i for i, token in enumerate(vocab)}
int_to_note = {i: token for i, token in enumerate(vocab)}
n_vocab = len(vocab)

print("Vocabulary size:", n_vocab)
print("Sample vocabulary tokens:", vocab[:20])

# Setting a training sequence length
sequence_length = 200

# Creating Input and Target Lists
# For each file sequence, we use a sliding window that generates:
#   input: [t0, t1, ..., t_{L-1}]
#   target: [t1, t2, ..., t_L]
train_input = []
train_target = []
for seq in all_token_sequences:

    if len(seq) < sequence_length + 1:
        continue

    for i in range(0, len(seq) - sequence_length):
        input_seq = seq[i : i + sequence_length]
        target_seq = seq[i + 1 : i + sequence_length + 1]
        train_input.append([note_to_int[token] for token in input_seq])
        train_target.append([note_to_int[token] for token in target_seq])

train_input = np.array(train_input)
train_target = np.array(train_target)

print("Training data shapes:")
print("Input:", train_input.shape)
print("Target:", train_target.shape)

# save data
with open('/content/drive/My Drive/training_input.pkl', 'wb') as f:
    pickle.dump(train_input, f)
with open('/content/drive/My Drive/training_target.pkl', 'wb') as f:
    pickle.dump(train_target, f)
with open('/content/drive/My Drive/note_to_int.pkl', 'wb') as f:
    pickle.dump(note_to_int, f)
with open('/content/drive/My Drive/int_to_note.pkl', 'wb') as f:
    pickle.dump(int_to_note, f)

print("✅ Training dataset saved successfully.")


Total token sequences (files): 830
Vocabulary size: 11050
Sample vocabulary tokens: ['CHORD_0.1.2.3.4.5.6.7.8.9.10.11_1', 'CHORD_0.1.2.3.4.5.6.7.8.9.10_1', 'CHORD_0.1.2.3.4.5.6.8.9.10_1', 'CHORD_0.1.2.3.4.5.6_1', 'CHORD_0.1.2.3.4.5.8_1', 'CHORD_0.1.2.3.4.5_1', 'CHORD_0.1.2.3.4.6.7_1', 'CHORD_0.1.2.3.4.6_1', 'CHORD_0.1.2.3.4.7.8_1', 'CHORD_0.1.2.3.4.7_1', 'CHORD_0.1.2.3.4.8_1', 'CHORD_0.1.2.3.4_1', 'CHORD_0.1.2.3.5.6.8_1', 'CHORD_0.1.2.3.5.6.9_1', 'CHORD_0.1.2.3.5.6_1', 'CHORD_0.1.2.3.5.7_1', 'CHORD_0.1.2.3.5.8.9_1', 'CHORD_0.1.2.3.5_1', 'CHORD_0.1.2.3.6.7_1', 'CHORD_0.1.2.3.6.8.9_1']


KeyboardInterrupt: 

load the files

In [None]:
import pickle

# Load processed token sequences
with open('/content/drive/My Drive/all_token_sequences.pkl', 'rb') as f:
    all_token_sequences = pickle.load(f)

# Load training inputs
with open('/content/drive/My Drive/training_input.pkl', 'rb') as f:
    train_input = pickle.load(f)

# Load training targets
with open('/content/drive/My Drive/training_target.pkl', 'rb') as f:
    train_target = pickle.load(f)

# Load token-to-integer mapping
with open('/content/drive/My Drive/note_to_int.pkl', 'rb') as f:
    note_to_int = pickle.load(f)

# Load integer-to-token mapping
with open('/content/drive/My Drive/int_to_note.pkl', 'rb') as f:
    int_to_note = pickle.load(f)

print("✅ All files loaded successfully.")


✅ All files loaded successfully.


This code defines a transformer-based music generator in PyTorch. It includes:

- An embedding layer to convert input tokens into vectors.
- A precomputed positional encoding (using sine and cosine functions) to inject sequence information.
- A transformer encoder to process the sequence with multiple layers and attention heads.
- A final linear layer that projects the transformer output to the vocabulary size, producing logits for each token position.

In [None]:
import torch
import torch.nn as nn
import math

class MusicGenerator(nn.Module):
    def __init__(self, n_vocab, sequence_length, embed_dim=256, num_heads=8, num_layers=4):
        super(MusicGenerator, self).__init__()
        # Embedding layer: converts token indices into embedding vectors
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        # Register positional encoding buffer: generates positional encodings for 2 * sequence_length tokens
        self.register_buffer("positional_encoding", self._generate_positional_encoding(sequence_length * 2, embed_dim))
        # Transformer encoder: composed of several transformer encoder layers
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=512,
                dropout=0.3,
                batch_first=True,
                norm_first=True
            ),
            num_layers=num_layers
        )
        # Fully connected layer: projects transformer outputs to vocabulary size logits
        self.fc = nn.Linear(embed_dim, n_vocab)  # Output dimension equals the vocabulary size

    def _generate_positional_encoding(self, seq_len, embed_dim):
        """
        Generates positional encoding using sine and cosine functions.
        Returns a tensor of shape (1, seq_len, embed_dim) to be added to the embeddings.
        """
        # Create a tensor of positions from 0 to seq_len - 1
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        # Calculate the div_term for the exponential decay factor
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pos_enc = torch.zeros(seq_len, embed_dim)
        # Apply sine to even indices of the embedding dimension
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices of the embedding dimension
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)  # Add a batch dimension

    def forward(self, x):
        # Convert input token indices to embedding vectors: shape [B, sequence_length, embed_dim]
        embedded = self.embedding(x)
        seq_len = x.size(1)
        # Extract the corresponding positional encoding for the current sequence length
        pos_enc = self.positional_encoding[:, :seq_len, :]
        # Add positional encoding to the embeddings
        x = embedded + pos_enc
        # Pass the input through the transformer encoder
        x = self.transformer(x)
        # Project the transformer output to logits for each token in the vocabulary
        x = self.fc(x)  # Output shape: [B, sequence_length, n_vocab]
        return x


This class implements a music discriminator network. It embeds input token sequences, flattens them, and processes the result through several fully connected layers with LeakyReLU and dropout, ending with a sigmoid output that indicates the probability of the input being real.


In [None]:
import torch
import torch.nn as nn

class MusicDiscriminator(nn.Module):
    def __init__(self, n_vocab, sequence_length, embed_dim=256):
        super(MusicDiscriminator, self).__init__()
        # Use an embedding layer to convert input tokens into embedding vectors.
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        # A fully connected network to classify the sequence as real or fake.
        self.fc = nn.Sequential(
            nn.Linear(sequence_length * embed_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Convert input token indices into embedding vectors.
        # Output shape: [B, sequence_length, embed_dim]
        x = self.embedding(x)
        # Flatten the embeddings into a single vector per batch element.
        # Output shape: [B, sequence_length * embed_dim]
        x = x.view(x.size(0), -1)
        # Pass the flattened vector through the fully connected layers.
        # Output shape: [B, 1]
        x = self.fc(x)
        return x


This snippet initializes the generator and discriminator models using the previously obtained vocabulary size (n_vocab) and sequence length (e.g., 200). It then transfers both models to the available device (GPU if available, otherwise CPU) and prints their architectures for verification.

In [None]:
# Model initialization – we use the provided variable n_vocab and the sequence length from train_inputs.shape[1] (for example, 200)
generator = MusicGenerator(
    n_vocab=n_vocab,
    sequence_length=train_input.shape[1],
    embed_dim=128,
    num_heads=4,
    num_layers=3
)

discriminator = MusicDiscriminator(
    n_vocab=n_vocab,
    sequence_length=train_input.shape[1],
    embed_dim=128
)

# Moving to the appropriate device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)

print("Generator:")
print(generator)
print("Discriminator:")
print(discriminator)




Generator:
MusicGenerator(
  (embedding): Embedding(11050, 128)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (dropout2): Dropout(p=0.3, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=128, out_features=11050, bias=True)
)
Discriminator:
MusicDiscriminator(
  (embedding): Embedding(11050, 128)
  (fc): Sequential(
    (0): Linear(in_features=25600, out_features=1024, bias=True)
    (

This code sets up the training environment by:

- Creating a directory for saving checkpoints and defining training parameters (epochs, batch size, and save frequency).
- Loading training input and target data from pickle files and constructing a DataLoader.
- Moving pre-initialized generator and discriminator models to the appropriate device (GPU if available).
- Initializing loss functions (BCELoss and CrossEntropyLoss) and Adam optimizers for both models.
- Preparing a variable to track the best generator loss during training.

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import pickle

# Basic configurations
save_path = "/content/drive/My Drive/checkpoints/"
os.makedirs(save_path, exist_ok=True)

epochs = 50  # You can start with 50-100 epochs
batch_size = 512
save_every = 5  # Save every 5 epochs

# Move to the appropriate device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the training data (if not already loaded)
with open('/content/drive/My Drive/training_input.pkl', 'rb') as f:
    train_inputs = pickle.load(f)
with open('/content/drive/My Drive/training_target.pkl', 'rb') as f:
    train_targets = pickle.load(f)

print("Training inputs shape:", train_inputs.shape)
print("Training targets shape:", train_targets.shape)

# Create DataLoader
dataset = TensorDataset(torch.tensor(train_inputs, dtype=torch.long),
                        torch.tensor(train_targets, dtype=torch.long))
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Assume that the models (generator, discriminator) have been defined and initialized as done previously
generator = generator.to(device)
discriminator = discriminator.to(device)

# Define loss functions and optimizers
bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

best_g_loss = float('inf')


This training loop updates the generator and discriminator in a GAN setup over multiple epochs. For each batch:

- Discriminator Update:
 - Real sequences (targets) and fake sequences (generated from inputs) are evaluated.
 - Binary cross-entropy losses are computed for both, and their average is used to update the discriminator.
- Generator Update:
 - The generator produces fake sequences which are compared to targets using cross-entropy loss (teacher forcing).
 - An adversarial loss (to fool the discriminator) is also computed and scaled.
 - The combined loss updates the generator.

Epoch losses are averaged, printed, and model checkpoints are saved when improvements occur or at regular intervals.

In [None]:
for epoch in range(epochs):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0.0
    epoch_d_loss = 0.0

    for batch_input, batch_target in tqdm(data_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        batch_input = batch_input.to(device)
        batch_target = batch_target.to(device).long()

        ## === Update Discriminator === ##
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_input.shape[0], 1).to(device)
        fake_labels = torch.zeros(batch_input.shape[0], 1).to(device)

        # For real data: use the discriminator (which includes an embedding) on the target indices
        real_validity = discriminator(batch_target)  # [B, 1]
        loss_real = bce_loss(real_validity, real_labels)

        # For fake data:
        fake_logits = generator(batch_input)  # [B, sequence_length, n_vocab]
        fake_tokens = torch.argmax(fake_logits, dim=-1).long()  # [B, sequence_length]
        fake_validity = discriminator(fake_tokens.detach())  # [B, 1]
        loss_fake = bce_loss(fake_validity, fake_labels)

        d_loss = (loss_real + loss_fake) / 2.0
        d_loss.backward()
        optimizer_D.step()

        ## === Update Generator === ##
        optimizer_G.zero_grad()

        fake_logits = generator(batch_input)  # [B, sequence_length, n_vocab]
        # Teacher Forcing Loss: compute loss against the target
        g_teacher_loss = ce_loss(fake_logits.view(-1, n_vocab), batch_target.view(-1))

        # Adversarial Loss: encourage the discriminator to believe that fake sequences are real
        fake_tokens = torch.argmax(fake_logits, dim=-1).long()
        validity = discriminator(fake_tokens.detach())
        g_adv_loss = bce_loss(validity, real_labels)

        g_loss = g_teacher_loss + 0.002 * g_adv_loss
        g_loss.backward()
        optimizer_G.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

    avg_g_loss = epoch_g_loss / len(data_loader)
    avg_d_loss = epoch_d_loss / len(data_loader)

    print(f"Epoch {epoch+1}/{epochs}, Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")

    # Save the model if it has improved
    if avg_g_loss < best_g_loss:
        best_g_loss = avg_g_loss
        save_file = os.path.join(save_path, f"best_model_Gloss_{avg_g_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, save_file)
        print(f"✅ Model saved at epoch {epoch+1} (Best Generator Loss: {best_g_loss:.4f})")

    # Save the model every X epochs
    if (epoch + 1) % save_every == 0:
        save_file = os.path.join(save_path, f"epoch_{epoch+1}_Gloss_{avg_g_loss:.4f}_Dloss_{avg_d_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, save_file)
        print(f"💾 Saved model at epoch {epoch+1}")


Training inputs shape: (3510447, 200)
Training targets shape: (3510447, 200)


Epoch 1/50:  13%|█▎        | 923/6857 [06:36<42:31,  2.33it/s]


KeyboardInterrupt: 

This snippet defines a scaled-down version of our model architecture. Due to hardware limitations, we reduced the model's size—using a smaller embedding dimension, fewer layers, and other optimizations—to ensure efficient training and inference while still meeting the project's core


In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import pickle

###########################
# Definition of the small models
###########################

class SmallMusicGenerator(nn.Module):
    def __init__(self, n_vocab, sequence_length, embed_dim=128, num_heads=4, num_layers=3):
        super(SmallMusicGenerator, self).__init__()
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        # Build positional encoding for double the input length
        self.register_buffer("positional_encoding", self._generate_positional_encoding(sequence_length * 2, embed_dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=256,  # Reduced dimension for the FFN
                dropout=0.3,
                batch_first=True,
                norm_first=True
            ),
            num_layers=num_layers
        )
        self.fc = nn.Linear(embed_dim, n_vocab)

    def _generate_positional_encoding(self, seq_len, embed_dim):
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pos_enc = torch.zeros(seq_len, embed_dim)
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)  # Add a batch dimension

    def forward(self, x):
        embedded = self.embedding(x)  # [B, sequence_length, embed_dim]
        seq_len = x.size(1)
        pos_enc = self.positional_encoding[:, :seq_len, :]
        x = embedded + pos_enc
        x = self.transformer(x)
        x = self.fc(x)  # [B, sequence_length, n_vocab]
        return x

class SmallMusicDiscriminator(nn.Module):
    def __init__(self, n_vocab, sequence_length, embed_dim=128):
        super(SmallMusicDiscriminator, self).__init__()
        self.embedding = nn.Embedding(n_vocab, embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(sequence_length * embed_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1)  # No Sigmoid applied because we use BCEWithLogitsLoss
        )

    def forward(self, x):
        x = self.embedding(x)  # [B, sequence_length, embed_dim]
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)  # [B, 1]
        return x


In [None]:
####################################
# Model Initialization
####################################
# Assume that note_to_int has already been loaded, and compute n_vocab
# with open('/content/drive/My Drive/note_to_int.pkl', 'rb') as f:
#     note_to_int = pickle.load(f)
# n_vocab = len(note_to_int)

# Assume that train_inputs has already been loaded (for example, from the saved training data)
# with open('/content/drive/My Drive/training_input.pkl', 'rb') as f:
#     train_inputs = pickle.load(f)
# with open('/content/drive/My Drive/training_target.pkl', 'rb') as f:
#     train_targets = pickle.load(f)

print("Training inputs shape:", train_inputs.shape)
print("Training targets shape:", train_targets.shape)

# Initialize the models
sequence_length = train_inputs.shape[1]  # should be 200
generator = SmallMusicGenerator(n_vocab=n_vocab, sequence_length=sequence_length)
discriminator = SmallMusicDiscriminator(n_vocab=n_vocab, sequence_length=sequence_length)

# Move models to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
generator.positional_encoding = generator.positional_encoding.to(device)

print("Small Generator:")
print(generator)
print("Small Discriminator:")
print(discriminator)


NameError: name 'train_inputs' is not defined

Due to hardware constraints, we also implemented a smaller model variant and trained it accordingly. The training procedure mirrors the full model's loop—with periodic checkpointing and loss monitoring—but uses reduced parameters and layers to decrease computational load while still meeting the project's objectives.

In [None]:
####################################
# Preparing Dataset and DataLoader
####################################
dataset = TensorDataset(torch.tensor(train_inputs, dtype=torch.long),
                        torch.tensor(train_targets, dtype=torch.long))
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

####################################
# Defining Loss Functions and Optimizers
####################################
# Use BCEWithLogitsLoss for the discriminator
bce_loss = nn.BCEWithLogitsLoss()
ce_loss = nn.CrossEntropyLoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001)

# Set up AMP (Automatic Mixed Precision)
scaler = GradScaler()

best_g_loss = float('inf')

####################################
# Training Loop
####################################
for epoch in range(epochs):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0.0
    epoch_d_loss = 0.0

    for batch_input, batch_target in tqdm(data_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        batch_input = batch_input.to(device).long()
        batch_target = batch_target.to(device).long()

        ## === Updating Discriminator === ##
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_input.shape[0], 1).to(device)
        fake_labels = torch.zeros(batch_input.shape[0], 1).to(device)

        with autocast():
            # For real data: pass batch_target (which is a sequence of indices)
            real_validity = discriminator(batch_target)  # [B, 1] (raw logits)
            loss_real = bce_loss(real_validity, real_labels)

            # For fake data:
            fake_logits = generator(batch_input)  # [B, sequence_length, n_vocab]
            fake_tokens = torch.argmax(fake_logits, dim=-1).long()  # [B, sequence_length]
            fake_validity = discriminator(fake_tokens.detach())  # [B, 1]
            loss_fake = bce_loss(fake_validity, fake_labels)

            d_loss = (loss_real + loss_fake) / 2.0

        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)
        scaler.update()

        ## === Updating Generator === ##
        optimizer_G.zero_grad()
        with autocast():
            fake_logits = generator(batch_input)  # [B, sequence_length, n_vocab]
            g_teacher_loss = ce_loss(fake_logits.view(-1, n_vocab), batch_target.view(-1))

            fake_tokens = torch.argmax(fake_logits, dim=-1).long()
            validity = discriminator(fake_tokens.detach())
            g_adv_loss = bce_loss(validity, real_labels)

            g_loss = g_teacher_loss + 0.002 * g_adv_loss

        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

    avg_g_loss = epoch_g_loss / len(data_loader)
    avg_d_loss = epoch_d_loss / len(data_loader)
    print(f"Epoch {epoch+1}/{epochs}, Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")

    if avg_g_loss < best_g_loss:
        best_g_loss = avg_g_loss
        save_file = os.path.join(save_path, f"best_model_Gloss_{avg_g_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, save_file)
        print(f"✅ Model saved at epoch {epoch+1} (Best Generator Loss: {best_g_loss:.4f})")

    if (epoch + 1) % save_every == 0:
        save_file = os.path.join(save_path, f"epoch_{epoch+1}_Gloss_{avg_g_loss:.4f}_Dloss_{avg_d_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, save_file)
        print(f"💾 Saved model at epoch {epoch+1}")


Training inputs shape: (3510447, 200)
Training targets shape: (3510447, 200)
Small Generator:
SmallMusicGenerator(
  (embedding): Embedding(11050, 128)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (dropout2): Dropout(p=0.3, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=128, out_features=11050, bias=True)
)
Small Discriminator:
SmallMusicDiscriminator(
  (embedding): Embedding(110

  scaler = GradScaler()
  with autocast():
  with autocast():
Epoch 1/50: 100%|██████████| 6857/6857 [11:01<00:00, 10.37it/s]


Epoch 1/50, Generator Loss: 2.2329, Discriminator Loss: 0.3424
✅ Model saved at epoch 1 (Best Generator Loss: 2.2329)


Epoch 2/50: 100%|██████████| 6857/6857 [10:57<00:00, 10.43it/s]


Epoch 2/50, Generator Loss: 0.1594, Discriminator Loss: 0.4201
✅ Model saved at epoch 2 (Best Generator Loss: 0.1594)


Epoch 3/50: 100%|██████████| 6857/6857 [10:58<00:00, 10.41it/s]


Epoch 3/50, Generator Loss: 0.1027, Discriminator Loss: 0.3669
✅ Model saved at epoch 3 (Best Generator Loss: 0.1027)


Epoch 4/50: 100%|██████████| 6857/6857 [10:56<00:00, 10.45it/s]


Epoch 4/50, Generator Loss: 0.0844, Discriminator Loss: 0.3509
✅ Model saved at epoch 4 (Best Generator Loss: 0.0844)


Epoch 5/50: 100%|██████████| 6857/6857 [10:56<00:00, 10.45it/s]


Epoch 5/50, Generator Loss: 0.0771, Discriminator Loss: 0.3426
✅ Model saved at epoch 5 (Best Generator Loss: 0.0771)
💾 Saved model at epoch 5


Epoch 6/50: 100%|██████████| 6857/6857 [10:55<00:00, 10.46it/s]


Epoch 6/50, Generator Loss: 0.0718, Discriminator Loss: 0.3512
✅ Model saved at epoch 6 (Best Generator Loss: 0.0718)


Epoch 7/50: 100%|██████████| 6857/6857 [10:58<00:00, 10.42it/s]


Epoch 7/50, Generator Loss: 0.0653, Discriminator Loss: 0.3562
✅ Model saved at epoch 7 (Best Generator Loss: 0.0653)


Epoch 8/50: 100%|██████████| 6857/6857 [11:00<00:00, 10.39it/s]


Epoch 8/50, Generator Loss: 0.0539, Discriminator Loss: 0.3455
✅ Model saved at epoch 8 (Best Generator Loss: 0.0539)


Epoch 9/50: 100%|██████████| 6857/6857 [11:00<00:00, 10.39it/s]


Epoch 9/50, Generator Loss: 0.0477, Discriminator Loss: 0.3255
✅ Model saved at epoch 9 (Best Generator Loss: 0.0477)


Epoch 10/50: 100%|██████████| 6857/6857 [10:57<00:00, 10.43it/s]


Epoch 10/50, Generator Loss: 0.0453, Discriminator Loss: 0.3128
✅ Model saved at epoch 10 (Best Generator Loss: 0.0453)
💾 Saved model at epoch 10


Epoch 11/50: 100%|██████████| 6857/6857 [10:54<00:00, 10.48it/s]


Epoch 11/50, Generator Loss: 0.0439, Discriminator Loss: 0.3037
✅ Model saved at epoch 11 (Best Generator Loss: 0.0439)


Epoch 12/50: 100%|██████████| 6857/6857 [10:56<00:00, 10.44it/s]


Epoch 12/50, Generator Loss: 0.0430, Discriminator Loss: 0.2962
✅ Model saved at epoch 12 (Best Generator Loss: 0.0430)


Epoch 13/50: 100%|██████████| 6857/6857 [10:55<00:00, 10.46it/s]


Epoch 13/50, Generator Loss: 0.0424, Discriminator Loss: 0.2903
✅ Model saved at epoch 13 (Best Generator Loss: 0.0424)


Epoch 14/50: 100%|██████████| 6857/6857 [10:59<00:00, 10.39it/s]


Epoch 14/50, Generator Loss: 0.0419, Discriminator Loss: 0.2845
✅ Model saved at epoch 14 (Best Generator Loss: 0.0419)


Epoch 15/50: 100%|██████████| 6857/6857 [10:58<00:00, 10.41it/s]


Epoch 15/50, Generator Loss: 0.0416, Discriminator Loss: 0.2794
✅ Model saved at epoch 15 (Best Generator Loss: 0.0416)
💾 Saved model at epoch 15


Epoch 16/50: 100%|██████████| 6857/6857 [10:59<00:00, 10.39it/s]


Epoch 16/50, Generator Loss: 0.0413, Discriminator Loss: 0.2753
✅ Model saved at epoch 16 (Best Generator Loss: 0.0413)


Epoch 17/50: 100%|██████████| 6857/6857 [11:02<00:00, 10.35it/s]


Epoch 17/50, Generator Loss: 0.0410, Discriminator Loss: 0.2714
✅ Model saved at epoch 17 (Best Generator Loss: 0.0410)


Epoch 18/50: 100%|██████████| 6857/6857 [11:03<00:00, 10.34it/s]


Epoch 18/50, Generator Loss: 0.0408, Discriminator Loss: 0.2680
✅ Model saved at epoch 18 (Best Generator Loss: 0.0408)


Epoch 19/50:   9%|▉         | 641/6857 [01:01<10:00, 10.34it/s]


KeyboardInterrupt: 

This snippet demonstrates how to load a saved checkpoint containing model weights and perform a quick test generation. It loads the checkpoint, restores the generator and discriminator states, and sets both models to evaluation mode. A dummy input sequence is then created and passed through the generator. The output logits are converted to probabilities, and the most likely token is chosen at each step. Finally, the predicted token indices are mapped back to their corresponding token strings, and a sample sequence is printed for inspection.

In [None]:
import torch
import pickle
from tqdm import tqdm

# Define the checkpoint file path
checkpoint_path = "/content/drive/My Drive/checkpoints/best_model_Gloss_0.0319.pth"

# Define the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the checkpoint file
checkpoint = torch.load(checkpoint_path, map_location=device)

# Load the weights into the models (assuming generator and discriminator already exist)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

# Move the models to the device and set them to evaluation mode
generator = generator.to(device)
discriminator = discriminator.to(device)
generator.eval()
discriminator.eval()

print(f"✅ Checkpoint loaded from {checkpoint_path}")
print("Generator in eval mode:")
print(generator)
print("\nDiscriminator in eval mode:")
print(discriminator)

# Perform an initial test:
# Create a random input from the training data array (if available) or use a dummy example.
# Assume that sequence_length is 200.
dummy_input = torch.randint(0, n_vocab, (1, 200), dtype=torch.long).to(device)
with torch.no_grad():
    output_logits = generator(dummy_input)  # Shape: [1, 200, n_vocab]
    output_probs = torch.softmax(output_logits, dim=-1)
    # Choose the token with the highest probability at each step
    predicted_tokens = torch.argmax(output_probs, dim=-1)

    # Convert tokens to words using int_to_note
    predicted_sequence = [int_to_note[idx.item()] for idx in predicted_tokens[0]]

print("\n✅ Sample generated sequence (first 50 tokens):")
print(predicted_sequence[:50])


  checkpoint = torch.load(checkpoint_path, map_location=device)


NameError: name 'generator' is not defined

This script provides two main functionalities:

Sequence Generation:

- The generate_sequence function takes a seed token sequence and uses the trained generator model to extend it until it reaches a target length.
- It uses temperature-controlled multinomial sampling to pick the next token from the generator’s output, ensuring variability in the generated sequence.

Token-to-Music Conversion with music21:
- A pitch class mapping converts numerical values to note names.
- The parse_duration function converts duration tokens (in 16th-note units) to standard quarter lengths.
- The token_to_music21 function converts individual tokens into music21 objects (notes, chords, or rests) based on their prefix.
- The build_music21_stream function assembles a full music21 stream from a token sequence by first extracting context tokens (tempo, key, time signature) and then appending the musical events, enabling playback or further analysis.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import re
import music21
from music21 import stream, note, chord, tempo, key, meter

# =============================================================================
# Function to generate a long sequence using probabilistic sampling
# =============================================================================
def generate_sequence(generator, seed_sequence, target_length, temperature=1.0):
    """
    Receives a seed_sequence (a list of indices) and continues to generate tokens until the sequence reaches target_length.
    The temperature parameter affects the diversity of the choices.
    """
    generator.eval()
    generated = seed_sequence.copy()
    while len(generated) < target_length:
        # Select the last window of input of length sequence_length
        current_window = generated[-sequence_length:]
        input_seq = torch.tensor([current_window], dtype=torch.long).to(device)
        with torch.no_grad():
            logits = generator(input_seq)  # Shape: [1, sequence_length, n_vocab]
        logits_last = logits[0, -1, :] / temperature
        probabilities = torch.softmax(logits_last, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1).item()
        generated.append(next_token)
    return generated

# =============================================================================
# Functions to convert tokens to music using music21
# =============================================================================

# Definition of the pitch class mapping
pitch_class_map = {
    0: "C", 1: "C#", 2: "D", 3: "D#",
    4: "E", 5: "F", 6: "F#", 7: "G",
    8: "G#", 9: "A", 10: "A#", 11: "B"
}

def parse_duration(dur_str, scale=1.0):
    """
    The duration is given as a number of 16th note units; divide by 4 to get quarterLength.
    For example, '8' -> 8/4 = 2.0 quarterLength.
    With scale=1.0, the duration of each event remains as is.
    """
    try:
        dur = float(dur_str)
    except:
        dur = 1.0
    return (dur / 4.0) * scale

def token_to_music21(token):
    """
    Converts a single token to a music21 object:
      - NOTE_<pitch>_<dur>: creates a note.Note object
      - CHORD_<chord_str>_<dur>: creates a chord.Chord object
      - REST_<dur>: creates a note.Rest object
    """
    if token.startswith("NOTE_"):
        parts = token.split("_")
        if len(parts) < 3:
            return None
        pitch_str = parts[1]
        dur = parse_duration(parts[2], scale=1.0)
        n_obj = note.Note(pitch_str)
        n_obj.quarterLength = dur
        return n_obj
    elif token.startswith("CHORD_"):
        parts = token.split("_")
        if len(parts) < 3:
            return None
        chord_numbers_str = parts[1]
        dur = parse_duration(parts[2], scale=1.0)
        numbers = chord_numbers_str.split(".")
        pitches = []
        for num_str in numbers:
            try:
                num = int(num_str)
                pc = num % 12
                pitch_name = pitch_class_map.get(pc, "C")
                pitches.append(f"{pitch_name}4")
            except:
                continue
        if pitches:
            c_obj = chord.Chord(pitches)
            c_obj.quarterLength = dur
            return c_obj
        else:
            return None
    elif token.startswith("REST_"):
        parts = token.split("_")
        if len(parts) < 2:
            return None
        dur = parse_duration(parts[1], scale=1.0)
        r_obj = note.Rest()
        r_obj.quarterLength = dur
        return r_obj
    else:
        return None

def build_music21_stream(token_sequence):
    """
    Receives a list of tokens (a sequence) and converts it to a music21 stream (music21.stream.Stream).
    Assumes that the first 3 tokens provide context: TEMPO, KEY, and TIME.
    """
    s = stream.Stream()

    if len(token_sequence) >= 3:
        tempo_token = token_sequence[0]  # e.g., TEMPO_120.0
        key_token = token_sequence[1]    # e.g., KEY_F#_major
        time_token = token_sequence[2]   # e.g., TIME_4_4

        # Process tempo
        m = re.match(r"TEMPO_(\d+(\.\d+)?)", tempo_token)
        if m:
            t = tempo.MetronomeMark(number=float(m.group(1)))
            s.insert(0, t)

        # Process key: split by "_" and convert to format like "F# major"
        m = re.match(r"KEY_(.+)", key_token)
        if m:
            # First, replace "_" with space
            k_str = m.group(1).replace("_", " ").strip()
            try:
                # Try to create the Key object directly
                k_obj = key.Key(k_str)
                s.insert(0, k_obj)
            except Exception as e:
                # In case of error, try replacing " sharp" back to "#"
                k_str_fixed = k_str.replace(" sharp", "#")
                try:
                    k_obj = key.Key(k_str_fixed)
                    s.insert(0, k_obj)
                except Exception as e2:
                    print("Error processing key:", k_str, e2)

        # Process time signature
        m = re.match(r"TIME_(.+)", time_token)
        if m:
            ts_str = m.group(1).replace("_", "/")
            try:
                ts_obj = meter.TimeSignature(ts_str)
                s.insert(0, ts_obj)
            except Exception as e:
                print("Error processing time signature:", e)

        events = token_sequence[3:]
    else:
        events = token_sequence

    for token in events:
        m21_event = token_to_music21(token)
        if m21_event is not None:
            s.append(m21_event)

    return s


This example demonstrates the end-to-end process of generating a new musical sequence and converting it to a MIDI file. Starting with a seed sequence from the dataset, the code uses the trained generator model to extend the sequence to a specified target length (e.g., 900 tokens) using temperature-controlled sampling. The generated token indices are then mapped back to their corresponding token strings. Finally, these tokens are converted into a music21 stream—incorporating tempo, key, and time signature context—and written out as a MIDI file for playback or further analysis.

In [None]:
# =============================================================================
# Example usage:
# Assume that train_input, int_to_note, device, n_vocab, and sequence_length are already defined.
# Select a seed from the dataset (for example, the first sequence)
seed_sequence = train_input[0].tolist()

# Define a higher target_length – for instance, 1500 tokens to obtain a composition lasting about 1:30 minutes
target_length = 900

# Generate the sequence
generated_sequence = generate_sequence(generator, seed_sequence, target_length, temperature=1.0)
generated_tokens = [int_to_note[token] for token in generated_sequence]

print("Generated sequence length:", len(generated_tokens))
print("First 50 tokens:", generated_tokens[:50])

# Convert the sequence to a music stream and generate a MIDI file
output_midi_path = '/content/drive/My Drive/generated_music.mid'
music_stream = build_music21_stream(generated_tokens)
music_stream.write('midi', fp=output_midi_path)
print("✅ MIDI file generated successfully:", output_midi_path)


Generated sequence length: 900
First 50 tokens: ['TEMPO_120.0', 'KEY_F#_major', 'TIME_4_4', 'REST_8', 'REST_8', 'REST_10', 'CHORD_1.6_5', 'NOTE_B-4_8', 'NOTE_B-3_1', 'NOTE_C#4_1', 'CHORD_1.6_1', 'CHORD_10.1_1', 'REST_3', 'REST_0', 'CHORD_10.1_1', 'CHORD_1.6_2', 'REST_4', 'REST_8', 'REST_14', 'NOTE_B-4_1', 'NOTE_B-3_1', 'CHORD_6.10.1_3', 'CHORD_10.1_1', 'NOTE_F#4_3', 'REST_1', 'CHORD_10.1_1', 'CHORD_1.6_1', 'NOTE_B-4_3', 'REST_3', 'CHORD_10.1_1', 'NOTE_F#4_4', 'REST_3', 'NOTE_C#4_1', 'NOTE_B-3_1', 'NOTE_E-3_3', 'NOTE_E-2_3', 'NOTE_B-4_2', 'REST_0', 'CHORD_10.3_1', 'REST_1', 'REST_9', 'NOTE_E-3_6', 'NOTE_E-2_8', 'NOTE_B-4_5', 'CHORD_3.6_1', 'CHORD_10.3_1', 'REST_1', 'CHORD_10.3_1', 'NOTE_F#4_1', 'CHORD_10.3_1']
Error processing key: F# major # ajor is not a supported accidental type
✅ MIDI file generated successfully: /content/drive/My Drive/generated_music.mid


OUTPUT READY: PROCEEDING WITH TRAINING IMPROVEMENTS
**bold text**

This snippet loads an existing checkpoint to resume training if needed. It restores the generator and discriminator weights, along with their optimizer states, ensuring that training can continue seamlessly from the saved epoch without losing progress

In [None]:
import os
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import pickle

# -------------------------------
# Basic configurations and data loading
# -------------------------------
save_path = "/content/drive/My Drive/checkpoints/"
checkpoint_path = os.path.join(save_path, "best_model_Gloss_0.0408.pth")

# Load the training data (if not already loaded)
with open('/content/drive/My Drive/training_input.pkl', 'rb') as f:
    train_inputs = pickle.load(f)
with open('/content/drive/My Drive/training_target.pkl', 'rb') as f:
    train_targets = pickle.load(f)

print("Training inputs shape:", train_inputs.shape)
print("Training targets shape:", train_targets.shape)

# Create DataLoader
dataset = TensorDataset(torch.tensor(train_inputs, dtype=torch.long),
                        torch.tensor(train_targets, dtype=torch.long))
batch_size = 512
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# -------------------------------
# Loading models and weights
# -------------------------------
# Assume that the models generator and discriminator have already been defined
# (SmallMusicGenerator, SmallMusicDiscriminator) and that note_to_int has been loaded,
# so we calculate n_vocab:
with open('/content/drive/My Drive/note_to_int.pkl', 'rb') as f:
    note_to_int = pickle.load(f)
n_vocab = len(note_to_int)
sequence_length = train_inputs.shape[1]  # for example, 200

# Initialize the models (if not already initialized)
generator = SmallMusicGenerator(n_vocab=n_vocab, sequence_length=sequence_length)
discriminator = SmallMusicDiscriminator(n_vocab=n_vocab, sequence_length=sequence_length)

# Move models to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)
generator.positional_encoding = generator.positional_encoding.to(device)

# Load weights from the checkpoint (epoch 18)
checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
# Load optimizer states
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

# Set models to training mode
generator.train()
discriminator.train()

print(f"✅ Checkpoint loaded from {checkpoint_path}. Continuing training from this point...")

# -------------------------------
# Define loss functions and additional optimizers, and set up AMP
# -------------------------------
bce_loss = nn.BCEWithLogitsLoss()   # for the discriminator
ce_loss = nn.CrossEntropyLoss()     # for Teacher Forcing
scaler = GradScaler()

best_g_loss = float('inf')
start_epoch = 18        # assume the checkpoint is from epoch 18
epochs_to_run = 50      # for example, continue for another 50 epochs

# -------------------------------
# Training Loop - Continue training from the checkpoint
# -------------------------------
for epoch in range(start_epoch, start_epoch + epochs_to_run):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0.0
    epoch_d_loss = 0.0

    for batch_input, batch_target in tqdm(data_loader, desc=f"Epoch {epoch+1}/{start_epoch+epochs_to_run}"):
        batch_input = batch_input.to(device).long()
        batch_target = batch_target.to(device).long()

        ## === Update Discriminator === ##
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_input.shape[0], 1).to(device)
        fake_labels = torch.zeros(batch_input.shape[0], 1).to(device)

        with autocast():
            # For real data: we pass batch_target (which is a sequence of indices)
            real_validity = discriminator(batch_target)  # [B, 1] (raw logits, no Sigmoid because using BCEWithLogitsLoss)
            loss_real = bce_loss(real_validity, real_labels)

            # For fake data:
            fake_logits = generator(batch_input)  # [B, sequence_length, n_vocab]
            fake_tokens = torch.argmax(fake_logits, dim=-1).long()  # [B, sequence_length]
            fake_validity = discriminator(fake_tokens.detach())
            loss_fake = bce_loss(fake_validity, fake_labels)

            d_loss = (loss_real + loss_fake) / 2.0

        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)
        scaler.update()

        ## === Update Generator === ##
        optimizer_G.zero_grad()
        with autocast():
            fake_logits = generator(batch_input)
            g_teacher_loss = ce_loss(fake_logits.view(-1, n_vocab), batch_target.view(-1))

            fake_tokens = torch.argmax(fake_logits, dim=-1).long()
            validity = discriminator(fake_tokens.detach())
            g_adv_loss = bce_loss(validity, real_labels)

            g_loss = g_teacher_loss + 0.002 * g_adv_loss

        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

    avg_g_loss = epoch_g_loss / len(data_loader)
    avg_d_loss = epoch_d_loss / len(data_loader)
    print(f"Epoch {epoch+1}/{start_epoch+epochs_to_run}, Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")

    if avg_g_loss < best_g_loss:
        best_g_loss = avg_g_loss
        best_save_file = os.path.join(save_path, f"best_model_Gloss_{avg_g_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, best_save_file)
        print(f"✅ Best model saved at epoch {epoch+1} (Best Generator Loss: {best_g_loss:.4f})")

    if (epoch + 1) % 5 == 0:
        save_file = os.path.join(save_path, f"epoch_{epoch+1}_Gloss_{avg_g_loss:.4f}_Dloss_{avg_d_loss:.4f}.pth")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, save_file)
        print(f"💾 Saved model at epoch {epoch+1}")


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/My Drive/training_input.pkl'

load weight and eval

In [None]:
import torch
import pickle
from tqdm import tqdm
import os
import numpy as np

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the mappings (note_to_int and int_to_note) – ensure these files exist
with open('/content/drive/My Drive/note_to_int.pkl', 'rb') as f:
    note_to_int = pickle.load(f)
with open('/content/drive/My Drive/int_to_note.pkl', 'rb') as f:
    int_to_note = pickle.load(f)
n_vocab = len(note_to_int)

# Set sequence length (for example, 200 tokens)
sequence_length = 200
# Initialize the models
generator = SmallMusicGenerator(n_vocab=n_vocab, sequence_length=sequence_length)
discriminator = SmallMusicDiscriminator(n_vocab=n_vocab, sequence_length=sequence_length)

# Move models to the appropriate device
generator = generator.to(device)
discriminator = discriminator.to(device)
generator.positional_encoding = generator.positional_encoding.to(device)

# Load the checkpoint (using weights_only=True to avoid FutureWarning)
checkpoint_path = "/content/drive/My Drive/checkpoints/best_model_Gloss_0.0319.pth"  # update this path accordingly
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
print("✅ Checkpoint loaded successfully.")

# Set models to evaluation mode
generator.eval()
discriminator.eval()

# Load the training data (choose a random example for evaluation)
with open('/content/drive/My Drive/training_target.pkl', 'rb') as f:
    train_inputs = pickle.load(f)

print("Training inputs shape:", train_inputs.shape)

# Select a random example
sample_idx = np.random.randint(0, train_inputs.shape[0])
sample_input = torch.tensor(train_inputs[sample_idx:sample_idx+1], dtype=torch.long).to(device)

# Generate output from the generator in evaluation mode
with torch.no_grad():
    output_logits = generator(sample_input)  # Output shape: [1, sequence_length, n_vocab]
    predicted_indices = torch.argmax(output_logits, dim=-1)  # [1, sequence_length]

# Convert indices to tokens
predicted_tokens = [int_to_note[idx.item()] for idx in predicted_indices[0]]
print("Predicted sequence:")
print(predicted_tokens)


✅ Checkpoint loaded successfully.
Training inputs shape: (3510447, 200)
Predicted sequence:
['CHORD_10.11.1_2', 'REST_1', 'CHORD_8.11.1_1', 'REST_1', 'CHORD_5.8.11.1_1', 'REST_2', 'REST_12', 'CHORD_5.8.11.1_1', 'REST_2', 'CHORD_6.10.1_1', 'REST_1', 'NOTE_G1_1', 'REST_2', 'CHORD_10.11.1.5_1', 'NOTE_B-1_1', 'REST_1', 'NOTE_F2_1', 'REST_3', 'REST_1', 'REST_1', 'REST_7', 'REST_14', 'REST_14', 'CHORD_3.6.10_3', 'REST_3', 'CHORD_5.8.0_4', 'CHORD_5_1', 'REST_1', 'NOTE_G#1_1', 'NOTE_G#2_1', 'REST_3', 'REST_2', 'NOTE_E-5_2', 'CHORD_3.8_2', 'NOTE_C#5_2', 'CHORD_1.3_2', 'REST_7', 'REST_9', 'REST_10', 'REST_10', 'REST_10', 'NOTE_E-5_9', 'CHORD_3.8_4', 'NOTE_C#5_7', 'CHORD_1.3_4', 'CHORD_8_1', 'REST_1', 'CHORD_3_2', 'CHORD_8.0_7', 'CHORD_3.8_4', 'NOTE_C4_3', 'NOTE_G3_1', 'REST_5', 'REST_4', 'REST_3', 'REST_2', 'CHORD_5_1', 'REST_1', 'REST_5', 'REST_5', 'REST_11', 'CHORD_5_2', 'CHORD_8.0_2', 'CHORD_1_1', 'REST_3', 'NOTE_F3_3', 'CHORD_1.5.8_6', 'CHORD_1.5_6', 'REST_3', 'NOTE_E-3_3', 'NOTE_E-2_5', 'RE

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import re
import music21
from music21 import stream, note, chord, tempo, key, meter
import random

# =============================================================================
# Function to generate a long sequence using probabilistic sampling
# =============================================================================
def generate_sequence(generator, seed_sequence, target_length, temperature=1.0):
    """
    Receives a seed_sequence (a list of indices) and continues to generate tokens until the sequence reaches target_length.
    The temperature parameter influences the diversity of the selection.
    """
    generator.eval()
    generated = seed_sequence.copy()
    while len(generated) < target_length:
        # Select the last input window of length sequence_length
        current_window = generated[-sequence_length:]
        input_seq = torch.tensor([current_window], dtype=torch.long).to(device)
        with torch.no_grad():
            logits = generator(input_seq)  # Shape: [1, sequence_length, n_vocab]
        logits_last = logits[0, -1, :] / temperature
        probabilities = torch.softmax(logits_last, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1).item()
        generated.append(next_token)
    return generated

# =============================================================================
# Functions to convert tokens to music using music21
# =============================================================================

# Define pitch class mapping
pitch_class_map = {
    0: "C", 1: "C#", 2: "D", 3: "D#",
    4: "E", 5: "F", 6: "F#", 7: "G",
    8: "G#", 9: "A", 10: "A#", 11: "B"
}

def parse_duration(dur_str, scale=1.0):
    """
    The duration is given as a number of 16th units; we divide by 4 to get quarterLength.
    For example, '8' → 8/4 = 2.0 quarterLength.
    With scale=1.0 – meaning the duration of each event remains as is.
    """
    try:
        dur = float(dur_str)
    except:
        dur = 1.0
    return (dur / 4.0) * scale

def token_to_music21(token):
    """
    Converts a single token to a music21 object:
      - NOTE_<pitch>_<dur>: creates a note.Note object
      - CHORD_<chord_str>_<dur>: creates a chord.Chord object
      - REST_<dur>: creates a note.Rest object
    """
    if token.startswith("NOTE_"):
        parts = token.split("_")
        if len(parts) < 3:
            return None
        pitch_str = parts[1]
        dur = parse_duration(parts[2], scale=1.0)
        n_obj = note.Note(pitch_str)
        n_obj.quarterLength = dur
        return n_obj
    elif token.startswith("CHORD_"):
        parts = token.split("_")
        if len(parts) < 3:
            return None
        chord_numbers_str = parts[1]
        dur = parse_duration(parts[2], scale=1.0)
        numbers = chord_numbers_str.split(".")
        pitches = []
        for num_str in numbers:
            try:
                num = int(num_str)
                pc = num % 12
                pitch_name = pitch_class_map.get(pc, "C")
                pitches.append(f"{pitch_name}4")
            except:
                continue
        if pitches:
            c_obj = chord.Chord(pitches)
            c_obj.quarterLength = dur
            return c_obj
        else:
            return None
    elif token.startswith("REST_"):
        parts = token.split("_")
        if len(parts) < 2:
            return None
        dur = parse_duration(parts[1], scale=1.0)
        r_obj = note.Rest()
        r_obj.quarterLength = dur
        return r_obj
    else:
        return None

def build_music21_stream(token_sequence):
    """
    Receives a list of tokens (a sequence) and converts it into a music21 stream (music21.stream.Stream).
    Assumes that the first 3 tokens provide context: TEMPO, KEY, and TIME.
    """
    s = stream.Stream()

    if len(token_sequence) >= 3:
        tempo_token = token_sequence[0]  # e.g., TEMPO_120.0
        key_token = token_sequence[1]    # e.g., KEY_F#_major
        time_token = token_sequence[2]   # e.g., TIME_4_4

        # Process tempo
        m = re.match(r"TEMPO_(\d+(\.\d+)?)", tempo_token)
        if m:
            t = tempo.MetronomeMark(number=float(m.group(1)))
            s.insert(0, t)

        # Process key: split by "_" and convert so the string is in the format "F# major"
        m = re.match(r"KEY_(.+)", key_token)
        if m:
            # First, replace "_" with a space
            k_str = m.group(1).replace("_", " ").strip()
            try:
                # Try to create the Key object directly
                k_obj = key.Key(k_str)
                s.insert(0, k_obj)
            except Exception as e:
                # In case of an error, try replacing " sharp" back to "#"
                k_str_fixed = k_str.replace(" sharp", "#")
                try:
                    k_obj = key.Key(k_str_fixed)
                    s.insert(0, k_obj)
                except Exception as e2:
                    print("Error processing key:", k_str, e2)

        # Process time signature
        m = re.match(r"TIME_(.+)", time_token)
        if m:
            ts_str = m.group(1).replace("_", "/")
            try:
                ts_obj = meter.TimeSignature(ts_str)
                s.insert(0, ts_obj)
            except Exception as e:
                print("Error processing time signature:", e)

        events = token_sequence[3:]
    else:
        events = token_sequence

    for token in events:
        m21_event = token_to_music21(token)
        if m21_event is not None:
            s.append(m21_event)

    return s

# =============================================================================
# Example usage:
# Assume that train_input, int_to_note, device, n_vocab, and sequence_length are already defined.
# Select a seed from the dataset (for example, the first sequence)
# seed_sequence = train_inputs[0].tolist()
seed_sequence = random.choice(train_inputs).tolist()
# Define a higher target_length – for example, 1500 tokens to obtain a composition of about 1:30 minutes
target_length = 500

# Generate the sequence
generated_sequence = generate_sequence(generator, seed_sequence, target_length, temperature=1.0)
generated_tokens = [int_to_note[token] for token in generated_sequence]

print("Generated sequence length:", len(generated_tokens))
print("First 50 tokens:", generated_tokens[:50])

# Convert the sequence to a music stream and generate a MIDI file
music_stream = build_music21_stream(generated_tokens)
output_midi_path = '/content/drive/My Drive/generated_music.mid'
music_stream.write('midi', fp=output_midi_path)
print("✅ MIDI file generated successfully:", output_midi_path)


Generated sequence length: 500
First 50 tokens: ['CHORD_2.4_1', 'REST_1', 'CHORD_10.2_1', 'NOTE_C#4_1', 'REST_0', 'REST_2', 'NOTE_E4_1', 'NOTE_D4_1', 'REST_3', 'REST_13', 'NOTE_B2_3', 'CHORD_11.1_1', 'REST_0', 'CHORD_11.1.2_1', 'REST_4', 'REST_1', 'NOTE_B4_3', 'NOTE_C#4_2', 'REST_1', 'NOTE_F#5_4', 'REST_2', 'NOTE_E-4_2', 'NOTE_A4_3', 'REST_1', 'REST_2', 'NOTE_E4_1', 'REST_1', 'REST_3', 'NOTE_E4_2', 'NOTE_A4_1', 'CHORD_9.11_1', 'REST_0', 'CHORD_8.9_1', 'CHORD_4.9_1', 'REST_1', 'REST_3', 'NOTE_B4_3', 'NOTE_F#3_2', 'REST_1', 'NOTE_D4_3', 'REST_2', 'NOTE_G#3_2', 'REST_1', 'REST_2', 'NOTE_B4_1', 'REST_0', 'NOTE_C#4_1', 'NOTE_A2_1', 'NOTE_A4_1', 'REST_2']
✅ MIDI file generated successfully: /content/drive/My Drive/generated_music.mid
