In [3]:
import numpy as np
import librosa
import glob
import os
from random import randint
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import matplotlib.pyplot as plt
%matplotlib inline

import import_ipynb

In [4]:
class SEN_classify(nn.Module):
    def __init__(self):
        super().__init__()
        self.siamese_cnn = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=[4,128],stride=[1,128]),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=[4,1],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 256, kernel_size=[4,1],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            )
        #max_pool keep padding="same"
        self.late_cnn = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=[3,3],stride=[3,3],padding=1),
            
            nn.Conv2d(64, 128, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=[3,3],stride=[3,3],padding=[1,0]),

            nn.Conv2d(128, 256, kernel_size=[3,3],stride=[1,1]),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            )
        self.fcWithDropout = nn.Sequential(
            nn.Linear(768, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            
            nn.Linear(1024, 2),
            nn.Softmax(dim=1)
            )
    def cal_concat(self,matrix1,matrix2):
        out1 = torch.squeeze(matrix1,dim=3)
        out2 = torch.squeeze(matrix2,dim=3)
        fms1 = torch.unsqueeze(out1,dim=1)
        fms2 = torch.unsqueeze(out1,dim=1)
        return torch.cat([fms1, fms2], 1)
    
    def _reduce_var(self,inputs):
        m1 = torch.mean(inputs,dim=2,keepdim=True)
        m = torch.mean(m1,dim=3,keepdim=True)
        devs_squared1 = torch.mul(inputs - m,inputs - m)
        #not keep dim
        devs_squared2 = torch.mean(devs_squared1,dim=2)
        devs_squared = torch.mean(devs_squared2,dim=2)
        return devs_squared

    def cal_global_pool(self,matrix):
        # not keep dim
        g_max1,_ = torch.max(matrix,dim=2)
        g_max,_ = torch.max(g_max1,dim=2)
        g_mean1 = torch.mean(matrix,dim=2)
        g_mean = torch.mean(g_mean1,dim=2)
        g_var = self._reduce_var(matrix)
        return torch.cat([g_max, g_mean, g_var], 1)
        
    def forward(self,x1,x2):
        out1 = self.siamese_cnn(x1)
        out2 = self.siamese_cnn(x2)
        concat = self.cal_concat(out1,out2)
        late_cnn_out = self.late_cnn(concat)
        golbal_pool_out = self.cal_global_pool(late_cnn_out)
        predictions = self.fcWithDropout(golbal_pool_out)
        return predictions

In [5]:
def test_SEN_classify():
    #batchsize channel height width
    x1 = torch.rand(32,1,345,128)
    x2 = torch.rand(32,1,345,128)
    model = SEN_classify()
    pre1 = model(x1,x2)
    print(pre1.shape)
#     print(pre1)
    return pre1

# test_SEN_classify()

torch.Size([32, 2])


tensor([[0.6261, 0.3739],
        [0.5096, 0.4904],
        [0.5704, 0.4296],
        [0.5873, 0.4127],
        [0.6084, 0.3916],
        [0.6255, 0.3745],
        [0.5223, 0.4777],
        [0.6388, 0.3612],
        [0.6859, 0.3141],
        [0.5932, 0.4068],
        [0.4961, 0.5039],
        [0.5108, 0.4892],
        [0.7574, 0.2426],
        [0.6056, 0.3944],
        [0.6148, 0.3852],
        [0.4939, 0.5061],
        [0.5957, 0.4043],
        [0.7272, 0.2728],
        [0.7381, 0.2619],
        [0.5266, 0.4734],
        [0.4489, 0.5511],
        [0.5259, 0.4741],
        [0.6374, 0.3626],
        [0.4956, 0.5044],
        [0.4732, 0.5268],
        [0.3317, 0.6683],
        [0.5426, 0.4574],
        [0.4456, 0.5544],
        [0.5735, 0.4265],
        [0.5398, 0.4602],
        [0.5428, 0.4572],
        [0.7272, 0.2728]], grad_fn=<SoftmaxBackward>)