In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import hyperparams as hp
import numpy as np
import math
import glu

In [30]:
class Encoder(nn.Module):
    """
    Encoder Network
    """
    def __init__(self, para):
        """
        :param para: dictionary that contains all parameters
        """
        super(Encoder, self).__init__()
        #self.alpha = nn.Parameter(t.ones(1))
        
        self.emb_phone = nn.Embedding(para['phone_size'], para['emb_dim'])
        #full connected
        self.fc_1 = nn.Linear(para['emb_dim'], para['GLU_in_dim'])
        
        self.GLU = glu.GLU(para['num_layers'], para['hidden_size'], para['kernel_size'], para['dropout'], para['GLU_in_dim'])
        
        self.fc_2 = nn.Linear(para['hidden_size'], para['emb_dim'])
        
    def refine(self, align_phone):
        '''filter silence phone and repeat phone'''
        out = []
        length = []
        batch_size = align_phone.shape[0]
        max_length = align_phone.shape[1]
        before = 0
        for i in range(batch_size):
            line = []
            for j in range(max_length):
                if align_phone[i][j] == 1 or align_phone[i][j] == 0:      #silence phone or padding
                    continue
                elif align_phone[i][j] == before:   #the same with the former phone
                    continue
                else:
                    before = align_phone[i][j]
                    line.append(before)
            out.append(line)
            length.append(len(line))
        
        #pad 0
        seq_length = max(length)
        Data = np.zeros((batch_size, seq_length))
        for i in range(batch_size):
            for j in range(seq_length):
                if j < len(out[i]):
                    Data[i][j] = out[i][j]
                    
        return torch.from_numpy(Data).type(torch.LongTensor)
        
    def forward(self, input):
        """
        input dim: [batch_size, text_phone_length]
        output dim : [batch_size, text_phone_length, embedded_dim]
        """
        input = self.refine(input)
        print(input)
        embedded_phone = self.emb_phone(input)    # [src len, batch size, emb dim]
        print(embedded_phone.shape,embedded_phone)
        glu_out = self.GLU(self.fc_1(embedded_phone))
        print(glu_out.shape)
        glu_out = self.fc_2(torch.transpose(glu_out, 1, 2))
        print(glu_out.shape,glu_out)
        out = embedded_phone + glu_out
        print(out.shape,out)
        out = out *  math.sqrt(0.5)
        print(out.shape,out)
        return out


In [33]:
para = {'phone_size':67, 'emb_dim':256, 'GLU_in_dim':64, 'num_layers':6, 'kernel_size':3, 'hidden_size':64, 'dropout':0.1 }
encoder = Encoder(para)
phone = torch.tensor([[1,3,3,3,3,5,5,6,0,0,0],[1,1,1,4,2,2,2,3,7,1,1]])
out = encoder(phone)
print(out.shape)

tensor([[3, 5, 6, 0],
        [4, 2, 3, 7]])
torch.Size([2, 4, 256]) tensor([[[-0.1634, -1.1551,  2.1703,  ...,  0.0700, -0.9398,  0.2773],
         [-0.0096,  0.2066, -1.1827,  ...,  0.0633, -0.4822,  1.9444],
         [-0.1730,  0.8219, -0.7822,  ...,  1.5981,  0.8638, -0.0104],
         [ 1.1558, -0.1630, -0.2812,  ..., -1.0534, -0.5543,  2.0829]],

        [[ 1.2973, -0.3188, -0.8178,  ..., -0.3757,  1.2727, -0.7920],
         [ 0.2669, -0.5398, -1.0456,  ...,  0.2583,  0.4187, -0.0991],
         [-0.1634, -1.1551,  2.1703,  ...,  0.0700, -0.9398,  0.2773],
         [-2.5658, -0.5674,  0.3152,  ..., -0.3762, -0.8581,  2.1603]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 64, 4])
torch.Size([2, 4, 256]) tensor([[[ 0.1461, -0.0676,  0.0320,  ..., -0.0243,  0.0367,  0.0234],
         [ 0.2043, -0.1260,  0.0625,  ..., -0.0467,  0.0961, -0.2177],
         [ 0.0364, -0.0187,  0.0945,  ..., -0.1370, -0.0689, -0.1879],
         [ 0.1482, -0.1088,  0.2525,  ..., -0.0798,  0.0931, -0