In [2]:
import torch
import torch.nn as nn
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
from torch.optim import Adam

In [3]:
class TruncatedLibri(LIBRISPEECH):
    def __getitem__(self, index):
        waveform, sr, transcript, speakerID, chapterID, utteranceID = super().__getitem__(index)
        max_samples = 249600 # 15.6 seconds
        waveform = waveform[:, :max_samples]  # truncate waveform to size max_samples
        if waveform.shape[1] < max_samples:
            waveform = torch.cat([waveform, torch.zeros(1, max_samples - waveform.shape[1])], dim=1)
        return waveform, sr, transcript, speakerID, chapterID, utteranceID

train_loader = torch.utils.data.DataLoader(TruncatedLibri('.', download=True), batch_size=2, shuffle=True)


try:
    waveform, sr, transcript, speakerID, chapterID, utteranceID = next(iter(train_loader))
except Exception as e:
    print(e)
    print('May need to downgrade pytorch to 2.0.1')

In [19]:
class Encoder(nn.Module):
    def __init__(self, channels, strides = [5, 2, 2, 2, 2, 2, 2], kernel_sizes = [10, 3, 3, 3, 3, 2, 2]):
        super(Encoder, self).__init__()

        assert len(strides) == len(kernel_sizes), "strides and kernel_sizes must have the same length"

        num_layers = len(strides)

        layers = []
        for i in range(num_layers):
            if i == 0:
                layers.append(nn.Conv1d(1, channels, kernel_size=kernel_sizes[i], stride=strides[i]))
            elif i == num_layers - 1:
                layers.append(nn.Conv1d(channels, 1, kernel_size=kernel_sizes[i], stride=strides[i]))
            else:
                layers.append(nn.Conv1d(channels, channels, kernel_size=kernel_sizes[i], stride=strides[i]))
            layers.append(nn.ReLU())
            if i == num_layers - 1:

                layers.append(nn.GroupNorm(1, 1))
            else:
                layers.append(nn.GroupNorm(1, channels))

        self.encoder = torch.nn.Sequential(*layers)
        
    def forward(self, x):
        return self.encoder(x)
    
class Context(torch.nn.Module):

    def __init__(self):
        super(Context, self).__init__()
        self.context = torch.nn.Transformer(d_model = 768, nhead=8)
        
    def forward(self, x):
        return self.context(x)

class ProductQuantization(nn.Module):
    """This is how it should be:
    For the quantization module we use G = 2 and V = 320 resulting in a theoretical maximum of 102.4k codewords. Entries are of size d/G = 128"""
    def __init__(self, num_subvectors, subvector_dim, num_codebooks):
        super(ProductQuantization, self).__init__()
        self.num_subvectors = num_subvectors
        self.subvector_dim = subvector_dim
        self.num_codebooks = num_codebooks
        
        # Initialize codebooks
        self.codebooks = nn.Parameter(torch.randn(num_subvectors, num_codebooks, subvector_dim))

    def forward(self, x):
        batch_size, dim = x.shape
        assert dim == self.num_subvectors * self.subvector_dim, \
            f"Input dimension must be equal to num_subvectors * subvector_dim but was {dim} instead of {num_subvectors} x {subvector_dim} = {num_subvectors * subvector_dim}"
        
        # Reshape input into subvectors
        x = x.view(batch_size, self.num_subvectors, self.subvector_dim)
        
        # Quantize each subvector independently
        quantized_vectors = []
        for i in range(self.num_subvectors):
            subvector = x[:, i, :]
            codebook = self.codebooks[i]
            
            # Compute distances between subvector and codebook entries
            distances = torch.cdist(subvector.unsqueeze(1), codebook.unsqueeze(0))
            
            # Find nearest codebook entry
            indices = torch.argmin(distances, dim=-1)
            
            # Get quantized vectors from codebook
            quantized_vector = codebook[indices]
            quantized_vectors.append(quantized_vector)
        
        # Concatenate quantized subvectors
        quantized_vectors = torch.stack(quantized_vectors, dim=1)
        quantized_vectors = quantized_vectors.view(batch_size, dim)
        
        return quantized_vectors

class Wav2Vec2(torch.nn.Module):
    def __init__(self, encoder_channels, context_channels, num_layers= 7, num_subvectors=19, subvector_dim=41, num_codebooks=256, mask_size=10):
        super(Wav2Vec2, self).__init__()
        self.encoder = Encoder(encoder_channels)
        self.context = Context()
        self.quantization = ProductQuantization(num_subvectors, subvector_dim, num_codebooks)
        self.mask_feature_vector = nn.Parameter(torch.randn(1, 1, mask_size))
        self.mask_size = mask_size
        
    def forward(self, x):
        z = self.encoder(x)
        q = self.quantization(z.squeeze(1)).unsqueeze(1)
        masked_z = mask(z, self.mask_feature_vector, M=self.mask_size)
        c = self.context(masked_z)
        return q, c
    

def mask(z, mask_feature_vector, p = 0.065, M=10):
    # sample p indices (timesteps) to be starting indices for mask
    T = z.shape[2]
    mask_start = torch.randperm(T)[:int(p*T)]
    mask = z
    for i in mask_start:

        if len(mask) >= i+M:
            mask[:, :, i:i+M] = mask_feature_vector[:, :, :M]
        else:
            mask[:, :, i:i+M] = mask_feature_vector
    return mask


In [20]:
model = Wav2Vec2(4, 4)
q, c = model(waveform)

TypeError: Transformer.forward() missing 1 required positional argument: 'tgt'

In [191]:
d = 256
G = 2
V = 320

d/G

128.0

In [187]:
# import cosine similarity
from torch.nn.functional import cosine_similarity

cosine_similarity(q, c, dim=-1)

tensor([[0.1258],
        [0.0693]], grad_fn=<SumBackward1>)