In [5]:
# Step 0: Install necessary libraries
!pip -q install torchaudio transformers PySoundFile tqdm



# The modified VATT architecture for audio-text alignment

In [6]:
# Step 1: Install and Import Libraries
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import numpy as np

In [7]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# Step 2: Audio Tokenization
class AudioTokenizer(nn.Module):
    def __init__(self, patch_size=128, embed_dim=2048):
        super(AudioTokenizer, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Linear(patch_size, embed_dim)
    
    def forward(self, audio_signal):
        # Assuming audio_signal is [batch, time_samples]
        batch_size, time_len = audio_signal.shape
        num_patches = time_len // self.patch_size
        audio_signal = audio_signal[:, :num_patches * self.patch_size]
        audio_patches = audio_signal.reshape(batch_size, num_patches, self.patch_size)
        # Project each patch to embedding dimension
        audio_embeddings = self.projection(audio_patches)  # [batch, num_patches, embed_dim]
        return audio_embeddings

In [9]:
class TextTokenizer(nn.Module):
    def __init__(self, embed_dim=768, max_length=512):
        super(TextTokenizer, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, embed_dim)
        self.max_length = max_length
    
    def forward(self, text):
        # text is a list of strings
        tokens = self.tokenizer(
            text, 
            truncation=True, 
            padding="max_length", 
            max_length=self.max_length, 
            return_tensors="pt"
        )
        input_ids = tokens["input_ids"].to(device)  # [batch, max_length]
        text_embeddings = self.embedding(input_ids)  # [batch, max_length, embed_dim]
        return text_embeddings  # Ensures 3D output


In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_dim
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # Shape [1, max_len, embed_dim]

    def forward(self, x):
        # Adjust the positional encoding to match the actual input shape
        batch_size, seq_len, embed_dim = x.size()
        pe = self.pe[:, :seq_len, :].expand(batch_size, seq_len, embed_dim)
        
        return x + pe.to(x.device)

In [11]:
# Step 5: Transformer Encoder Components
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
    
    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        return attn_output

class FeedForward(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion_factor)
        self.fc2 = nn.Linear(embed_dim * expansion_factor, embed_dim)
        self.activation = nn.GELU()
    
    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.linear1 = nn.Linear(embed_dim, embed_dim * 4)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(embed_dim * 4, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [12]:
class TransformerFactory(nn.Module):
    """
    Factory class to create TransformerEncoder instances with customizable parameters.
    """
    def __init__(self, embed_dim, num_heads, num_layers, feedforward_dim_multiplier=4, dropout=0.1):
        """
        Args:
            embed_dim (int): Dimension of the input embeddings.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of Transformer layers.
            feedforward_dim_multiplier (int): Multiplier for feedforward layer dimensions.
            dropout (float): Dropout probability.
        """
        super(TransformerFactory, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, 
                nhead=num_heads, 
                dim_feedforward=embed_dim * feedforward_dim_multiplier, 
                dropout=dropout,
                batch_first=True  # Enable batch-first optimization
            ),
            num_layers=num_layers
        )

    def forward(self, x):
        return self.encoder(x)

In [13]:
# Step 6: Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, embed_dim, proj_dim):
        super(ProjectionHead, self).__init__()
        self.hidden_dim = embed_dim * 2
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, proj_dim),
        )
    
    def forward(self, x):
        # Extract [CLS] token (or average pool)
        x = x[:, 0]  # Assuming first token is [CLS]
        return self.projection(x)

In [14]:
# Step 7: Contrastive Learning (NCE and MIL-NCE)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, features_a, features_b):
        # Normalize features and compute similarity matrix
        features_a = nn.functional.normalize(features_a, dim=1)
        features_b = nn.functional.normalize(features_b, dim=1)
        logits = torch.matmul(features_a, features_b.T) / self.temperature
        labels = torch.arange(len(features_a)).to(device)
        return nn.CrossEntropyLoss()(logits, labels)

In [15]:
# Step 8: DropToken Implementation
class DropToken(nn.Module):
    def __init__(self, drop_rate=0.5):
        super(DropToken, self).__init__()
        self.drop_rate = drop_rate
    
    def forward(self, x):
        if not self.training:
            return x
        keep_prob = 1 - self.drop_rate
        mask = torch.rand(x.shape[:2], device=x.device) < keep_prob
        x = x * mask.unsqueeze(-1)
        return x

# Data loading and training

In [16]:
# Step 1: Assuming we have already downloaded TED-LIUM dataset and extracted audio + text files
# From: https://www.openslr.org/51/ [50.6 GB]
PATH_TO_AUDIO_FILES = "./audio_files"
PATH_TO_TRANSCRIPT_FILES = "./transcript_files"

In [17]:
# Step 2: Imports
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import os
from transformers import AutoTokenizer
from transformers import Wav2Vec2Model, Wav2Vec2Processor, AutoTokenizer, BertModel

In [18]:
import torch.nn.functional as F

def collate_fn(batch, max_audio_length=160000):
    audio_tensors = []
    text_list = []

    for audio, text in batch:
        # Adjust audio length to max_audio_length
        if audio.size(0) > max_audio_length:
            audio = audio[:max_audio_length]
        else:
            audio = F.pad(audio, (0, max_audio_length - audio.size(0)))
        
        audio_tensors.append(audio)
        text_list.append(text)

    # Stack the audio tensors into a batch
    audio_batch = torch.stack(audio_tensors)

    # text_list is a list of strings
    return audio_batch, text_list

In [19]:
# Step 3: TED-LIUM Dataset Setup
class TEDLIUMDataset(Dataset):
    def __init__(self, audio_dir, transcript_dir, tokenizer, max_text_length=512, sample_rate=16000):
        """
        Initialize dataset with paths to audio and transcript directories and tokenizer.
        
        Args:
            audio_dir (str): Path to directory containing audio files.
            transcript_dir (str): Path to directory containing transcript files.
            tokenizer (transformers.AutoTokenizer): Tokenizer for text data.
            max_text_length (int): Maximum length for text tokenization.
            sample_rate (int): Desired sample rate for audio.
        """
        self.audio_files = sorted([
            os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".wav")
        ])
        self.transcript_files = sorted([
            os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir) if f.endswith(".stm")
        ])
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        # Load audio file
        audio_path = self.audio_files[idx]
        waveform, sr = torchaudio.load(audio_path)
        
        # Resample to desired sample rate if necessary
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

        # Load transcript with specified encoding
        try:
            with open(self.transcript_files[idx], 'r', encoding='utf-8') as f:
                transcript = f.read().strip()
        except UnicodeDecodeError as e:
            print(f"UnicodeDecodeError for file: {self.transcript_files[idx]}")
            print(f"Error details: {e}")
            # Optionally, handle the error by skipping the file or using a fallback
            transcript = ""  # Assign an empty string or any default value
        
        # Return waveform and raw text
        return waveform.squeeze(0), transcript

# Initialize tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = TEDLIUMDataset(audio_dir=PATH_TO_AUDIO_FILES, transcript_dir=PATH_TO_TRANSCRIPT_FILES, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True, collate_fn=lambda x: collate_fn(x))

# Defining training components

In [20]:
# Step 4: Define Training Components

# Initialize model components
from transformers import Wav2Vec2Model, BertModel

"""
audio_tokenizer = AudioTokenizer(patch_size=128, embed_dim=2048).to(device)
text_tokenizer = TextTokenizer(embed_dim=768, max_length=512).to(device)

positional_encoding_audio = PositionalEncoding(embed_dim=2048).to(device)
positional_encoding_text = PositionalEncoding(embed_dim=768).to(device)

transformer_audio = TransformerFactory(embed_dim=2048, num_heads=8, num_layers=2, feedforward_dim_multiplier=2).to(device)
transformer_text = TransformerFactory(embed_dim=768, num_heads=8, num_layers=12, feedforward_dim_multiplier=6).to(device)
"""

audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-large-960h').to(device)
audio_encoder.gradient_checkpointing_enable()

text_encoder = BertModel.from_pretrained('bert-base-uncased').to(device)
text_encoder.gradient_checkpointing_enable()

projection_head_audio = ProjectionHead(embed_dim=1024, proj_dim=256).to(device)
projection_head_text = ProjectionHead(embed_dim=768, proj_dim=256).to(device)

# Example for BERT model

# Example for Wav2Vec2 model

contrastive_loss = ContrastiveLoss().to(device)
droptoken = DropToken(drop_rate=0.5).to(device)

# Define a lower learning rate for pretrained encoders
pretrained_lr = 1e-5  # Lower learning rate for BERT and Wav2Vec2
training_lr = 1e-4  # Default learning rate for the rest of the model

# Optimizer
optimizer = torch.optim.AdamW([
    # {'params': audio_tokenizer.parameters()},
    # {'params': text_tokenizer.parameters()},
    # {'params': positional_encoding_audio.parameters()},
    # {'params': positional_encoding_text.parameters()},
    # {'params': transformer_audio.parameters()},
    # {'params': transformer_text.parameters()},
    {'params': audio_encoder.parameters(), 'lr': pretrained_lr},
    {'params': text_encoder.parameters(), 'lr': pretrained_lr},
    {'params': projection_head_audio.parameters(), 'lr': training_lr},
    {'params': projection_head_text.parameters(), 'lr': training_lr},
], lr=1e-4)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from tqdm import tqdm
from torch.amp import GradScaler

# Step 5: Training Loop with Progress Bars for Steps within Each Epoch
def train(model_components, dataloader, optimizer, num_epochs=10):
    (
        audio_encoder,
        text_encoder,
        projection_head_audio,
        projection_head_text,
        contrastive_loss,
        droptoken
    ) = model_components
    
    for epoch in range(num_epochs):
        total_loss = 0
        # Initialize progress bar for batches within the current epoch
        batch_progress = tqdm(
            dataloader,
            desc=f"Epoch {epoch+1}/{num_epochs}",
            unit="batch",
            leave=True  # Keeps the progress bar after the epoch completes
        )
        
        accumulation_steps = 4
        scaler = GradScaler()

        for i, (audio, text) in enumerate(batch_progress):
            # Move audio to device; text remains on CPU
            audio = audio.to(device)

            # Audio processing with Wav2Vec2
            audio_output = audio_encoder(audio)  # [batch, seq_len, 768]
            audio_embeddings = audio_output.last_hidden_state  # Extract the hidden states
            audio_embeddings = droptoken(audio_embeddings)     # Apply DropToken
            audio_proj = projection_head_audio(audio_embeddings)  # Project to common space

            # Text processing with BERT
            tokenized_text = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            ).to(device)
            text_output = text_encoder(**tokenized_text)  # [batch, seq_len, 768]
            text_embeddings = text_output.last_hidden_state  # Extract the hidden states
            text_embeddings = droptoken(text_embeddings)     # Apply DropToken
            text_proj = projection_head_text(text_embeddings)  # Project to common space

            # Contrastive Loss
            loss = contrastive_loss(audio_proj, text_proj)
            total_loss += loss.item()

            scaler.scale(loss).backward()

            # Backpropagation
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)  # Apply the gradients
                scaler.update()          # Update the scaler
                optimizer.zero_grad()    # Reset gradients
            # No need to call scaler.update() here since it's done after step()

            # Update the progress bar's postfix with the current loss every 10 steps
            if (i + 1) % 10 == 0:
                batch_progress.set_postfix(loss=loss.item())
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

# Define model components as a tuple
model_components = (
    audio_encoder,
    text_encoder,
    projection_head_audio,
    projection_head_text,
    contrastive_loss,
    droptoken
)

# Start training
num_epochs = 10
train(model_components, dataloader, optimizer, num_epochs=num_epochs)

Epoch 1/10:  12%|█▏        | 68/587 [1:35:35<8:26:49, 58.59s/batch, loss=1.39]  

In [None]:
import time

def save_model(model_components, optimizer, epoch):
    """
    Save the state dictionaries of all model components and the optimizer.

    Args:
        model_components (tuple): Tuple containing all model components.
        optimizer (torch.optim.Optimizer): The optimizer used during training.
        epoch (int): The current epoch number.
    """
    # Generate a unique filename for the saved model
    path = f"trained-{epoch}-{time.time()}.pth"

    # Unpack the components
    (
        audio_tokenizer,
        text_tokenizer,
        positional_encoding_audio,
        positional_encoding_text,
        transformer_audio,
        transformer_text,
        projection_head_audio,
        projection_head_text,
        contrastive_loss,
        droptoken
    ) = model_components

    # Save state dictionaries
    torch.save({
        'epoch': epoch,
        'audio_tokenizer_state_dict': audio_tokenizer.state_dict(),
        'text_tokenizer_state_dict': text_tokenizer.state_dict(),
        'positional_encoding_audio_state_dict': positional_encoding_audio.state_dict(),
        'positional_encoding_text_state_dict': positional_encoding_text.state_dict(),
        'transformer_audio_state_dict': transformer_audio.state_dict(),
        'transformer_text_state_dict': transformer_text.state_dict(),
        'projection_head_audio_state_dict': projection_head_audio.state_dict(),
        'projection_head_text_state_dict': projection_head_text.state_dict(),
        'contrastive_loss_state_dict': contrastive_loss.state_dict(),
        'droptoken_state_dict': droptoken.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

    print(f"Model saved to {path}")

save_model(model_components, optimizer, num_epochs)


Model saved to trained-10-1732919613.250842.pth


In [None]:
def count_parameters(model_components):
    """
    Counts and prints the total number of trainable parameters for each model component.

    Args:
        model_components (tuple): A tuple of PyTorch model components.

    Returns:
        int: The total number of trainable parameters in the model.
    """
    total_params = 0
    print(f"{'Component':<30}{'Parameters':>15}")
    print("-" * 50)

    # Iterate through each component in the tuple
    component_names = [
        "Audio Tokenizer",
        "Text Tokenizer",
        "Positional Encoding (Audio)",
        "Positional Encoding (Text)",
        "Transformer (Audio)",
        "Transformer (Text)",
        "Projection Head (Audio)",
        "Projection Head (Text)",
        "Contrastive Loss",
        "DropToken"
    ]

    for name, component in zip(component_names, model_components):
        if hasattr(component, "parameters"):
            component_params = sum(p.numel() for p in component.parameters() if p.requires_grad)
            total_params += component_params
            print(f"{name:<30}{component_params:>15,}")

    print("-" * 50)
    print(f"{'Total Trainable Parameters':<30}{total_params:>15,}")
    return total_params

total_params = count_parameters(model_components)

Component                          Parameters
--------------------------------------------------
Audio Tokenizer                       264,192
Text Tokenizer                     23,440,896
Positional Encoding (Audio)                 0
Positional Encoding (Text)                  0
Transformer (Audio)                67,153,920
Transformer (Text)                113,384,448
Projection Head (Audio)             9,441,536
Projection Head (Text)              1,574,656
Contrastive Loss                            0
DropToken                                   0
--------------------------------------------------
Total Trainable Parameters        215,259,648


**CLAP parameter sizes**

Total parameters: 196,304,144  
Trainable parameters: 195,220,688  
Text encoder weights have been frozen.  
Audio Encoder - Total parameters: 84,984,847  
Audio Encoder - Trainable parameters: 83,901,391  
Text Encoder - Total parameters: 111,319,296  
Text Encoder - Trainable parameters: 1,837,056  