# Cross-Attention GPT-2 Model for Independent Music Track Generation

This script trains a GPT-2-based model enhanced with cross-attention mechanisms for independent music track generation. The model processes audio data encoded as 10-second segments, leveraging positional embeddings and track-specific conditioning inputs for improved contextual learning. A custom dataset framework is used to integrate positional and conditioning inputs into the training pipeline, ensuring track classes such as hi_hat, kick, snare, and bass are modeled independently.

The model introduces a cross-attention layer that allows it to focus on specific track data during sequence generation, enhancing its ability to generate accurate and context-aware outputs. Training is performed sequentially for each track class using Hugging Face's Trainer with mixed precision and epoch-based checkpointing. The script uses Weights & Biases (WandB) for tracking and logging, saving the best models for each track class for future inference tasks.

In [3]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    Trainer,
    TrainingArguments,
)
import torch.nn as nn
from sklearn.model_selection import train_test_split
import wandb

# Set the device to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load your saved .npy file
data = np.load(
    'fulldataset_10sec_positional_embs.npy',
    allow_pickle=True
).item()

VOCAB_SIZE = 1024
MAX_LENGTH = 3000
track_classes = ['hi_hat', 'kick', 'snare', 'clap', 'bass', 'drums', 'keys', 'full_instrumental']

class MusicDataset(Dataset):
    def __init__(self, data, track_class):
        self.track_class = track_class
        self.data = {
            k: v for k, v in data.items()
            if self.track_class in v['generation_data']
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_id = list(self.data.keys())[idx]
        sample = self.data[sample_id]

        vocal_audio_codes = sample['generation_data'].get(
            'vocal', np.zeros((4, 750))
        )
        track_data = sample['generation_data'].get(
            self.track_class, np.zeros((4, 750))
        )
        positional_embedding = sample.get(
            'positional_embedding', np.zeros((4, 750))
        )

        # Clip values to valid range
        vocal_audio_codes = np.clip(vocal_audio_codes, 0, VOCAB_SIZE - 1)
        track_data = np.clip(track_data, 0, VOCAB_SIZE - 1)

        # Pad and truncate sequences to MAX_LENGTH
        vocal_audio_codes = np.pad(
            vocal_audio_codes.flatten(),
            (0, MAX_LENGTH - len(vocal_audio_codes.flatten())),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH]
        track_data = np.pad(
            track_data.flatten(),
            (0, MAX_LENGTH - len(track_data.flatten())),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH]

        attention_mask = (vocal_audio_codes != 0).astype(int)

        # Flatten and pad positional embeddings
        pos_emb_flat = positional_embedding.flatten()
        pos_emb_flat = np.pad(
            pos_emb_flat,
            (0, MAX_LENGTH * positional_embedding.shape[1] - len(pos_emb_flat)),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH * positional_embedding.shape[1]]

        positional_embedding = pos_emb_flat.reshape(MAX_LENGTH, positional_embedding.shape[1])

        return {
            'input_ids': torch.tensor(vocal_audio_codes, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(track_data, dtype=torch.long),
            'positional_embeddings': torch.tensor(positional_embedding, dtype=torch.float),
            'conditioning_inputs': torch.tensor(track_data, dtype=torch.long),  # Track data as conditioner
            'sample_id': sample_id
        }

class CustomGPT2WithCrossAttention(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.cross_attention_layer = nn.MultiheadAttention(
            embed_dim=config.n_embd,
            num_heads=config.n_head,
            dropout=config.attn_pdrop
        )
        self.fc_layer = nn.Linear(config.n_embd, config.n_embd)

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                positional_embeddings=None, conditioning_inputs=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)

        # Combine positional embeddings with input embeddings
        embs_dim = positional_embeddings.shape[2]
        input_embeds = torch.cat((input_embeds[:, :, :-embs_dim], input_embeds[:, :, -embs_dim:] + positional_embeddings), dim=-1)


        # Apply self-attention within the model
        transformer_outputs = self.transformer(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            output_attentions=False,
            output_hidden_states=True,
        )

        hidden_states = transformer_outputs.last_hidden_state

        # Apply cross-attention with conditioning inputs (track data)
        if conditioning_inputs is not None:
            conditioning_mask = (conditioning_inputs != 0).float()  # Create attention mask
            conditioning_inputs_embeds = self.transformer.wte(conditioning_inputs)

            # Cross-attention mechanism
            cross_attention_output, _ = self.cross_attention_layer(
                query=hidden_states.permute(1, 0, 2),  # (seq_len, batch_size, embed_dim)
                key=conditioning_inputs_embeds.permute(1, 0, 2),  # (seq_len, batch_size, embed_dim)
                value=conditioning_inputs_embeds.permute(1, 0, 2),  # (seq_len, batch_size, embed_dim)
                key_padding_mask=~conditioning_mask.bool()  # Use `~` to invert the boolean mask
            )
            cross_attention_output = cross_attention_output.permute(1, 0, 2)  # Back to (batch_size, seq_len, embed_dim)

            # Add cross-attention output to hidden states
            hidden_states = hidden_states + self.fc_layer(cross_attention_output)

        # Pass through the language model head for logits
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return {"loss": loss, "logits": logits}

class DataCollatorWithPositionalEmbeddings:
    def __call__(self, batch):
        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])
        positional_embeddings = torch.stack([item['positional_embeddings'] for item in batch])
        conditioning_inputs = torch.stack([item['conditioning_inputs'] for item in batch])

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'positional_embeddings': positional_embeddings,
            'conditioning_inputs': conditioning_inputs
        }

config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_LENGTH,
    n_ctx=MAX_LENGTH,
    n_embd=128,
    n_layer=6,
    n_head=8,
    activation_function='gelu',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

model = CustomGPT2WithCrossAttention(config=config).to(device)
data_collator = DataCollatorWithPositionalEmbeddings()

for track_idx, track in enumerate(track_classes):
    print(f"Training for {track}...")

    dataset = MusicDataset(data, track)
    train_indices, val_indices = train_test_split(
        range(len(dataset)), test_size=0.2, random_state=42
    )
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    track_output_dir = f'./independent_track_generation_gpt2_checkpointing_{track}_model_cross_attention_10secsdataset_final'

    wandb.init(project="music_generation", name=f'{track}_training_run_gpt2_cross_10secsdataset_final')

    training_args = TrainingArguments(
        output_dir=track_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=1e-4,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=120,
        weight_decay=0.01,
        save_total_limit=3,
        logging_dir=f'./logs_{track}',
        fp16=True,
        report_to=['wandb'],
        dataloader_pin_memory=False,
        logging_steps=10,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )

    trainer.train()
    trainer.save_model(track_output_dir)
    wandb.finish()

print("Training completed for all tracks.")


Training for hi_hat...


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33muniqlabs[0m. Use [1m`wandb login --relogin`[0m to force relogin


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,5.0081,4.776132
2,2.6295,2.958881
3,3.11,2.73072
4,3.1424,2.666206
5,2.7085,2.617558
6,3.0907,2.591557
7,2.6736,2.584707
8,2.5759,2.54971
9,2.3817,2.540797
10,2.3195,2.525972


0,1
eval/loss,█▇▇▆▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,▁▃▃▄▅▅▅▅▅▅▅▆▅▆▆▆▅▆▆▆▅▅▆▅▅▅▅▅▅▅▅▇█▇▆▆▆▆▆▆
eval/samples_per_second,█▆▅▅▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▁▃▃▃▃▃▃▃
eval/steps_per_second,█▇▅▅▄▄▄▄▄▄▃▃▃▃▃▄▃▃▃▃▃▄▃▃▃▄▃▃▄▄▂▁▃▃▃▃▃▄▃▃
train/epoch,▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
train/grad_norm,▃▁▄▃█▂▃▂▂▄▅▁▂▂▁▆▃▃▁▆▁▅▁▃▂▄▂▆▃▃▃▁▂▁▂▄▃▅▅▁
train/learning_rate,███▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▁▁
train/loss,██▄▄▄▄▄▃▄▆▆▃▃▇▃▃▃▃▃▃▆▄▁▃▂▃▃▄▁▂▄▃▁▄▄▂▃▃▂▃
train/total_flos,▁

0,1
eval/loss,2.31778
eval/runtime,1.7463
eval/samples_per_second,15.461
eval/steps_per_second,15.461
train/epoch,120.0
train/global_step,12840.0
train/grad_norm,1.04571
train/learning_rate,0.0
train/loss,1.8852
train/total_flos,294088181760000.0


Training for kick...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,2.3442,2.178631
2,1.6095,2.054694
3,2.2112,2.00459
4,2.1094,1.960391
5,2.3911,1.925263
6,2.209,1.915276
7,2.015,1.902902
8,1.6547,1.887194
9,1.4237,1.898422
10,2.0713,1.884959


0,1
eval/loss,█▆▅▄▃▂▂▁▂▂▁▂▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▃▄▄▄▄▄▄▄▄▅
eval/runtime,▂▁▁▄▂▃▃█▁▂▂▃▂▂▂▂▅▃▇▂▅▄▃▃▃▃▃▃▅▄▄▃▃▃▄▃▃▄▆▂
eval/samples_per_second,██▇██▆▄▆▆▆▇▅▇▇▆▇▇▇▇▇▅▅▁▃▆▆▅▅▆▆▄▅▄▆▅▅▆▄▂▇
eval/steps_per_second,▇█▇▅▇▇▁▆▇▄▇▆▇▇▇▅▆▇▆▇▆▆▅▆▇▆▆▄▇▆▆▆▆▆▆▆▆▆▅▆
train/epoch,▁▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇█
train/global_step,▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
train/grad_norm,▄▂▂▁▂▁▂▁▁▂▂▁▂▁▁▂█▂▂▂▅▂▂▃▃▄▃▂▂▂▃▂▄▂▂▂▂▂▃▂
train/learning_rate,█████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,▃▇▃▆▅▃▇▃▃▅▇█▅▆▃▂▄▅▇▆▆▅▄▄▅▄▃▃▁▃▂▂▄▂▂▂▁▂▄▂
train/total_flos,▁

0,1
eval/loss,1.95126
eval/runtime,1.7469
eval/samples_per_second,15.456
eval/steps_per_second,15.456
train/epoch,120.0
train/global_step,12840.0
train/grad_norm,0.59535
train/learning_rate,0.0
train/loss,1.3558
train/total_flos,294088181760000.0


Training for snare...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,2.6378,2.131308
2,2.7532,2.022117
3,2.0829,1.990611
4,2.1917,1.972749
5,2.6336,1.947277
6,1.8831,1.933614
7,1.7329,1.906582
8,2.1034,1.906523
9,2.2535,1.897117
10,1.8223,1.881855


0,1
eval/loss,▇▆▄▄▃▂▂▁▁▁▂▂▂▂▂▄▃▃▄▄▅▅▅▆▅▆▇▆▆▇▇▇██▇█████
eval/runtime,▄▄▃▄▂▂▃▅▃▄▃▆▂▃▃▄▃▃▁▃▄▃▃▂▄▃▃▄▃▄█▅▅▆▆▅▅▄▆▆
eval/samples_per_second,█▇▅▅▆▆▆▇▆▅▇▇▆▄▆▆▆▆▆▆█▆▆▆▇▇▆█▆▆▆▆▁▅▄▃▃▅▃▄
eval/steps_per_second,▇█▄▇▄▅▆▆▆▅▄▇▅▅▅█▅▅▅█▅▅▄▅▇▆▅▆▇▄▄▃▂▃▃▃▂▁▄▂
train/epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇█████
train/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇████
train/grad_norm,█▃▃▂▁▂▃▂▂▄▃▃▃▂▂▃▂▂▃▃▃▃▂▂▄▃▆▃▁▂▁▄▃▂▃▂▃▃▂▃
train/learning_rate,███▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▂▂▂▂▂▂▂▂▁▁
train/loss,▆▆▅█▅█▇▅▄▅▄▆▄▄▅▂▅▅▄▃▄▅▃▃▃▂▃▂▃▃▃▄▂▃▂▁▂▄▃▂
train/total_flos,▁

0,1
eval/loss,2.04453
eval/runtime,1.1027
eval/samples_per_second,15.417
eval/steps_per_second,15.417
train/epoch,120.0
train/global_step,7800.0
train/grad_norm,1.89957
train/learning_rate,0.0
train/loss,1.2075
train/total_flos,178651699200000.0


Training for clap...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,1.5692,1.424866
2,1.3898,1.366259
3,1.5731,1.332663
4,1.4248,1.294087
5,1.2382,1.271235
6,1.6664,1.238858
7,0.9212,1.228014
8,1.426,1.207181
9,1.1267,1.197819
10,0.7908,1.198611


0,1
eval/loss,█▆▅▃▃▂▂▃▂▂▂▂▂▂▁▁▂▁▂▃▂▂▂▂▃▂▂▃▃▂▂▃▃▂▃▃▃▃▃▃
eval/runtime,▁▃▂▄▂▃▂▃▄▃▂▄▃▃▃▂▅▃▂▁▂▄▄▃▃▃▄▅▃▃▃▂▁█▃▂▂▃▃▁
eval/samples_per_second,█▄▆▆▅▅▄▅▄▄▄▃▃▄▄▄▃▁▆█▂▅▃▂▁▄▆▃▅▄▄▆▄▅▆▅▄▇▃▆
eval/steps_per_second,█▆▅▆▇▆▄▆▇▇▄▅▇█▆▅▆▆▆▆▄▆▅▇▅▆█▆▁▇▆▆▆▃█▆▆▆▇▇
train/epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step,▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
train/grad_norm,▇▄▇▅▄▁▇▄▅▁▅▄▇▆▇▄▅▅▅▄▅▅▄▃▄▄▅▅▃█▅▅▁▃▃▃▁▅▁▁
train/learning_rate,███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▁▁▁
train/loss,▇█▇▅▁▃▄▂▄▅▅▆▄▃▃▄▃▄▆▂▃▂▅▂▃▃▂▁▂▂▂▁▃▂▁▁▂▂▃▂
train/total_flos,▁

0,1
eval/loss,1.19038
eval/runtime,1.2333
eval/samples_per_second,15.406
eval/steps_per_second,15.406
train/epoch,120.0
train/global_step,8760.0
train/grad_norm,2.07013
train/learning_rate,0.0
train/loss,0.6628
train/total_flos,200639600640000.0


Training for bass...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,3.7353,3.878748
2,4.5737,3.757229
3,3.5552,3.687701
4,3.9951,3.62435
5,3.584,3.582553
6,3.6085,3.554781
7,3.7072,3.520705
8,3.2508,3.480429
9,3.5494,3.483246
10,3.5102,3.458662


0,1
eval/loss,█▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃
eval/runtime,▄▂▅▅▇▃▃▃▂▂▇▅▄▄▄▅▄▃▅▃▄█▃▅▂▃▆▃▅▅▅▇▄▄▂▃▃▁▅▁
eval/samples_per_second,▁▆█▅▆▇▇▆▇▇▇█▅▆▆▇▅▇▅▇▆▆▅█▇▆▆▆▇▆▇█▇█▇▇█▆▅▆
eval/steps_per_second,▁▇▆█▅▆▄▇▆▃▆▅▅█▆▇▂▅▆▆▇▆▆▅▄▆▆▆▅▃▅▆▆▄▅▅█▆▆█
train/epoch,▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█
train/global_step,▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇███
train/grad_norm,▂▂▂▂▁▃▂▃▃▃▃▅▂▃█▂▃▃▂▅▂▁▃▂▂▃▂▃▃▃▂▂▂▃▂▂▃▃▂▂
train/learning_rate,█████▇▇▇▇▇▇▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁
train/loss,▇▆▅▅▅▂▂█▂▃▇▃▆▄▄▄▂▇▄▃▂▃▂▃▁▂▂▁▃▄▁▃▃▂▃▂▃▂▃▂
train/total_flos,▁

0,1
eval/loss,3.4773
eval/runtime,1.9504
eval/samples_per_second,15.382
eval/steps_per_second,15.382
train/epoch,120.0
train/global_step,14400.0
train/grad_norm,1.62296
train/learning_rate,0.0
train/loss,3.1763
train/total_flos,329818521600000.0


Training for drums...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,5.6347,4.957232
2,5.3716,4.840569
3,5.3102,4.772684
4,5.3931,4.731454
5,5.0171,4.687002
6,5.29,4.654484
7,5.1488,4.66975
8,5.0718,4.609754
9,5.1566,4.611542
10,5.3893,4.591094


0,1
eval/loss,█▅▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁▂▂▁▂▂▂▂▂▂▂▂▂
eval/runtime,▁▄▃▄▂▂▄▃▂▃▄▄▃▄▃▄▆▄▄▄▇▆▆▅▆▅▃▆▄█▅▃▃▆▄▃█▆█▅
eval/samples_per_second,▄▂█▄▅▄▅▅▆▇▆█▅▅▆▅▆▅▃▆▄▃▃▄▅▅▆▁▂▆▆▅▆▅▄▃▅▃▇▄
eval/steps_per_second,██▇▁▅▅▄▇▇▇▆▅█▅█▇▂▅▅▄▇▇▅▂█▄▃▅▆▄▅▁▅▄▆▇▇▁▃▄
train/epoch,▁▁▁▁▁▂▂▂▂▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇█████
train/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
train/grad_norm,▃▁▂▁▂▁▄▆▁▃▃▁▂▁▃▂▃▂▂▅█▁▂▃▂▄▃▂▃▃▂▂▂▂▃▂▂▃▄▂
train/learning_rate,███▇▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁
train/loss,█▄▃▅▄▅▅▄▄▅▃▃▃▃▃▄▂▃▃▄▂▄▄▃▃▃▁▃▃▄▃▃▃▃▃▃▂▃▃▃
train/total_flos,▁

0,1
eval/loss,4.48867
eval/runtime,2.2132
eval/samples_per_second,15.362
eval/steps_per_second,15.362
train/epoch,120.0
train/global_step,16320.0
train/grad_norm,2.24763
train/learning_rate,0.0
train/loss,4.3271
train/total_flos,373794324480000.0


Training for keys...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,5.3887,5.061367
2,4.7287,4.840999
3,4.4584,4.711109
4,4.9487,4.630985
5,4.5864,4.55241
6,3.8826,4.531574
7,3.8022,4.496453
8,4.2821,4.460728
9,4.2731,4.433827
10,4.7523,4.401883


0,1
eval/loss,█▅▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
eval/runtime,▃▂▁▁▁▁▁▁▁▃▂▂▂▃▃▄▄▃▃▄▄▄█▂▂▂▃▂▃▃▁▂▂▁▂▂▁▂▂▂
eval/samples_per_second,█▇▇▃█████▁██▇▇▇▆▆▆▆▆▆▆▆▆▅▇▇▇█▇▇▇█▇█▇▇██▇
eval/steps_per_second,▇▅▆▅▇▅█▇███▇█▇▆▆▅▅▅▄▃▃▃▃▁▁▄▂▆▆▅▆▆▆▇▇▇▇▆▇
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇██
train/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇█████
train/grad_norm,▅▁▅▂▁▁▇▆▄▂▃▃▄▃▇▃▃▃▃▅▃▄▄▄▄▅▄█▃▄▃▄▆▃▆▄▅▅▂▃
train/learning_rate,█████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▆▅▅▄▃▄▃▃▄▃▃▄▄▃▃▁▃▂▄▂▃▂▂▃▃▂▁▂▂▂▃▂▂▂▁▂▁▂▃
train/total_flos,▁

0,1
eval/loss,4.37729
eval/runtime,2.2107
eval/samples_per_second,15.38
eval/steps_per_second,15.38
train/epoch,120.0
train/global_step,16320.0
train/grad_norm,2.97969
train/learning_rate,0.0
train/loss,2.9846
train/total_flos,373794324480000.0


Training for full_instrumental...


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,6.1489,5.814071
2,5.7682,5.651714
3,5.5643,5.58149
4,5.7747,5.53982
5,5.5635,5.519042
6,5.5188,5.506169
7,5.4464,5.509966
8,5.467,5.471783
9,5.5319,5.458737
10,5.7856,5.449085


0,1
eval/loss,█▄▄▃▃▂▂▂▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃
eval/runtime,▁▂▁▂▅▅▄▄▄▆▅▅▆▅▅▅▃▇▅▇▃█▅▆▆▅▇▅▆▅▄▆▇█▄▇▆▅▆▄
eval/samples_per_second,█▃▇▄▆▅▁▇▄▃▇▄▅▅▆▅▄▂▃▆▄▃▄▅▃▁▂▆▄▃▄▂▃▃▄▃▅▇▃▅
eval/steps_per_second,▇▄█▇▅▅▅▄▅▄▅▄▅▃▅▄▅▆▄▅▅▃▅▅▃▅▄▃▂▃▅▄▅▄▄▅▁▄▃▄
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇████
train/global_step,▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇█████
train/grad_norm,▂█▁▁▂▁▄▁▁▁▁▁▁▄▂▃▃▂▁▂▂▂▂▂▂▃▂▂▂▃▂▂▃▂▃▂▂▂▂▂
train/learning_rate,████▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁▁▁
train/loss,█▆▇▇▇▆▄▅▇▆▅▄▅▆▄▆▄▄▅▃▄▄▃▄▄▄▃▃▃▃▃▁▄▃▄▄▃▂▅▃
train/total_flos,▁

0,1
eval/loss,5.50519
eval/runtime,2.2084
eval/samples_per_second,15.396
eval/steps_per_second,15.396
train/epoch,120.0
train/global_step,16320.0
train/grad_norm,2.6315
train/learning_rate,0.0
train/loss,4.7054
train/total_flos,373794324480000.0


Training completed for all tracks.


# Inference
This script performs inference for music track generation using pre-trained GPT-2 models that were originally trained with cross-attention mechanisms and conditioning inputs for specific stems. However, during inference, the conditioning inputs are excluded due to the availability of only vocal encodings and reference beat and rhythm features. The model definition has been adjusted to omit the cross-attention layer, relying solely on encoded 10-second audio segments and positional embeddings derived from beat and downbeat information detected via Madmom.

Despite the lack of conditioning inputs during inference, the model demonstrates its ability to generate musical patterns for various track classes, maintaining coherence and some structure. The positional embeddings are integrated into the input embeddings during the forward pass, enabling temporal consistency. The generated sequences are reshaped, clamped to valid ranges, and decoded back into audio using Facebook's EnCodec, producing musically plausible outputs for each track class. This highlights the model's robustness and adaptability in generating music even under altered inference conditions.

In [4]:
import os
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, EncodecModel, GPT2LMHeadModel
import madmom
import torch.nn as nn
import tempfile

# Constants
VOCAB_SIZE = 1024
MAX_LENGTH = 3000
device = torch.device("cpu")  # Use "cuda" if GPU is available

# Track classes
track_classes = ['hi_hat', 'kick', 'snare', 'clap', 'bass', 'drums', 'keys', 'full_instrumental']

# Initialize Encodec Model and Processor
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
model_encodec = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)

# Function Definitions (unchanged except added safety checks)
def encode_audio(audio_path):
    audio, rate = torchaudio.load(audio_path)
    max_length_in_samples = int(rate * 10)

    if audio.shape[1] > max_length_in_samples:
        audio = audio[:, :max_length_in_samples]
    else:
        pad_length = max_length_in_samples - audio.shape[1]
        audio = torch.nn.functional.pad(audio, (0, pad_length))

    if audio.shape[0] > 1:
        audio = audio.mean(dim=0)
    else:
        audio = audio.squeeze(0)

    inputs = processor(audio.numpy(), sampling_rate=rate, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        outputs = model_encodec.encode(inputs["input_values"], inputs["padding_mask"], 3)
    
    duration = audio.shape[0] / rate
    return outputs.audio_codes.squeeze(), min(duration, 10.0)

def extract_beats_and_downbeats(audio_path, fps=100, duration=10):
    audio, rate = torchaudio.load(audio_path)
    audio = audio[:, :int(duration * rate)]

    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
        temp_audio_path = temp_audio_file.name
        torchaudio.save(temp_audio_path, audio, rate)

    proc_downbeats = madmom.features.downbeats.DBNDownBeatTrackingProcessor(beats_per_bar=[4], fps=fps)
    act_downbeats = madmom.features.downbeats.RNNDownBeatProcessor(fps=fps)(temp_audio_path)
    downbeats = proc_downbeats(act_downbeats)

    proc_beats = madmom.features.beats.BeatDetectionProcessor(fps=fps)
    act_beats = madmom.features.beats.RNNBeatProcessor(fps=fps)(temp_audio_path)
    beats = proc_beats(act_beats)

    os.remove(temp_audio_path)

    if len(beats) == 0 or len(downbeats) == 0:
        raise ValueError(f"No beats or downbeats detected in {audio_path}")

    return beats, downbeats[downbeats[:, 1] == 1, 0]

def create_positional_embeddings(beat_times, downbeat_times, audio_duration, fps=75, K=32):
    total_frames = int(np.ceil(audio_duration * fps))

    def ramps(positions, size):
        result = np.zeros(size)
        for a, b in zip(positions[:-1], positions[1:]):
            result[a:b] = np.linspace(0, 1, b - a, endpoint=False)
        missing = positions[0]
        if missing:
            piece = result[positions[0]:positions[1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[:missing] = pieces[-missing:]
        missing = size - positions[-1]
        if missing:
            piece = result[positions[-2]:positions[-1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[-missing:] = pieces[:missing]
        return result

    vector_downbeat = ramps((downbeat_times * fps).astype(int), total_frames)
    vector_beat = ramps((beat_times * fps).astype(int), total_frames)

    frequencies = np.arange(1, K + 1)
    embeddings_downbeat = []
    embeddings_beat = []

    for k in frequencies:
        embeddings_downbeat.append(np.sin(2 * np.pi * vector_downbeat * k))
        embeddings_downbeat.append(np.cos(2 * np.pi * vector_downbeat * k))
        embeddings_beat.append(np.sin(2 * np.pi * vector_beat * k))
        embeddings_beat.append(np.cos(2 * np.pi * vector_beat * k))

    embeddings_downbeat = np.stack(embeddings_downbeat, axis=1)
    embeddings_beat = np.stack(embeddings_beat, axis=1)
    embeddings = np.hstack((embeddings_downbeat, embeddings_beat))

    return torch.from_numpy(embeddings).float()

class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # No need to define projection layer if not used

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                positional_embeddings=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)
        # Combine positional embeddings with input embeddings
        input_embeds = input_embeds + positional_embeddings

        # Proceed with the standard GPT-2 forward pass
        return super().forward(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

def find_highest_checkpoint(folder):
    checkpoints = [d for d in os.listdir(folder) if d.startswith("checkpoint-")]
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {folder}")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
    return os.path.join(folder, checkpoints[0])

def generate_track(model, audio_codes, positional_embeddings, attention_mask):
    # Flatten and pad audio_codes to MAX_LENGTH
    audio_codes = audio_codes.flatten().to(device)
    audio_codes = torch.nn.functional.pad(
        audio_codes,
        (0, MAX_LENGTH - audio_codes.shape[0]),
        value=0
    )[:MAX_LENGTH]

    # Generate attention_mask from padded audio_codes
    attention_mask = (audio_codes != 0).long().to(device)

    # Ensure positional_embeddings are padded to MAX_LENGTH
    if positional_embeddings.shape[0] < MAX_LENGTH:
        padding_length = MAX_LENGTH - positional_embeddings.shape[0]
        positional_embeddings = torch.nn.functional.pad(
            positional_embeddings,
            (0, 0, 0, padding_length),
            mode="constant",
            value=0
        )

    positional_embeddings = positional_embeddings[:MAX_LENGTH].to(device)


    # Pass inputs to the model
    with torch.no_grad():
        outputs = model(
            input_ids=audio_codes.unsqueeze(0),  # [1, MAX_LENGTH]
            attention_mask=attention_mask.unsqueeze(0),  # [1, MAX_LENGTH]
            positional_embeddings=positional_embeddings.unsqueeze(0)  # [1, MAX_LENGTH, embedding_dim]
        )
        return outputs.logits.argmax(dim=-1).squeeze().detach().cpu()



# Inference
inference_files = [
    ("inference.wav", "inference_posemb.wav", "1"),
    ("inference2.wav", "inference_posemb2.wav", "2"),
    ("inference3.wav", "inference_posemb3.wav", "3"),
]

for audio_path, posemb_path, folder in inference_files:
    for track_class in track_classes:
        print(f"Processing {audio_path} -> {track_class} in folder {folder}...")
        model_folder = f'./independent_track_generation_gpt2_checkpointing_{track_class}_model_cross_attention_10secsdataset_final'
        model = CustomGPT2ForConditionalGeneration.from_pretrained(model_folder).to(device)
        model.eval()

        audio_codes, audio_length = encode_audio(audio_path)
        beats, downbeats = extract_beats_and_downbeats(posemb_path, duration=audio_length)
        positional_embeddings = create_positional_embeddings(beats, downbeats, audio_length)

        padding_length = MAX_LENGTH - positional_embeddings.shape[0]
        positional_embeddings = torch.nn.functional.pad(positional_embeddings, (0, 0, 0, padding_length))

        attention_mask = (audio_codes != 0).long()
        generated_sequence = generate_track(model, torch.tensor(audio_codes), positional_embeddings, attention_mask)

        reshaped_output = generated_sequence.view(4, 750).unsqueeze(0).unsqueeze(0)
        reshaped_output = torch.clamp(reshaped_output, min=0, max=1023)
        print(f"Reshaped Output: {reshaped_output.shape}, Device: {reshaped_output.device}")
        print(f"Max value in reshaped_output: {reshaped_output.max()}")
        print(f"Min value in reshaped_output: {reshaped_output.min()}")
        print(f"Contains NaN: {torch.isnan(reshaped_output).any()}")
        print(f"Contains Inf: {torch.isinf(reshaped_output).any()}")
        decoded_audio = model_encodec.decode(reshaped_output, [None])[0]
        decoded_audio = decoded_audio.detach()
        decoded_audio = decoded_audio.squeeze(0).squeeze(0)  # Shape: [samples]
        decoded_audio = decoded_audio.unsqueeze(0)
        output_audio_path = f"./{model_folder}/{folder}/{track_class}_generated.wav"
        os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
        torchaudio.save(output_audio_path, decoded_audio.cpu(), processor.sampling_rate)
        print(f"Saved: {output_audio_path}")




Processing inference.wav -> hi_hat in folder 1...


  file_sample_rate, signal = wavfile.read(filename, mmap=True)
  best = np.argmax(np.asarray(results)[:, 1])
  generated_sequence = generate_track(model, torch.tensor(audio_codes), positional_embeddings, attention_mask)


Reshaped Output: torch.Size([1, 1, 4, 750]), Device: cpu
Max value in reshaped_output: 1017
Min value in reshaped_output: 25
Contains NaN: False
Contains Inf: False
Saved: ././independent_track_generation_gpt2_checkpointing_hi_hat_model_cross_attention_10secsdataset_final/1/hi_hat_generated.wav
Processing inference.wav -> kick in folder 1...
Reshaped Output: torch.Size([1, 1, 4, 750]), Device: cpu
Max value in reshaped_output: 1022
Min value in reshaped_output: 5
Contains NaN: False
Contains Inf: False
Saved: ././independent_track_generation_gpt2_checkpointing_kick_model_cross_attention_10secsdataset_final/1/kick_generated.wav
Processing inference.wav -> snare in folder 1...
Reshaped Output: torch.Size([1, 1, 4, 750]), Device: cpu
Max value in reshaped_output: 1023
Min value in reshaped_output: 23
Contains NaN: False
Contains Inf: False
Saved: ././independent_track_generation_gpt2_checkpointing_snare_model_cross_attention_10secsdataset_final/1/snare_generated.wav
Processing inference.w