<a href="https://colab.research.google.com/github/01PrathamS/ECAPA-TDNN-Implementation/blob/main/notebooks/ECAPA_TDNN_architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class SEBlock(nn.Module):

  def __init__(self, channels, reduction=16):
    super(SEBlock, self).__init__()
    self.fc1 = nn.Linear(channels, channels // reduction)
    self.fc2 = nn.Linear(channels // reduction, channels)

  def forward(self, x):
    scale = x.mean(dim=-1)
    scale = F.relu(self.fc1(scale))
    scale = torch.sigmoid(self.fc2(scale))
    return x * scale.unsqueeze(-1)


class Res2Block(nn.Module):

  def __init__(self, in_channels, out_channels, scale=4):
    super(Res2Block, self).__init__()
    self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
    self.res2_layers = nn.ModuleList([
        nn.Conv1d(out_channels // scale, out_channels // scale, kernel_size=3, padding=1)
        for _ in range(scale - 1)
    ])
    self.conv3 = nn.Conv1d(out_channels, out_channels, kernel_size=1)
    self.bn = nn.BatchNorm1d(out_channels)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.conv1(x)
    split = torch.chunk(x, len(self.res2_layers) + 1, dim=1)
    res_outs = [split[0]]
    for i, conv in enumerate(self.res2_layers):
      res_outs.append(conv(split[i + 1]) + res_outs[-1])
    x = torch.cat(res_outs, dim=1)
    x = self.conv3(x)
    return self.relu(self.bn(x))

class AttentiveStatisticsPooling(nn.Module):
    def __init__(self, input_dim):
        super(AttentiveStatisticsPooling, self).__init__()
        self.attention = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        attn_weights = torch.tanh(self.attention(x.permute(0, 2, 1)))
        attn_weights = torch.softmax(attn_weights, dim=1)
        mean = torch.sum(x * attn_weights.permute(0, 2, 1), dim=-1)
        std = torch.sqrt(torch.sum(x**2 * attn_weights.permute(0, 2, 1), dim=-1) - mean**2)
        return torch.cat([mean, std], dim=1)

In [4]:
class ECAPA_TDNN(nn.Module):

  def __init__(self, input_dim=30, emb_dim=512, num_speakers=1000):
    super(ECAPA_TDNN, self).__init__()

    self.layer1 = Res2Block(input_dim, 512)
    self.layer2 = Res2Block(512, 512)
    self.layer3 = Res2Block(512, 512)

    # Multi layer feature aggregation
    self.mfa = nn.Conv1d(3 * 512, 1536, kernel_size=1)

    self.se_block = SEBlock(1536)

    self.pooling = AttentiveStatisticsPooling(1536)

    self.fc1 = nn.Linear(1536 * 2, emb_dim)
    self.bn = nn.BatchNorm1d(emb_dim)

    self.fc2 = nn.Linear(emb_dim, num_speakers)

  def forward(self, x):

    x = x.permute(0, 2, 1)

    x1 = self.layer1(x)
    x2 = self.layer2(x1)
    x3 = self.layer3(x2)


    x = torch.cat([x1, x2, x3], dim=1)
    x = self.mfa(x)

    x = self.se_block(x)

    x = self.pooling(x)

    x = self.fc1(x)
    x = self.bn(x)
    x = F.relu(x)

    logits = self.fc2(x)

    return x, logits


model = ECAPA_TDNN(input_dim=30, emb_dim=512, num_speakers=1000)
print(model)

ECAPA_TDNN(
  (layer1): Res2Block(
    (conv1): Conv1d(30, 512, kernel_size=(1,), stride=(1,))
    (res2_layers): ModuleList(
      (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (conv3): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer2): Res2Block(
    (conv1): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (res2_layers): ModuleList(
      (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (conv3): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer3): Res2Block(
    (conv1): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (res2_layers): ModuleList(
      (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (conv3): Conv1d(512, 512, kernel_si

In [24]:
## Training Example

batch_size, time_steps, features = 16, 100, 30

num_speakers = 1000

dummy_data = torch.randn(batch_size, time_steps, features)
dummy_labels = torch.randint(0, num_speakers, (batch_size,))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
optimizer.zero_grad()
embedding, logits = model(dummy_data)
loss = criterion(logits, dummy_labels)
loss.backward()
optimizer.step()


print(f"Training Loss: {loss.item():.4f}")


Training Loss: 6.8680
