# Independent Track Generation with Custom GPT-2 Model and Periodic Checkpointing

This script trains a GPT-2-based conditional generation model for independent music track generation using audio embeddings combined with positional embeddings. The dataset, loaded from a .npy file, contains audio codes, positional embeddings, and metadata for various track classes (e.g., bass, full_instrumental). A custom PyTorch dataset (MusicDataset) is used to preprocess and structure the data, ensuring each track class is trained independently without cumulative mixing. Positional embeddings are integrated into the model’s input embeddings during the forward pass using a modified GPT-2 architecture.

Training is performed sequentially for each track class, using the Hugging Face Trainer with cosine learning rate scheduling and mixed precision (FP16) for efficiency. Logging and monitoring are managed through Weights & Biases (WandB), while checkpoints are saved after each epoch, with the best model stored for future use. This approach ensures robust, track-specific generative modeling by leveraging both vocal and positional embeddings.

In [None]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    Trainer,
    TrainingArguments,
    get_cosine_schedule_with_warmup,
)
import torch.nn as nn
from sklearn.model_selection import train_test_split

# Import wandb for logging
import wandb

# Set the device to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load your saved .npy file
data = np.load(
    'fulldataset_10sec_positional_vocal_embs.npy',
    allow_pickle=True
).item()

VOCAB_SIZE = 1026
MAX_LENGTH = 3000  # Adjusted for maximum length
track_classes = ['bass', 'full_instrumental']

class MusicDataset(Dataset):
    def __init__(self, data, track_class):
        self.track_class = track_class
        self.data = {
            k: v for k, v in data.items()
            if self.track_class in v['generation_data']
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_id = list(self.data.keys())[idx]
        sample = self.data[sample_id]

        # Retrieve the specific track data without cumulative mixing
        vocal_audio_codes = sample['generation_data'].get(
            'vocal', np.zeros((4, 750))
        )
        track_data = sample['generation_data'].get(
            self.track_class, np.zeros((4, 750))
        )  # Track for training
        positional_embedding = sample.get(
            'positional_embedding', np.zeros((4, 750))
        )

        # Ensure values are within valid bounds
        vocal_audio_codes = np.clip(vocal_audio_codes, 0, VOCAB_SIZE - 1)

        # Flatten and pad/truncate sequences
        vocal_audio_codes = np.pad(
            vocal_audio_codes.flatten(),
            (0, MAX_LENGTH - len(vocal_audio_codes.flatten())),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH]
        track_data = np.pad(
            track_data.flatten(),
            (0, MAX_LENGTH - len(track_data.flatten())),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH]

        attention_mask = (vocal_audio_codes != 0).astype(int)
        # Flatten and pad positional embeddings
        pos_emb_flat = positional_embedding.flatten()
        pos_emb_flat = np.pad(
            pos_emb_flat,
            (0, MAX_LENGTH * positional_embedding.shape[1] - len(pos_emb_flat)),
            'constant',
            constant_values=(0, 0)
        )[:MAX_LENGTH * positional_embedding.shape[1]]

        # Reshape back to (MAX_LENGTH, embedding_dim)
        positional_embedding = pos_emb_flat.reshape(MAX_LENGTH, positional_embedding.shape[1])
        return {
            'input_ids': torch.tensor(vocal_audio_codes, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(track_data, dtype=torch.long),
            'positional_embeddings': torch.tensor(positional_embedding, dtype=torch.float),
            'sample_id': sample_id  # Only needed for generation and caching
        }

class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # No need to define projection layer if not used

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                positional_embeddings=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)

        # Combine positional embeddings with input embeddings
        embs_dim = positional_embeddings.shape[2]
        input_embeds = torch.cat((input_embeds[:, :, :-embs_dim], input_embeds[:, :, -embs_dim:] + positional_embeddings), dim=-1)

        # Proceed with the standard GPT-2 forward pass
        return super().forward(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

# Configuration for the model
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_LENGTH,
    n_ctx=MAX_LENGTH,
    n_embd=128,  # Match with your previous d_model
    n_layer=6,
    n_head=8,
    activation_function='gelu',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

# Initialize the custom GPT-2 model
model = CustomGPT2ForConditionalGeneration(config=config).to(device)

# Custom data collator to handle positional embeddings
class DataCollatorWithPositionalEmbeddings:
    def __call__(self, batch):
        input_ids = torch.stack([item['input_ids'] for item in batch]).to(device)
        attention_mask = torch.stack([item['attention_mask'] for item in batch]).to(device)
        labels = torch.stack([item['labels'] for item in batch]).to(device)
        positional_embeddings = torch.stack([item['positional_embeddings'] for item in batch]).to(device)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'positional_embeddings': positional_embeddings
        }

data_collator = DataCollatorWithPositionalEmbeddings()

# Set up the batch size and epoch interval for saving
batch_size = 1  # Adjust as needed
total_epochs = 120

# Sequential training and generation for each track in track_classes
for track_idx, track in enumerate(track_classes):
    print(f"Training for {track}...")

    # Create dataset filtered by the current track class (string)
    dataset = MusicDataset(data, track)

    train_indices, val_indices = train_test_split(
        range(len(dataset)), test_size=0.2, random_state=42
    )
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    # Set unique output directory for each track
    track_output_dir = f'./independent_track_generation_gpt2_checkpointing_{track}_model_pos_emb_concat_fulldataset_10sec_embeddings_final'

    # Initialize wandb project name and run name
    wandb_project_name = 'music_generation'
    wandb_run_name = f'independent_track_generation_gpt2_checkpointing_{track}_training_post_emb_concat_run_fulldataset_10sec_embeddings_final'

    # Initialize wandb before training starts
    wandb.init(project=wandb_project_name, name=wandb_run_name)

    track_training_args = TrainingArguments(
        output_dir=track_output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=1e-4,  # Lower the learning rate
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=total_epochs,
        weight_decay=0.01,
        save_total_limit=3,
        logging_dir=f'./logs_{track}',
        logging_steps=10,
        metric_for_best_model="loss",
        greater_is_better=False,
        fp16=True,
        dataloader_pin_memory=False,
        load_best_model_at_end=True,
        lr_scheduler_type='cosine',  # Use cosine learning rate scheduler
        warmup_steps=500,            # Number of warmup steps
        report_to=['wandb'],         # Report to wandb
        run_name=wandb_run_name,     # Set the wandb run name
    )

    # Calculate total training steps
    total_steps = (
        len(train_dataset) // track_training_args.per_device_train_batch_size
    ) * total_epochs

    # Initialize the optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=track_training_args.learning_rate,
        weight_decay=track_training_args.weight_decay
    )

    # Initialize the scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=track_training_args.warmup_steps,
        num_training_steps=total_steps,
    )

    # Initialize the Trainer for each track
    trainer = Trainer(
        model=model,
        args=track_training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        optimizers=(optimizer, scheduler),  # Pass the optimizer and scheduler
    )

    # Train the model for the current track
    trainer.train()

    print(f"Finished training and saved model checkpoints for {track}.")

    # Define the path to save the best model
    best_model_output_dir = os.path.join(
        'independent_track_generation_gpt2_checkpointing_pos_emb_concat_full_dataset_10secembeddings_final',
        f'{track}_model'
    )

    # Ensure the directory exists
    os.makedirs(best_model_output_dir, exist_ok=True)

    # Save the best model after training
    trainer.save_model(output_dir=best_model_output_dir)

    # Finish the wandb run
    wandb.finish() 

print("Training completed for all tracks.")


# Inference
This script performs track-specific music generation using pre-trained GPT-2-based models and Facebook's EnCodec for audio processing. It encodes input audio files into fixed-length audio codes and extracts positional embeddings using beat and downbeat detection via Madmom. These embeddings, combined with the audio codes, are passed to fine-tuned GPT-2 models for generating sequences specific to various track classes (e.g., hi_hat, kick, snare, bass, etc.).

The generated sequences are decoded back into audio using EnCodec, producing distinct audio tracks for each class. The script processes multiple input files and saves the outputs in organized directories, ensuring efficient handling of track-specific generative tasks while maintaining temporal coherence through positional embeddings.

In [9]:
import os
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, EncodecModel, GPT2LMHeadModel
import madmom
import torch.nn as nn
import tempfile

# Constants
VOCAB_SIZE = 1026
MAX_LENGTH = 3000
device = torch.device("cpu")  # Use "cuda" if GPU is available

# Track classes
track_classes = ['hi_hat', 'kick', 'snare', 'clap', 'bass', 'drums', 'keys', 'full_instrumental']

# Initialize Encodec Model and Processor
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
model_encodec = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)

# Function Definitions (unchanged except added safety checks)
def encode_audio(audio_path):
    audio, rate = torchaudio.load(audio_path)
    max_length_in_samples = int(rate * 10)

    if audio.shape[1] > max_length_in_samples:
        audio = audio[:, :max_length_in_samples]
    else:
        pad_length = max_length_in_samples - audio.shape[1]
        audio = torch.nn.functional.pad(audio, (0, pad_length))

    if audio.shape[0] > 1:
        audio = audio.mean(dim=0)
    else:
        audio = audio.squeeze(0)

    inputs = processor(audio.numpy(), sampling_rate=rate, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        outputs = model_encodec.encode(inputs["input_values"], inputs["padding_mask"], 3)
    
    duration = audio.shape[0] / rate
    return outputs.audio_codes.squeeze(), min(duration, 10.0)

def extract_beats_and_downbeats(audio_path, fps=100, duration=10):
    audio, rate = torchaudio.load(audio_path)
    audio = audio[:, :int(duration * rate)]

    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
        temp_audio_path = temp_audio_file.name
        torchaudio.save(temp_audio_path, audio, rate)

    proc_downbeats = madmom.features.downbeats.DBNDownBeatTrackingProcessor(beats_per_bar=[4], fps=fps)
    act_downbeats = madmom.features.downbeats.RNNDownBeatProcessor(fps=fps)(temp_audio_path)
    downbeats = proc_downbeats(act_downbeats)

    proc_beats = madmom.features.beats.BeatDetectionProcessor(fps=fps)
    act_beats = madmom.features.beats.RNNBeatProcessor(fps=fps)(temp_audio_path)
    beats = proc_beats(act_beats)

    os.remove(temp_audio_path)

    if len(beats) == 0 or len(downbeats) == 0:
        raise ValueError(f"No beats or downbeats detected in {audio_path}")

    return beats, downbeats[downbeats[:, 1] == 1, 0]

def create_positional_embeddings(beat_times, downbeat_times, audio_duration, fps=75, K=32):
    total_frames = int(np.ceil(audio_duration * fps))

    def ramps(positions, size):
        result = np.zeros(size)
        for a, b in zip(positions[:-1], positions[1:]):
            result[a:b] = np.linspace(0, 1, b - a, endpoint=False)
        missing = positions[0]
        if missing:
            piece = result[positions[0]:positions[1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[:missing] = pieces[-missing:]
        missing = size - positions[-1]
        if missing:
            piece = result[positions[-2]:positions[-1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[-missing:] = pieces[:missing]
        return result

    vector_downbeat = ramps((downbeat_times * fps).astype(int), total_frames)
    vector_beat = ramps((beat_times * fps).astype(int), total_frames)

    frequencies = np.arange(1, K + 1)
    embeddings_downbeat = []
    embeddings_beat = []

    for k in frequencies:
        embeddings_downbeat.append(np.sin(2 * np.pi * vector_downbeat * k))
        embeddings_downbeat.append(np.cos(2 * np.pi * vector_downbeat * k))
        embeddings_beat.append(np.sin(2 * np.pi * vector_beat * k))
        embeddings_beat.append(np.cos(2 * np.pi * vector_beat * k))

    embeddings_downbeat = np.stack(embeddings_downbeat, axis=1)
    embeddings_beat = np.stack(embeddings_beat, axis=1)
    embeddings = np.hstack((embeddings_downbeat, embeddings_beat))

    return torch.from_numpy(embeddings).float()

class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # No need to define projection layer if not used

    def forward(self, input_ids=None, attention_mask=None, labels=None,
                positional_embeddings=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)
        # Combine positional embeddings with input embeddings
        input_embeds = input_embeds + positional_embeddings

        # Proceed with the standard GPT-2 forward pass
        return super().forward(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

def find_highest_checkpoint(folder):
    checkpoints = [d for d in os.listdir(folder) if d.startswith("checkpoint-")]
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {folder}")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
    return os.path.join(folder, checkpoints[0])

def generate_track(model, audio_codes, positional_embeddings, attention_mask):
    # Flatten and pad audio_codes to MAX_LENGTH
    audio_codes = audio_codes.flatten().to(device)
    audio_codes = torch.nn.functional.pad(
        audio_codes,
        (0, MAX_LENGTH - audio_codes.shape[0]),
        value=0
    )[:MAX_LENGTH]

    # Generate attention_mask from padded audio_codes
    attention_mask = (audio_codes != 0).long().to(device)

    # Ensure positional_embeddings are padded to MAX_LENGTH
    if positional_embeddings.shape[0] < MAX_LENGTH:
        padding_length = MAX_LENGTH - positional_embeddings.shape[0]
        positional_embeddings = torch.nn.functional.pad(
            positional_embeddings,
            (0, 0, 0, padding_length),
            mode="constant",
            value=0
        )

    positional_embeddings = positional_embeddings[:MAX_LENGTH].to(device)


    # Pass inputs to the model
    with torch.no_grad():
        outputs = model(
            input_ids=audio_codes.unsqueeze(0),  # [1, MAX_LENGTH]
            attention_mask=attention_mask.unsqueeze(0),  # [1, MAX_LENGTH]
            positional_embeddings=positional_embeddings.unsqueeze(0)  # [1, MAX_LENGTH, embedding_dim]
        )
        return outputs.logits.argmax(dim=-1).squeeze().detach().cpu()



# Inference
inference_files = [
    ("inference.wav", "inference_posemb.wav", "1"),
    ("inference2.wav", "inference_posemb2.wav", "2"),
    ("inference3.wav", "inference_posemb3.wav", "3"),
]

for audio_path, posemb_path, folder in inference_files:
    for track_class in track_classes:
        print(f"Processing {audio_path} -> {track_class} in folder {folder}...")
        model_folder = f'./independent_track_generation_gpt2_checkpointing_{track_class}_model_pos_emb_concat_fulldataset_10sec_embeddings_final'
        highest_checkpoint = find_highest_checkpoint(model_folder)
        model = CustomGPT2ForConditionalGeneration.from_pretrained(highest_checkpoint).to(device)
        model.eval()

        audio_codes, audio_length = encode_audio(audio_path)
        beats, downbeats = extract_beats_and_downbeats(posemb_path, duration=audio_length)
        positional_embeddings = create_positional_embeddings(beats, downbeats, audio_length)

        padding_length = MAX_LENGTH - positional_embeddings.shape[0]
        positional_embeddings = torch.nn.functional.pad(positional_embeddings, (0, 0, 0, padding_length))

        attention_mask = (audio_codes != 0).long()
        generated_sequence = generate_track(model, torch.tensor(audio_codes), positional_embeddings, attention_mask)

        reshaped_output = generated_sequence.view(4, 750).unsqueeze(0).unsqueeze(0)
        decoded_audio = model_encodec.decode(reshaped_output, [None])[0]
        decoded_audio = decoded_audio.detach()
        decoded_audio = decoded_audio.squeeze(0).squeeze(0)  # Shape: [samples]
        decoded_audio = decoded_audio.unsqueeze(0)
        output_audio_path = f"./{model_folder}/{track_class}_{audio_path}_generated.wav"
        os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
        torchaudio.save(output_audio_path, decoded_audio.cpu(), processor.sampling_rate)
        print(f"Saved: {output_audio_path}")




Processing inference.wav -> hi_hat in folder 1...


  generated_sequence = generate_track(model, torch.tensor(audio_codes), positional_embeddings, attention_mask)


Saved: ././independent_track_generation_gpt2_checkpointing_hi_hat_model_pos_emb_concat_fulldataset_10sec_embeddings_final/hi_hat_inference.wav_generated.wav
Processing inference.wav -> kick in folder 1...
Saved: ././independent_track_generation_gpt2_checkpointing_kick_model_pos_emb_concat_fulldataset_10sec_embeddings_final/kick_inference.wav_generated.wav
Processing inference.wav -> snare in folder 1...
Saved: ././independent_track_generation_gpt2_checkpointing_snare_model_pos_emb_concat_fulldataset_10sec_embeddings_final/snare_inference.wav_generated.wav
Processing inference.wav -> clap in folder 1...
Saved: ././independent_track_generation_gpt2_checkpointing_clap_model_pos_emb_concat_fulldataset_10sec_embeddings_final/clap_inference.wav_generated.wav
Processing inference.wav -> bass in folder 1...
Saved: ././independent_track_generation_gpt2_checkpointing_bass_model_pos_emb_concat_fulldataset_10sec_embeddings_final/bass_inference.wav_generated.wav
Processing inference.wav -> drums in