In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import hyperparams as hp
import numpy as np
import math
import glu
import positional_encoding

In [2]:
class Encoder(nn.Module):
    """
    Encoder Network
    """
    def __init__(self, para):
        """
        :param para: dictionary that contains all parameters
        """
        super(Encoder, self).__init__()
        #self.alpha = nn.Parameter(t.ones(1))
        
        self.emb_phone = nn.Embedding(para['phone_size'], para['emb_dim'])
        #full connected
        self.fc_1 = nn.Linear(para['emb_dim'], para['GLU_in_dim'])
        
        self.GLU = glu.GLU(para['num_layers'], para['hidden_size'], para['kernel_size'], para['dropout'], para['GLU_in_dim'])
        
        self.fc_2 = nn.Linear(para['hidden_size'], para['emb_dim'])
        
    def refine(self, align_phone):
        '''filter silence phone and repeat phone'''
        out = []
        length = []
        batch_size = align_phone.shape[0]
        max_length = align_phone.shape[1]
        before = 0
        for i in range(batch_size):
            line = []
            for j in range(max_length):
                if align_phone[i][j] == 1 or align_phone[i][j] == 0:      #silence phone or padding
                    continue
                elif align_phone[i][j] == before:   #the same with the former phone
                    continue
                else:
                    before = align_phone[i][j]
                    line.append(before)
            out.append(line)
            length.append(len(line))
        
        #pad 0
        seq_length = max(length)
        Data = np.zeros((batch_size, seq_length))
        for i in range(batch_size):
            for j in range(seq_length):
                if j < len(out[i]):
                    Data[i][j] = out[i][j]
                    
        return torch.from_numpy(Data).type(torch.LongTensor)
        
    def forward(self, input):
        """
        input dim: [batch_size, text_phone_length]
        output dim : [batch_size, text_phone_length, embedded_dim]
        """
        input = self.refine(input)
        print(input)
        embedded_phone = self.emb_phone(input)    # [src len, batch size, emb dim]
        print(embedded_phone.shape,embedded_phone)
        glu_out = self.GLU(self.fc_1(embedded_phone))
        print(glu_out.shape)
        glu_out = self.fc_2(torch.transpose(glu_out, 1, 2))
        print(glu_out.shape,glu_out)
        out = embedded_phone + glu_out
        print(out.shape,out)
        out = out *  math.sqrt(0.5)
        print(out.shape,out)
        return out


In [3]:
class Encoder_Postnet(nn.Module):
    """
    Encoder Postnet
    """
    def __init__(self):
        super(Encoder_Postnet, self, seq_length).__init__()
        #length of sequence = number of frames
        self.fc = nn.Linear(seq_length, seq_length)
         
    def aligner(encoder_out, align_phone):
        return
        
    def forward(self, encoder_out, align_phone, pitch, beats):
        aligner_out = aligner(encoder_out, align_phone)
        pitch = self.fc(pitch)
        out = aligner_out + pitch
        beats_avg = len(beats) / sum(beats)
        return


In [4]:
para = {'phone_size':67, 'emb_dim':256, 'GLU_in_dim':64, 'num_layers':6, 'kernel_size':3, 'hidden_size':64, 'dropout':0.1 }
encoder = Encoder(para)
phone = torch.tensor([[1,3,3,3,3,5,5,6,0,0,0],[1,1,1,4,2,2,2,3,7,1,1]])
out = encoder(phone)
#print(out.shape,out)

tensor([[3, 5, 6, 0],
        [4, 2, 3, 7]])
torch.Size([2, 4, 256]) tensor([[[-0.1056,  0.4499,  1.0647,  ..., -0.1748,  0.2333, -0.7960],
         [ 1.5063, -0.0484, -1.4599,  ..., -1.3495, -0.3319, -0.4540],
         [ 0.1539, -0.7001,  1.2649,  ..., -0.5915,  0.1337,  2.0040],
         [ 0.6995,  0.0996,  0.2429,  ..., -0.4456, -0.3416,  0.9352]],

        [[-1.7274, -0.8727, -0.9020,  ...,  0.7828,  0.6503, -1.3442],
         [ 0.5065, -1.0624,  0.6139,  ..., -0.8265, -1.2438,  1.8009],
         [-0.1056,  0.4499,  1.0647,  ..., -0.1748,  0.2333, -0.7960],
         [-0.3482,  0.5405, -0.1930,  ..., -1.2947, -0.1482, -0.0392]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 64, 4])
torch.Size([2, 4, 256]) tensor([[[ 1.0397e-01,  9.0222e-02, -1.1796e-01,  ..., -2.1993e-01,
          -1.1756e-01, -1.4229e-01],
         [-8.7076e-02, -1.5004e-01,  5.8925e-02,  ...,  1.3462e-01,
          -9.9201e-03, -1.3423e-01],
         [-1.4504e-01, -4.9546e-02, -1.4552e-01,  ...,  1.4421e-01

In [11]:
beats = torch.tensor([[[0,1,0,0,0,1,0,0]],[[0,0,0,1,1,0,1,0]]])
pos = positional_encoding.PositionalEncoding(8)
out = pos(beats)
print(out)

tensor([[[0.0000, 2.2222, 0.0000, 0.0000, 0.0000, 2.2222, 0.0000, 1.1111]],

        [[0.9350, 0.0000, 0.1109, 2.2167, 1.1222, 1.1111, 1.1122, 1.1111]]])
