<a href="https://colab.research.google.com/github/DURUII/HIMIA-course/blob/main/DURUII/%E3%80%90C0%E3%80%91Model%20Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch
import numpy as np
import math


# 便于复现
def same_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


same_seed(2024)

# 前端

## X-VECTOR

https://readpaper.com/paper/2890964092

In [2]:
class TDNN(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False):
        super(TDNN, self).__init__()

        self.conv = nn.Conv1d(in_channels,
                              out_channels,
                              kernel_size,
                              stride,
                              padding,
                              dilation,
                              bias=bias)

        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

![](https://github.com/DURUII/HIMIA-course/blob/main/DURUII/res/gap.svg?raw=1)

In [3]:
class StatsPooling(nn.Module):

    def __init__(self):
        super(StatsPooling, self).__init__()

    def forward(self, x):
        #  eg. conv1d - [batch, 1500, T]
        #  eg. conv2d - [batch, 32*8, F, T]
        x = x.view(x.shape[0], x.shape[1], -1)
        mean = x.mean(dim=2)
        std = x.std(dim=2)
        return torch.cat([mean, std], dim=1)

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/56f6fb62f8e0c0fa1dfdcfa47af86f1c_1_Figure_1.png)

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/a1d526b4bfc73adeb1746ccfc3654803_1_Table_1.png)

In [4]:
class X_VECTOR(nn.Module):

    def __init__(self, input_dim=24, hidden_dim=512, embedding_size=512):
        super(X_VECTOR, self).__init__()
        self.frame1 = TDNN(input_dim, hidden_dim, kernel_size=5, dilation=1)
        self.frame2 = TDNN(hidden_dim, hidden_dim, kernel_size=3, dilation=2)
        self.frame3 = TDNN(hidden_dim, hidden_dim, kernel_size=3, dilation=3)
        self.frame4 = TDNN(hidden_dim, hidden_dim, kernel_size=1, dilation=1)
        self.frame5 = TDNN(hidden_dim, 1500, kernel_size=1, dilation=1)
        self.pool = StatsPooling()
        self.fc1 = nn.Linear(1500 * 2, embedding_size)
        self.fc2 = nn.Linear(embedding_size, embedding_size)

    def forward(self, x, bottleneck_last=False):
        print('Input Feature:', x.shape)
        x = self.frame1(x)
        print('Frame 1 Output:', x.shape)
        x = self.frame2(x)
        print('Frame 2 Output:', x.shape)
        x = self.frame3(x)
        print('Frame 3 Output:', x.shape)
        x = self.frame4(x)
        print('Frame 4 Output:', x.shape)
        x = self.frame5(x)
        print('Frame 5 Output:', x.shape)
        x = self.pool(x)
        print('Pooling Output:', x.shape)
        embd_a = self.fc1(x)
        print('FC1 Output:', embd_a.shape)
        embd_b = self.fc2(embd_a)
        print('FC2 Output:', embd_b.shape)
        if bottleneck_last:
            return embd_b
        return embd_a

In [5]:
feats = torch.randn(4, 24, 100)
model = X_VECTOR()
model(feats)

Input Feature: torch.Size([4, 24, 100])
Frame 1 Output: torch.Size([4, 512, 96])
Frame 2 Output: torch.Size([4, 512, 92])
Frame 3 Output: torch.Size([4, 512, 86])
Frame 4 Output: torch.Size([4, 512, 86])
Frame 5 Output: torch.Size([4, 1500, 86])
Pooling Output: torch.Size([4, 3000])
FC1 Output: torch.Size([4, 512])
FC2 Output: torch.Size([4, 512])


tensor([[ 0.1943, -0.1423,  0.0615,  ...,  0.3088,  0.0225,  0.0509],
        [ 0.2338, -0.1090,  0.1566,  ...,  0.2737,  0.0060,  0.2187],
        [ 0.2274, -0.1500,  0.0751,  ...,  0.2390,  0.0536,  0.1253],
        [ 0.1188, -0.1389,  0.1014,  ...,  0.2195,  0.0270,  0.1550]],
       grad_fn=<AddmmBackward0>)

## ResNet34

https://readpaper.com/paper/2949650786

![](https://github.com/DURUII/HIMIA-course/blob/main/DURUII/res/resblock.svg?raw=1)

In [6]:
class BuildingBlock(nn.Module):
    expansion = 1

    def __init__(self, ConvLayer, NormLayer, in_planes, planes, stride=1):
        super(BuildingBlock, self).__init__()
        self.conv1 = ConvLayer(in_planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn1 = NormLayer(planes)

        self.conv2 = ConvLayer(planes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = NormLayer(planes)

        self.relu = nn.ReLU(inplace=True)

        # eg.  Input: torch.Size([4, 32, 40, 200])
        # eg. Output: torch.Size([4, 64, 20, 100])
        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.downsample = nn.Sequential(
                ConvLayer(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                NormLayer(self.expansion * planes)
            )

    def forward(self, x):
        #  print('Block Input:', x.shape)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        x = self.downsample(x)
        out += x
        out = self.relu(out)
        #  print('Block Output:', out.shape)
        return out

In [7]:
class ResNet(nn.Module):

    def __init__(self,
                 in_planes,
                 block,
                 num_blocks,
                 num_classes=10,
                 in_ch=1,
                 feat_dim='2d',
                 **kwargs):
        super(ResNet, self).__init__()
        if feat_dim == '1d':
            self.ConvLayer = nn.Conv1d
            self.NormLayer = nn.BatchNorm1d
        elif feat_dim == '2d':
            self.ConvLayer = nn.Conv2d
            self.NormLayer = nn.BatchNorm2d
        elif feat_dim == '3d':
            self.ConvLayer = nn.Conv3d
            self.NormLayer = nn.BatchNorm3d
        else:
            print('error')

        self.in_planes = in_planes
        self.conv1 = self.ConvLayer(in_ch,
                                    in_planes,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1,
                                    bias=False)
        self.bn1 = self.NormLayer(in_planes)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block,
                                       planes=in_planes,
                                       num_blocks=num_blocks[0],
                                       stride=1)
        self.layer2 = self._make_layer(block,
                                       planes=in_planes * 2,
                                       num_blocks=num_blocks[1],
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       planes=in_planes * 4,
                                       num_blocks=num_blocks[2],
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       planes=in_planes * 8,
                                       num_blocks=num_blocks[3],
                                       stride=2)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(
                block(self.ConvLayer,
                      self.NormLayer,
                      self.in_planes,
                      planes,
                      stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/27f2cdce86e93cfe09084cc4d1009835_1_Table_2_431817400.png)

In [8]:
class StatsPoolingFlatten(nn.Module):

    def __init__(self):
        super(StatsPoolingFlatten, self).__init__()

    def forward(self, x):
        #  eg. input: [@batch, channels=256, H/F=5, W/T=T/8]
        x = x.view(x.shape[0], x.shape[1] * x.shape[2], -1)
        mean = x.mean(dim=2)
        std = x.std(dim=2)
        return torch.cat([mean, std], dim=1)


class ResNet34StatsPooling(nn.Module):

    def __init__(self, in_planes, embedding_size, dropout=0, **kwargs):
        super(ResNet34StatsPooling, self).__init__()
        self.front = ResNet(in_planes, BuildingBlock, [3, 4, 6, 3], **kwargs)
        self.pool = StatsPoolingFlatten()  # [batch, 32*8=256, F/r, T/r]
        self.bottleneck = nn.Linear(in_planes * 8 * 2 * 5, embedding_size)
        self.drop = nn.Dropout(drop) if dropout else None

    def forward(self, x):
        x = x.unsqueeze(dim=1)
        print('Input Size:', x.shape)
        x = self.front(x)
        print('ResNet Output:', x.shape)
        x = self.pool(x)
        print('Pooling Output:', x.shape)
        x = self.bottleneck(x)
        print('Embedding Output:', x.shape)

        if self.drop:
            x = self.drop(x)

        return x

In [9]:
model = ResNet34StatsPooling(in_planes=32, embedding_size=256, dropout=0)
feats = torch.randn(4, 40, 200)
embd = model(feats)

Input Size: torch.Size([4, 1, 40, 200])
ResNet Output: torch.Size([4, 256, 5, 25])
Pooling Output: torch.Size([4, 2560])
Embedding Output: torch.Size([4, 256])


## SE-Module

If a network can be enhanced from the aspect of channel relationship?
——[ImageNet 冠军模型 SE-Net 详解](https://www.bilibili.com/video/BV1Up4y187qb/?share_source=copy_web&vd_source=ef1797dacc3337c5dbcd89d489cc72a3)

https://readpaper.com/paper/2963420686



![](https://pdf.cdn.readpaper.com/parsed/fetch_target/5e8a39a0a2107065a20f8092224863f0_1_Figure_1.png)

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/5e8a39a0a2107065a20f8092224863f0_3_Figure_3.png)

In [10]:
class SELayer(nn.Module):

    def __init__(self, channel, reduction=8):
        super(SELayer, self).__init__()
        self.sq = nn.AdaptiveAvgPool2d(1)

        nn.ex = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.sq(x).view(b, c)
        y = self.ex(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [11]:
class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 ConvLayer,
                 NormLayer,
                 in_planes,
                 planes,
                 stride=1,
                 reduction=8):
        super(SEBasicBlock, self).__init__()
        self.conv1 = ConvLayer(in_planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn1 = NormLayer(planes)
        self.conv2 = ConvLayer(planes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = NormLayer(planes)
        self.relu = nn.ReLU(inplace=True)

        self.se = SELayer(planes, reduction)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.downsample = nn.Sequential(
                ConvLayer(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False), NormLayer(self.expansion * planes))

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)

        out += self.downsample(x)
        out = self.relu(out)
        return out


class SEResNet34StatsPooling(nn.Module):

    def __init__(self, in_planes, embedding_size, dropout=0.5, **kwargs):
        super(SEResNet34StatsPooling, self).__init__()
        self.front = ResNet(in_planes, SEBasicBlock, [3, 4, 6, 3], **kwargs)
        self.pool = StatsPooling()
        self.bottleneck = nn.Linear(in_planes * 8 * 2 * 5, embedding_size)
        self.drop = nn.Dropout(drop) if dropout else None

    def forward(self, x):
        x = x.unsqueeze(dim=1)
        print('Input Size:', x.shape)
        x = self.front(x)
        print('ResNet Output:', x.shape)
        x = self.pool(x)
        print('Pooling Output:', x.shape)
        x = self.bottleneck(x)
        print('Embedding Output:', x.shape)

        if self.drop:
            x = self.drop(x)

        return x

In [12]:
model = ResNet34StatsPooling(in_planes=32, embedding_size=256, dropout=0)
feats = torch.randn(4, 40, 200)
embd = model(feats)

Input Size: torch.Size([4, 1, 40, 200])
ResNet Output: torch.Size([4, 256, 5, 25])
Pooling Output: torch.Size([4, 2560])
Embedding Output: torch.Size([4, 256])


# 编码

## ASP

https://readpaper.com/paper/2794506738

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/56f6fb62f8e0c0fa1dfdcfa47af86f1c_2_Figure_2.png)

In [13]:
class ASP(nn.Module):

    def __init__(self, input_dim=1500, bottlneck_dim=120):
        super(ASP, self).__init__()
        #  eg. Frame 5 Output: torch.Size([4, 1500, 86])
        self.attention = nn.Sequential(
            # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
            nn.Conv1d(input_dim, bottlneck_dim, kernel_size=1),
            # DON'T use ReLU here! In experiments, ReLU is found hard to converge.
            nn.Tanh(),
            nn.BatchNorm1d(bottlneck_dim),
            nn.Conv1d(bottlneck_dim, input_dim, kernel_size=1),
            nn.Softmax(dim=2),
        )

    def forward(self, x):
        w = self.attention(x)
        mean = torch.sum(x * w, dim=2)

        std = torch.sqrt((
                                 torch.sum((x ** 2) * w, dim=2) - mean ** 2)
                         .clamp(min=1e-5)  # min where x in less than min
                         )

        x = torch.cat((mean, std), 1)
        return x.view(x.shape[0], -1)


class X_VECTOR_ASP(nn.Module):

    def __init__(self, input_dim=24, hidden_dim=512, embedding_size=512):
        super(X_VECTOR_ASP, self).__init__()
        self.frame1 = TDNN(input_dim, hidden_dim, kernel_size=5, dilation=1)
        self.frame2 = TDNN(hidden_dim, hidden_dim, kernel_size=3, dilation=2)
        self.frame3 = TDNN(hidden_dim, hidden_dim, kernel_size=3, dilation=3)
        self.frame4 = TDNN(hidden_dim, hidden_dim, kernel_size=1, dilation=1)
        self.frame5 = TDNN(hidden_dim, 1500, kernel_size=1, dilation=1)
        self.pool = ASP()
        self.fc1 = nn.Linear(1500 * 2, embedding_size)
        self.fc2 = nn.Linear(embedding_size, embedding_size)

    def forward(self, x, bottleneck_last=False):
        print('Input Feature:', x.shape)
        x = self.frame1(x)
        print('Frame 1 Output:', x.shape)
        x = self.frame2(x)
        print('Frame 2 Output:', x.shape)
        x = self.frame3(x)
        print('Frame 3 Output:', x.shape)
        x = self.frame4(x)
        print('Frame 4 Output:', x.shape)
        x = self.frame5(x)
        print('Frame 5 Output:', x.shape)
        x = self.pool(x)
        print('Pooling Output:', x.shape)
        embd_a = self.fc1(x)
        print('FC1 Output:', embd_a.shape)
        embd_b = self.fc2(embd_a)
        print('FC2 Output:', embd_b.shape)
        if bottleneck_last:
            return embd_b
        return embd_a

In [14]:
feats = torch.rand(4, 24, 200)
model = X_VECTOR_ASP()
embd = model(feats)

Input Feature: torch.Size([4, 24, 200])
Frame 1 Output: torch.Size([4, 512, 196])
Frame 2 Output: torch.Size([4, 512, 192])
Frame 3 Output: torch.Size([4, 512, 186])
Frame 4 Output: torch.Size([4, 512, 186])
Frame 5 Output: torch.Size([4, 1500, 186])
Pooling Output: torch.Size([4, 3000])
FC1 Output: torch.Size([4, 512])
FC2 Output: torch.Size([4, 512])


# 输出

## ArcFace

https://readpaper.com/paper/2784874046

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/e95210a97106b72ca72d1dac425ebb16_3_Figure_2.png)

$$Softmax \ Loss = -log \frac{e^{W^T_{y_i} x_i+b_{y_i}}}{\sum_{j=1}^{N} {e^{W^T_j x_i+b_j}}}$$

$$Arcface \ Loss = -log \frac{e^{s(\cos (\theta_{y_i}+m))}}{e^{s(\cos (\theta_{y_i}+m))}+\sum_{{j=1},j\ne y_i}^{n} {e^{s\cos \theta_j}}}$$

In [15]:
class AAMSoftmax(nn.Module):

    def __init__(self,
                 in_features,
                 out_features,
                 device_id,
                 s=30.0,
                 m=0.50,
                 easy_margin=False):
        super(AAMSoftmax, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        self.device_id = device_id

        # Xavier Initialization
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        if self.device_id == None:
            cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        else:
            x = input
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[0])
                weight = sub_weights[i].cuda(self.device_id[i])
                cosine = torch.cat(
                    (cosine,
                     F.linear(F.normalize(temp_x),
                              F.normalize(weight)).cuda(self.device_id[i])))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size())
        if self.device_id != None:
            one_hot = one_hot.cuda(self.device_id[0])
        one_hot.scatter(1, label.view(-1, 1).long(), 1)

        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [16]:
feats = torch.rand(4, 24, 200).cuda()
model = X_VECTOR_ASP().cuda()
embd = model(feats).cuda()
classifier = AAMSoftmax(512, 1000, device_id=[0], s=30.0, m=0.50)
output = classifier(embd, torch.tensor([0, 1, 2, 3]).cuda())
print(output.shape)

Input Feature: torch.Size([4, 24, 200])
Frame 1 Output: torch.Size([4, 512, 196])
Frame 2 Output: torch.Size([4, 512, 192])
Frame 3 Output: torch.Size([4, 512, 186])
Frame 4 Output: torch.Size([4, 512, 186])
Frame 5 Output: torch.Size([4, 1500, 186])
Pooling Output: torch.Size([4, 3000])
FC1 Output: torch.Size([4, 512])
FC2 Output: torch.Size([4, 512])
torch.Size([4, 1000])


# 作业

## SE-Res2Block

In [17]:
class Conv1dReluBn(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False):
        super(Conv1dReluBn, self).__init__()

        self.conv = nn.Conv1d(in_channels,
                              out_channels,
                              kernel_size,
                              stride,
                              padding,
                              dilation,
                              bias=bias)

        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))


class Res2DilatedConv1dReluBn(nn.Module):
    # padding and dilation shoule be carefully set equal
    def __init__(self, channels, kernel_size, dilation, padding, scale=4):
        super(Res2DilatedConv1dReluBn, self).__init__()

        self.scale = scale

        self.l = []
        for i in range(2):
            self.l.append(Conv1dReluBn(in_channels=channels, out_channels=channels, kernel_size=1,
                                       stride=1, padding=0, dilation=1, bias=False))
        self.l = nn.ModuleList(self.l)

        self.k = []
        for i in range(scale - 1):
            self.k.append(
                Conv1dReluBn(in_channels=channels // scale, out_channels=channels // scale, kernel_size=kernel_size,
                             stride=1, padding=padding, dilation=dilation, bias=False))

        self.k = nn.ModuleList(self.k)

    def forward(self, x):
        x = self.l[0](x)
        xs = torch.chunk(x, self.scale, dim=1)
        ys = [i for i in range(self.scale)]
        ys[0] = xs[0]
        ys[1] = self.k[0](xs[1])
        ys[2] = self.k[1](xs[2] + ys[1])
        ys[3] = self.k[2](xs[3] + ys[2])
        y = torch.cat(ys, dim=1)

        y = self.l[1](y)
        return y + x


class SEBlock(nn.Module):

    def __init__(self, channels, reduction=8):
        super(SEBlock, self).__init__()
        self.sq = nn.AdaptiveAvgPool1d(1)

        self.ex = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.sq(x).view(b, c)
        y = self.ex(y).view(b, c, 1)
        return x * y.expand_as(x)

In [18]:
class SERes2Block(nn.Module):
    def __init__(self, channels, kernel_size, dilation, padding):
        super(SERes2Block, self).__init__()
        self.frame1 = Conv1dReluBn(channels, channels)
        self.frame2 = Res2DilatedConv1dReluBn(channels, kernel_size, dilation, padding)
        self.frame3 = Conv1dReluBn(channels, channels)
        self.frame4 = SEBlock(channels)

    def forward(self, x):
        y1 = self.frame1(x)
        y2 = self.frame2(y1)
        y3 = self.frame3(y2)
        y4 = self.frame4(y3)
        return x + y4

In [19]:
class AAMSoftmax(nn.Module):

    def __init__(self,
                 in_features,
                 out_features,
                 device_id,
                 s=30.0,
                 m=0.50,
                 easy_margin=False):
        super(AAMSoftmax, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        self.device_id = device_id

        # Xavier Initialization
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        if self.device_id == None:
            cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        else:
            x = input
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[0])
                weight = sub_weights[i].cuda(self.device_id[i])
                cosine = torch.cat(
                    (cosine,
                     F.linear(F.normalize(temp_x),
                              F.normalize(weight)).cuda(self.device_id[i])))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size())
        if self.device_id != None:
            one_hot = one_hot.cuda(self.device_id[0])
        one_hot.scatter(1, label.view(-1, 1).long(), 1)

        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

## ECAPA-TDNN

https://readpaper.com/paper/3024869864

![](https://pdf.cdn.readpaper.com/parsed/fetch_target/7dcf7bfaa875a883299e72921f6137ac_2_Figure_2.png)

In [20]:
class ASP(nn.Module):

    def __init__(self, input_dim=1500, bottlneck_dim=120):
        super(ASP, self).__init__()
        #  eg. Frame 5 Output: torch.Size([4, 1500, 86])
        self.attention = nn.Sequential(
            # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
            nn.Conv1d(input_dim, bottlneck_dim, kernel_size=1),
            # DON'T use ReLU here! In experiments, ReLU is found hard to converge.
            nn.Tanh(),
            nn.BatchNorm1d(bottlneck_dim),
            nn.Conv1d(bottlneck_dim, input_dim, kernel_size=1),
            nn.Softmax(dim=2),
        )

    def forward(self, x):
        w = self.attention(x)
        mean = torch.sum(x * w, dim=2)

        std = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mean ** 2)
                         .clamp(min=1e-5)  # min where x in less than min
                         )

        x = torch.cat((mean, std), 1)
        return x.view(x.shape[0], -1)


class ECAPA_TDNN(nn.Module):
    def __init__(self, F=80, C=512, E=192, S=10):
        super(ECAPA_TDNN, self).__init__()
        self.layer1 = Conv1dReluBn(F, C, kernel_size=5, padding=2)
        self.layer2 = SERes2Block(C, 3, 2, 2)
        self.layer3 = SERes2Block(C, 3, 3, 3)
        self.layer4 = SERes2Block(C, 3, 4, 4)
        self.layer5 = Conv1dReluBn(3 * C, 3 * C)
        self.pool = ASP(input_dim=3 * C)
        self.layer7 = nn.Sequential(
            nn.Linear(3 * C * 2, E),
            nn.BatchNorm1d(E)
        )
        # self.layer8 = AAMSoftmax(in_features=E, out_features=S, device_id=None)

    def forward(self, x):
        y1 = self.layer1(x)
        y2 = y1 + self.layer2(y1)
        y3 = y1 + y2 + self.layer3(y2)
        y4 = y1 + y2 + y3 + self.layer4(y3)

        x5 = torch.cat([y2, y3, y4], dim=1)
        y5 = self.layer5(x5)

        y6 = self.pool(y5)

        y7 = self.layer7(y6)
        # y8 = self.layer8(y7)

        return y7

In [21]:
feats = torch.rand(4, 80, 301)
model = ECAPA_TDNN()
embd = model(feats)

classifier = AAMSoftmax(192, 1000, device_id=None, s=30.0, m=0.50)
output = classifier(embd, torch.tensor([0, 1, 2, 3]))
print(output.shape)

torch.Size([4, 1000])
