In [5]:
import import_ipynb

import numpy as np
import librosa
import glob
import os
from random import randint
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class TripletLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.siamese_cnn = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=[4,128],stride=[1,128]),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=[4,1],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, kernel_size=[4,1],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            )
        #max_pool need keep input.shape=output.shape
        self.late_cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=[3,3],stride=[3,3],padding=1),

            nn.Conv2d(64, 128, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=[3,3],stride=[3,3],padding=[1,0]),

            nn.Conv2d(128, 256, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            )
        self.fcWithDropout = nn.Sequential(
            nn.Linear(256, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(p=0.5),

            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(p=0.5),

            nn.Linear(1024, 2),
            nn.Softmax(dim=1),
            )
        self.rnnLSTM = nn.LSTM(
            input_size=256, 
            hidden_size=128,
            num_layers=2,
            dropout=0.1,
            batch_first=True,
            bidirectional=True,
            )
        
    def cal_similarity(self,matrix1,matrix2):
        #out1.shape = torch.Size([10, 256, 336])
        out1 = torch.squeeze(matrix1,dim=3)
        out2 = torch.squeeze(matrix2,dim=3)
        #num.shape = torch.Size([10, 336, 336])
        num = torch.bmm(torch.transpose(out1,1,2),out2)
        h1_norm = torch.sqrt(torch.sum(torch.mul(out1,out1), dim=1, keepdim=True))
        h2_norm = torch.sqrt(torch.sum(torch.mul(out2,out2), dim=1, keepdim=True))
        #denom torch.Size([10, 336, 336])
        denom = torch.bmm(torch.transpose(h1_norm,1,2),h2_norm)
        #fms shape torch.Size([10, 1, 336, 336])
        fms = torch.unsqueeze(torch.div(num,denom),dim=1)
        return fms
    
    def _reduce_var(self,inputs):
        m1 = torch.mean(inputs,dim=2,keepdim=True)
        m = torch.mean(m1,dim=3,keepdim=True)
        devs_squared1 = torch.mul(inputs - m,inputs - m)
        #not keep dim
        devs_squared2 = torch.mean(devs_squared1,dim=2)
        devs_squared = torch.mean(devs_squared2,dim=2)
        return devs_squared
    
    def cal_global_pool(self,matrix):
        # not keep dim
        g_max1,_ = torch.max(matrix,dim=2)
        g_max,_ = torch.max(g_max1,dim=2)
        g_mean1 = torch.mean(matrix,dim=2)
        g_mean = torch.mean(g_mean1,dim=2)
        g_var = self._reduce_var(matrix)
        return torch.cat([g_max, g_mean, g_var], 1)
    
    def cal_lstm(self,seqInput):
        #view batchsize seq-length input-size
        #dataParaller may change the batchsize
        lstm_batchSize = list(seqInput.size())[0]
        seqInput = seqInput.view(lstm_batchSize,3,-1)
        outLstm,(_,_) = self.rnnLSTM(seqInput)
        return outLstm[:,-1,:]
        
    def forward(self,x1,x2,x3):
        out1 = self.siamese_cnn(x1)
        out2 = self.siamese_cnn(x2)
        out3 = self.siamese_cnn(x2)
        out4 = self.siamese_cnn(x3)
        similarity1 = self.cal_similarity(out1,out2)
        similarity2 = self.cal_similarity(out3,out4)
        late_cnn_out = self.late_cnn((similarity1 + similarity2)/2)
        golbal_pool_out = self.cal_global_pool(late_cnn_out)
        lstmout = self.cal_lstm(golbal_pool_out)
        predictions = self.fcWithDropout(lstmout)
        return predictions
    
    

In [3]:
def test_classify():
    x1 = torch.rand(32,1,345,128)
    x2 = torch.rand(32,1,345,128)
    x3 = torch.rand(32,1,345,128)
    model = TripletLSTM()
    pre1 = model(x1,x2,x3)
    print(pre1.shape)
    print(pre1)
    return pre1
    
# target = test_classify()

torch.Size([32, 2])
tensor([[0.5346, 0.4654],
        [0.3655, 0.6345],
        [0.5059, 0.4941],
        [0.4798, 0.5202],
        [0.4739, 0.5261],
        [0.4640, 0.5360],
        [0.5890, 0.4110],
        [0.4330, 0.5670],
        [0.5491, 0.4509],
        [0.7732, 0.2268],
        [0.3904, 0.6096],
        [0.5019, 0.4981],
        [0.6882, 0.3118],
        [0.5308, 0.4692],
        [0.4678, 0.5322],
        [0.4942, 0.5058],
        [0.5599, 0.4401],
        [0.4814, 0.5186],
        [0.5574, 0.4426],
        [0.3310, 0.6690],
        [0.4170, 0.5830],
        [0.5928, 0.4072],
        [0.7525, 0.2475],
        [0.5751, 0.4249],
        [0.4270, 0.5730],
        [0.6850, 0.3150],
        [0.5700, 0.4300],
        [0.5453, 0.4547],
        [0.4831, 0.5169],
        [0.5219, 0.4781],
        [0.7045, 0.2955],
        [0.4871, 0.5129]], grad_fn=<SoftmaxBackward>)
