In [31]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR

In [None]:
# lstm1 = nn.LSTM(input_size=20,hidden_size=768,batch_first=False)
# lstm2 = nn.LSTM(input_size=256,hidden_size=768,batch_first=False)
# lstm3 = nn.LSTM(input_size=256,hidden_size=768,batch_first=False)

In [None]:
# out1,(hidden1,cell1) = lstm1(x)
# out2,(hidden2,cell2) = lstm2(out1)
# out3,(hidden3,cell3) = lstm3(out2)

In [None]:
# cell1.size()#[num layers* direction,batch,hidden]

In [None]:
# out1.size()#[T,B,H]

In [21]:
class SpeakerEncoder(nn.Module):
    def __init__(self,input_size,N,M,hidden_size=768,project_size=256):
        super(SpeakerEncoder,self).__init__()
        self.w = nn.Parameter(torch.tensor(10.0))
        self.b = nn.Parameter(torch.tensor(-5.0))
        self.N = N
        self.M = M
        self.lstm1 = nn.LSTM(input_size=input_size,hidden_size=hidden_size,batch_first=False)
        self.project1 = nn.Linear(hidden_size,project_size)
        self.lstm2 = nn.LSTM(input_size=project_size,hidden_size=hidden_size,batch_first=False)
        self.project2 = nn.Linear(hidden_size,project_size)
        self.lstm3 = nn.LSTM(input_size=project_size,hidden_size=hidden_size,batch_first=False)
        self.project3 = nn.Linear(hidden_size,project_size)
        
    def similarity_matrix(self,x):
        N,M = self.N,self.M
        # x [N*M,d] B=N*M,d is a vector
        c = x.split([M]*N,0)
        c = torch.mean(torch.stack(c,0),1)# centroids [N,d]
        y = x.unsqueeze(1).repeat(1,N,1)  #[N,N*M,d]
        c1 = c.unsqueeze(0).repeat(N*M,1,1) #[N,N*M,d]
        similarity = self.w*F.cosine_similarity(y,c1,dim=-1)+ self.b
        return similarity 
    
    def forward(self,x):
        x,(h0,c0) = self.lstm1(x)
        x = x.permute(1,0,2)
        x = self.project1(x)
        x = x.permute(1,0,2)
        x,(h0,c0) = self.lstm2(x)
        x = x.permute(1,0,2)
        x = self.project2(x)
        x = x.permute(1,0,2)
        x,(h0,c0) = self.lstm3(x)
        x = x.permute(1,0,2)
        x = self.project3(x)
        x = x.permute(1,0,2)
        x = x[-1,:,:]
        # l2 norm
        x = x/torch.norm(x)
        return self.similarity_matrix(x)

In [22]:
speaker_encoder = SpeakerEncoder(40,5,10)

In [23]:
x = torch.randn(100,50,40)

In [24]:
x1 = speaker_encoder(x)

In [25]:
x1.size()

torch.Size([50, 5])

In [26]:
class GE2ELoss(nn.Module):
    def __init__(self,N,M,loss_type='softmax'):
        super(GE2ELoss,self).__init__()
        self.N=N
        self.M=M
        assert loss_type in ['softmax','contrast']
        self.loss_type = loss_type
    def softmax(self,x):#论文里的这个地方说是最优化loss,应该是-loss
        N,M = self.N,self.M
        # x [N*M,N] ==> [N,M,N]
        c = x.split([M]*N,0)
        c = torch.stack(c,0)# centroids [N,M,N]
        c = F.softmax(c,-1)
        return -torch.sum(torch.sum(c,1)*torch.eye(N))
    def contrast(self,x):
        N,M = self.N,self.M
        c = x.split([M]*N,0)
        c = torch.stack(c,0)# centroids [N,M,N]
        y = F.sigmoid(x)-F.sigmoid(x.max(-1)[0].unsqueeze(2).repeat(1,1,N))
        return -torch.sum(torch.sum(y,1)*torch.eye(N))
    def forward(self,similarity_matrix):
        if self.loss_type =='softmax':
            return self.softmax(similarity_matrix) 
        else:
            return self.contrast(similarity_matrix)
        

In [29]:
loss_f = GE2ELoss(5,10)
loss = loss_f(x1)

In [30]:
mode_type = 'TD-SV'# 'TI-SV'
hidden_size,project_size = 128,64
if mode_type == 'TI-SV':
    hidden_size,project_size = 768,256
model = SpeakerEncoder(input_size=40,N=5,M=10,hidden_size=hidden_size,project_size=project_size)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
scheduler = StepLR(optimizer,step_size=30*1e6,gamma=0.5)

In [None]:
torch.gather(t,1,torch.tensor([[1,1,1],[1,1,1],[0,1,1]]))