In [1]:
# Step 0: Install necessary libraries
!pip -q install torchaudio transformers PySoundFile tqdm



# The modified VATT architecture for audio-text alignment

In [2]:
# Step 1: Install and Import Libraries
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import numpy as np

In [3]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Step 2: Audio Tokenization
class AudioTokenizer(nn.Module):
    def __init__(self, patch_size=128, embed_dim=2048):
        super(AudioTokenizer, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Linear(patch_size, embed_dim)
    
    def forward(self, audio_signal):
        # Assuming audio_signal is [batch, time_samples]
        batch_size, time_len = audio_signal.shape
        num_patches = time_len // self.patch_size
        audio_signal = audio_signal[:, :num_patches * self.patch_size]
        audio_patches = audio_signal.reshape(batch_size, num_patches, self.patch_size)
        # Project each patch to embedding dimension
        audio_embeddings = self.projection(audio_patches)  # [batch, num_patches, embed_dim]
        return audio_embeddings

In [5]:
class TextTokenizer(nn.Module):
    def __init__(self, embed_dim=768, max_length=512):
        super(TextTokenizer, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, embed_dim)
        self.max_length = max_length
    
    def forward(self, text):
        # text is a list of strings
        tokens = self.tokenizer(
            text, 
            truncation=True, 
            padding="max_length", 
            max_length=self.max_length, 
            return_tensors="pt"
        )
        input_ids = tokens["input_ids"].to(device)  # [batch, max_length]
        text_embeddings = self.embedding(input_ids)  # [batch, max_length, embed_dim]
        return text_embeddings  # Ensures 3D output


In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_dim
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # Shape [1, max_len, embed_dim]

    def forward(self, x):
        # Adjust the positional encoding to match the actual input shape
        batch_size, seq_len, embed_dim = x.size()
        pe = self.pe[:, :seq_len, :].expand(batch_size, seq_len, embed_dim)
        
        return x + pe.to(x.device)

In [7]:
# Step 5: Transformer Encoder Components
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
    
    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        return attn_output

class FeedForward(nn.Module):
    def __init__(self, embed_dim, expansion_factor=4):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion_factor)
        self.fc2 = nn.Linear(embed_dim * expansion_factor, embed_dim)
        self.activation = nn.GELU()
    
    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.linear1 = nn.Linear(embed_dim, embed_dim * 4)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(embed_dim * 4, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [8]:
class TransformerFactory(nn.Module):
    """
    Factory class to create TransformerEncoder instances with customizable parameters.
    """
    def __init__(self, embed_dim, num_heads, num_layers, feedforward_dim_multiplier=4, dropout=0.1):
        """
        Args:
            embed_dim (int): Dimension of the input embeddings.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of Transformer layers.
            feedforward_dim_multiplier (int): Multiplier for feedforward layer dimensions.
            dropout (float): Dropout probability.
        """
        super(TransformerFactory, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, 
                nhead=num_heads, 
                dim_feedforward=embed_dim * feedforward_dim_multiplier, 
                dropout=dropout,
                batch_first=True  # Enable batch-first optimization
            ),
            num_layers=num_layers
        )

    def forward(self, x):
        return self.encoder(x)

In [9]:
# Step 6: Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, embed_dim, proj_dim):
        super(ProjectionHead, self).__init__()
        self.hidden_dim = embed_dim * 2
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, proj_dim),
        )
    
    def forward(self, x):
        # Extract [CLS] token (or average pool)
        x = x[:, 0]  # Assuming first token is [CLS]
        return self.projection(x)

In [10]:
# Step 7: Contrastive Learning (NCE and MIL-NCE)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, features_a, features_b):
        # Normalize features and compute similarity matrix
        features_a = nn.functional.normalize(features_a, dim=1)
        features_b = nn.functional.normalize(features_b, dim=1)
        logits = torch.matmul(features_a, features_b.T) / self.temperature
        labels = torch.arange(len(features_a)).to(device)
        return nn.CrossEntropyLoss()(logits, labels)

In [11]:
# Step 8: DropToken Implementation
class DropToken(nn.Module):
    def __init__(self, drop_rate=0.5):
        super(DropToken, self).__init__()
        self.drop_rate = drop_rate
    
    def forward(self, x):
        if not self.training:
            return x
        keep_prob = 1 - self.drop_rate
        mask = torch.rand(x.shape[:2], device=x.device) < keep_prob
        x = x * mask.unsqueeze(-1)
        return x

# Data loading and training

In [12]:
# Step 1: Assuming we have already downloaded TED-LIUM dataset and extracted audio + text files
# From: https://www.openslr.org/51/ [50.6 GB]
PATH_TO_AUDIO_FILES = "./audio_files"
PATH_TO_TRANSCRIPT_FILES = "./transcript_files"

In [13]:
# Step 2: Imports
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import os
from transformers import AutoTokenizer

In [14]:
import torch.nn.functional as F

def collate_fn(batch, max_audio_length=160000):
    audio_tensors = []
    text_list = []

    for audio, text in batch:
        # Adjust audio length to max_audio_length
        if audio.size(0) > max_audio_length:
            audio = audio[:max_audio_length]
        else:
            audio = F.pad(audio, (0, max_audio_length - audio.size(0)))
        
        audio_tensors.append(audio)
        text_list.append(text)

    # Stack the audio tensors into a batch
    audio_batch = torch.stack(audio_tensors)

    # text_list is a list of strings
    return audio_batch, text_list

In [15]:
# Step 3: TED-LIUM Dataset Setup
class TEDLIUMDataset(Dataset):
    def __init__(self, audio_dir, transcript_dir, tokenizer, max_text_length=512, sample_rate=16000, test_size=10):
        """
        Initialize dataset with paths to audio and transcript directories and tokenizer.
        
        Args:
            audio_dir (str): Path to directory containing audio files.
            transcript_dir (str): Path to directory containing transcript files.
            tokenizer (transformers.AutoTokenizer): Tokenizer for text data.
            max_text_length (int): Maximum length for text tokenization.
            sample_rate (int): Desired sample rate for audio.
            test_size (int): Number of samples to use for the test set.
        """
        self.audio_files = sorted([
            os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".wav")
        ])
        self.transcript_files = sorted([
            os.path.join(transcript_dir, f) for f in os.listdir(transcript_dir) if f.endswith(".stm")
        ])
        
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.sample_rate = sample_rate
        self.test_size = test_size

        # Split data into training and testing
        self.train_audio_files = self.audio_files[:-self.test_size]  # All except last 10
        self.test_audio_files = self.audio_files[-self.test_size:]   # Last 10
        self.train_transcript_files = self.transcript_files[:-self.test_size]  # All except last 10
        self.test_transcript_files = self.transcript_files[-self.test_size:]   # Last 10

    def __len__(self):
        # The length should be the size of the training set (excluding test set entries)
        return len(self.train_audio_files)

    def __getitem__(self, idx):
        # Get the training sample
        audio_path = self.train_audio_files[idx]
        transcript_path = self.train_transcript_files[idx]
        
        # Load audio file
        waveform, sr = torchaudio.load(audio_path)
        
        # Resample to desired sample rate if necessary
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

        # Load transcript
        with open(transcript_path, 'r', encoding='utf-8') as f:
            transcript = f.read().strip()
        
        # Return waveform and transcript for training
        return waveform.squeeze(0), transcript
    
    def get_test_samples(self):
        """Return a list of test samples (audio and corresponding transcript)."""
        test_samples = []
        for i in range(self.test_size):
            audio_path = self.test_audio_files[i]
            transcript_path = self.test_transcript_files[i]
            
            # Load audio file
            waveform, sr = torchaudio.load(audio_path)
            
            # Resample if necessary
            if sr != self.sample_rate:
                waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

            # Load transcript
            with open(transcript_path, 'r', encoding='utf-8') as f:
                transcript = f.read().strip()
            
            # Append audio and transcript to test_samples
            test_samples.append((waveform.squeeze(0), transcript))
        
        return test_samples

# Initialize tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = TEDLIUMDataset(audio_dir=PATH_TO_AUDIO_FILES, transcript_dir=PATH_TO_TRANSCRIPT_FILES, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True, collate_fn=lambda x: collate_fn(x))

# Defining training components

In [None]:
# Step 4: Define Training Components

# Define a lower learning rate for pretrained encoders
pretrained_lr = 1e-4  # Lower learning rate for BERT and Wav2Vec2
training_lr = 1e-3  # Default learning rate for the rest of the model
num_of_layers_to_unfreeze = 8

# Initialize model components
from transformers import Wav2Vec2Model, BertModel

audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-large-960h').to(device)
audio_encoder.gradient_checkpointing_enable()

text_encoder = BertModel.from_pretrained('bert-base-uncased').to(device)
text_encoder.gradient_checkpointing_enable()

for param in audio_encoder.parameters(): param.requires_grad = False
for param in audio_encoder.encoder.layers[-num_of_layers_to_unfreeze:].parameters(): param.requires_grad = True
for param in text_encoder.parameters(): param.requires_grad = False
for param in text_encoder.encoder.layer[-num_of_layers_to_unfreeze:].parameters(): param.requires_grad = True

projection_head_audio = ProjectionHead(embed_dim=1024, proj_dim=256).to(device)
projection_head_text = ProjectionHead(embed_dim=768, proj_dim=256).to(device)

contrastive_loss = ContrastiveLoss().to(device)
droptoken = DropToken(drop_rate=0.5).to(device)

# Optimizer
optimizer = torch.optim.AdamW([
    {'params': audio_encoder.parameters(), 'lr': pretrained_lr},
    {'params': text_encoder.parameters(), 'lr': pretrained_lr},
    {'params': projection_head_audio.parameters(), 'lr': training_lr},
    {'params': projection_head_text.parameters(), 'lr': training_lr},
])

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
from tqdm import tqdm
from torch.amp import GradScaler

# Step 5: Training Loop with Progress Bars for Steps within Each Epoch
def train(model_components, dataloader, optimizer, num_epochs=10):
    (
        audio_encoder,
        text_encoder,
        projection_head_audio,
        projection_head_text,
        contrastive_loss,
        droptoken
    ) = model_components
    
    for epoch in range(num_epochs):
        total_loss = 0
        # Initialize progress bar for batches within the current epoch
        batch_progress = tqdm(
            dataloader,
            desc=f"Epoch {epoch+1}/{num_epochs}",
            unit="batch",
            leave=True  # Keeps the progress bar after the epoch completes
        )
        
        accumulation_steps = 4
        scaler = GradScaler()

        for i, (audio, text) in enumerate(batch_progress):
            # Move audio to device; text remains on CPU
            audio = audio.to(device)

            # Audio processing with Wav2Vec2
            audio_output = audio_encoder(audio)  # [batch, seq_len, 768]
            audio_embeddings = audio_output.last_hidden_state  # Extract the hidden states
            audio_embeddings = droptoken(audio_embeddings)     # Apply DropToken
            audio_proj = projection_head_audio(audio_embeddings)  # Project to common space

            # Text processing with BERT
            tokenized_text = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt"
            ).to(device)
            text_output = text_encoder(**tokenized_text)  # [batch, seq_len, 768]
            text_embeddings = text_output.last_hidden_state  # Extract the hidden states
            text_embeddings = droptoken(text_embeddings)     # Apply DropToken
            text_proj = projection_head_text(text_embeddings)  # Project to common space

            # Contrastive Loss
            loss = contrastive_loss(audio_proj, text_proj)
            total_loss += loss.item()

            scaler.scale(loss).backward()

            # Backpropagation
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)  # Apply the gradients
                scaler.update()          # Update the scaler
                optimizer.zero_grad()    # Reset gradients
            # No need to call scaler.update() here since it's done after step()

            # Update the progress bar's postfix with the current loss every 10 steps
            if (i + 1) % 10 == 0:
                batch_progress.set_postfix(loss=loss.item())
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

# Define model components as a tuple
model_components = (
    audio_encoder,
    text_encoder,
    projection_head_audio,
    projection_head_text,
    contrastive_loss,
    droptoken
)

# Start training
num_epochs = 5
train(model_components, dataloader, optimizer, num_epochs=num_epochs)

Epoch 1/5: 100%|██████████| 585/585 [05:19<00:00,  1.83batch/s, loss=1.36]


Epoch [1/5], Average Loss: 1.4239


Epoch 2/5: 100%|██████████| 585/585 [04:23<00:00,  2.22batch/s, loss=1.35]


Epoch [2/5], Average Loss: 1.3990


Epoch 3/5: 100%|██████████| 585/585 [04:24<00:00,  2.21batch/s, loss=1.39]


Epoch [3/5], Average Loss: 1.4123


Epoch 4/5: 100%|██████████| 585/585 [04:34<00:00,  2.13batch/s, loss=1.3] 


Epoch [4/5], Average Loss: 1.3952


Epoch 5/5: 100%|██████████| 585/585 [04:35<00:00,  2.13batch/s, loss=1.39]

Epoch [5/5], Average Loss: 1.3916





# Housekeeping

In [18]:
import time

def save_model(model_components, optimizer, epoch):
    """
    Save the state dictionaries of all model components and the optimizer.

    Args:
        model_components (tuple): Tuple containing all model components.
        optimizer (torch.optim.Optimizer): The optimizer used during training.
        epoch (int): The current epoch number.
    """
    # Generate a unique filename for the saved model
    path = f"trained-{epoch}-{time.time()}.pth"

    # Unpack the components
    (
        audio_encoder,
        text_encoder,
        projection_head_audio,
        projection_head_text,
        contrastive_loss,
        droptoken
    ) = model_components

    # Save state dictionaries
    torch.save({
        'epoch': epoch,
        'text_encoder_state_dict': text_encoder.state_dict(),
        'audio_encoder_state_dict': audio_encoder.state_dict(),
        'projection_head_audio_state_dict': projection_head_audio.state_dict(),
        'projection_head_text_state_dict': projection_head_text.state_dict(),
        'contrastive_loss_state_dict': contrastive_loss.state_dict(),
        'droptoken_state_dict': droptoken.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

    torch.save({
        'text_encoder_state_dict': text_encoder.state_dict(),
        'projection_head_text_state_dict': projection_head_text.state_dict(),
    }, "text-our-vatt-" + path)

    torch.save({
        'audio_encoder_state_dict': audio_encoder.state_dict(),
        'projection_head_audio_state_dict': projection_head_audio.state_dict(),
    }, "audio-our-vatt-" + path)

    print(f"Model saved to {path}")

save_model(model_components, optimizer, num_epochs)


Model saved to trained-5-1733197704.1865647.pth


In [19]:
def count_parameters(model_components):
    """
    Counts and prints the total number of trainable parameters for each model component.

    Args:
        model_components (tuple): A tuple of PyTorch model components.

    Returns:
        int: The total number of trainable parameters in the model.
    """
    total_params = 0
    print(f"{'Component':<30}{'Parameters':>15}")
    print("-" * 50)

    # Iterate through each component in the tuple
    component_names = [
        "Audio Tokenizer",
        "Text Tokenizer",
        "Projection Head (Audio)",
        "Projection Head (Text)",
        "Contrastive Loss",
        "DropToken"
    ]

    for name, component in zip(component_names, model_components):
        if hasattr(component, "parameters"):
            component_params = sum(p.numel() for p in component.parameters() if p.requires_grad)
            total_params += component_params
            print(f"{name:<30}{component_params:>15,}")

    print("-" * 50)
    print(f"{'Total Trainable Parameters':<30}{total_params:>15,}")
    return total_params

model_components = (
    audio_encoder,
    text_encoder,
    projection_head_audio,
    projection_head_text,
    contrastive_loss,
    droptoken
)

total_params = count_parameters(model_components)

Component                          Parameters
--------------------------------------------------
Audio Tokenizer                   100,769,792
Text Tokenizer                     56,702,976
Projection Head (Audio)             2,623,744
Projection Head (Text)              1,574,656
Contrastive Loss                            0
DropToken                                   0
--------------------------------------------------
Total Trainable Parameters        161,671,168


# Evaluation

In [None]:
import torch
from transformers import BertTokenizer
import numpy as np

In [21]:
def load_ted_samples():
    """Loads test samples (audio, transcript) from the dataset"""
    # Load the dataset
    test_set = dataset.get_test_samples()
    
    sample_texts = []
    sample_audio = []
    
    # Collect sample texts and corresponding audio
    for audio, text in test_set:
        sample_texts.append(text)
        sample_audio.append(audio)
    
    return sample_audio, sample_texts

In [31]:
# Load CLAP Model (from Huggingface)
from transformers import AutoProcessor, ClapModel
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
clap_processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")
clap_model.eval()  # Set to evaluation mode

# Your trained VATT Model (already loaded as model_components)
audio_encoder, text_encoder, projection_head_audio, projection_head_text, contrastive_loss, droptoken = model_components
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
audio_encoder.eval()
text_encoder.eval()
projection_head_audio.eval()
projection_head_text.eval()
contrastive_loss.eval()
droptoken.eval()

DropToken()

In [34]:
# Step 3: Extract Summary for Our VATT Model
from sklearn.metrics.pairwise import cosine_similarity

def extract_summary_for_our_vatt(audio, text, model_components, tokenizer):
    """
    Extracts a summary for the provided audio and text using the trained VATT model.
    
    Args:
        audio (tensor): Audio tensor (waveform) for the input.
        text (str): Corresponding transcript text for the input.
        model_components (tuple): Tuple containing audio_encoder, text_encoder, projection_head_audio,
                                  projection_head_text, contrastive_loss, droptoken.
        tokenizer: Tokenizer used for text input.

    Returns:
        str: Extracted summary.
    """
    audio_encoder, text_encoder, projection_head_audio, projection_head_text, contrastive_loss, droptoken = model_components
    audio_encoder.to(device)
    text_encoder.to(device)
    projection_head_audio.to(device)
    projection_head_text.to(device)
    contrastive_loss.to(device)
    droptoken.to(device)
    
    # Process audio input
    audio = audio.unsqueeze(0).to(device)  # Add batch dimension
    audio_output = audio_encoder(audio)
    audio_embeddings = audio_output.last_hidden_state
    audio_embeddings = droptoken(audio_embeddings)
    audio_proj = projection_head_audio(audio_embeddings)  # Project to common space

    # Process text input
    text_tokenized = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt").to(device)
    text_output = text_encoder(**text_tokenized)
    text_embeddings = text_output.last_hidden_state
    text_embeddings = droptoken(text_embeddings)
    text_proj = projection_head_text(text_embeddings)  # Project to common space
    
    # Compute cosine similarity
    sim_score = cosine_similarity(audio_proj.cpu().detach().numpy(), text_proj.cpu().detach().numpy())
    
    top_k_indices = np.argsort(sim_score[0])[::-1][:5]  # Sorting to pick the top-k
    
    # Step 7: Construct the extractive summary
    summary_sentences = [text[idx] for idx in top_k_indices]
    summary = '\n'.join(summary_sentences)
    
    return summary

In [39]:
# Step 4: Extract Summary for CLAP Model
def extract_summary_for_clap(audio, text, clap_model, clap_processor, max_text_length=512):
    """
    Extracts a summary for the provided audio and text using the CLAP model.
    
    Args:
        audio (tensor): Audio tensor (waveform) for the input.
        text (str): Corresponding transcript text for the input.
        clap_model: The pre-trained CLAP model.
        clap_processor: The processor used for converting inputs to the model.
        max_text_length (int): Maximum length of text sequences for padding/truncation.

    Returns:
        str: Extracted summary.
    """

    # Process text (padding and truncation)
    text_input = clap_processor(text=text, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")

    # Process audio (convert waveform to features)
    audio_input = clap_processor(audios=audio, return_tensors="pt", padding=True)

    # Ensure that both text and audio inputs are of similar lengths (in case of mismatch)
    # This could involve padding/truncating the audio to match the length of the text
    # Here, we assume text and audio should align by some logic (this part may need model-specific adjustments)
    if audio_input["input_features"].shape[1] > text_input["input_ids"].shape[1]:
        audio_input["input_features"] = audio_input["input_features"][:, :text_input["input_ids"].shape[1]]

    with torch.no_grad():
        # Pass the inputs through the CLAP model
        outputs = clap_model(**text_input, input_features=audio_input["input_features"])

    # Similarity-based summary extraction (implement as per your logic)
    sim_score = cosine_similarity(outputs.text_embeds.cpu().detach().numpy(), outputs.audio_embeds.cpu().detach().numpy())

    top_k_indices = np.argsort(sim_score[0])[::-1][:5]  # Sorting to pick the top-k sentences based on similarity
    
    # Construct the extractive summary
    summary_sentences = [text_input['input_ids'][0][idx] for idx in top_k_indices]
    summary = ' '.join([clap_processor.decode([idx]) for idx in summary_sentences])  # Decode to text

    return summary

In [40]:
# Step 5: Presenting Results for Human Evaluation
device = torch.device("cpu")

def print_summaries(model_name, summaries):
    """
    Prints summaries in a human-readable format for evaluation.
    
    Args:
        model_name (str): Name of the model (e.g., "VATT" or "CLAP").
        summaries (list of str): List of summaries to print.
    """
    print(f"==================== {model_name} Summaries ====================")
    for idx, summary in enumerate(summaries):
        print(f"\nTest Case {idx+1}:")
        print(f"Summary: {summary}")
    print("===============================================================")


sample_audio, sample_texts = load_ted_samples()

# Prepare summaries for both models
vatt_summaries = []
clap_summaries = []
i = 1

# Extract summaries for the 5 test samples using both models
for audio, text in zip(sample_audio, sample_texts):
    print(f"Summary {i}/{len(sample_texts)}")
    i+=1

    vatt_summary = extract_summary_for_our_vatt(audio, text, model_components, tokenizer)
    vatt_summaries.append(vatt_summary)

    clap_summary = extract_summary_for_clap(audio, text, clap_model, clap_processor)
    clap_summaries.append(clap_summary)

print_summaries("VATT", vatt_summaries)
print_summaries("CLAP", clap_summaries)

Summary 1/10


It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Summary 2/10


: 