# Training a GPT-2 Model for Music Generation Using 1-Second Audio Chunks

This script trains a GPT-2-based model for music track generation using 1-second audio chunks, offering finer granularity compared to the usual 10-second training chunks. The dataset comprises 10-second audio embeddings, which are split into smaller 1-second sequential chunks (75 tokens each) for training. This approach enables the model to learn short-term musical patterns more effectively while still leveraging positional embeddings to capture temporal dependencies.

In [21]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split
import wandb
import torch.nn as nn

VOCAB_SIZE = 1026
CHUNK_LENGTH = 75  # 1 second = 75 tokens
track_classes = ['hi_hat']

# Load your saved .npy file
data = np.load(
    'processed_tracks_data_final_10secs_embeddings_standardized.npy',
    allow_pickle=True
).item()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SequentialMusicDataset(Dataset):
    def __init__(self, data, track_class, chunk_length=CHUNK_LENGTH):
        """
        Initializes the dataset.

        Args:
            data (dict): Loaded data dictionary.
            track_class (str): The track class to process.
            chunk_length (int): Length of each chunk.
        """
        self.track_class = track_class
        self.chunk_length = chunk_length
        self.data = {
            k: v for k, v in data.items()
            if self.track_class in v['generation_data']
        }
        self.samples = []

        # Preprocess all samples to include positional embeddings
        for sample_id, sample in self.data.items():
            # Retrieve precomputed positional embeddings
            positional_embedding = sample.get('positional_embedding')
            if positional_embedding is None:
                print(f"Warning: No 'positional_embedding' for sample_id {sample_id}. Skipping.")
                continue

            # Get track data and split into chunks
            track_data = sample['generation_data'][self.track_class].flatten()
            vocal_audio_codes = sample['generation_data'].get('vocal', np.zeros((4, chunk_length)))

            # Clip values in audio codes to ensure valid range
            track_data = np.clip(track_data, 0, VOCAB_SIZE - 1)
            vocal_audio_codes = np.clip(vocal_audio_codes, 0, VOCAB_SIZE - 1)

            # Split data into chunks
            num_chunks = int(np.ceil(len(track_data) / self.chunk_length))
            track_chunks = np.array_split(track_data, num_chunks)
            vocal_chunks = np.array_split(vocal_audio_codes.flatten(), num_chunks)

            # Pad each chunk to CHUNK_LENGTH
            track_chunks = [
                np.pad(chunk, (0, self.chunk_length - len(chunk)), 'constant', constant_values=0)
                for chunk in track_chunks
            ]
            vocal_chunks = [
                np.pad(chunk, (0, self.chunk_length - len(chunk)), 'constant', constant_values=0)
                for chunk in vocal_chunks
            ]

            # Convert lists of arrays to single NumPy arrays for efficient tensor creation
            track_chunks = np.array(track_chunks)
            vocal_chunks = np.array(vocal_chunks)

            # Split positional_embeddings into chunks
            # Assuming positional_embedding has shape [MAX_LENGTH, embedding_dim]
            # and chunk_length corresponds to sequential segments
            # Here, we assume each chunk corresponds to 'CHUNK_LENGTH' tokens
            # Thus, positional_embeddings for a chunk are positional_embedding[i*CHUNK_LENGTH : (i+1)*CHUNK_LENGTH]
            positional_chunks = np.array_split(positional_embedding, num_chunks)
            positional_chunks = [
                np.pad(chunk, ((0, self.chunk_length - chunk.shape[0]), (0,0)), 'constant', constant_values=0)
                for chunk in positional_chunks
            ]
            positional_chunks = np.array(positional_chunks)
            positional_chunks = torch.tensor(positional_chunks, dtype=torch.float)  # [num_chunks, CHUNK_LENGTH, embedding_dim]

            # Append to samples
            for input_ids, labels, pos_emb in zip(vocal_chunks, track_chunks, positional_chunks):
                self.samples.append({
                    'input_ids': torch.tensor(input_ids, dtype=torch.long),
                    'labels': torch.tensor(labels, dtype=torch.long),
                    'positional_embeddings': pos_emb
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Custom GPT-2 model with positional embeddings
class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # Projection layer to match positional embeddings to model's embedding size
        self.positional_proj = nn.Linear(128, config.n_embd)  # Assuming positional_embeddings have dim 64

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                past_key_values=None, positional_embeddings=None, **kwargs):
        """
        Forward pass that incorporates positional embeddings.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask.
            labels (torch.Tensor): Labels for language modeling.
            past_key_values (tuple): Past key and value states.
            positional_embeddings (torch.Tensor): Positional embeddings to add to input embeddings.

        Returns:
            CausalLMOutputWithCrossAttentions: Model outputs.
        """
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)  # [batch, seq_length, n_embd]

        if positional_embeddings is not None:
            # Project positional embeddings to match n_embd
            pos_emb_proj = self.positional_proj(positional_embeddings)  # [batch, seq_length, n_embd]
            input_embeds = input_embeds + pos_emb_proj  # [batch, seq_length, n_embd]

        # Proceed with the standard GPT-2 forward pass using inputs_embeds
        return super().forward(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            **kwargs
        )


# Custom data collator to handle batches
class DataCollatorWithPositionalEmbeddings:
    def __call__(self, batch):
        input_ids = torch.stack([item['input_ids'] for item in batch])  # [batch_size, CHUNK_LENGTH]
        labels = torch.stack([item['labels'] for item in batch])        # [batch_size, CHUNK_LENGTH]
        positional_embeddings = torch.stack([item['positional_embeddings'] for item in batch])  # [batch_size, CHUNK_LENGTH, embedding_dim]
        
        # Create attention mask: 1 where input_ids != 0, else 0
        attention_mask = (input_ids != 0).long()  # [batch_size, CHUNK_LENGTH]
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'positional_embeddings': positional_embeddings
        }


# Initialize the model configuration
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=3000,  # Adjusted for maximum sequence length
    n_ctx=3000,         # Adjusted for maximum context length
    n_embd=128,         # Embedding size
    n_layer=6,          # Number of transformer layers
    n_head=8,           # Number of attention heads
    activation_function='gelu',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

# Initialize the custom GPT-2 model
model = CustomGPT2ForConditionalGeneration(config=config).to(device)

# Initialize Weights & Biases
wandb.init(project='music_generation_with_memory')

# Define a simple compute_metrics function (optional)
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    shift_logits = torch.tensor(logits).transpose(0, 1)
    shift_labels = torch.tensor(labels).transpose(0, 1)
    
    loss_fct = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
    loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
    
    perplexity = torch.exp(loss)
    return {"perplexity": perplexity.item(), "loss": loss.item()}


# Training loop with sequential memory using Trainer
for track_idx, track in enumerate(track_classes):
    print(f"Training for {track} with sequential memory...")

    # Create dataset for the current track
    dataset = SequentialMusicDataset(data, track, chunk_length=CHUNK_LENGTH)
    train_indices, val_indices = train_test_split(
        range(len(dataset)), test_size=0.2, random_state=42
    )
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    # Output and logging directories
    track_output_dir = f'./sequential_memory_checkpointing_{track}_model'
    os.makedirs(track_output_dir, exist_ok=True)
    wandb_run_name = f'sequential_memory_training_{track}_run'

    # Initialize wandb for the current track
    wandb.init(project='music_generation_with_memory', name=wandb_run_name)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=track_output_dir,
        overwrite_output_dir=True,
        num_train_epochs=120,
        per_device_train_batch_size=1,  # Adjust based on GPU memory
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=1,   # Adjust as needed
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=1e-4,
        weight_decay=0.01,
        logging_dir=f'./logs_{track}',
        logging_steps=100,
        report_to=["wandb"],             # Enable logging to wandb
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
        fp16=True,                       # Enable mixed precision if supported
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=DataCollatorWithPositionalEmbeddings(),
        compute_metrics=compute_metrics,  # Optional: Remove if not using
    )

    # Start training
    trainer.train()

    # Evaluate the model
    eval_results = trainer.evaluate()
    print(f"Evaluation results for {track}: {eval_results}")

    # Save the best model
    trainer.save_model(track_output_dir)

    # Finish the current wandb run
    wandb.finish()

print("Training completed for all tracks.")


0,1
eval/loss,██▄▂▂▁▁▁▁▁
eval/perplexity,██▃▂▁▁▁▁▁▁
eval/runtime,▆▅▃▆▅▄█▂▁▆
eval/samples_per_second,▂▃▆▂▃▄▁▇█▃
eval/steps_per_second,▂▃▆▂▃▄▁▇█▃
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█
train/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇██
train/grad_norm,▁▂▂▄▃▂▄▃▅▄▄▃▃▅▅▆▃▁▆▂▁▄▅▄██▃▁▄▅▄▂▂▁▁▁▅▂▅▁
train/learning_rate,██████▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train/loss,█▆▄▄▄▃▄▄▄▄▃▄▃▃▃▃▂▂▂▂▁▂▂▂▁▁▁▁▂▁▁▂▁▁▁▂▁▁▁▁

0,1
eval/loss,2.87806
eval/perplexity,17.77974
eval/runtime,2.5182
eval/samples_per_second,101.658
eval/steps_per_second,101.658
train/epoch,10.16
train/global_step,10400.0
train/grad_norm,5.89472
train/learning_rate,9e-05
train/loss,2.6429


Training for hi_hat with sequential memory...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss,Perplexity
1,3.793,3.980779,53.55875
2,3.8444,3.770371,43.396179
3,3.4232,3.348563,28.461796
4,3.0543,3.102402,22.251339
5,2.81,2.964585,19.386648
6,2.7379,2.939131,18.899406
7,2.5597,2.912565,18.403934
8,2.971,2.862559,17.506273
9,2.7925,2.876948,17.759989
10,2.5784,2.870871,17.652384


KeyboardInterrupt: 

In [23]:
import os
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, EncodecModel, GPT2LMHeadModel
import madmom
import torch.nn as nn
import tempfile

# Constants
VOCAB_SIZE = 1026
MAX_LENGTH = 3000
CHUNK_LENGTH = 75  # Default chunk size matching training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Track classes
track_classes = ['hi_hat']

# Initialize Encodec Model and Processor
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
model_encodec = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)

# Function Definitions
def encode_audio(audio_path):
    audio, rate = torchaudio.load(audio_path)
    max_length_in_samples = int(rate * 10)

    if audio.shape[1] > max_length_in_samples:
        audio = audio[:, :max_length_in_samples]
    else:
        pad_length = max_length_in_samples - audio.shape[1]
        audio = torch.nn.functional.pad(audio, (0, pad_length))

    if audio.shape[0] > 1:
        audio = audio.mean(dim=0)
    else:
        audio = audio.squeeze(0)

    inputs = processor(audio.cpu().numpy(), sampling_rate=rate, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        outputs = model_encodec.encode(inputs["input_values"], inputs["padding_mask"], 3)
    
    duration = audio.shape[0] / rate
    return outputs.audio_codes.squeeze(), min(duration, 10.0)

def extract_beats_and_downbeats(audio_path, fps=100, duration=10):
    audio, rate = torchaudio.load(audio_path)
    audio = audio[:, :int(duration * rate)]

    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
        temp_audio_path = temp_audio_file.name
        torchaudio.save(temp_audio_path, audio, rate)

    proc_downbeats = madmom.features.downbeats.DBNDownBeatTrackingProcessor(beats_per_bar=[4], fps=fps)
    act_downbeats = madmom.features.downbeats.RNNDownBeatProcessor(fps=fps)(temp_audio_path)
    downbeats = proc_downbeats(act_downbeats)

    proc_beats = madmom.features.beats.BeatDetectionProcessor(fps=fps)
    act_beats = madmom.features.beats.RNNBeatProcessor(fps=fps)(temp_audio_path)
    beats = proc_beats(act_beats)

    os.remove(temp_audio_path)

    if len(beats) == 0 or len(downbeats) == 0:
        raise ValueError(f"No beats or downbeats detected in {audio_path}")

    return beats, downbeats[downbeats[:, 1] == 1, 0]

def create_positional_embeddings(beat_times, downbeat_times, audio_duration, fps=75, K=32):
    total_frames = int(np.ceil(audio_duration * fps))

    def ramps(positions, size):
        result = np.zeros(size)
        for a, b in zip(positions[:-1], positions[1:]):
            result[a:b] = np.linspace(0, 1, b - a, endpoint=False)
        missing = positions[0]
        if missing:
            piece = result[positions[0]:positions[1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[:missing] = pieces[-missing:]
        missing = size - positions[-1]
        if missing:
            piece = result[positions[-2]:positions[-1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[-missing:] = pieces[:missing]
        return result

    vector_downbeat = ramps((downbeat_times * fps).astype(int), total_frames)
    vector_beat = ramps((beat_times * fps).astype(int), total_frames)

    frequencies = np.arange(1, K + 1)
    embeddings_downbeat = []
    embeddings_beat = []

    for k in frequencies:
        embeddings_downbeat.append(np.sin(2 * np.pi * vector_downbeat * k))
        embeddings_downbeat.append(np.cos(2 * np.pi * vector_downbeat * k))
        embeddings_beat.append(np.sin(2 * np.pi * vector_beat * k))
        embeddings_beat.append(np.cos(2 * np.pi * vector_beat * k))

    embeddings_downbeat = np.stack(embeddings_downbeat, axis=1)
    embeddings_beat = np.stack(embeddings_beat, axis=1)
    embeddings = np.hstack((embeddings_downbeat, embeddings_beat))

    return torch.from_numpy(embeddings).float()

# Define the custom GPT-2 model class
class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # Projection layer to match positional embeddings to model's embedding size
        self.positional_proj = nn.Linear(128, config.n_embd)  # Assuming positional_embeddings have dim 64

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                past_key_values=None, positional_embeddings=None, **kwargs):
        """
        Forward pass that incorporates positional embeddings.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            attention_mask (torch.Tensor): Attention mask.
            labels (torch.Tensor): Labels for language modeling.
            past_key_values (tuple): Past key and value states.
            positional_embeddings (torch.Tensor): Positional embeddings to add to input embeddings.

        Returns:
            CausalLMOutputWithCrossAttentions: Model outputs.
        """
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)  # [batch, seq_length, n_embd]

        if positional_embeddings is not None:
            # Project positional embeddings to match n_embd
            pos_emb_proj = self.positional_proj(positional_embeddings)  # [batch, seq_length, n_embd]
            input_embeds = input_embeds + pos_emb_proj  # [batch, seq_length, n_embd]

        # Proceed with the standard GPT-2 forward pass using inputs_embeds
        return super().forward(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            **kwargs
        )

def generate_track_in_chunks_with_memory(model, audio_codes, positional_embeddings, chunk_length=CHUNK_LENGTH):
    """
    Generates a track in chunks with sequential dependence using past_key_values.
    
    Args:
        model: The trained model.
        audio_codes: Tensor containing audio codes.
        positional_embeddings: Tensor containing positional embeddings.
        chunk_length: Length of each chunk.

    Returns:
        generated_sequence: Concatenated generated sequence from all chunks.
    """
    # Ensure audio_codes is 1D
    if len(audio_codes.shape) > 1:
        audio_codes = audio_codes.flatten()

    # Split audio_codes and positional_embeddings into chunks
    num_chunks = int(np.ceil(audio_codes.shape[0] / chunk_length))
    audio_chunks = np.array_split(audio_codes.cpu().numpy(), num_chunks)
    positional_chunks = np.array_split(positional_embeddings.cpu().numpy(), num_chunks)

    generated_chunks = []
    past_key_values = None  # Initialize past_key_values for sequential dependence

    for i in range(num_chunks):
        # Pad each audio chunk to the chunk length
        audio_chunk = torch.tensor(audio_chunks[i], dtype=torch.long).to(device)
        audio_chunk = torch.nn.functional.pad(
            audio_chunk,
            (0, chunk_length - audio_chunk.shape[0]),
            value=0
        )

        # Pad each positional embedding chunk to the chunk length
        pos_chunk = torch.tensor(positional_chunks[i], dtype=torch.float).to(device)
        pos_chunk = torch.nn.functional.pad(
            pos_chunk,
            (0, 0, 0, chunk_length - pos_chunk.shape[0]),
            value=0
        )

        # Ensure dimensions match
        if audio_chunk.shape[0] != pos_chunk.shape[0]:
            raise ValueError(f"Dimension mismatch in chunk {i}. "
                             f"Audio chunk length: {audio_chunk.shape[0]}, "
                             f"Positional embedding chunk length: {pos_chunk.shape[0]}.")

        # Generate attention mask
        attention_mask = (audio_chunk != 0).long().to(device)

        # Pass inputs to the model
        with torch.no_grad():
            outputs = model(
                input_ids=audio_chunk.unsqueeze(0),  # [1, chunk_length]
                attention_mask=attention_mask.unsqueeze(0),  # [1, chunk_length]
                positional_embeddings=pos_chunk.unsqueeze(0),  # [1, chunk_length, embedding_dim]
                past_key_values=past_key_values  # Include past_key_values for memory
            )
            logits = outputs.logits.argmax(dim=-1).squeeze().detach().cpu()
            past_key_values = outputs.past_key_values  # Update past_key_values for the next chunk
            generated_chunks.append(logits)

    # Concatenate all generated chunks
    generated_sequence = torch.cat(generated_chunks, dim=0)
    return generated_sequence




# Inference
inference_files = [
    ("inference.wav", "inference_posemb.wav", "1"),
    ("inference2.wav", "inference_posemb2.wav", "2"),
    ("inference3.wav", "inference_posemb3.wav", "3"),
]

for audio_path, posemb_path, folder in inference_files:
    for track_class in track_classes:
        print(f"Processing {audio_path} -> {track_class} in folder {folder}...")
        model_folder = f'./sequential_memory_checkpointing_{track_class}_model'
        checkpoint_path = os.path.join(model_folder, 'checkpoint-44032')  # Specify the checkpoint

        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")

        # Load the model from the specified checkpoint
        model = CustomGPT2ForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
        model.eval()

        # Encode the audio
        audio_codes, audio_length = encode_audio(audio_path)

        # Extract beats and downbeats
        try:
            beats, downbeats = extract_beats_and_downbeats(posemb_path, duration=audio_length)
        except ValueError as e:
            print(e)
            continue  # Skip this audio file if beats/downbeats are not detected

        # Create positional embeddings
        positional_embeddings = create_positional_embeddings(beats, downbeats, audio_length)

        # Generate the sequence in chunks
        generated_sequence = generate_track_in_chunks(model, torch.tensor(audio_codes), positional_embeddings)

        # Reshape and decode the generated sequence
        reshaped_output = generated_sequence.view(4, 750).unsqueeze(0).unsqueeze(0).to(device)  # Move to device
        decoded_audio = model_encodec.decode(reshaped_output, [None])[0]
        decoded_audio = decoded_audio.detach()
        decoded_audio = decoded_audio.squeeze(0).squeeze(0)  # Shape: [samples]
        decoded_audio = decoded_audio.unsqueeze(0)
        output_audio_path = f"./{model_folder}/{folder}/{track_class}_generated.wav"
        os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
        torchaudio.save(output_audio_path, decoded_audio.cpu(), processor.sampling_rate)
        print(f"Saved: {output_audio_path}")



Processing inference.wav -> hi_hat in folder 1...


  generated_sequence = generate_track_in_chunks(model, torch.tensor(audio_codes), positional_embeddings)


Saved: ././sequential_memory_checkpointing_hi_hat_model/1/hi_hat_generated.wav
Processing inference2.wav -> hi_hat in folder 2...
Saved: ././sequential_memory_checkpointing_hi_hat_model/2/hi_hat_generated.wav
Processing inference3.wav -> hi_hat in folder 3...
Saved: ././sequential_memory_checkpointing_hi_hat_model/3/hi_hat_generated.wav
