<a href="https://colab.research.google.com/github/01PrathamS/SonicSpeech/blob/main/notebooks/HuBERT_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchaudio

In [32]:
import torch
import torch.nn as nn
import torchaudio

class ConvFeatureEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(1, 512, kernel_size=10, stride=5, padding=0),
            nn.ReLU(),
            nn.Conv1d(512, 512, kernel_size=5, stride=3, padding=0),
            nn.ReLU(),
            nn.Conv1d(512, 512, kernel_size=5, stride=3, padding=0),
            nn.ReLU(),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv1d(512, 512, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.permute(0, 2, 1)
        return x

In [33]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout=0.1, max_len=5000):
    super().__init__()
    self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))

  def forward(self, x):
    return x + self.pos_embedding[:, :x.size(1), :]

In [34]:
class HuBERTTransformer(nn.Module):
  def __init__(self, d_model=512, n_heads=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
    super().__init__()
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=d_model, nhead=n_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
    )
    self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    self.pos_encoding = PositionalEncoding(d_model)

  def forward(self, x):
    x = self.pos_encoding(x)
    return self.transformer(x)

In [35]:
class HuBERT(nn.Module):

  def __init__(self, vocab_size=100, d_model=512, num_layers=6):
    super().__init__()
    self.feature_encoder = ConvFeatureEncoder()
    self.projector = nn.Linear(512, d_model)
    self.transformer_encoder = HuBERTTransformer(d_model=d_model, num_layers=num_layers)
    self.output_layer = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    x = self.feature_encoder(x)
    x = self.projector(x)
    x = self.transformer_encoder(x)
    x = self.output_layer(x)
    return x

In [36]:
model = HuBERT(vocab_size = 100)
audio_input = torch.randn(1, 1, 16000)
output = model(audio_input)
print(output.shape)

torch.Size([1, 87, 100])


In [37]:
import torch
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()
target = torch.randint(0, 100, (1, 87))  # dummy
output = output.view(-1, 100)
target = target.view(-1)
loss = loss_fn(output, target)
print("Loss:", loss.item())


Loss: 4.774441719055176
