<a href="https://colab.research.google.com/github/01PrathamS/ECAPA-TDNN-Implementation/blob/main/notebooks/x_vector_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

class TDNNLayer(nn.Module):

  def __init__(self, input_dim, output_dim, context_size=1, dilation=1):
    super(TDNNLayer, self).__init__()
    self.tdnn = nn.Conv1d(input_dim, output_dim, kernel_size=context_size, dilation=dilation)
    self.relu = nn.ReLU()
    self.bn = nn.BatchNorm1d(output_dim)

  def forward(self, x):
    x = self.tdnn(x)
    x = self.bn(x)
    return self.relu(x)

class XVector(nn.Module):

  def __init__(self, input_dim=30, num_speakers=1000, embedding_dim=512):
    super(XVector, self).__init__()

    self.tdnn1 = TDNNLayer(input_dim, 512, context_size=5)
    self.tdnn2 = TDNNLayer(512, 512, context_size=3)
    self.tdnn3 = TDNNLayer(512, 512, context_size=3)
    self.tdnn4 = TDNNLayer(512, 512, context_size=1)
    self.tdnn5 = TDNNLayer(512, 512, context_size=1)

    self.pooling = nn.AdaptiveAvgPool1d(1)  ## Mean Pooling across time

    self.fc1 = nn.Linear(512, embedding_dim)
    self.relu = nn.ReLU()
    self.bn = nn.BatchNorm1d(embedding_dim)

    self.fc2 = nn.Linear(embedding_dim, num_speakers)

  def forward(self, x):

    x = x.permute(0, 2, 1)

    x = self.tdnn1(x)
    x = self.tdnn2(x)
    x = self.tdnn3(x)
    x = self.tdnn4(x)
    x = self.tdnn5(x)

    x = self.pooling(x)
    x = x.squeeze(-1)

    x = self.fc1(x)
    x = self.bn(x)
    x = self.relu(x)

    logits = self.fc2(x)
    return x, logits


model = XVector(input_dim=30, num_speakers=1000)
print(model)

XVector(
  (tdnn1): TDNNLayer(
    (tdnn): Conv1d(30, 512, kernel_size=(5,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn2): TDNNLayer(
    (tdnn): Conv1d(512, 512, kernel_size=(3,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn3): TDNNLayer(
    (tdnn): Conv1d(512, 512, kernel_size=(3,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn4): TDNNLayer(
    (tdnn): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tdnn5): TDNNLayer(
    (tdnn): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (relu): ReLU()
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pooling):

In [11]:
## Dummy Dataset

batch_size, time_steps, features = 16, 100, 30
num_speakers = 1000

dummy_data = torch.randn(batch_size, time_steps, features)
dummy_labels = torch.randint(0, num_speakers, (batch_size, ))


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
optimizer.zero_grad()
embedding, logits = model(dummy_data)
loss = criterion(logits, dummy_labels)
loss.backward()
optimizer.step()

print(f"Training Loss: {loss.item(): .4f}")

Training Loss:  7.0490
