<a href="https://colab.research.google.com/github/Youruler1/Speech-Processing-Lab-Material/blob/main/wav2vec1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Wav2vec1.0 (Scratch Implementation)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- Step 1: Feature Encoder (Temporal CNN) ----
class FeatureEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=512):
        super(FeatureEncoder, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 128, kernel_size=10, stride=5, padding=2)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=8, stride=4, padding=2)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv1d(512, hidden_dim, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(hidden_dim)  # Layer Normalization

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = x.permute(0, 2, 1)  # (B, T', C) for layer_norm
        x = self.norm(x)
        return x  # Feature embeddings

# ---- Step 2: Context Network (Deeper Temporal CNN) ----
class ContextNetwork(nn.Module):
    def __init__(self, hidden_dim=512, num_layers=9):
        super(ContextNetwork, self).__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1)
            for _ in range(num_layers)
        ])
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Convert to (B, C, T')
        for conv in self.convs:
            x = self.relu(conv(x))
        return x.permute(0, 2, 1)  # Back to (B, T', C)

# ---- Step 3: Contrastive Loss ----
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z, z_pos, z_neg):
        """
        z: Anchor representations (B, T', C)
        z_pos: Positive representations (B, T', C)
        z_neg: Negative samples (B, T', C)
        """
        z = F.normalize(z, dim=-1)
        z_pos = F.normalize(z_pos, dim=-1)
        z_neg = F.normalize(z_neg, dim=-1)

        # Compute similarity scores
        pos_sim = (z * z_pos).sum(dim=-1) / self.temperature  # Positive similarity
        neg_sim = (z * z_neg).sum(dim=-1) / self.temperature  # Negative similarity

        # Contrastive loss: maximize difference between positive & negative
        loss = -torch.mean(pos_sim - neg_sim)
        return loss

# ---- Step 4: Utility Function for Downsampling ----
def downsample(x, feature_encoder):
    """Pass positive/negative samples through the same feature encoder to match time steps."""
    with torch.no_grad():  # No gradient needed for precomputed samples
        x = feature_encoder(x)  # Now (B, T', C)
    return x

# ---- Step 5: Combine Everything into Wav2Vec 1.0 Model ----
class Wav2Vec1(nn.Module):
    def __init__(self):
        super(Wav2Vec1, self).__init__()
        self.feature_encoder = FeatureEncoder()
        self.context_network = ContextNetwork()
        self.contrastive_loss = ContrastiveLoss()

    def forward(self, x, z_pos, z_neg):
        features = self.feature_encoder(x)   # (B, T', C)
        z_pos = downsample(z_pos, self.feature_encoder)  # Match T'
        z_neg = downsample(z_neg, self.feature_encoder)  # Match T'

        context = self.context_network(features)  # (B, T', C)
        loss = self.contrastive_loss(context, z_pos, z_neg)
        return loss, context

# ---- Step 6: Training on synthetic Data ----
if __name__ == "__main__":
    # Create a random speech-like waveform (Batch size=4, Channels=1, Time=16000)
    speech_waveform = torch.randn(4, 1, 16000)
    positive_samples = torch.randn(4, 1, 16000)  # Augmented or real samples
    negative_samples = torch.randn(4, 1, 16000)  # Distractor speech samples

    model = Wav2Vec1()
    loss, _ = model(speech_waveform, positive_samples, negative_samples)

    print(f"Contrastive Loss: {loss.item()}")


Contrastive Loss: 0.011717967689037323


Load Speech Audio Dataset

In [None]:
import os
import torchaudio

# Create the 'data' directory if it doesn't exist
os.makedirs("./data", exist_ok=True)

# Now, download the dataset
dataset = torchaudio.datasets.SPEECHCOMMANDS(root="./data", download=True)
print(f"Total samples: {len(dataset)}")

waveform, sample_rate, label, *_ = dataset[0]
print(f"Waveform shape: {waveform.shape}, Label: {label}")


100%|██████████| 2.26G/2.26G [00:24<00:00, 101MB/s]


Total samples: 105829
Waveform shape: torch.Size([1, 16000]), Label: backward


Modify the existing wav2vec1.0 to a classifier for speech audio dataset

In [None]:
class Wav2VecClassifier(nn.Module):
    def __init__(self, num_classes):
        super(Wav2VecClassifier, self).__init__()
        self.feature_encoder = FeatureEncoder()
        self.context_network = ContextNetwork()
        self.fc = nn.Linear(512, num_classes)  # Classification head

    def forward(self, x):
        x = self.feature_encoder(x)
        x = self.context_network(x)
        x = torch.mean(x, dim=1)  # Global Average Pooling (B, C)
        return self.fc(x)  # Logits

Fine tune Wav2Vec Classifier on speech audio dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader # Import the DataLoader class
def preprocess_audio(waveform, sample_rate, target_length=16000):
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)  # Convert to mono
    if sample_rate != 16000:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    if waveform.shape[1] < target_length:
        pad_amount = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_amount))
    else:
        waveform = waveform[:, :target_length]
    return waveform

# ---- Load Speech Commands Dataset ----
class SpeechCommandsDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, label_map):
        self.dataset = dataset
        self.label_map = label_map

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sample_rate, label, *_ = self.dataset[idx]
        waveform = preprocess_audio(waveform, sample_rate)
        return waveform, self.label_map[label]

# ---- Training Function ----
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        total_loss, total_correct, total_samples = 0, 0, 0
        for waveforms, labels in train_loader:
            waveforms, labels = waveforms.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(waveforms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_correct += (outputs.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Accuracy={total_correct/total_samples:.4f}")

# ---- Main Execution ----
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = torchaudio.datasets.SPEECHCOMMANDS(root="./data", download=True)

    # Create label mapping
    labels = sorted(set(entry[2] for entry in dataset))
    label_map = {label: i for i, label in enumerate(labels)}

    # Prepare dataset and loader
    train_dataset = SpeechCommandsDataset(dataset, label_map)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Initialize model
    model = Wav2VecClassifier(num_classes=len(label_map)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs=10)