In [4]:
# Step 0: Install necessary libraries
!pip -q install torchaudio transformers PySoundFile tqdm

# The modified VATT architecture for audio-text alignment

In [26]:
# Step 1: Install and Import Libraries
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import numpy as np

In [27]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [49]:
def compute_transformer_params(embed_dim, num_heads, num_layers=1):
    """
    Compute the total number of parameters for a TransformerEncoderLayer.

    Args:
        embed_dim (int): Embedding dimension.
        num_heads (int): Number of attention heads.
        num_layers (int): Number of transformer layers.

    Returns:
        int: Total number of parameters.
    """
    # Multihead attention parameters: Query, Key, Value projections + output projection
    attention_params = 4 * embed_dim * embed_dim

    # Feedforward network parameters: Intermediate (4x embed_dim) + output (embed_dim)
    feedforward_params = (embed_dim * 4 * embed_dim) + (embed_dim * 4 * embed_dim)

    # Layer normalization parameters (2 layers)
    layer_norm_params = 2 * embed_dim * 2

    # Total for one layer, scaled by the number of layers
    total_params = (attention_params + feedforward_params + layer_norm_params) * num_layers
    return total_params

# CLAP target dimensions
clap_audio_dim = 2048
clap_text_dim = 768

# Target parameters for VATT
target_params = 196_304_144

# Calculate required layers for VATT
def calculate_layers(target_params, embed_dim):
    single_layer_params = compute_transformer_params(embed_dim, num_heads=8, num_layers=1)
    required_layers = target_params // single_layer_params
    return required_layers

# Number of layers for VATT
audio_layers = calculate_layers(target_params // 2, clap_audio_dim)
text_layers = calculate_layers(target_params // 2, clap_text_dim)

print("(audio_layers, text_layers):", (audio_layers, text_layers))

(audio_layers, text_layers): (1, 13)


In [50]:
# Step 2: Audio Tokenization
class AudioTokenizer(nn.Module):
    def __init__(self, patch_size=128, embed_dim=2048):
        super(AudioTokenizer, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Linear(patch_size, embed_dim)

    def forward(self, audio_signal):
        # Assuming audio_signal is [batch, time_samples]
        batch_size, time_len = audio_signal.shape
        num_patches = time_len // self.patch_size
        audio_signal = audio_signal[:, :num_patches * self.patch_size]
        audio_patches = audio_signal.reshape(batch_size, num_patches, self.patch_size)
        # Project each patch to embedding dimension
        audio_embeddings = self.projection(audio_patches)  # [batch, num_patches, embed_dim]
        return audio_embeddings

In [51]:
class TextTokenizer(nn.Module):
    def __init__(self, embed_dim=768, max_length=512):
        super(TextTokenizer, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, embed_dim)
        self.max_length = max_length

    def forward(self, text):
        # text is a list of strings
        tokens = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        input_ids = tokens["input_ids"].to(device)  # [batch, max_length]
        text_embeddings = self.embedding(input_ids)  # [batch, max_length, embed_dim]
        return text_embeddings  # Ensures 3D output


In [52]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_dim
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # Shape [1, max_len, embed_dim]

    def forward(self, x):
        # Adjust the positional encoding to match the actual input shape
        batch_size, seq_len, embed_dim = x.size()
        pe = self.pe[:, :seq_len, :].expand(batch_size, seq_len, embed_dim)

        return x + pe.to(x.device)

In [53]:
# Step 5: Transformer Encoder Components
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        return attn_output

class FeedForward(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion_factor)
        self.fc2 = nn.Linear(embed_dim * expansion_factor, embed_dim)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.linear1 = nn.Linear(embed_dim, embed_dim * 4)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(embed_dim * 4, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [54]:
# Step 6: Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, embed_dim, proj_dim):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(embed_dim, proj_dim)

    def forward(self, x):
        # Extract [CLS] token (or average pool)
        x = x[:, 0]  # Assuming first token is [CLS]
        return self.projection(x)

In [55]:
# MADE CHANGES HERE - Abhigyan
class UpdatedProjectionHead(nn.Module):
    def __init__(self, embed_dim, proj_dim):
        super(UpdatedProjectionHead, self).__init__()
        self.hidden_dim = embed_dim * 2  # Internal hidden size
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, self.hidden_dim),  # Increase internal size
            nn.GELU(),
            nn.Linear(self.hidden_dim, proj_dim),  # Project to the final dimension
        )

    def forward(self, x):
        # Extract [CLS] token (or average pool)
        x = x[:, 0]  # Assuming first token is [CLS]
        return self.projection(x)

In [56]:
# Step 7: Contrastive Learning (NCE and MIL-NCE)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features_a, features_b):
        # Normalize features and compute similarity matrix
        features_a = nn.functional.normalize(features_a, dim=1)
        features_b = nn.functional.normalize(features_b, dim=1)
        logits = torch.matmul(features_a, features_b.T) / self.temperature
        labels = torch.arange(len(features_a)).to(device)
        return nn.CrossEntropyLoss()(logits, labels)

In [57]:
# Step 8: DropToken Implementation
class DropToken(nn.Module):
    def __init__(self, drop_rate=0.5):
        super(DropToken, self).__init__()
        self.drop_rate = drop_rate

    def forward(self, x):
        if not self.training:
            return x
        keep_prob = 1 - self.drop_rate
        mask = torch.rand(x.shape[:2], device=x.device) < keep_prob
        x = x * mask.unsqueeze(-1)
        return x

# Data loading and training

In [58]:
# Step 1: Assuming we have already downloaded TED-LIUM dataset and extracted audio + text files
# From: https://www.openslr.org/51/ [50.6 GB]
PATH_TO_AUDIO_FILES = "./audio_files"
PATH_TO_TRANSCRIPT_FILES = "./transcript_files"

In [59]:
# Step 2: Imports
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import os
from transformers import AutoTokenizer

In [60]:
import torch.nn.functional as F

def collate_fn(batch, max_audio_length=160000):
    audio_tensors = []
    text_list = []

    for audio, text in batch:
        # Adjust audio length to max_audio_length
        if audio.size(0) > max_audio_length:
            audio = audio[:max_audio_length]
        else:
            audio = F.pad(audio, (0, max_audio_length - audio.size(0)))

        audio_tensors.append(audio)
        text_list.append(text)

    # Stack the audio tensors into a batch
    audio_batch = torch.stack(audio_tensors)

    # text_list is a list of strings
    return audio_batch, text_list

In [19]:
# Step 3: TED-LIUM Dataset Setup
class TEDLIUMDataset(Dataset):
    def __init__(self, audio_dir, transcript_dir, tokenizer, max_text_length=512, sample_rate=16000):
        """
        Initialize dataset with paths to audio and transcript directories and tokenizer.

        Args:
            audio_dir (str): Path to directory containing audio files.
            transcript_dir (str): Path to directory containing transcript files.
            tokenizer (transformers.AutoTokenizer): Tokenizer for text data.
            max_text_length (int): Maximum length for text tokenization.
            sample_rate (int): Desired sample rate for audio.
        """
        self.audio_files = sorted([
            os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".wav")
        ])
        self.transcript_files = sorted([
            os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir) if f.endswith(".stm")
        ])
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        # Load audio file
        audio_path = self.audio_files[idx]
        waveform, sr = torchaudio.load(audio_path)

        # Resample to desired sample rate if necessary
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

        # Load transcript with specified encoding
        try:
            with open(self.transcript_files[idx], 'r', encoding='utf-8') as f:
                transcript = f.read().strip()
        except UnicodeDecodeError as e:
            print(f"UnicodeDecodeError for file: {self.transcript_files[idx]}")
            print(f"Error details: {e}")
            # Optionally, handle the error by skipping the file or using a fallback
            transcript = ""  # Assign an empty string or any default value

        # Return waveform and raw text
        return waveform.squeeze(0), transcript

# Initialize tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = TEDLIUMDataset(audio_dir=PATH_TO_AUDIO_FILES, transcript_dir=PATH_TO_TRANSCRIPT_FILES, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=lambda x: collate_fn(x))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

FileNotFoundError: [Errno 2] No such file or directory: './audio_files'

In [79]:
# Step 4: Define Training Components
# Assuming VATT model, projection heads, contrastive loss, and DropToken components are defined as above

# Initialize model components

#CHANGES MADE HERE -- Abhigyan

audio_num_layers = 2  # Increase the depth of the transformer
audio_ffn_expansion = 2  # Increase the expansion factor

text_num_layers = 12  # Increase the depth of the transformer
text_ffn_expansion = 6 # Increase the expansion factor


audio_tokenizer = AudioTokenizer(patch_size=128, embed_dim=2048).to(device)
text_tokenizer = TextTokenizer(embed_dim=768, max_length=512).to(device)
positional_encoding_audio = PositionalEncoding(embed_dim=2048).to(device)
positional_encoding_text = PositionalEncoding(embed_dim=768).to(device)

# Updated Audio Transformer
transformer_audio = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=2048,          # Input/output stays the same
        nhead=8,               # Number of heads remains unchanged
        dim_feedforward=2048 * audio_ffn_expansion,  # Increase internal hidden size
        dropout=0.1
    ),
    num_layers=audio_num_layers
).to(device)
#transformer_audio = TransformerEncoderLayer(embed_dim=2048, num_heads=8).to(device)

# Updated Text Transformer
transformer_text = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=768,           # Input/output stays the same
        nhead=8,               # Number of heads remains unchanged
        dim_feedforward=768 * text_ffn_expansion,  # Increase internal hidden size
        dropout=0.1
    ),
    num_layers=text_num_layers
).to(device)

#transformer_text = TransformerEncoderLayer(embed_dim=768, num_heads=8).to(device)

# Apply the updated projection heads
projection_head_audio = UpdatedProjectionHead(embed_dim=2048, proj_dim=256).to(device)
projection_head_text = UpdatedProjectionHead(embed_dim=768, proj_dim=256).to(device)

# projection_head_audio = ProjectionHead(embed_dim=2048, proj_dim=256).to(device)
# projection_head_text = ProjectionHead(embed_dim=768, proj_dim=256).to(device)


contrastive_loss = ContrastiveLoss().to(device)
droptoken = DropToken(drop_rate=0.5).to(device)

# Optimizer
# Optimizer
optimizer = torch.optim.AdamW([
    {'params': audio_tokenizer.parameters()},
    {'params': text_tokenizer.parameters()},
    {'params': positional_encoding_audio.parameters()},
    {'params': positional_encoding_text.parameters()},
    {'params': transformer_audio.parameters()},
    {'params': transformer_text.parameters()},
    {'params': projection_head_audio.parameters()},  # Use audio-specific projection head
    {'params': projection_head_text.parameters()},   # Use text-specific projection head
], lr=1e-4)


In [80]:
from tqdm import tqdm

# Step 5: Training Loop with Progress Bars for Steps within Each Epoch
def train(model_components, dataloader, optimizer, num_epochs=10):
    (
        audio_tokenizer,
        text_tokenizer,
        positional_encoding_audio,
        positional_encoding_text,
        transformer_audio,
        transformer_text,
        projection_head_audio,  # Use the audio-specific projection head
        projection_head_text,
        contrastive_loss,
        droptoken
    ) = model_components

    for epoch in range(num_epochs):
        total_loss = 0
        # Initialize progress bar for batches within the current epoch
        batch_progress = tqdm(
            dataloader,
            desc=f"Epoch {epoch+1}/{num_epochs}",
            unit="batch",
            leave=True  # Keeps the progress bar after the epoch completes
        )

        for i, (audio, text) in enumerate(batch_progress):
            # Move audio to device; text is a list of strings
            audio = audio.to(device)
            # text remains on CPU

            # Audio processing
            audio_embeddings = audio_tokenizer(audio)  # [batch, num_patches, embed_dim]
            audio_embeddings = positional_encoding_audio(audio_embeddings)
            audio_embeddings = transformer_audio(audio_embeddings)  # [batch, num_patches, embed_dim]

            # Text processing
            text_embeddings = text_tokenizer(text)  # [batch, max_length, embed_dim]
            text_embeddings = positional_encoding_text(text_embeddings)  # [batch, max_length, embed_dim]
            text_embeddings = transformer_text(text_embeddings)  # [batch, max_length, embed_dim]

            # Apply DropToken during training
            audio_embeddings = droptoken(audio_embeddings)
            text_embeddings = droptoken(text_embeddings)

            # Project to common space
            audio_proj = projection_head_audio(audio_embeddings)  # [batch, proj_dim]
            text_proj = projection_head_text(text_embeddings)    # [batch, proj_dim]

            # Contrastive Loss
            loss = contrastive_loss(audio_proj, text_proj)
            total_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the progress bar's postfix with the current loss every 10 steps
            if (i + 1) % 10 == 0:
                batch_progress.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

# Define model components as a tuple
model_components = (
    audio_tokenizer,
    text_tokenizer,
    positional_encoding_audio,
    positional_encoding_text,
    transformer_audio,
    transformer_text,
    projection_head_audio,  # Use the audio-specific projection head
    projection_head_text,
    contrastive_loss,
    droptoken
)

# Start training
num_epochs = 10
#train(model_components, dataloader, optimizer, num_epochs=num_epochs)

In [24]:
import time

def save_model(model_components, optimizer, epoch):
    """
    Save the state dictionaries of all model components and the optimizer.

    Args:
        model_components (tuple): Tuple containing all model components.
        optimizer (torch.optim.Optimizer): The optimizer used during training.
        epoch (int): The current epoch number.
    """
    # Generate a unique filename for the saved model
    path = f"trained-{epoch}-{time.time()}.pth"

    # Unpack the components
    (
        audio_tokenizer,
        text_tokenizer,
        positional_encoding_audio,
        positional_encoding_text,
        transformer_audio,
        transformer_text,
        projection_head_audio,  # Updated to include audio-specific projection head
        projection_head_text,   # Updated to include text-specific projection head
        contrastive_loss,
        droptoken
    ) = model_components

    # Save state dictionaries
    torch.save({
        'epoch': epoch,
        'audio_tokenizer_state_dict': audio_tokenizer.state_dict(),
        'text_tokenizer_state_dict': text_tokenizer.state_dict(),
        'positional_encoding_audio_state_dict': positional_encoding_audio.state_dict(),
        'positional_encoding_text_state_dict': positional_encoding_text.state_dict(),
        'transformer_audio_state_dict': transformer_audio.state_dict(),
        'transformer_text_state_dict': transformer_text.state_dict(),
        'projection_head_audio_state_dict': projection_head_audio.state_dict(),  # Added
        'projection_head_text_state_dict': projection_head_text.state_dict(),    # Added
        'contrastive_loss_state_dict': contrastive_loss.state_dict(),
        'droptoken_state_dict': droptoken.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

    print(f"Model saved to {path}")

save_model(model_components, optimizer, num_epochs)


KeyboardInterrupt: 

In [81]:
def count_parameters(model_components):
    """
    Counts and prints the total number of trainable parameters for each model component.

    Args:
        model_components (tuple): A tuple of PyTorch model components.

    Returns:
        int: The total number of trainable parameters in the model.
    """
    total_params = 0
    print(f"{'Component':<30}{'Parameters':>15}")
    print("-" * 50)

    # Iterate through each component in the tuple
    component_names = [
        "Audio Tokenizer",
        "Text Tokenizer",
        "Positional Encoding (Audio)",
        "Positional Encoding (Text)",
        "Transformer (Audio)",
        "Transformer (Text)",
        "Projection Head (Audio)",
        "Projection Head (Text)",
        "Contrastive Loss",
        "DropToken"
    ]

    for name, component in zip(component_names, model_components):
        if hasattr(component, "parameters"):
            component_params = sum(p.numel() for p in component.parameters() if p.requires_grad)
            total_params += component_params
            print(f"{name:<30}{component_params:>15,}")

    print("-" * 50)
    print(f"{'Total Trainable Parameters':<30}{total_params:>15,}")
    return total_params

total_params = count_parameters(model_components)

Component                          Parameters
--------------------------------------------------
Audio Tokenizer                       264,192
Text Tokenizer                     23,440,896
Positional Encoding (Audio)                 0
Positional Encoding (Text)                  0
Transformer (Audio)                67,153,920
Transformer (Text)                113,384,448
Projection Head (Audio)             9,441,536
Projection Head (Text)              1,574,656
Contrastive Loss                            0
DropToken                                   0
--------------------------------------------------
Total Trainable Parameters        215,259,648


**CLAP parameter sizes**

Total parameters: 196,304,144  
Trainable parameters: 195,220,688  
Text encoder weights have been frozen.  
Audio Encoder - Total parameters: 84,984,847  
Audio Encoder - Trainable parameters: 83,901,391  
Text Encoder - Total parameters: 111,319,296  
Text Encoder - Trainable parameters: 1,837,056  