# Joint Music Track Generation with GPT-2, Weights & Biases Tracking, and Periodic Checkpointing

This notebook trains a conditional GPT-2 model for music generation with a joint approach across multiple track classes. Each track type is identified by a unique instrument token embedded periodically throughout the sequence, enabling the model to learn different track classes within a single unified framework. The dataset pre-filters valid samples for each track, ensuring that only meaningful data is used for training. The model is configured to handle track class conditioning, and it combines positional embeddings and instrument tokens with the input embeddings to conditionally generate each track type.

The script also integrates Weights & Biases (W&B) for real-time tracking of training metrics, providing an online dashboard for monitoring model performance. The data is split into train and validation sets to allow periodic evaluation, giving a more comprehensive view of model generalization across track types.

In [1]:
import wandb
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments, TrainerCallback
import torch.nn as nn
from sklearn.model_selection import train_test_split

# Set the device to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

print(torch.cuda.is_available())  # Should return True if a GPU is detected
print(torch.cuda.device_count())  # Should return the number of GPUs detected
print(torch.cuda.get_device_name(0))

# Load your saved .npy file
data = np.load('fulldataset_10sec_positional_embs.npy', allow_pickle=True).item()

VOCAB_SIZE = 1034
MAX_LENGTH = 3000  # Adjusted for maximum length
track_classes = ['hi_hat', 'kick', 'snare', 'clap', 'bass', 'drums', 'keys', 'full_instrumental']  # Example track classes

# Map each track class to a unique token ID for conditioning
instrument_token_map = {
    'hi_hat': 1027,
    'kick': 1028,
    'snare': 1029,
    'clap': 1030,
    'bass': 1031,
    'drums': 1032,
    'keys': 1032,
    'full_instrumental': 1033
}

wandb_run_name = f'joint_track_generation_gpt2_training_pos_emb_concat_run_fulldataset_10sec_embeddings_final'

wandb.init(project="music-generation", name=wandb_run_name, )

class MusicDataset(Dataset):
    def __init__(self, data, instrument_token_map):
        self.data = data
        self.instrument_token_map = instrument_token_map

        # Initialize a dictionary to store counts of valid samples per track class
        self.track_class_counts = {track_class: 0 for track_class in track_classes}

        # Precompute a list of (sample_id, track_class) pairs that have valid data
        self.valid_pairs = []
        for sample_id, sample in self.data.items():
            for track_class in track_classes:
                if track_class in sample['generation_data'] and np.any(sample['generation_data'][track_class] != 0):
                    self.valid_pairs.append((sample_id, track_class))
                    self.track_class_counts[track_class] += 1

        # Log the number of valid samples for each track class
        print("Valid sample counts per track class:")
        for track_class, count in self.track_class_counts.items():
            print(f"{track_class}: {count} samples")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        # Get the (sample_id, track_class) pair for this index
        sample_id, track_class = self.valid_pairs[idx]
        sample = self.data[sample_id]

        # Retrieve the specific track data
        vocal_audio_codes = sample['generation_data'].get('vocal', np.zeros((4, 750)))
        track_data = sample['generation_data'][track_class]

        # Retrieve positional embeddings
        positional_embedding = sample.get('positional_embedding', np.zeros((4, 750)))

        # Ensure vocal_audio_codes is clipped
        vocal_audio_codes = np.clip(vocal_audio_codes, 0, VOCAB_SIZE - 1)

        # Flatten and pad/truncate sequences
        vocal_audio_codes = np.pad(vocal_audio_codes.flatten(), (0, MAX_LENGTH - len(vocal_audio_codes.flatten())), 'constant', constant_values=(0, 0))[:MAX_LENGTH]
        track_data = np.pad(track_data.flatten(), (0, MAX_LENGTH - len(track_data.flatten())), 'constant', constant_values=(0, 0))[:MAX_LENGTH]

        # Generate attention mask
        attention_mask = (vocal_audio_codes != 0).astype(int)

        # Pad positional embeddings
        padding_length = MAX_LENGTH - positional_embedding.shape[0]
        positional_embedding = np.pad(positional_embedding, ((0, padding_length), (0, 0)), mode='constant', constant_values=(0, 0))

        # Repeat the instrument token throughout the sequence at regular intervals
        interval = 50  # Define the interval at which to repeat the token
        instrument_token = np.full(MAX_LENGTH, self.instrument_token_map[track_class], dtype=int)
        instrument_token[::interval] = self.instrument_token_map[track_class]

        return {
            'input_ids': torch.tensor(vocal_audio_codes, dtype=torch.long).to(device),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long).to(device),
            'labels': torch.tensor(track_data, dtype=torch.long).to(device),
            'positional_embeddings': torch.tensor(positional_embedding, dtype=torch.float).to(device),
            'instrument_token': torch.tensor(instrument_token, dtype=torch.long).to(device),
            'track_class': track_class,  # Track the class for generation
            'sample_id': sample_id  # Track the sample ID for caching
        }

class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.projection = nn.Linear(config.n_embd * 2, config.n_embd)

    def forward(self, input_ids=None, attention_mask=None, labels=None, positional_embeddings=None, instrument_token=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)

        # Combine positional embeddings with input embeddings
        embs_dim = positional_embeddings.shape[2]
        input_embeds = torch.cat((input_embeds[:, :, :-embs_dim], input_embeds[:, :, -embs_dim:] + positional_embeddings), dim=-1)

        # Expand instrument token embeddings and add them to the input
        instrument_embeds = self.transformer.wte(instrument_token)
        combined_input = input_embeds + instrument_embeds

        # Forward pass through the GPT model
        return super().forward(inputs_embeds=combined_input, attention_mask=attention_mask, labels=labels, **kwargs)

# Configuration for the model
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_LENGTH,
    n_ctx=MAX_LENGTH,
    n_embd=128,  # Match with your previous d_model
    n_layer=6,
    n_head=8,
    activation_function='gelu',
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
)

# Initialize the custom GPT-2 model with regular instrument token conditioning
model = CustomGPT2ForConditionalGeneration(config=config).to(device)

# Training arguments (set to evaluate/save every 10 epochs)
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=50,
    weight_decay=0.01,
    save_total_limit=3,
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    fp16=torch.cuda.is_available(),
    dataloader_pin_memory=False,
    report_to=["wandb"],
)

# Custom data collator to handle positional embeddings and instrument tokens
class DataCollatorWithPositionalEmbeddings:
    def __call__(self, batch):
        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])
        positional_embeddings = torch.stack([item['positional_embeddings'] for item in batch])
        instrument_tokens = torch.stack([item['instrument_token'] for item in batch])

        return {
            'input_ids': input_ids.to(device),
            'attention_mask': attention_mask.to(device),
            'labels': labels.to(device),
            'positional_embeddings': positional_embeddings.to(device),
            'instrument_token': instrument_tokens.to(device),
        }

data_collator = DataCollatorWithPositionalEmbeddings()

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
)

# Create dataset
dataset = MusicDataset(data, instrument_token_map)

# Split dataset indices into train and validation sets (e.g., 80% train, 20% validation)
train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)

# Create separate train and validation datasets
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

# Train the model
trainer.train_dataset = train_dataset
trainer.eval_dataset = val_dataset
trainer.train()

# Save the model after training
trainer.save_model('./saved_model_joint_standard_full_dataset_final')
print("Model saved.")


KeyboardInterrupt: 

# Inference

In [3]:
import os
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, EncodecModel, GPT2LMHeadModel
import tempfile
import madmom
import torch.nn as nn
from torch.nn.functional import pad as torch_pad

# Constants
VOCAB_SIZE = 1034
MAX_LENGTH = 3000
device = torch.device("cpu")  # Set device to CPU for inference

# Track classes and corresponding instrument tokens
track_classes = ['hi_hat', 'kick', 'snare', 'clap', 'bass', 'drums', 'keys', 'full_instrumental']  # Example track classes

# Map each track class to a unique token ID for conditioning
instrument_token_map = {
    'hi_hat': 1027,
    'kick': 1028,
    'snare': 1029,
    'clap': 1030,
    'bass': 1031,
    'drums': 1032,
    'keys': 1032,
    'full_instrumental': 1033
}


# Initialize Encodec Model and Processor
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
model_encodec = EncodecModel.from_pretrained("facebook/encodec_24khz").to(device)

# Function to encode the audio and extract audio codes
def encode_audio(audio_path):
    audio, rate = torchaudio.load(audio_path)
    
    # Calculate the number of samples corresponding to 10 seconds
    max_length_in_samples = int(rate * 10)
    
    # Trim or pad the audio to ensure it's exactly 10 seconds long
    if audio.shape[1] > max_length_in_samples:
        audio = audio[:, :max_length_in_samples]
    else:
        pad_length = max_length_in_samples - audio.shape[1]
        audio = torch.nn.functional.pad(audio, (0, pad_length))
    
    # Ensure that the audio has a single channel
    if audio.ndim > 1:
        audio = audio[0]
    
    audio_np = audio.numpy()
    inputs = processor(audio_np, sampling_rate=rate, return_tensors="pt")
    
    # Move input values to the same device as the model (CPU in this case)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    with torch.no_grad():
        outputs = model_encodec.encode(inputs["input_values"], inputs["padding_mask"], 3)
    
    return outputs.audio_codes.squeeze(), min(audio.shape[0] / rate, 10)

# Function to extract downbeats
def extract_downbeats(audio_path, fps=100, duration=10):
    audio, rate = torchaudio.load(audio_path)
    num_samples = int(duration * rate)
    
    if audio.shape[1] > num_samples:
        audio = audio[:, :num_samples]
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
        temp_audio_path = temp_audio_file.name
        torchaudio.save(temp_audio_path, audio, rate)
    
    proc = madmom.features.downbeats.DBNDownBeatTrackingProcessor(beats_per_bar=[4], fps=fps)
    act = madmom.features.downbeats.RNNDownBeatProcessor(fps=fps)(temp_audio_path)
    downbeats = proc(act)
    
    os.remove(temp_audio_path)
    
    return downbeats[downbeats[:, 1] == 1, 0]

# Function to extract beats
def extract_beats(audio_path, fps=100, duration=10):
    audio, rate = torchaudio.load(audio_path)
    num_samples = int(duration * rate)
    
    if audio.shape[1] > num_samples:
        audio = audio[:, :num_samples]
    
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file:
        temp_audio_path = temp_audio_file.name
        torchaudio.save(temp_audio_path, audio, rate)
    
    beat_act = madmom.features.beats.RNNBeatProcessor(fps=fps)(temp_audio_path)
    beat_proc = madmom.features.beats.BeatDetectionProcessor(fps=fps)
    beats = beat_proc(beat_act)
    
    os.remove(temp_audio_path)
    
    return beats

# Function to create positional embeddings
def create_positional_embeddings(beat_times, downbeat_times, audio_duration, fps=75, K=32):
    total_frames = int(np.ceil(audio_duration * fps))

    def ramps(positions, size):
        result = np.zeros(size)
        for a, b in zip(positions[:-1], positions[1:]):
            result[a:b] = np.linspace(0, 1, b - a, endpoint=False)
        missing = positions[0]
        if missing:
            piece = result[positions[0]:positions[1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[:missing] = pieces[-missing:]
        missing = size - positions[-1]
        if missing:
            piece = result[positions[-2]:positions[-1]]
            pieces = np.tile(piece, missing // len(piece) + 1)
            result[-missing:] = pieces[:missing]
        return result

    vector_downbeat = ramps((downbeat_times * fps).astype(int), total_frames)
    vector_beat = ramps((beat_times * fps).astype(int), total_frames)

    frequencies = np.arange(1, K + 1)
    embeddings_downbeat = []
    embeddings_beat = []

    for k in frequencies:
        embeddings_downbeat.append(np.sin(2 * np.pi * vector_downbeat * k))
        embeddings_downbeat.append(np.cos(2 * np.pi * vector_downbeat * k))
        embeddings_beat.append(np.sin(2 * np.pi * vector_beat * k))
        embeddings_beat.append(np.cos(2 * np.pi * vector_beat * k))

    embeddings_downbeat = np.stack(embeddings_downbeat, axis=1)
    embeddings_beat = np.stack(embeddings_beat, axis=1)
    embeddings = np.hstack((embeddings_downbeat, embeddings_beat))

    return torch.from_numpy(embeddings).float()

# Initialize the trained GPT-2 model
class CustomGPT2ForConditionalGeneration(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.projection = nn.Linear(config.n_embd * 2, config.n_embd)

    def forward(self, input_ids=None, attention_mask=None, labels=None, positional_embeddings=None, instrument_token=None, **kwargs):
        # Get input embeddings
        input_embeds = self.transformer.wte(input_ids)

        # Expand instrument token embeddings and add them to the input
        instrument_embeds = self.transformer.wte(instrument_token)

        # Combine positional embeddings with input embeddings
        embs_dim = positional_embeddings.shape[2]
        input_embeds = torch.cat((input_embeds[:, :, :-embs_dim], input_embeds[:, :, -embs_dim:] + positional_embeddings), dim=-1)
        
        # Add instrument embeddings at the defined intervals
        combined_input = input_embeds + instrument_embeds

        # Forward pass through the GPT model
        return super().forward(inputs_embeds=combined_input, attention_mask=attention_mask, labels=labels, **kwargs)

# Function to generate each track with debug prints
def generate_track(model, audio_codes, positional_embeddings, attention_mask, track_class, device):
    # Set the interval for repeating the instrument token
    interval = 50  # For example, place the instrument token every 50 positions
    
    # Create the instrument token array filled with the default class token
    instrument_token = np.full(MAX_LENGTH, instrument_token_map[track_class], dtype=int)
    
    # Repeat the instrument token at regular intervals
    instrument_token[::interval] = instrument_token_map[track_class]
    
    # Convert to tensor and ensure it matches the input shape
    instrument_token = torch.tensor(instrument_token, dtype=torch.long).unsqueeze(0).to(device)

    # Ensure positional embeddings and attention mask are correctly sized
    attention_mask = attention_mask.to(device)
    positional_embeddings = positional_embeddings.to(device)

    # Debugging prints
    print(f"Generating for {track_class}...")
    print(f"Instrument token map for {track_class}: {instrument_token_map[track_class]}")
    print(f"Shapes - audio_codes: {audio_codes.shape}, instrument_token: {instrument_token.shape}, attention_mask: {attention_mask.shape}, positional_embeddings: {positional_embeddings.shape}")
    
    # Generate output
    with torch.no_grad():  # Ensure no gradient computation
        outputs = model(input_ids=audio_codes, attention_mask=attention_mask, positional_embeddings=positional_embeddings, instrument_token=instrument_token)
        generated_sequence = outputs.logits.argmax(dim=-1)
    
    return generated_sequence


model_gpt2_path = f'./saved_model_joint_standard_full_dataset_final'
model = CustomGPT2ForConditionalGeneration.from_pretrained(model_gpt2_path).to(device)

# Set the model to evaluation mode to disable dropout and other training-specific layers
model.eval()

# Main inference loop for each track class
audio_path = "inference.wav"
reference_beat_path = "inference_posemb.wav"

# Encode audio to get audio codes
audio_codes, length_in_seconds = encode_audio(audio_path)

# Move the tensor to the CPU before converting to NumPy and processing
audio_codes = audio_codes.cpu().numpy().flatten()

# Apply padding to the flattened audio codes and ensure it matches MAX_LENGTH
audio_codes = np.pad(audio_codes, (0, MAX_LENGTH - len(audio_codes)), 'constant', constant_values=(0, 0))[:MAX_LENGTH]

# Convert back to tensor and move to the appropriate device
audio_codes = torch.tensor(audio_codes, dtype=torch.long).unsqueeze(0).to(device)

# Extract beats and downbeats
reference_beats = extract_beats(reference_beat_path, duration=length_in_seconds)
reference_downbeats = extract_downbeats(reference_beat_path, duration=length_in_seconds)

# Create positional embeddings
positional_embeddings = create_positional_embeddings(reference_beats, reference_downbeats, length_in_seconds)
positional_embeddings = torch_pad(positional_embeddings, (0, 0, 0, MAX_LENGTH - positional_embeddings.shape[0])).unsqueeze(0).to(device)

# Create attention mask
attention_mask = (audio_codes != 0).long()

# Generate and save tracks for each instrument
output_directory = f'generated_joint_model'
os.makedirs(output_directory, exist_ok=True)
for track_class in track_classes:
    print(f"Generating {track_class}...")

    # Generate the track
    generated_sequence = generate_track(model, audio_codes, positional_embeddings, attention_mask, track_class, device)
    print('generated sequence: ', generated_sequence.shape)
    # Reshape and decode
    reshaped_output = generated_sequence.view(4, 750).unsqueeze(0).unsqueeze(0)
    audio_values = model_encodec.decode(reshaped_output, [None])[0]
    audio_np = audio_values.cpu().squeeze().detach().numpy()

    # Save the generated audio
    output_audio_path = os.path.join(output_directory, f"generated_{track_class}.wav")
    torchaudio.save(output_audio_path, torch.tensor(audio_np).unsqueeze(0), processor.sampling_rate)
    print(f"{track_class} track saved to {output_audio_path}")



Generating hi_hat...
Generating for hi_hat...
Instrument token map for hi_hat: 1027
Shapes - audio_codes: torch.Size([1, 3000]), instrument_token: torch.Size([1, 3000]), attention_mask: torch.Size([1, 3000]), positional_embeddings: torch.Size([1, 3000, 128])
generated sequence:  torch.Size([1, 3000])
hi_hat track saved to generated_joint_model\generated_hi_hat.wav
Generating kick...
Generating for kick...
Instrument token map for kick: 1028
Shapes - audio_codes: torch.Size([1, 3000]), instrument_token: torch.Size([1, 3000]), attention_mask: torch.Size([1, 3000]), positional_embeddings: torch.Size([1, 3000, 128])
generated sequence:  torch.Size([1, 3000])
kick track saved to generated_joint_model\generated_kick.wav
Generating snare...
Generating for snare...
Instrument token map for snare: 1029
Shapes - audio_codes: torch.Size([1, 3000]), instrument_token: torch.Size([1, 3000]), attention_mask: torch.Size([1, 3000]), positional_embeddings: torch.Size([1, 3000, 128])
generated sequence: 