#Dependencies

In [1]:
import torch
import torch.nn as nn
import torchaudio

from transformers import AutoFeatureExtractor, AutoModel, RobertaTokenizer, RobertaModel
from transformers.modeling_outputs import BaseModelOutput

#Pretrained Models

In [2]:
audio_model_type = 'facebook/hubert-xlarge-ll60k'
text_model_type = 'FacebookAI/roberta-base'

In [3]:
audio_feature_extractor = AutoFeatureExtractor.from_pretrained(audio_model_type)
audio_model = AutoModel.from_pretrained(audio_model_type)

text_tokenizer = RobertaTokenizer.from_pretrained(text_model_type)
text_model = RobertaModel.from_pretrained(text_model_type)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/3.85G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.85G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#Contrastive Loss

In [None]:
def contrastive_loss(embeddings, labels, temp=1.0):
  B, D = embedding.shape
  embeddings = F.normalize(embeddings, dim=-1)
  similarity_matrix = torch.matmul(embeddings, embeddings.T) / temp

  mask = torch.eye(B, dtype=torch.bool)
  pos_mask = labels.unsqueeze(0) == labels.unsqueeze(1)

  similarity_matrix[mask] = float('-inf')

  pos_sim = similarity_matrix.clone()
  pos_sim[~pos_mask] = float('-inf')

  pos_sim_exp = torch.exp(pos_sim)
  sim_exp = torch.exp(similarity_matrix).sum(dim=0)

  nll = -torch.log(pos_sim_exp / sim_exp)
  nll[mask + ~pos_mask] = 0
  nll = nll.sum(dim=0)

  loss = (nll / (pos_mask.sum(dim=0) + 1e-6)).sum()

  return loss

#Audio Encoder Block

In [None]:
class AudioEncoderBlock(nn.Module):
  def __init__(self, feature_extractor, hidden_dim, output_dim, mean_pooling, dropout=0):
    super(AudioEncoderBlock, self).__init__()

    self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor)
    self.model = AutoModel.from_pretrained(feature_extractor)

    self.mlp = nn.Sequential(
        nn.Linear(hidden_dim, output_dim),
        nn.SiLU(),
        nn.Dropout(dropout)
    )

    self.mean_pooling = mean_pooling

  def forward(self, x):
    inputs = self.feature_extractor(x, sampling_rate=16000, return_tensors='pt')
    x = self.model(**inputs)

    if self.mean_pooling:
      x = x.last_hidden_state.mean(dim=1)
    else:
      x = x.last_hidden_state[:, 0, :]

    return self.mlp(x)

#Residual Vector Quantinization

In [None]:
class ResidualVectorQuantization(nn.Module):
  def __init__(self, num_embed, embed_dim, num_stages, commitment_cost=0.25):
    super(ResidualVectorQuantization, self).__init__()

    self.num_embed = num_embed
    self.embed_dim = embed_dim
    self.num_stages = num_stages
    self.commitment_cost = commitment_cost

    self.codebooks = torch.nn.ModuleList([
        torch.nn.Embedding(num_embed, embed_dim) for _ in range(num_stages)
    ])

    for codebook in self.codebooks:
        torch.nn.init.uniform_(codebook.weight.data, -1.0 / num_embed, 1.0 / num_embed)

  def forward(self, x):
    B, C, H, W = x.shape

    if C != self.embed_dim:
      raise ValueError(
          f"Input channels ({C}) must match the embedding dimension ({self.embed_dim})."
      )

    current_residual = x
    quantized_output_sum = torch.zeros_like(x)
    total_vq_loss = 0.0

    for i in range(self.num_stages):
        codebook = self.codebooks[i].weight
        reshaped_residual = current_residual.permute(0, 2, 3, 1).reshape(-1, C)

        distances = torch.cdist(reshaped_residual, codebook, p=2)
        nearest_indices = torch.argmin(distances, dim=1)
        quantized_vector_raw = codebook[nearest_indices] # The chosen embedding

        # Calculate VQ loss for this stage
        # commitment_loss: ||sg(encoder_output) - e_q||^2
        # codebook_loss: ||encoder_output - sg(e_q)||^2

        # The `inputs` in the standard VQ-VAE loss is `current_residual` for this stage.
        # `quantized` is `quantized_vector_raw` reshaped.
        quantized_vector_reshaped = quantized_vector_raw.reshape(B, H, W, C).permute(0, 3, 1, 2)

        # e_latent_loss encourages encoder output to be close to chosen embedding
        # q_latent_loss encourages chosen embedding to be close to encoder output
        e_latent_loss_stage = F.mse_loss(quantized_vector_reshaped.detach(), current_residual)
        q_latent_loss_stage = F.mse_loss(quantized_vector_reshaped, current_residual.detach())
        vq_loss_stage = q_latent_loss_stage + self.commitment_cost * e_latent_loss_stage
        total_vq_loss += vq_loss_stage

        # Straight-through estimator: copy gradients from `quantized_vector_raw` to `reshaped_residual`
        # This is the actual quantized output passed through the network
        quantized_vector_st = reshaped_residual + (quantized_vector_raw - reshaped_residual).detach()
        quantized_vector_st = quantized_vector_st.reshape(B, H, W, C).permute(0, 3, 1, 2)


        quantized_output_sum = quantized_output_sum + quantized_vector_st
        current_residual = current_residual - quantized_vector_st # Update residual for next stage

    return quantized_output_sum, total_vq_loss

#Audio Encoder

In [15]:
class AudioEncoder(nn.Module):
    def __init__(self,
                 encoder_hps,
                 vq_hps,
                 dropout=0):
      super(AudioEncoder, self).__init__()

      self.encoder_block = AudioEncoderBlock(encoder_hps.feature_extractor,
                                             encoder_hps.model,
                                             encoder_hps.hidden_dim,
                                             encoder_hps.output_dim,
                                             encoder_hps.mean_pooling,
                                             dropout)

      self.speech_partition = nn.Linear(encoder_hps.output_dim, encoder_hps.output_dim)
      self.emotion_partition = nn.Linear(encoder_hps.output_dim, encoder_hps.output_dim)

      self.speech_rvq = ResidualVectorQuantization(vq_hps.num_embed,
                                                   vq_hps.embed_dim,
                                                   vq_hps.num_stages,
                                                   vq_hps.commitment_cost)
      self.emotion_rvq = ResidualVectorQuantization(vq_hps.num_embed,
                                                    vq_hps.embed_dim,
                                                    vq_hps.num_stages,
                                                     vq_hps.commitment_cost)
    def forward(self, x):
      x = self.encoder_block(x)
      speech_x = self.speech_partition(x)
      emotion_x = self.emotion_partition(x)

      speech_x, speech_loss = self.speech_rvq(speech_x)
      emotion_x, emotion_loss = self.emotion_rvq(emotion_x)

      return speech_x, emotion_x, speech_loss + emotion_loss

In [6]:
waveform = torch.randn(10, 180000)

In [13]:
inputs = audio_feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")

In [14]:
inputs

{'input_values': tensor([[[-1.6219, -0.5797, -1.5807,  ..., -0.7034,  1.1353,  0.9758],
         [-0.7096,  0.3576,  1.6078,  ..., -0.2408,  0.5288,  0.0708],
         [-0.5138, -1.0793, -0.5743,  ...,  1.2155,  0.7204,  1.2709],
         ...,
         [ 0.4996, -0.2346,  1.2999,  ..., -1.3518, -0.1171,  2.2759],
         [-0.8121, -1.0908,  0.1171,  ..., -1.3943,  0.0096, -0.6555],
         [-0.5125, -0.8679,  2.0057,  ..., -0.9939,  1.2328, -0.8414]]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)}

In [9]:
x = audio_model(**inputs)

In [10]:
x.last_hidden_state.shape

torch.Size([10, 812, 1280])

In [10]:
x.last_hidden_state.mean(dim=1).shape

torch.Size([10, 1280])

In [11]:
tokens = text_tokenizer("Hello world", return_tensors="pt")

In [13]:
embedding = text_model(**tokens)

In [18]:
embedding.last_hidden_state[:, 0, :].shape

torch.Size([1, 768])