In [None]:
import torch.nn as nn
from torch.autograd import Variable

In [None]:
class DelayedRNN(nn.Module):
    
    def __init__(self, num_hidden):
        super(DelayedRNN, self).__init__()

        self.t_delay_RNN_x = nn.GRU(input_size=num_hidden, hidden_size=num_hidden, batch_first=True)
        self.t_delay_RNN_y = nn.GRU(input_size=num_hidden, hidden_size=num_hidden, batch_first=True)
        self.t_delay_RNN_z = nn.GRU(input_size=num_hidden, hidden_size=num_hidden, batch_first=True)

        self.W_t = nn.Linear(3 * num_hidden, num_hidden)

        self.c_RNN = nn.GRU(input_size=num_hidden, hidden_size=num_hidden, batch_first=True)

        self.W_c = nn.Linear(num_hidden, num_hidden)

        self.f_delay_RNN = nn.GRU(input_size=num_hidden, hidden_size=num_hidden, batch_first=True)

        self.W_f = nn.Linear(num_hidden, num_hidden)

    def forward(self, input_h_t, input_h_f, input_h_c):

        h_t_x = Variable(torch.zeros(input_h_t.shape))
        h_t_y = Variable(torch.zeros(input_h_t.shape))
        h_t_z = Variable(torch.zeros(input_h_t.shape))

        for i in range(input_h_t.shape[2]):
            h_t_x_slice, _ = self.t_delay_RNN_x(input_h_t[:, :, i, :])
            h_t_x[:, :, i, :] = h_t_x_slice

        reverse_index = np.arange(input_h_t.shape[2] - 1, -1, -1)
        for i in range(input_h_t.shape[1]):
            h_t_y_slice, _ = self.t_delay_RNN_y(input_h_t[:, i, :, :])
            h_t_z_slice, _ = self.t_delay_RNN_z(input_h_t[:, i, reverse_index, :])
            h_t_y[:, i, :, :] = h_t_y_slice
            h_t_z[:, i, :, :] = h_t_z_slice[:, reverse_index, :]

        h_t_concat = torch.cat([h_t_x, h_t_y, h_t_z], 3)

        h_t_w = self.W_t(h_t_concat)

        output_h_t = torch.add(input_h_t, h_t_w)

        h_c_rnn, _ = self.c_RNN(input_h_c)
        h_c_w = self.W_c(h_c_rnn)
        output_h_c = torch.add(input_h_c, h_c_w)

        h_c_expand = output_h_c.view(output_h_c.shape[0], output_h_c.shape[1], 1, output_h_c.shape[2]).repeat(1, 1, 32, 1)
        h_f_sum = torch.add(torch.add(input_h_f, output_h_t), h_c_expand)

        h_f_ = Variable(torch.zeros(input_h_f.shape))

        for i in range(h_f_sum.shape[1]):
            h_f_slice, _ = self.f_delay_RNN(h_f_sum[:, i, :, :])
            h_f_[:, i, :, :] = h_f_slice

        h_f_w = self.W_f(h_f_)

        output_h_f = torch.add(input_h_f, h_f_w)

        return output_h_t, output_h_f, output_h_c

In [None]:
class MelNet(nn.Module):

    def __init__(self, num_hidden, num_layer, K):
        super(MelNet, self).__init__()
        
        self.W_t_0 = nn.Linear(1, num_hidden)
        self.W_f_0 = nn.Linear(1, num_hidden)
        self.W_c_0 = nn.Linear(32, num_hidden)
        
        self.module_list = nn.ModuleList([DelayedRNN(512) for i in range(num_layer)])

        self.W_theta = nn.Linear(num_hidden, 3 * K)
        self.pi_softmax = nn.Softmax(dim=3)
        self.K = K
        
    def forward(self, input_tensor):
        
        h_t = self.W_t_0(input_tensor)
        h_f = self.W_f_0(input_tensor)
        h_c = self.W_c_0(input_tensor[:, :, :, 0])
        
        print('h_t: {}\nh_f: {}\nh_c: {}'.format(h_t.shape, h_f.shape, h_c.shape))
        
        for layer in self.module_list:
            h_t, h_f, h_c = layer(h_t, h_f, h_c)
            
        theta_hat = self.W_theta(h_f)
        
        mu = theta_hat[:, :, :, :K]
        std = torch.exp(theta_hat[:, :, :, K:2*K])
        pi = self.pi_softmax(theta_hat[:, :, :, 2*K:])
        
#         loss = torch.tensor([0])
        
#         for batch in range(mu.shape[0]):
#             for i in range(mu.shape[1]):
#                 for j in range(mu.shape[2]):
#                     prob = 0
#                     for k in range(self.K):
#                         prob += pi[batch, i, j, k] * torch.exp(torch.distributions.normal.Normal(mu[batch, i, j, k], std[batch, i, j, k]).log_prob(input_tensor[batch, i, j, 0]))

#                     loss = torch.add(loss, -torch.log(prob))
        
        return h_t, h_f, h_c


In [None]:
net = MelNet(512, 3, 10)

In [None]:
net.parameters

In [None]:
import librosa
import scipy
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import IPython.display as ipd

In [None]:
nsc = 6 * 256
hop = 256
nov = nsc - hop
n_mels = 256
fs = 44100/2
num_hidden = 512
K = 10

mel_filters = librosa.filters.mel(sr=fs, n_fft=nsc, n_mels=n_mels)

In [None]:
data_dir = '../data'
file_list = glob.glob(data_dir + '/*')
y, fs = librosa.core.load(file_list[0])

f, t, Sxx = scipy.signal.stft(y, fs=fs, window='hann', nperseg=nsc, noverlap=nov)
# Sxx = Sxx[1:, :]
Zxx = np.abs(Sxx)
log_spectrogram = 20 * np.log10(np.maximum(Zxx, 1e-8))
log_spectrogram_norm = (log_spectrogram + 160) / 160

mel_spectrogram = np.matmul(mel_filters, Zxx)
log_mel_spectrogram = 20 * np.log10(np.maximum(mel_spectrogram, 1e-8))
mel_input = (log_mel_spectrogram + 160) / 160

Tier6 = mel_input[::2, :]
Tier6_not = mel_input[1::2, :]

Tier5 = Tier6_not[:, ::2]
Tier5_not = Tier6_not[:, 1::2]

Tier4 = Tier5_not[::2, :]
Tier4_not = Tier5_not[1::2, :]

Tier3 = Tier4_not[:, ::2]
Tier3_not = Tier4_not[:, 1::2]

Tier2 = Tier3_not[::2, :]
Tier1 = Tier3_not[1::2, :]

Tiers = [Tier1, Tier2, Tier3, Tier4, Tier5, Tier6]

In [None]:
tensor = torch.tensor(Tier1.T)
input_tensor = tensor.view([1, tensor.shape[0], tensor.shape[1], 1])

In [None]:
_h_t, _h_f, _h_c = net(input_tensor)