In [1]:
import os
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
!ls 'data/sample/'

sample-000000.txt  sample-000001.wav  sample-000003.txt  sample-000004.wav
sample-000000.wav  sample-000002.txt  sample-000003.wav  sample-000005.txt
sample-000001.txt  sample-000002.wav  sample-000004.txt  sample-000005.wav


In [3]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [19]:
class Encoder(nn.Module):
    def __init__(self, seq_len, input_size, enc_hid_dim, dec_hid_dim, dropout_rate):
        super().__init__()
        self.seq_len = seq_len
        self.input_size = input_size
        self.enc_hid_dim = enc_hid_dim
        self.dec_hid_dim = dec_hid_dim
        self.dropout_rate = dropout_rate
        
        self._input_size = (enc_hid_dim * 2) + enc_hid_dim
        
        self.bi_gru1 = nn.GRU(input_size=input_size, 
                                              hidden_size=enc_hid_dim, bidirectional=True)
        self.bi_gru2 = nn.GRU(input_size=self._input_size, 
                              hidden_size=enc_hid_dim, bidirectional=True)
        self.bi_gru3 = nn.GRU(input_size=self._input_size, 
                              hidden_size=enc_hid_dim, bidirectional=True)
        self.bi_gru_stack = nn.GRU(input_size=self._input_size, 
                              hidden_size=enc_hid_dim, bidirectional=True,
                              num_layers=3)
        self.pool =  nn.MaxPool2d(kernel_size=(2,1), stride=(2,1)) 
        
    def forward(self, src, init_hidden=None):
        # src = [seq_len, batch size, input size]
        
        ##### Bidirection layer 1 #########
        bi_gru1_out, bi_gru1_hid_st =  self.bi_gru1(src)
        # bi_gru1_out = [seq_len, batch size, enc_hid_dim * 2]
        # bi_gru1_hid_st = [n_direction, batch_size, enc_hid_dim]
        bi_gru1_cat_out = torch.cat((bi_gru1_out, 
                                     bi_gru1_hid_st.repeat(int(bi_gru1_out.size(0)/2), 1, 1)), dim=2)
        # bi_gru1_cat_out = [seq_len, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        bi_gru1_cat_out = bi_gru1_cat_out.permute(1, 0, 2)
        # bi_gru1_cat_out = [batch size, seq_len, (enc_hid_dim * 2) + enc_hid_dim]
        
        bi_gru1_cat_out = bi_gru1_cat_out.unsqueeze(1)
        # bi_gru1_cat_out = [batch size, 1, seq_len, (enc_hid_dim * 2) + enc_hid_dim]
        
        ######### Polling 1 #######
        pool_out1 = self.pool(bi_gru1_cat_out)
        # pool_out = [batch size, 1, seq_len / 2, (enc_hid_dim * 2) + enc_hid_dim]
        pool_out1 = pool_out1.squeeze(1)
        # pool_out = [batch size, seq_len / 2, ((enc_hid_dim * 2) + enc_hid_dim)]
        
        pool_out1 = pool_out1.permute(1, 0, 2)
        # pool_out = [seq_len / 2, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        #### Bidirection Layer 2 ##########
        bi_gru2_out, bi_gru2_hid_st =  self.bi_gru2(pool_out1)
        # bi_gru2_out = [seq_len/2, batch size, enc_hid_dim * 2]
        # bi_gru2_hid_st = [n_direction, batch size, enc_hid_dim]
        
        bi_gru2_cat_out = torch.cat((bi_gru2_out, 
                                    bi_gru2_hid_st.repeat(int(bi_gru2_out.size(0)/2), 1, 1)), dim=2)
        # bi_gru2_cat_out = [seq_len/2, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        bi_gru2_cat_out = bi_gru2_cat_out.permute(1, 0, 2).unsqueeze(1)
        # bi_gru2_cat_out = [batch size, 1, seq_len/2, (enc_hid_dim * 2) + enc_hid_dim]
        
        #### Pooling 2 ##########
        pool_out2 = self.pool(bi_gru2_cat_out).squeeze(1)
        # pool_out2 = [batch size, seq_len / 4, (enc_hid_dim * 2) + enc_hid_dim]
        
        pool_out2 = pool_out2.permute(1, 0, 2)
        # pool_out2 = [seq_len / 4, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        #### Bidirection Layer 3 ##########
        bi_gru3_out, bi_gru3_hid_st =  self.bi_gru3(pool_out2)
        # bi_gru3_out = [seq_len/4, batch size, enc_hid_dim * 2]
        # bi_gru3_hid_st = [n_direction, batch size, enc_hid_dim]
        
        bi_gru3_cat_out = torch.cat((bi_gru3_out, 
                                    bi_gru3_hid_st.repeat(int(bi_gru3_out.size(0)/2), 1, 1)), dim=2)
        # bi_gru3_cat_out = [seq_len/4, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        bi_gru3_cat_out = bi_gru3_cat_out.permute(1, 0, 2).unsqueeze(1)
        # bi_gru3_cat_out = [batch size, 1, seq_len/4, (enc_hid_dim * 2) + enc_hid_dim]
        
        #### Pooling 3 ##########
        pool_out3 = self.pool(bi_gru3_cat_out).squeeze(1)
        # pool_out3 = [batch size, seq_len / 6, (enc_hid_dim * 2) + enc_hid_dim]
        
        pool_out3 = pool_out3.permute(1, 0, 2)
        # pool_out3 = [seq_len / 6, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        ### Bidirectional Stack ######
        bi_gru_stack_out, bi_gru_stack_hid_st =  self.bi_gru_stack(pool_out3)
        # bi_gru_stack_out = [seq_len/6, batch size, enc_hid_dim * 2]
        # bi_gru_stack_hid_st = [n_direction, batch size, enc_hid_dim]
        
        zeros = torch.zeros(bi_gru_stack_out.size(0)%bi_gru_stack_hid_st.size(0), 1, self.enc_hid_dim)
        bi_gru_stack_hid_st = bi_gru_stack_hid_st.repeat(
                                        int(bi_gru_stack_out.size(0)/bi_gru_stack_hid_st.size(0)), 1, 1)
        extended_hid_st = torch.cat((bi_gru_stack_hid_st, zeros), dim=0)
        bi_gru_stack_cat_out = torch.cat((bi_gru_stack_out, extended_hid_st), dim=2)
        # bi_gru_stack_cat_out = [seq_len/6, batch size, (enc_hid_dim * 2) + enc_hid_dim]
        
        bi_gru_stack_cat_out = bi_gru_stack_cat_out.permute(1, 0, 2).unsqueeze(1)
        # bi_gru_stack_cat_out = [batch size, 1, seq_len/6, (enc_hid_dim * 2) + enc_hid_dim]
        
        return bi_gru_stack_cat_out

### Check Encoder

In [20]:
import librosa
import warnings
warnings.filterwarnings("ignore")

y, sr = librosa.load('data/sample/sample-000003.wav', sr=16000)
mel_spectrogram = librosa.feature.melspectrogram(y, sr)
mel_spectrogram.shape, sr

((128, 241), 16000)

In [21]:
src = torch.from_numpy(mel_spectrogram.reshape(128, 1, -1)).float() #.reshape(129, 1, 227)
src.size()

torch.Size([128, 1, 241])

In [22]:
SEQ_LEN = src.size(0)
INPUT_SIZE = src.size(2)
ENC_HID_DIM = 256
DEC_HID_DIM = 256 
DROPOUT_RATE = 0.2

encoder = Encoder(SEQ_LEN, INPUT_SIZE, ENC_HID_DIM, DEC_HID_DIM, DROPOUT_RATE)
out = encoder(src)

In [23]:
out.size()#, hid.size()

torch.Size([1, 1, 16, 768])

In [24]:
encoder

Encoder(
  (bi_gru1): GRU(241, 256, bidirectional=True)
  (bi_gru2): GRU(768, 256, bidirectional=True)
  (bi_gru3): GRU(768, 256, bidirectional=True)
  (bi_gru_stack): GRU(768, 256, num_layers=3, bidirectional=True)
  (pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
)