In [1]:
!pip install torchaudio pandas torchvggish nltk soundfile numpy scipy

Collecting torchaudio
  Using cached torchaudio-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torchvggish
  Using cached torchvggish-0.2-py3-none-any.whl
Collecting soundfile
  Using cached soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl.metadata (14 kB)
Collecting torch==2.4.0 (from torchaudio)
  Using cached torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0->torchaudio)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0->torchaudio)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0->torchaudio)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0->torchaudio)
  U

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import pandas as pd
import numpy as np
from tqdm import tqdm
import pandas as pd
import nltk
from collections import Counter
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from tqdm import tqdm
import math
import gc
from torch.cuda.amp import autocast, GradScaler
import pandas as pd

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# 1. Import libraries and set up environment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenize(text):
        return nltk.word_tokenize(text.lower())

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1

                if frequencies[word] >= self.freq_threshold:
                    if word not in self.stoi:
                        self.stoi[word] = idx
                        self.itos[idx] = word
                        idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenize(text)
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokenized_text]
# 3. Dataset and DataLoader

class MusicCapsDataset(Dataset):
    def __init__(self, csv_files, audio_dir, vocab, fixed_length=160000, n_mels=128, target_sample_rate=16000):
        # Merge all CSVs into a single DataFrame
        dataframes = [pd.read_csv(csv_file) for csv_file in csv_files]
        self.data = pd.concat(dataframes, ignore_index=True)
        self.data = self.data.sample(frac=1).reset_index(drop=True)

        print(self.data.shape)
        self.audio_dir = audio_dir
        self.vocab = vocab
        self.fixed_length = fixed_length
        self.n_mels = n_mels
        self.target_sample_rate = target_sample_rate
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.target_sample_rate,
            n_mels=self.n_mels
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        audio_id = self.data.iloc[idx]['ytid']
        audio_path = os.path.join(self.audio_dir, f"{audio_id}.wav")
        caption = self.data.iloc[idx]['caption']
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = waveform.mean(dim=0, keepdim=True)  # Convert to mono
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)(waveform)

        if waveform.size(1) > self.fixed_length:
            waveform = waveform[:, :self.fixed_length]
        else:
            pad_length = self.fixed_length - waveform.size(1)
            waveform = F.pad(waveform, (0, pad_length))

        mel_spectrogram = self.mel_spectrogram(waveform)

        numericalized_caption = [self.vocab.stoi["<SOS>"]] + self.vocab.numericalize(caption) + [self.vocab.stoi["<EOS>"]]

        sample = {
            'audio': mel_spectrogram.squeeze(0).T,  # Transpose to match (time, n_mels)
            'caption': torch.tensor(numericalized_caption),
            'original_caption': caption,
            'path': audio_path,
            'sample_rate': self.target_sample_rate
        }
        return sample

def pad_sequence(batch):
    batch = [item['caption'] for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)
    return batch

def collate_fn(data):
    audio = [item['audio'] for item in data]
    captions = [item['caption'] for item in data]

    # Pad audio sequences to the same length
    audio_lengths = [len(a) for a in audio]
    max_audio_length = max(audio_lengths)
    padded_audio = torch.zeros(len(audio), max_audio_length, audio[0].size(1))
    for i, a in enumerate(audio):
        padded_audio[i, :len(a), :] = a

    # Pad caption sequences to the same length
    caption_lengths = [len(c) for c in captions]
    max_caption_length = max(caption_lengths)
    padded_captions = torch.zeros(len(captions), max_caption_length).long()
    for i, c in enumerate(captions):
        padded_captions[i, :len(c)] = c

    original_captions = [item['original_caption'] for item in data]
    paths = [item['path'] for item in data]

    return {'audio': padded_audio, 'caption': padded_captions, 'original_caption': original_captions, 'path': paths}

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class AudioEncoder(nn.Module):
    def __init__(self, n_mels, d_model, nhead, num_encoder_layers, dim_feedforward):
        super(AudioEncoder, self).__init__()
        self.conv = nn.Conv1d(n_mels, d_model, kernel_size=3, padding=1)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), 
            num_layers=num_encoder_layers
        )
        self.positional_encoding = PositionalEncoding(d_model)

    def forward(self, src):
        # src shape: (batch_size, time_steps, n_mels)
        src = self.conv(src.transpose(1, 2))  # (batch_size, d_model, time_steps)
        src = src.transpose(1, 2)  # (batch_size, time_steps, d_model)
        src = self.positional_encoding(src)  # (batch_size, time_steps, d_model)
        output = self.transformer_encoder(src.transpose(0, 1))  # (time_steps, batch_size, d_model)
        return output.transpose(0, 1)  # (batch_size, time_steps, d_model)

class CaptionDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_decoder_layers, dim_feedforward):
        super(CaptionDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward),
            num_layers=num_decoder_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory):
        # tgt shape: (batch_size, tgt_len)
        tgt = self.embedding(tgt)  # (batch_size, tgt_len, d_model)
        tgt = self.positional_encoding(tgt.transpose(0, 1))  # (tgt_len, batch_size, d_model)
        output = self.transformer_decoder(tgt, memory.transpose(0, 1))  # (tgt_len, batch_size, d_model)
        output = self.fc_out(output)  # (tgt_len, batch_size, vocab_size)
        return output.transpose(0, 1)  # (batch_size, tgt_len, vocab_size)

class AudioCaptioningModel(nn.Module):
    def __init__(self, n_mels, vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048):
        super(AudioCaptioningModel, self).__init__()
        self.encoder = AudioEncoder(n_mels, d_model, nhead, num_encoder_layers, dim_feedforward)
        self.decoder = CaptionDecoder(vocab_size, d_model, nhead, num_decoder_layers, dim_feedforward)
        self.d_model = d_model
        self.vocab_size = vocab_size
            
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        memory = self.encoder(src)
        
        # Shift the target to the left (remove the last token for input)
        tgt_input = tgt[:, :-1]
        tgt_expected = tgt[:, 1:]  # The expected output, shifted by one
    
        # Embed and apply positional encoding to the target sequence
        tgt_embedded = self.decoder.embedding(tgt_input)
        tgt_embedded = self.decoder.positional_encoding(tgt_embedded.transpose(0, 1))  # (tgt_len, batch_size, d_model)
    
        # Generate a causal mask for the decoder (prevent attending to future tokens)
        tgt_len = tgt_input.size(1)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(tgt.device)  # (tgt_len, tgt_len)
    
        # Generate padding mask (optional if padding is used)
        if hasattr(self.decoder.embedding, 'padding_idx') and self.decoder.embedding.padding_idx is not None:
            # Create the padding mask as a boolean tensor
            tgt_padding_mask = (tgt_input == self.decoder.embedding.padding_idx)  # (batch_size, tgt_len)
        else:
            tgt_padding_mask = None
    
        # Apply the transformer decoder
        output = self.decoder.transformer_decoder(
            tgt_embedded, 
            memory.transpose(0, 1),
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )
        
        # Apply the final linear layer
        output = self.decoder.fc_out(output)  # (tgt_len, batch_size, vocab_size)
        return output.transpose(0, 1)  # (batch_size, tgt_len, vocab_size)


In [None]:
import torch
import numpy as np
from tqdm import tqdm

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoints'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, epoch_number):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch_number)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch_number)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch_number):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), f'{self.path}/epoch_{epoch_number}_loss_{val_loss}')
        self.val_loss_min = val_loss

# Training function
def train_epoch(model, dataloader, criterion, optimizer, epoch, device, teacher_forcing_ratio=0.5):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}")
    
    for batch_idx, data in enumerate(progress_bar):
        src = data['audio'].to(device)
        tgt = data['caption'].to(device)

        optimizer.zero_grad()
        output = model(src, tgt, teacher_forcing_ratio)
        
        # Reshape the output and target for computing the loss
        output = output.reshape(-1, model.vocab_size)
        tgt_expected = tgt[:, 1:].contiguous().view(-1)  # Shifted target for comparison

        loss = criterion(output, tgt_expected)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Update tqdm progress bar with the current loss
        progress_bar.set_postfix(loss=epoch_loss / (batch_idx + 1))

    return epoch_loss / len(dataloader)

# Validation function (unchanged)
def validate_epoch(model, dataloader, criterion, epoch, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(dataloader, desc=f"Validating Epoch {epoch+1}")):
            src = data['audio'].to(device)
            tgt = data['caption'].to(device)

            output = model(src, tgt, teacher_forcing_ratio=0)  # No teacher forcing during validation

            # Reshape the output and target for computing the loss
            output = output.reshape(-1, model.vocab_size)
            tgt_expected = tgt[:, 1:].contiguous().view(-1)  # Shifted target for comparison

            loss = criterion(output, tgt_expected)
            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

# Full training loop with early stopping and learning rate scheduler
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=20, teacher_forcing_ratio=0.5, patience=5, model_save_path='checkpoints5', scheduler=None):
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=model_save_path)

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, epoch, device, teacher_forcing_ratio)
        val_loss = validate_epoch(model, val_loader, criterion, epoch, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f'Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        print(infer(audio_path='./song_eden.wav', model=model, vocab=vocab, device=device))

        # Step the scheduler based on validation loss if provided
        if scheduler is not None:
            scheduler.step(val_loss)

        early_stopping(val_loss, model, epoch)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    return train_losses, val_losses


In [None]:
import torch
import torchaudio
import torch.nn.functional as F

def infer(audio_path, model, vocab, fixed_length=160000, n_mels=128, target_sample_rate=16000, max_caption_length=50, device='cuda'):
    # Load and preprocess the audio
    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = waveform.mean(dim=0, keepdim=True)  # Convert to mono
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(waveform)

    # Pad or truncate waveform to fixed length
    if waveform.size(1) > fixed_length:
        waveform = waveform[:, :fixed_length]
    else:
        pad_length = fixed_length - waveform.size(1)
        waveform = F.pad(waveform, (0, pad_length))

    # Convert waveform to mel spectrogram
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_mels=n_mels
    )(waveform)

    # Prepare the input tensor for the model
    mel_spectrogram = mel_spectrogram.squeeze(0).T.unsqueeze(0).to(device)  # (1, time_steps, n_mels)

    # Set the model to evaluation mode
    model.eval()

    # Initialize the input for the decoder (start with the <SOS> token)
    tgt_input = torch.tensor([[vocab.stoi['<SOS>']]], device=device)  # Shape: (1, 1)

    # Run the encoder
    with torch.no_grad():
        memory = model.encoder(mel_spectrogram)

    # Initialize a list to store generated tokens
    generated_tokens = []

    # Generate the caption using the decoder
    for _ in range(max_caption_length):
        # Embed and apply positional encoding to the target input
        tgt_embedded = model.decoder.embedding(tgt_input)
        tgt_embedded = model.decoder.positional_encoding(tgt_embedded.transpose(0, 1))  # (tgt_len, batch_size, d_model)

        # Generate a causal mask for the decoder
        tgt_len = tgt_input.size(1)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(device)  # (tgt_len, tgt_len)

        # Run the decoder
        output = model.decoder.transformer_decoder(
            tgt_embedded,
            memory.transpose(0, 1),
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=None  # No padding mask needed during inference
        )

        # Apply the final linear layer to get logits
        output = model.decoder.fc_out(output)  # (tgt_len, batch_size, vocab_size)

        # Get the predicted token for the current step
        next_token = output[-1, :, :].argmax(-1).unsqueeze(0)  # Shape: (1, batch_size)

        # Transpose next_token to match tgt_input shape for concatenation
        next_token = next_token.transpose(0, 1)  # Shape: (batch_size, 1)

        # Append the predicted token to the input sequence and to the generated tokens
        tgt_input = torch.cat([tgt_input, next_token], dim=1)
        generated_tokens.append(next_token.item())

        # Stop if <EOS> token is generated
        if next_token.item() == vocab.stoi['<EOS>']:
            break

    # Convert the generated sequence of tokens to words
    generated_caption = [vocab.itos[token] for token in generated_tokens]

    return ' '.join(generated_caption)


In [None]:

# Load and prepare data
csv_files = ['mayo_final_final.csv', 'musiccaps-public.csv']
# csv_files = ['musiccaps-public.csv']
audio_dir = './music_data/music_data'
freq_threshold = 5

dataframes = [pd.read_csv(csv_file) for csv_file in csv_files]
df = pd.concat(dataframes, ignore_index=True)

# Split dataset
dataset = MusicCapsDataset(csv_files=csv_files, audio_dir=audio_dir, vocab=None)  # Pass vocab=None for now

train_size = int(0.8 * len(dataset))
val_size = math.trunc((len(dataset) - train_size) / 2)
random_seed = 42
generator = torch.Generator().manual_seed(random_seed)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size + 1, val_size], generator)

# Build vocabulary from training set

train_captions = [df.iloc[idx]['caption'] for idx in train_dataset.indices]
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(train_captions)

# Update datasets with vocab
train_dataset.dataset.vocab = vocab
val_dataset.dataset.vocab = vocab

# Create dataloaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
vocab_size = len(vocab)

num_epochs = 10

# Instantiate the model
model = AudioCaptioningModel(
    n_mels=128, 
    vocab_size=len(vocab), 
    d_model=512, 
    nhead=8, 
    num_encoder_layers=6, 
    num_decoder_layers=6, 
    dim_feedforward=2048
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
optimizer = optim.AdamW(model.parameters(), lr=0.0002)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
# optimizer = torch.optim.Adam(model.parameters())
# scheduler = None
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

# # Optimizer and scheduler
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
# criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])

(6312, 11)




In [None]:
torch.save(vocab, 'vocab4.pt')

In [None]:
train_losses, val_losses = train_model(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler=scheduler, num_epochs=num_epochs, device=device, teacher_forcing_ratio=0.1)
# train_model(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler=NoneD„, num_epochs=num_epochs, device=device)

Training Epoch 1:   1%|          | 1/158 [00:02<05:24,  2.07s/it, loss=7.75]

In [None]:
import matplotlib.pyplot as plt

def visualize_loss(eval_losses, train_losses, number_of_epochs):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, number_of_epochs + 1), eval_losses, marker='o', linestyle='-', color='b', label='val')
    plt.plot(range(1, number_of_epochs + 1), train_losses, marker='o', linestyle='-', color='r', label='train')
    plt.title('Validation Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend(loc='upper left', fontsize='large', frameon=True)

    plt.show()

visualize_loss(val_losses, train_losses, number_of_epochs=10)

In [None]:
validate_epoch(model, val_dataloader, criterion, 1, device)

In [None]:
import dill
with open("./new-dill.pkl", 'wb') as file:
    dill.dump(vocab, file)

In [None]:
infer(audio_path='./song_eden.wav', model=model, vocab=vocab, device=device)

In [None]:
infer(audio_path='./mozart.wav', model=model, vocab=vocab, device=device)

In [None]:
infer(audio_path='./kendrick.wav', model=model, vocab=vocab, device=device)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
infer(audio_path='./music_data/music_data/4yJZ4VX8XQI.wav', model=model, vocab=vocab, device=device)