<a href="https://colab.research.google.com/github/01PrathamS/SonicSpeech/blob/main/notebooks/Implementation_of_Wav2vec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Architecture



1.   Feature Encoder
2.   Quantization Module
3.   Transformer Encoder
4.   Fine Tuning Layer (for ASR)
    - can be fine tuned with ctc loss, linear layer mapping features to text



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class FeatureEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=512):
        super(FeatureEncoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=10, stride=5, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv_layers(x)

In [19]:

class VectorQuantization(nn.Module):
    def __init__(self, input_dim=1024, codebook_size=320):
        super(VectorQuantization, self).__init__()
        self.embedding = nn.Embedding(codebook_size, input_dim)

    def forward(self, x):
        batch_size, channels, seq_len = x.shape
        x_flat = x.permute(0, 2, 1).contiguous().view(-1, channels)
        distances = torch.cdist(x_flat.unsqueeze(0), self.embedding.weight.unsqueeze(0), p=2).squeeze(0)

        nearest_idx = distances.argmin(dim=-1)
        quantized = self.embedding(nearest_idx).view(batch_size, seq_len, -1)
        return quantized.permute(0, 2, 1)

In [20]:
class TransformerEncoder(nn.Module):
    def __init__(self, input_dim=1024, num_layers=6, num_heads=8):
        super(TransformerEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim, nhead=num_heads, dim_feedforward=2048, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        return self.transformer(x)

In [21]:
class Wav2Vec2(nn.Module):

    def __init__(self, input_dim=1, hidden_dim=512, quant_codebook_size=320, transformer_layers=6, num_heads=8):
        super(Wav2Vec2, self).__init__()
        self.feature_encoder = FeatureEncoder(input_dim, hidden_dim)
        self.quantizer = VectorQuantization(hidden_dim * 2, quant_codebook_size)
        self.transformer_encoder = TransformerEncoder(hidden_dim * 2, transformer_layers, num_heads)
        self.output_layer = nn.Linear(hidden_dim * 2, 29)

    def forward(self, x):
        x = self.feature_encoder(x)
        x = self.quantizer(x)
        x = x.permute(0, 2, 1)
        x = self.transformer_encoder(x)
        x = self.output_layer(x)
        return x


In [22]:
model = Wav2Vec2()
x = torch.randn(2, 1, 16000)
output = model(x)
print(output.shape)

torch.Size([2, 1600, 29])


In [23]:
class wav2vec2CTC(nn.Module):
  def __init__(self, vocab_size=29):
    super(wav2vec2CTC, self).__init__()
    self.wav2vec2 = Wav2Vec2()
    self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    self.vocab_size = vocab_size

  def forward(self, x, labels, input_lengths, label_lengths):
    logits = self.wav2vec2(x)
    logits = logits.permute(1, 0, 2)

    loss = self.ctc_loss(
        torch.nn.functional.log_softmax(logits, dim=-1),
        labels,
        input_lengths,
        label_lengths
    )
    return loss, logits

model = wav2vec2CTC()