# Autoregressive Music Track Generation with Custom GPT-2 Model Using Multiple Codebooks and Positional Embeddings 

This script implements an autoregressive model for music track generation using a custom GPT-2 architecture. It processes multiple codebooks per timestep and integrates positional embeddings based on beat and downbeat timings to capture complex musical patterns and temporal dynamics. The model is designed to generate specific music tracks (e.g., 'hi_hat') conditioned on vocal audio codes, leveraging the relationships between different codebooks and the rhythmic structure of the music.

Key features include a custom dataset class that structures input data as sequences of tokens across multiple codebooks, ensuring temporal alignment. The model uses separate embedding layers for each codebook and sums them to form input embeddings, which are combined with positional embeddings. Training utilizes a custom Trainer class to handle multiple outputs and computes loss by averaging over all codebooks. The script also incorporates Weights & Biases for logging and monitoring, allowing for real-time tracking of training metrics.

In [1]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import GPT2Config, Trainer, TrainingArguments
import torch.nn as nn
from sklearn.model_selection import train_test_split

# Set the device to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load your saved .npy file containing the dataset
data = np.load('fulldataset_10sec_positional_embs.npy', allow_pickle=True).item()

VOCAB_SIZE = 1026
SEQ_LEN = 750  # 10 seconds at 75 FPS
NUM_CODEBOOKS = 4
MAX_LENGTH = SEQ_LEN  # Sequence length per codebook
TRACK_CLASSES = ['full_instrumental']


# Dataset class for handling multi-codebook data
class MusicDataset(Dataset):
    def __init__(self, data, track_class):
        self.data = {
            k: v for k, v in data.items() if track_class in v['generation_data']
        }
        self.track_class = track_class

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_id = list(self.data.keys())[idx]
        sample = self.data[sample_id]

        # Access the vocal codes (input) directly
        input_sequence = sample['generation_data'].get('vocal', np.zeros((NUM_CODEBOOKS, SEQ_LEN)))  # Shape: (4, 750)

        # Access the track-specific codes (label) directly
        label_sequence = sample['generation_data'].get(self.track_class, np.zeros((NUM_CODEBOOKS, SEQ_LEN)))  # Shape: (4, 750)

        # Generate attention mask for the input (vocal)
        attention_mask = (input_sequence != 0).astype(int)
        
        return {
            'input_ids': torch.tensor(input_sequence, dtype=torch.long),  # Vocal input (4, 750)
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),  # Mask for vocal input
            'labels': torch.tensor(label_sequence, dtype=torch.long),  # Track-specific labels (4, 750)
        }




# Model definition
class CustomGPT2ForMusicGen(nn.Module):
    def __init__(self, config, num_codebooks=NUM_CODEBOOKS):
        super().__init__()
        self.num_codebooks = num_codebooks

        # Shared embedding layer for input tokens
        self.codebook_embeddings = nn.Embedding(config.vocab_size, config.n_embd)

        # Positional embeddings per codebook
        self.positional_embeddings = nn.Parameter(
            torch.zeros(num_codebooks, config.n_positions, config.n_embd)
        )

        # Transformer encoder (with causal masking for autoregressive modeling)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.n_embd,
            nhead=config.n_head,
            dim_feedforward=config.n_embd * 4,
            dropout=config.resid_pdrop,
        )
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=config.n_layer)

        # Output heads for each codebook
        self.codebook_heads = nn.ModuleList([
            nn.Linear(config.n_embd, config.vocab_size) for _ in range(num_codebooks)
        ])

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        batch_size, num_codebooks, seq_len = input_ids.shape  # (B, 4, 750)

        # Embed tokens for all codebooks
        input_embeds = self.codebook_embeddings(input_ids)  # Shape: (B, 4, 750, D)

        # Add positional embeddings for each codebook
        input_embeds = input_embeds + self.positional_embeddings[:, :seq_len, :]

        # Flatten codebooks for Transformer input
        input_embeds = input_embeds.view(batch_size, num_codebooks * seq_len, -1).permute(1, 0, 2)  # (seq_len, B, D)
        attention_mask = attention_mask.view(batch_size, num_codebooks * seq_len)

        # Create causal mask for autoregressive generation
        seq_len_flat = num_codebooks * seq_len
        causal_mask = torch.triu(torch.ones(seq_len_flat, seq_len_flat), diagonal=1).bool()
        causal_mask = causal_mask.to(input_embeds.device)

        # Transformer forward pass
        transformer_outputs = self.transformer(
            input_embeds, mask=causal_mask
        )  # Shape: (seq_len, B, D)
        hidden_states = transformer_outputs.permute(1, 0, 2)  # Back to (B, seq_len, D)

        # Reshape back to (B, 4, 750, D)
        hidden_states = hidden_states.view(batch_size, num_codebooks, seq_len, -1)

        # Compute logits for each codebook
        logits = torch.stack([
            head(hidden_states[:, i, :, :]) for i, head in enumerate(self.codebook_heads)
        ], dim=1)  # Shape: (B, 4, 750, vocab_size)

        # Loss computation (if labels are provided)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            losses = [
                loss_fct(logits[:, i, :, :].view(-1, logits.size(-1)), labels[:, i, :].view(-1))
                for i in range(num_codebooks)
            ]
            loss = sum(losses)

        return {
            "loss": loss,
            "logits": logits,
        }


# Loop through each track class and train a separate model
for track_class in TRACK_CLASSES:
    print(f"Training model for track class: {track_class}")

    # Create dataset for the current track class
    dataset = MusicDataset(data, track_class)
    train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    # Model configuration
    config = GPT2Config(
        vocab_size=VOCAB_SIZE,
        n_positions=SEQ_LEN,
        n_ctx=SEQ_LEN,
        n_embd=128,  # Adjust based on available GPU memory
        n_layer=6,
        n_head=8,
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
    )
    model = CustomGPT2ForMusicGen(config=config).to(device)

    # Data collator
    class DataCollatorWithMultiCodebooks:
        def __call__(self, batch):
            input_ids = torch.stack([item['input_ids'] for item in batch])  # Shape: (B, 4, 750)
            attention_mask = torch.stack([item['attention_mask'] for item in batch])  # Shape: (B, 4, 750)
            labels = torch.stack([item['labels'] for item in batch])  # Shape: (B, 4, 750)

            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels,
            }

    data_collator = DataCollatorWithMultiCodebooks()

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"./checkpoints_{track_class}",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=1e-4,
        per_device_train_batch_size=1,  # Adjust based on GPU memory
        per_device_eval_batch_size=1,
        num_train_epochs=120,
        weight_decay=0.01,
        logging_dir=f"./logs_{track_class}",
        logging_steps=50,
        save_total_limit=3,
        load_best_model_at_end=True,
        report_to="none",
        fp16=True,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )

    # Train the model
    trainer.train()
    print(f"Finished training for track class: {track_class}")

Training model for track class: full_instrumental


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,26.5503,25.589325
2,25.4723,24.957373
3,25.1228,24.762791
4,24.9415,24.708237
5,24.896,24.691833
6,25.1369,24.68651
7,25.0775,24.678108
8,24.7584,24.675783
9,25.045,24.653824
10,24.5738,24.625332


KeyboardInterrupt: 

In [46]:
import os
import torch
import torchaudio
import numpy as np
from transformers import GPT2Config, AutoProcessor, EncodecModel
import soundfile as sf
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load EnCodec model and processor
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
encodec_model = EncodecModel.from_pretrained("facebook/encodec_24khz").eval().to(DEVICE)

# List of track classes
TRACK_CLASSES = ["hi_hat"]


# Model definition
class CustomGPT2ForMusicGen(nn.Module):
    def __init__(self, config, num_codebooks=NUM_CODEBOOKS):
        super().__init__()
        self.num_codebooks = num_codebooks

        # Shared embedding layer for input tokens
        self.codebook_embeddings = nn.Embedding(config.vocab_size, config.n_embd)

        # Positional embeddings per codebook
        self.positional_embeddings = nn.Parameter(
            torch.zeros(num_codebooks, config.n_positions, config.n_embd)
        )

        # Transformer encoder (with causal masking for autoregressive modeling)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.n_embd,
            nhead=config.n_head,
            dim_feedforward=config.n_embd * 4,
            dropout=config.resid_pdrop,
        )
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=config.n_layer)

        # Output heads for each codebook
        self.codebook_heads = nn.ModuleList([
            nn.Linear(config.n_embd, config.vocab_size) for _ in range(num_codebooks)
        ])

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        batch_size, num_codebooks, seq_len = input_ids.shape  # (B, 4, 750)

        # Embed tokens for all codebooks
        input_embeds = self.codebook_embeddings(input_ids)  # Shape: (B, 4, 750, D)

        # Add positional embeddings for each codebook
        input_embeds = input_embeds + self.positional_embeddings[:, :seq_len, :]

        # Flatten codebooks for Transformer input
        input_embeds = input_embeds.view(batch_size, num_codebooks * seq_len, -1).permute(1, 0, 2)  # (seq_len, B, D)
        attention_mask = attention_mask.view(batch_size, num_codebooks * seq_len)

        # Create causal mask for autoregressive generation
        seq_len_flat = num_codebooks * seq_len
        causal_mask = torch.triu(torch.ones(seq_len_flat, seq_len_flat), diagonal=1).bool()
        causal_mask = causal_mask.to(input_embeds.device)

        # Transformer forward pass
        transformer_outputs = self.transformer(
            input_embeds, mask=causal_mask
        )  # Shape: (seq_len, B, D)
        hidden_states = transformer_outputs.permute(1, 0, 2)  # Back to (B, seq_len, D)

        # Reshape back to (B, 4, 750, D)
        hidden_states = hidden_states.view(batch_size, num_codebooks, seq_len, -1)

        # Compute logits for each codebook
        logits = torch.stack([
            head(hidden_states[:, i, :, :]) for i, head in enumerate(self.codebook_heads)
        ], dim=1)  # Shape: (B, 4, 750, vocab_size)

        # Loss computation (if labels are provided)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            losses = [
                loss_fct(logits[:, i, :, :].view(-1, logits.size(-1)), labels[:, i, :].view(-1))
                for i in range(num_codebooks)
            ]
            loss = sum(losses)

        return {
            "loss": loss,
            "logits": logits,
        }

# Function to encode input audio
def encode_audio(audio_path, model, processor, device, sequence_length=750):
    """
    Encodes a vocal audio file into audio codes using Facebook's EnCodec.

    Returns:
    - torch.Tensor: Encoded audio codes of shape (4, 750).
    """
    # Load and preprocess audio
    audio, rate = torchaudio.load(audio_path)
    if rate != 24000:
        audio = torchaudio.transforms.Resample(orig_freq=rate, new_freq=24000)(audio)
    audio = torch.nn.functional.pad(audio, (0, int(24000 * 10 - audio.size(1))), "constant")
    if audio.size(0) > 1:
        audio = audio.mean(dim=0)  # Convert stereo to mono

    # Process and encode audio
    inputs = processor(audio.numpy(), sampling_rate=24000, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        encoded = model.encode(inputs["input_values"], inputs["padding_mask"], 3)

    # Truncate or pad to match the expected sequence length
    codes = encoded.audio_codes.squeeze(0)  # Shape: (4, ?)
    if codes.size(-1) < sequence_length:
        codes = torch.nn.functional.pad(codes, (0, sequence_length - codes.size(-1)), "constant")
    elif codes.size(-1) > sequence_length:
        codes = codes[:, :sequence_length]

    return codes


# Function to decode model output
def decode_generated_sequence(logits, temperature=1.0, top_k=10):
    """
    Decode logits into discrete codebook values using sampling.
    """
    # Ensure logits have the correct shape: (batch_size, vocab_size)
    probabilities = torch.nn.functional.softmax(logits / temperature, dim=-1)  # Shape: (batch_size, vocab_size)
    
    # Get top_k probabilities and indices
    top_prob, top_idx = probabilities.topk(top_k, dim=-1)  # Shape: (batch_size, top_k)

    # Sample from the top_k probabilities
    sampled_indices = torch.multinomial(top_prob, num_samples=1)  # Shape: (batch_size, 1)

    # Map sampled indices back to original vocab indices
    decoded_tokens = top_idx.gather(-1, sampled_indices).squeeze(-1)  # Shape: (batch_size)

    return decoded_tokens.cpu().numpy()



# Function to save generated output as audio
def save_as_audio(sequence, output_path):
    """
    Convert generated sequence into an audio file and save.
    """
    # Placeholder: In real implementation, map codebooks back to audio
    reconstructed_audio = np.zeros((75000,))  # Dummy audio data
    sf.write(output_path, reconstructed_audio, samplerate=24000)

from safetensors.torch import load_file as load_safetensors

def load_trained_model(checkpoint_dir, vocab_size=1026, seq_len=750, num_codebooks=4):
    """
    Load the trained CustomGPT2ForMusicGen model from a checkpoint.

    Parameters:
    - checkpoint_dir (str): Directory containing model checkpoints.
    - vocab_size (int): Vocabulary size used in the model.
    - seq_len (int): Sequence length used during training.
    - num_codebooks (int): Number of codebooks.

    Returns:
    - CustomGPT2ForMusicGen: The loaded model instance.
    """
    # Define the configuration
    config = GPT2Config(
        vocab_size=vocab_size,
        n_positions=seq_len,
        n_ctx=seq_len,
        n_embd=128,
        n_layer=6,
        n_head=8,
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
    )

    # Debugging: Print configuration details
    print(f"Config: vocab_size={config.vocab_size}, n_embd={config.n_embd}, n_positions={config.n_positions}")

    # Initialize the model
    model = CustomGPT2ForMusicGen(config=config, num_codebooks=num_codebooks)

    # Load weights from model.safetensors
    model_weights_path = os.path.join(checkpoint_dir, "model.safetensors")
    if not os.path.exists(model_weights_path):
        raise FileNotFoundError(f"model.safetensors not found in {checkpoint_dir}")

    # Load the state dictionary from the safetensors file
    state_dict = load_safetensors(model_weights_path)

    # Load weights into the model
    model.load_state_dict(state_dict)
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    print("Missing keys:", missing_keys)
    print("Unexpected keys:", unexpected_keys)
    print(model)
    model.eval()

    return model





def find_highest_checkpoint(folder):
    """
    Find the highest-numbered checkpoint in a given folder.
    
    Parameters:
    - folder (str): The directory containing the checkpoints.

    Returns:
    - str: The path to the highest-numbered checkpoint.
    """
    checkpoints = [d for d in os.listdir(folder) if d.startswith("checkpoint-")]
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {folder}")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
    return checkpoints[0]  # Return only the folder name of the highest checkpoint


# Main inference function
def generate_tracks(vocal_path, output_dir):
    """
    Generate instrumental tracks for all track classes from a given vocal input.
    """
    # Encode vocal input
    audio_codes = encode_audio(vocal_path, encodec_model, processor, DEVICE)  # Shape: (4, 750)
    
    print(f"Shape of audio_codes before unsqueeze: {audio_codes.shape}")
    
    # Remove extra batch dimension if present
    if len(audio_codes.shape) == 3 and audio_codes.shape[0] == 1:
        audio_codes = audio_codes.squeeze(0)  # Shape becomes (4, 750)
        
    print(f"Shape of audio_codes after squeeze: {audio_codes.shape}")

    for track_class in TRACK_CLASSES:
        print(f"Generating track for: {track_class}")

        # Locate the highest checkpoint
        checkpoint_folder = f"checkpoints_{track_class}"
        try:
            checkpoint_dir = find_highest_checkpoint(checkpoint_folder)
        except ValueError as e:
            print(f"Skipping {track_class} due to error: {e}")
            continue
            
        checkpoint_dir = os.path.join(checkpoint_folder, checkpoint_dir) 
        print(checkpoint_dir)
        # Load the model
        model = load_trained_model(checkpoint_dir, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN, num_codebooks=NUM_CODEBOOKS)
        model = model.to(DEVICE)

        # Prepare input for the model
        input_ids = audio_codes.unsqueeze(0)  # Add batch dimension again, final shape: (1, 4, 750)
        print(f"Shape of input_ids after unsqueeze: {input_ids.shape}")
        attention_mask = (input_ids != 0).long()
        print(f"Shape of attention_mask: {attention_mask.shape}")

        # Generate sequence autoregressively
        generated_sequence = []
        with torch.no_grad():
            for step in tqdm(range(input_ids.size(-1)), desc=f"Generating tokens for {track_class}", unit="step"):
                # Forward pass through the model
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)

                # Extract logits for the current step
                logits = outputs["logits"][:, :, step, :]  # Shape: (batch_size, num_codebooks, vocab_size)
                logits = logits.view(-1, logits.size(-1))  # Flatten to (batch_size * num_codebooks, vocab_size)

                # Decode the logits to get next tokens
                next_tokens = decode_generated_sequence(logits)

                # Reshape next_tokens to match the expected input_ids structure
                next_tokens = torch.tensor(next_tokens).view(input_ids.size(0), input_ids.size(1))

                # Append tokens to the generated sequence
                generated_sequence.append(next_tokens)

                # Update input_ids for the next step
                if step < input_ids.size(-1) - 1:
                    input_ids[:, :, step + 1] = next_tokens


        # Convert generated sequence to audio
        generated_sequence = np.stack(generated_sequence, axis=1)  # Shape: (4, 750)
        print(f"Generated sequence for {track_class}:", generated_sequence.shape)
        reshaped_output = torch.tensor(generated_sequence, dtype=torch.float32).squeeze(0)  # Shape: (750, 4)
        reshaped_output = reshaped_output.permute(1, 0).unsqueeze(0).unsqueeze(0)  # Final shape: (1, 1, 4, 750)
        print(f"Reshaped output for decoding: {reshaped_output.shape}")
        reshaped_output = reshaped_output.to(torch.long).to(DEVICE)
        # Decode the audio using encodec_model
        decoded_audio = encodec_model.decode(reshaped_output, [None])[0]  # Shape: (samples,)
        decoded_audio = decoded_audio.detach()
        decoded_audio = decoded_audio.squeeze(0).squeeze(0)  # Shape: [samples]
        decoded_audio = decoded_audio.unsqueeze(0)

        # Save the decoded audio
        output_audio_path = f"./{track_class}_generated.wav"
        os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
        torchaudio.save(output_audio_path, decoded_audio.cpu(), processor.sampling_rate)
        print(f"Saved generated audio for {track_class} to: {output_audio_path}")



# Example usage
if __name__ == "__main__":
    vocal_path = "inference.wav"  # Path to input vocal file
    output_dir = "./generated_tracks_codebooks"  # Directory to save generated tracks
    os.makedirs(output_dir, exist_ok=True)

    generate_tracks(vocal_path, output_dir)




Shape of audio_codes before unsqueeze: torch.Size([1, 4, 750])
Shape of audio_codes after squeeze: torch.Size([4, 750])
Generating track for: hi_hat
checkpoints_hi_hat\checkpoint-3531
Config: vocab_size=1026, n_embd=128, n_positions=750
Missing keys: []
Unexpected keys: []
CustomGPT2ForMusicGen(
  (codebook_embeddings): Embedding(1026, 128)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=512, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(


Generating tokens for hi_hat: 100%|████████████████████████████████████████████████| 750/750 [00:52<00:00, 14.22step/s]


Generated sequence for hi_hat: (1, 750, 4)
Reshaped output for decoding: torch.Size([1, 1, 4, 750])
Saved generated audio for hi_hat to: ./hi_hat_generated.wav
