In [None]:
import os
import random
import torch #pytorch
import librosa
import numpy as np
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor, GPT2Config, GPT2Model
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
# used to define neural network layers and setting up and training the network
from IPython.display import Audio as IPyAudio
import soundfile as sf
import csv
from typing import List
from tqdm import tqdm # for progress bars
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise
import wandb

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "FoleyGen_Oct.ipynb"
wandb.login(key = '1a9688afc9e6c6bf3585eecf13438d302bdbcd73')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mssr9055[0m ([33mdl4m_final[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ssr9055/.netrc


True

In [3]:
#Resuming from last run
wandb.init(
    project="FoleyGen_Oct",
    id="7ue2sn90",  
    resume="must"  
)

## Splitting datasets

In [4]:
clas_dict = {
    "DogBark": 0,
    "Footstep": 1,
    "Gunshot": 2,
    "Keyboard": 3,
    "MovingMotorVehicle": 4,
    "Rain": 5,
    "SneezeCough": 6,
}

In [1]:
# Update with your path
dataset_path = ""
output_csv_path = ""

In [6]:
def load_dataset_from_csv(csv_path: str):
    training_files = []
    valid_files = []
    test_files = []

    with open(csv_path, "r") as csv_file:
        reader = csv.DictReader(csv_file)
        for row in reader:
            file_info = {
                "file_path": row["filepath"],
                "class_id": clas_dict.get(os.path.basename(os.path.dirname(row["filepath"])), None),
            }
            if row["split"] == "train":
                training_files.append(file_info)
            elif row["split"] == "validation":
                valid_files.append(file_info)
            elif row["split"] == "test":
                test_files.append(file_info)

    return training_files, valid_files, test_files



def split_dataset_files(dataset_path: str, csv_path: str):
   
    # Check if the CSV file exists
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"The specified CSV file does not exist: {csv_path}")
    
    # Load splits from the CSV
    print(f"Loading splits from existing CSV: {csv_path}")
    return load_dataset_from_csv(csv_path)



In [7]:
# Split the dataset
training_files, valid_files, test_files = split_dataset_files(dataset_path, csv_path=output_csv_path)

Loading splits from existing CSV: /scratch/ssr9055/my_env/dataset_splits.csv


## Initalizing Encodec Model


In [10]:
# Initialize EnCodec model
from encodec import EncodecModel

encodec_model = EncodecModel.encodec_model_24khz()
encodec_model.set_target_bandwidth(6.0)

# Define codebook_size and num_quantizers based on the actual model
codebook_size = 1024  # Since max index is up to 1023
num_quantizers = 8     # For EnCodec model with target bandwidth of 6.0

  WeightNorm.apply(module, name, dim)


# Gaussian Noise Augmentation and converting raw audio to tokens with Encodec

In [11]:


# Define augmentation
augment = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
])

def encode_audio(file_path, model, sr=24000, apply_augmentation = False):
    import soundfile as sf
    audio, original_sr = sf.read(file_path)

    if len(audio) == 0:
        audio = np.zeros(1)

    if apply_augmentation:
        # Apply augmentation (CPU)
        audio = augment(samples=audio, sample_rate=sr)

    

    # Convert to tensor and move to GPU
    audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0)
    audio_tensor = audio_tensor.to(next(model.parameters()).device)

    with torch.no_grad():
        # Encode the audio using EnCodec (GPU)
        encoded_frames = model.encode(audio_tensor)

    codes_list = [frame[0] for frame in encoded_frames]  
    codes = torch.cat(codes_list, dim=2)

    codes = codes.squeeze(0).permute(1, 0).long()

    num_quantizers = codes.shape[1]

    for q in range(num_quantizers):
        max_index = codes[:, q].max().item()
        min_index = codes[:, q].min().item()

    return codes


## Functions for applying and removing delay pattern

In [12]:
def apply_delay_pattern(codes):
    
    num_frames, num_quantizers = codes.shape
    
    max_delay = num_quantizers - 1
    
    padding_value = codebook_size  
    
    delayed_codes = torch.full((num_frames + max_delay, num_quantizers), fill_value=padding_value, dtype=codes.dtype)
    
    for q in range(num_quantizers):
        delayed_codes[q:q + num_frames, q] = codes[:, q]
        
    return delayed_codes  # Shape: [num_frames + max_delay, num_quantizers]

def remove_delay_pattern(delayed_codes, num_quantizers):
    
    num_frames = delayed_codes.shape[0] - (num_quantizers - 1)
    
    codes = torch.zeros(num_frames, num_quantizers, dtype=delayed_codes.dtype)
    
    for q in range(num_quantizers):
        codes[:, q] = delayed_codes[q:q + num_frames, q]
        
    return codes  # Shape: [num_frames, num_quantizers]

## Create AudioDataset Class and Data Loaders

In [13]:
class AudioDataset(Dataset):
    def __init__(self, file_list, encodec_model, max_length=300, codebook_size=1024, apply_augmentation = False):
        self.file_list = file_list
        self.encodec_model = encodec_model
        self.max_length = max_length
        self.codebook_size = codebook_size
        self.vocab_size = self.codebook_size + 1  # +1 for padding
        self.num_quantizers = 8
        self.apply_augmentation = apply_augmentation

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_info = self.file_list[idx]
        file_path = file_info["file_path"]
        class_id = file_info["class_id"]  # Get the class ID
        codes = encode_audio(file_path, self.encodec_model, apply_augmentation=self.apply_augmentation)
        delayed_codes = apply_delay_pattern(codes)

        # Truncate or pad sequences to max_length
        input_ids = delayed_codes
        padding_value = self.vocab_size - 1
        if input_ids.shape[0] > self.max_length:
            input_ids = input_ids[:self.max_length, :]
        else:
            pad_length = self.max_length - input_ids.shape[0]
            padding = torch.full((pad_length, input_ids.shape[1]), padding_value, dtype=torch.long)
            input_ids = torch.cat([input_ids, padding], dim=0)

        return input_ids, class_id  # Return the class ID along with the input_ids




# Create Data Loaders
max_sequence_length = 300
batch_size = 4

train_dataset = AudioDataset(training_files, encodec_model, max_length=max_sequence_length, apply_augmentation = True)
valid_dataset = AudioDataset(valid_files, encodec_model, max_length=max_sequence_length, apply_augmentation = False)
test_dataset = AudioDataset(test_files, encodec_model, max_length=max_sequence_length, apply_augmentation = False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

## Define AudioGPT2 Model


In [14]:
class AudioGPT2(nn.Module):
    def __init__(self, num_quantizers=8, codebook_size=1024, hidden_size=768, num_classes=7):
        super(AudioGPT2, self).__init__()
        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.vocab_size = self.codebook_size + 1  # +1 for padding
        self.total_vocab_size = self.num_quantizers * self.vocab_size

        # Custom embedding layer for the audio tokens
        self.embedding = nn.Embedding(self.total_vocab_size, hidden_size)

        # Embedding layer for class labels
        self.class_embedding = nn.Embedding(num_classes, hidden_size)

        # Loading pre-trained GPT-2 model
        self.gpt2 = GPT2Model.from_pretrained('gpt2')

        # Output layer
        self.output_layer = nn.Linear(hidden_size, self.total_vocab_size)

    def forward(self, input_ids=None, class_id=None):
        # input_ids: [batch_size, seq_length, num_quantizers]
        # class_id: [batch_size]

        # Offset tokens for each quantizer to avoid overlap
        offset = torch.arange(self.num_quantizers, device=input_ids.device) * self.vocab_size
        offset = offset.unsqueeze(0).unsqueeze(0)
        input_ids = input_ids + offset  

        # Get audio embeddings
        embeddings = self.embedding(input_ids).sum(dim=2)  # Sum across quantizers

        # Get class embeddings and expand to match the sequence length
        class_embeddings = self.class_embedding(class_id).unsqueeze(1).expand(-1, embeddings.size(1), -1)

        # Concatenate class embeddings with audio embeddings
        combined_embeddings = embeddings + class_embeddings

        # Pass through GPT-2
        gpt_outputs = self.gpt2(inputs_embeds=combined_embeddings)

        # Output layer
        logits = self.output_layer(gpt_outputs.last_hidden_state)
        return logits

## Train model

In [15]:
# Initialize the model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AudioGPT2(num_quantizers=num_quantizers, codebook_size=codebook_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=model.vocab_size - 1)  # Ignore padding token
optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay = 1e-4)

#### For starting a new training run

In [18]:

num_epochs = 140

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        input_ids, class_id = batch
        input_ids, class_id = input_ids.to(device), class_id.to(device)
        
        optimizer.zero_grad()
        
        

        # Prepare inputs and targets
        inputs = input_ids[:, :-1, :]  # [batch_size, seq_length-1, num_quantizers]
        targets = input_ids[:, 1:, :]  # [batch_size, seq_length-1, num_quantizers]

        # Forward pass
        logits = model(input_ids = inputs, class_id = class_id)  # [batch_size, seq_length-1, total_vocab_size]

        # Reshape logits and targets
        batch_size, seq_length_minus1, _ = inputs.shape
        logits = logits.reshape(batch_size * seq_length_minus1, model.total_vocab_size)
        targets = targets.reshape(batch_size * seq_length_minus1, model.num_quantizers)

        # Compute loss per quantizer
        loss = 0
        for q in range(model.num_quantizers):
            q_targets = targets[:, q]  # [batch_size * seq_length_minus1]
            q_offset = q * model.vocab_size
            q_logits = logits[:, q_offset : q_offset + model.vocab_size]  # [batch_size * seq_length_minus1, vocab_size]
            loss += criterion(q_logits, q_targets)
        loss = loss / model.num_quantizers  # Average over quantizers
        

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_loss:.4f}")
    wandb.log({"Training Loss": avg_loss, "Epoch": epoch + 1})

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}"):
            input_ids, class_id = batch
            input_ids, class_id = input_ids.to(device), class_id.to(device)
            inputs = input_ids[:, :-1, :]
            targets = input_ids[:, 1:, :]
            logits = model(input_ids=inputs, class_id=class_id)
            batch_size, seq_length_minus1, _ = inputs.shape
            logits = logits.reshape(batch_size * seq_length_minus1, model.total_vocab_size)
            targets = targets.reshape(batch_size * seq_length_minus1, model.num_quantizers)

            loss = 0
            for q in range(model.num_quantizers):
                q_targets = targets[:, q]
                q_offset = q * model.vocab_size
                q_logits = logits[:, q_offset : q_offset + model.vocab_size]
                loss += criterion(q_logits, q_targets)
            loss = loss / model.num_quantizers
            val_loss += loss.item()
    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")
    wandb.log({"Validation Loss": avg_val_loss, "Epoch": epoch + 1})

    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"augmodel_checkpoint_l2_{epoch + 1}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved for epoch {epoch + 1} at {checkpoint_path}")

wandb.finish()


Training Epoch 1: 100%|██████████| 1102/1102 [14:54<00:00,  1.23it/s]


Epoch 1/140, Training Loss: 4.9914


Validation Epoch 1: 100%|██████████| 138/138 [01:42<00:00,  1.34it/s]


Epoch 1/140, Validation Loss: 4.4233


Training Epoch 2: 100%|██████████| 1102/1102 [13:49<00:00,  1.33it/s]


Epoch 2/140, Training Loss: 4.6759


Validation Epoch 2: 100%|██████████| 138/138 [01:40<00:00,  1.37it/s]


Epoch 2/140, Validation Loss: 4.3268


Training Epoch 3: 100%|██████████| 1102/1102 [14:57<00:00,  1.23it/s]


Epoch 3/140, Training Loss: 4.5905


Validation Epoch 3: 100%|██████████| 138/138 [01:42<00:00,  1.34it/s]


Epoch 3/140, Validation Loss: 4.2694


Training Epoch 4: 100%|██████████| 1102/1102 [14:10<00:00,  1.30it/s]


Epoch 4/140, Training Loss: 4.5522


Validation Epoch 4: 100%|██████████| 138/138 [01:40<00:00,  1.37it/s]


Epoch 4/140, Validation Loss: 4.2240


Training Epoch 5: 100%|██████████| 1102/1102 [13:52<00:00,  1.32it/s]


Epoch 5/140, Training Loss: 4.5279


Validation Epoch 5: 100%|██████████| 138/138 [01:29<00:00,  1.54it/s]


Epoch 5/140, Validation Loss: 4.1860


Training Epoch 6: 100%|██████████| 1102/1102 [13:25<00:00,  1.37it/s]


Epoch 6/140, Training Loss: 4.4925


Validation Epoch 6: 100%|██████████| 138/138 [01:28<00:00,  1.55it/s]


Epoch 6/140, Validation Loss: 4.1809


Training Epoch 7: 100%|██████████| 1102/1102 [14:28<00:00,  1.27it/s]


Epoch 7/140, Training Loss: 4.4839


Validation Epoch 7: 100%|██████████| 138/138 [01:41<00:00,  1.36it/s]


Epoch 7/140, Validation Loss: 4.1521


Training Epoch 8: 100%|██████████| 1102/1102 [15:00<00:00,  1.22it/s]


Epoch 8/140, Training Loss: 4.4615


Validation Epoch 8: 100%|██████████| 138/138 [01:42<00:00,  1.34it/s]


Epoch 8/140, Validation Loss: 4.1475


Training Epoch 9: 100%|██████████| 1102/1102 [14:50<00:00,  1.24it/s]


Epoch 9/140, Training Loss: 4.4545


Validation Epoch 9: 100%|██████████| 138/138 [01:40<00:00,  1.37it/s]


Epoch 9/140, Validation Loss: 4.1240


Training Epoch 10: 100%|██████████| 1102/1102 [14:23<00:00,  1.28it/s]


Epoch 10/140, Training Loss: 4.4253


Validation Epoch 10: 100%|██████████| 138/138 [01:41<00:00,  1.36it/s]


Epoch 10/140, Validation Loss: 4.1138
Checkpoint saved for epoch 10 at augmodel_checkpoint_l2_10.pth


Training Epoch 11: 100%|██████████| 1102/1102 [13:43<00:00,  1.34it/s]


Epoch 11/140, Training Loss: 4.4247


Validation Epoch 11: 100%|██████████| 138/138 [01:35<00:00,  1.45it/s]


Epoch 11/140, Validation Loss: 4.1053


Training Epoch 12: 100%|██████████| 1102/1102 [13:49<00:00,  1.33it/s]


Epoch 12/140, Training Loss: 4.4137


Validation Epoch 12: 100%|██████████| 138/138 [01:26<00:00,  1.59it/s]


Epoch 12/140, Validation Loss: 4.1055


Training Epoch 13: 100%|██████████| 1102/1102 [12:39<00:00,  1.45it/s]


Epoch 13/140, Training Loss: 4.4115


Validation Epoch 13: 100%|██████████| 138/138 [01:35<00:00,  1.45it/s]


Epoch 13/140, Validation Loss: 4.0916


Training Epoch 14: 100%|██████████| 1102/1102 [13:46<00:00,  1.33it/s]


Epoch 14/140, Training Loss: 4.3866


Validation Epoch 14: 100%|██████████| 138/138 [01:22<00:00,  1.66it/s]


Epoch 14/140, Validation Loss: 4.0917


Training Epoch 15: 100%|██████████| 1102/1102 [12:23<00:00,  1.48it/s]


Epoch 15/140, Training Loss: 4.3986


Validation Epoch 15: 100%|██████████| 138/138 [01:30<00:00,  1.53it/s]


Epoch 15/140, Validation Loss: 4.0748


Training Epoch 16: 100%|██████████| 1102/1102 [13:39<00:00,  1.35it/s]


Epoch 16/140, Training Loss: 4.3849


Validation Epoch 16: 100%|██████████| 138/138 [01:34<00:00,  1.46it/s]


Epoch 16/140, Validation Loss: 4.0677


Training Epoch 17: 100%|██████████| 1102/1102 [13:54<00:00,  1.32it/s]


Epoch 17/140, Training Loss: 4.3685


Validation Epoch 17: 100%|██████████| 138/138 [01:34<00:00,  1.46it/s]


Epoch 17/140, Validation Loss: 4.0611


Training Epoch 18: 100%|██████████| 1102/1102 [13:56<00:00,  1.32it/s]


Epoch 18/140, Training Loss: 4.3794


Validation Epoch 18: 100%|██████████| 138/138 [01:35<00:00,  1.44it/s]


Epoch 18/140, Validation Loss: 4.0600


Training Epoch 19: 100%|██████████| 1102/1102 [13:53<00:00,  1.32it/s]


Epoch 19/140, Training Loss: 4.3507


Validation Epoch 19: 100%|██████████| 138/138 [01:33<00:00,  1.47it/s]


Epoch 19/140, Validation Loss: 4.0524


Training Epoch 20: 100%|██████████| 1102/1102 [13:49<00:00,  1.33it/s]


Epoch 20/140, Training Loss: 4.3688


Validation Epoch 20: 100%|██████████| 138/138 [01:35<00:00,  1.45it/s]


Epoch 20/140, Validation Loss: 4.0550
Checkpoint saved for epoch 20 at augmodel_checkpoint_l2_20.pth


Training Epoch 21: 100%|██████████| 1102/1102 [13:54<00:00,  1.32it/s]


Epoch 21/140, Training Loss: 4.3674


Validation Epoch 21: 100%|██████████| 138/138 [01:35<00:00,  1.44it/s]


Epoch 21/140, Validation Loss: 4.0465


Training Epoch 22: 100%|██████████| 1102/1102 [12:43<00:00,  1.44it/s]


Epoch 22/140, Training Loss: 4.3665


Validation Epoch 22: 100%|██████████| 138/138 [01:24<00:00,  1.64it/s]


Epoch 22/140, Validation Loss: 4.0478


Training Epoch 23: 100%|██████████| 1102/1102 [13:12<00:00,  1.39it/s]


Epoch 23/140, Training Loss: 4.3599


Validation Epoch 23: 100%|██████████| 138/138 [01:35<00:00,  1.44it/s]


Epoch 23/140, Validation Loss: 4.0396


Training Epoch 24: 100%|██████████| 1102/1102 [13:33<00:00,  1.35it/s]


Epoch 24/140, Training Loss: 4.3376


Validation Epoch 24: 100%|██████████| 138/138 [01:29<00:00,  1.54it/s]


Epoch 24/140, Validation Loss: 4.0437


Training Epoch 25: 100%|██████████| 1102/1102 [13:22<00:00,  1.37it/s]


Epoch 25/140, Training Loss: 4.3410


Validation Epoch 25: 100%|██████████| 138/138 [01:29<00:00,  1.55it/s]


Epoch 25/140, Validation Loss: 4.0400


Training Epoch 26: 100%|██████████| 1102/1102 [13:19<00:00,  1.38it/s]


Epoch 26/140, Training Loss: 4.3361


Validation Epoch 26: 100%|██████████| 138/138 [01:29<00:00,  1.53it/s]


Epoch 26/140, Validation Loss: 4.0276


Training Epoch 27: 100%|██████████| 1102/1102 [12:43<00:00,  1.44it/s]


Epoch 27/140, Training Loss: 4.3361


Validation Epoch 27: 100%|██████████| 138/138 [01:22<00:00,  1.66it/s]


Epoch 27/140, Validation Loss: 4.0251


Training Epoch 28: 100%|██████████| 1102/1102 [13:28<00:00,  1.36it/s]


Epoch 28/140, Training Loss: 4.3198


Validation Epoch 28: 100%|██████████| 138/138 [01:32<00:00,  1.50it/s]


Epoch 28/140, Validation Loss: 4.0326


Training Epoch 29:  50%|█████     | 555/1102 [06:18<06:12,  1.47it/s]


KeyboardInterrupt: 

In [16]:
# Resuming previous run
checkpoint_path = "/scratch/ssr9055/my_env/BEST_CHECKPOINT_AUG_L2.pth"


start_epoch = 0
best_val_loss = float('inf') 

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Resume from the next epoch
    print(f"Resuming training from epoch {start_epoch}")


#Number of additional epochs to train
num_epochs = 200

# Adjust the training loop to account for start_epoch
for epoch in range(start_epoch, start_epoch + num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        
        input_ids, class_id = batch  # Get input IDs and class IDs
        input_ids, class_id = input_ids.to(device), class_id.to(device)

        optimizer.zero_grad()

        inputs = input_ids[:, :-1, :]  # Input tokens
        targets = input_ids[:, 1:, :]  # Target tokens

        # Forward pass (use keyword arguments)
        logits = model(input_ids=inputs, class_id=class_id)  # Pass class_id using keyword args

        # Reshape logits and targets
        batch_size, seq_length_minus1, _ = inputs.shape
        logits = logits.reshape(batch_size * seq_length_minus1, model.total_vocab_size)
        targets = targets.reshape(batch_size * seq_length_minus1, model.num_quantizers)

        # Compute loss per quantizer
        loss = 0
        for q in range(model.num_quantizers):
            q_targets = targets[:, q]  # [batch_size * seq_length_minus1]
            q_offset = q * model.vocab_size
            q_logits = logits[:, q_offset : q_offset + model.vocab_size]  # [batch_size * seq_length_minus1, vocab_size]
            loss += criterion(q_logits, q_targets)
        loss = loss / model.num_quantizers  # Average over quantizers

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}/{start_epoch + num_epochs}, Training Loss: {avg_loss:.4f}")
    wandb.log({"Training Loss": avg_loss, "Epoch": epoch + 1})


    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc=f"Validation Epoch {epoch}"):
            input_ids, class_id = batch  # Ensure you extract both input_ids and class_id
            input_ids, class_id = input_ids.to(device), class_id.to(device)

            inputs = input_ids[:, :-1, :]
            targets = input_ids[:, 1:, :]

            # Forward pass (use keyword arguments)
            logits = model(input_ids=inputs, class_id=class_id)  # Pass class_id using keyword args

            batch_size, seq_length_minus1, _ = inputs.shape
            logits = logits.reshape(batch_size * seq_length_minus1, model.total_vocab_size)
            targets = targets.reshape(batch_size * seq_length_minus1, model.num_quantizers)

            loss = 0
            for q in range(model.num_quantizers):
                q_targets = targets[:, q]
                q_offset = q * model.vocab_size
                q_logits = logits[:, q_offset : q_offset + model.vocab_size]
                loss += criterion(q_logits, q_targets)
            loss = loss / model.num_quantizers
            val_loss += loss.item()

    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch {epoch}/{start_epoch + num_epochs}, Validation Loss: {avg_val_loss:.4f}")
    wandb.log({"Validation Loss": avg_val_loss, "Epoch": epoch + 1})

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        checkpoint_path = f"BEST_CHECKPOINT_AUG_L2.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'best_val_loss': best_val_loss
        }, checkpoint_path)
        print(f"New best validation loss: {best_val_loss:.4f}. Checkpoint saved at {checkpoint_path}")


    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"model_checkpoint_epoch_{epoch + 1}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved for epoch {epoch + 1} at {checkpoint_path}")

wandb.finish()

Loading checkpoint from /scratch/ssr9055/my_env/BEST_CHECKPOINT_AUG_L2.pth


  checkpoint = torch.load(checkpoint_path)


Resuming training from epoch 280


Training Epoch 280: 100%|██████████| 1102/1102 [12:03<00:00,  1.52it/s]


Epoch 280/480, Training Loss: 4.0173


Validation Epoch 280: 100%|██████████| 138/138 [01:20<00:00,  1.71it/s]


Epoch 280/480, Validation Loss: 3.7742
New best validation loss: 3.7742. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 281: 100%|██████████| 1102/1102 [10:04<00:00,  1.82it/s]


Epoch 281/480, Training Loss: 4.0083


Validation Epoch 281: 100%|██████████| 138/138 [01:05<00:00,  2.11it/s]


Epoch 281/480, Validation Loss: 3.7715
New best validation loss: 3.7715. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 282: 100%|██████████| 1102/1102 [10:06<00:00,  1.82it/s]


Epoch 282/480, Training Loss: 4.0248


Validation Epoch 282: 100%|██████████| 138/138 [01:05<00:00,  2.11it/s]


Epoch 282/480, Validation Loss: 3.7730


Training Epoch 283: 100%|██████████| 1102/1102 [10:05<00:00,  1.82it/s]


Epoch 283/480, Training Loss: 4.0248


Validation Epoch 283: 100%|██████████| 138/138 [01:04<00:00,  2.12it/s]


Epoch 283/480, Validation Loss: 3.7724


Training Epoch 284: 100%|██████████| 1102/1102 [10:00<00:00,  1.83it/s]


Epoch 284/480, Training Loss: 4.0416


Validation Epoch 284: 100%|██████████| 138/138 [01:04<00:00,  2.15it/s]


Epoch 284/480, Validation Loss: 3.7704
New best validation loss: 3.7704. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 285: 100%|██████████| 1102/1102 [10:04<00:00,  1.82it/s]


Epoch 285/480, Training Loss: 4.0113


Validation Epoch 285: 100%|██████████| 138/138 [01:05<00:00,  2.12it/s]


Epoch 285/480, Validation Loss: 3.7699
New best validation loss: 3.7699. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 286: 100%|██████████| 1102/1102 [10:36<00:00,  1.73it/s]


Epoch 286/480, Training Loss: 4.0302


Validation Epoch 286: 100%|██████████| 138/138 [01:18<00:00,  1.76it/s]


Epoch 286/480, Validation Loss: 3.7747


Training Epoch 287: 100%|██████████| 1102/1102 [11:07<00:00,  1.65it/s]


Epoch 287/480, Training Loss: 4.0213


Validation Epoch 287: 100%|██████████| 138/138 [01:04<00:00,  2.13it/s]


Epoch 287/480, Validation Loss: 3.7729


Training Epoch 288: 100%|██████████| 1102/1102 [10:02<00:00,  1.83it/s]


Epoch 288/480, Training Loss: 4.0164


Validation Epoch 288: 100%|██████████| 138/138 [01:05<00:00,  2.11it/s]


Epoch 288/480, Validation Loss: 3.7724


Training Epoch 289: 100%|██████████| 1102/1102 [10:04<00:00,  1.82it/s]


Epoch 289/480, Training Loss: 4.0250


Validation Epoch 289: 100%|██████████| 138/138 [01:05<00:00,  2.11it/s]


Epoch 289/480, Validation Loss: 3.7732
Checkpoint saved for epoch 290 at model_checkpoint_epoch_290.pth


Training Epoch 290: 100%|██████████| 1102/1102 [10:04<00:00,  1.82it/s]


Epoch 290/480, Training Loss: 4.0210


Validation Epoch 290: 100%|██████████| 138/138 [01:05<00:00,  2.12it/s]


Epoch 290/480, Validation Loss: 3.7699


Training Epoch 291: 100%|██████████| 1102/1102 [10:00<00:00,  1.84it/s]


Epoch 291/480, Training Loss: 4.0141


Validation Epoch 291: 100%|██████████| 138/138 [01:04<00:00,  2.13it/s]


Epoch 291/480, Validation Loss: 3.7718


Training Epoch 292: 100%|██████████| 1102/1102 [09:57<00:00,  1.84it/s]


Epoch 292/480, Training Loss: 4.0174


Validation Epoch 292: 100%|██████████| 138/138 [01:04<00:00,  2.14it/s]


Epoch 292/480, Validation Loss: 3.7740


Training Epoch 293: 100%|██████████| 1102/1102 [10:01<00:00,  1.83it/s]


Epoch 293/480, Training Loss: 4.0186


Validation Epoch 293: 100%|██████████| 138/138 [01:04<00:00,  2.12it/s]


Epoch 293/480, Validation Loss: 3.7672
New best validation loss: 3.7672. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 294: 100%|██████████| 1102/1102 [10:01<00:00,  1.83it/s]


Epoch 294/480, Training Loss: 4.0105


Validation Epoch 294: 100%|██████████| 138/138 [01:05<00:00,  2.12it/s]


Epoch 294/480, Validation Loss: 3.7702


Training Epoch 295: 100%|██████████| 1102/1102 [11:16<00:00,  1.63it/s]


Epoch 295/480, Training Loss: 4.0268


Validation Epoch 295: 100%|██████████| 138/138 [01:17<00:00,  1.78it/s]


Epoch 295/480, Validation Loss: 3.7677


Training Epoch 296: 100%|██████████| 1102/1102 [10:33<00:00,  1.74it/s]


Epoch 296/480, Training Loss: 4.0092


Validation Epoch 296: 100%|██████████| 138/138 [01:04<00:00,  2.12it/s]


Epoch 296/480, Validation Loss: 3.7687


Training Epoch 297: 100%|██████████| 1102/1102 [10:03<00:00,  1.83it/s]


Epoch 297/480, Training Loss: 4.0293


Validation Epoch 297: 100%|██████████| 138/138 [01:04<00:00,  2.14it/s]


Epoch 297/480, Validation Loss: 3.7733


Training Epoch 298: 100%|██████████| 1102/1102 [09:59<00:00,  1.84it/s]


Epoch 298/480, Training Loss: 4.0031


Validation Epoch 298: 100%|██████████| 138/138 [01:05<00:00,  2.12it/s]


Epoch 298/480, Validation Loss: 3.7690


Training Epoch 299: 100%|██████████| 1102/1102 [09:59<00:00,  1.84it/s]


Epoch 299/480, Training Loss: 4.0150


Validation Epoch 299: 100%|██████████| 138/138 [01:04<00:00,  2.14it/s]


Epoch 299/480, Validation Loss: 3.7691
Checkpoint saved for epoch 300 at model_checkpoint_epoch_300.pth


Training Epoch 300: 100%|██████████| 1102/1102 [09:58<00:00,  1.84it/s]


Epoch 300/480, Training Loss: 4.0263


Validation Epoch 300: 100%|██████████| 138/138 [01:04<00:00,  2.13it/s]


Epoch 300/480, Validation Loss: 3.7663
New best validation loss: 3.7663. Checkpoint saved at BEST_CHECKPOINT_AUG_L2.pth


Training Epoch 301:  49%|████▉     | 538/1102 [04:52<05:07,  1.84it/s]


KeyboardInterrupt: 

### Generating audio

In [20]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import soundfile as sf
from IPython.display import Audio as IPyAudio

checkpoint_path = "/scratch/ssr9055/my_env/BEST_CHECKPOINT_AUG_L2.pth"

# Helper function to remove delay pattern
def remove_delay_pattern(delayed_codes, num_quantizers):
    num_frames = delayed_codes.shape[0] - (num_quantizers - 1)
    codes = torch.zeros(num_frames, num_quantizers, dtype=delayed_codes.dtype)
    for q in range(num_quantizers):
        codes[:, q] = delayed_codes[q:q + num_frames, q]
    return codes  # Shape: [num_frames, num_quantizers]

# Function to generate audio
def generate_audio(model, encodec_model, class_id, num_quantizers=8, codebook_size=1024, max_length=600, temperature=1.0, device='cpu'):
    model.eval()  # Set model to evaluation mode
    encodec_model.eval()  # Set encodec_model to evaluation mode

    model.to(device)  # Move model to the specified device

    # Start with padding tokens for each quantizer
    start_token = codebook_size  # Padding token index
    input_ids = torch.full((1, 1, num_quantizers), start_token, dtype=torch.long, device=device)  # Shape: [1, 1, num_quantizers]

    # Convert the class_id to a tensor and move to device
    class_id_tensor = torch.tensor([class_id], device=device)

    generated = []  # List to hold generated tokens

    with torch.no_grad():  # Disable gradient calculations for generation
        for _ in tqdm(range(max_length), desc="Generating Audio"):
            # Forward pass through the model, passing class_id_tensor
            logits = model(input_ids=input_ids, class_id=class_id_tensor)  # [1, seq_length, total_vocab_size]
            logits = logits[:, -1, :]  # [1, total_vocab_size] - Get the logits for the last time step

            # Apply temperature to control randomness
            logits = logits / temperature

            next_tokens = []  # List to store the next token for each quantizer

            # Sample next token for each quantizer
            for q in range(num_quantizers):
                q_offset = q * (codebook_size + 1)
                q_logits = logits[:, q_offset:q_offset + codebook_size + 1]  # [1, vocab_size]
                q_probs = F.softmax(q_logits, dim=-1)  # Convert logits to probabilities
                q_next_token = torch.multinomial(q_probs, num_samples=1)  # Sample next token
                q_next_token = q_next_token.squeeze(1)  # Remove extra dimension

                # If the sampled token is the padding token, replace it with a valid token
                q_next_token_value = q_next_token.item()
                if q_next_token_value == codebook_size:
                    q_next_token_value = torch.randint(0, codebook_size, (1,)).item()

                next_tokens.append(torch.tensor([q_next_token_value], device=device, dtype=torch.long))

            # Stack the next tokens for each quantizer and append to generated sequence
            next_tokens = torch.stack(next_tokens, dim=1)  # Shape: [1, num_quantizers]
            generated.append(next_tokens.squeeze(0))  # Append generated tokens
            input_ids = torch.cat([input_ids, next_tokens.unsqueeze(0)], dim=1)  # Update input_ids for the next time step

    # Stack the generated tokens to form the final token sequence
    generated_tokens = torch.stack(generated, dim=0)  # [seq_length, num_quantizers]

    # Remove the delay pattern from generated tokens
    codes = remove_delay_pattern(generated_tokens, num_quantizers)

    # Add batch dimension to match the input shape for decode
    codes = codes.unsqueeze(0)

    # Ensure codes are within valid codebook size
    codes = codes.clamp(0, codebook_size - 1)

    # Decode the audio using the EnCodec model
    try:
        with torch.no_grad():
            codes = codes.permute(0, 2, 1)  # [batch_size, num_quantizers, num_frames]
            encoded_frames = [(codes.to(next(encodec_model.parameters()).device), None)]
            decoded_audio = encodec_model.decode(encoded_frames)
            audio = decoded_audio.squeeze().cpu().detach().numpy()  # Convert to numpy array
        return audio
    except Exception as e:
        print(f"Error during decoding: {str(e)}")
        return None


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_length = 350  # Adjust as needed
temperature = 0.7  # Adjust as needed

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# Loop over each category in clas_dict
for class_name, class_id in clas_dict.items():
    print(f"Generating audio for category: {class_name}")

    generated_audio = generate_audio(
        model,
        encodec_model,
        class_id=class_id,  # Pass the class_id for each category
        num_quantizers=num_quantizers,
        codebook_size=codebook_size,
        max_length=max_length,
        temperature=temperature,
        device=device
    )

    if generated_audio is not None:
        # Save the generated audio to a file (optional)
        output_filename = f'generated_audio_{class_name}.wav'
        sf.write(output_filename, generated_audio, 24000)

        # Display the audio
        display(IPyAudio(output_filename))
    else:
        print(f"Audio generation failed for category: {class_name}")


  checkpoint = torch.load(checkpoint_path, map_location=device)


Generating audio for category: DogBark


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 131.70it/s]


Generating audio for category: Footstep


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 131.57it/s]


Generating audio for category: Gunshot


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 132.15it/s]


Generating audio for category: Keyboard


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 121.32it/s]


Generating audio for category: MovingMotorVehicle


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 132.87it/s]


Generating audio for category: Rain


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 132.87it/s]


Generating audio for category: SneezeCough


Generating Audio: 100%|██████████| 350/350 [00:02<00:00, 131.23it/s]
