In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm #progress bar!!

In [None]:
#adapted from model/rnn.py

class DelayedRNN(nn.Module):
    
    def __init__(self, hp):
    #should inherit all properties from all __init__ definitions
        super(DelayedRNN, self).__init__()
        self.num_hidden_layers = hp.model.hidden

        self.t_delay_RNN_x = nn.LSTM(
            input_size=self.num_hidden_layers,
            hidden_size=self.num_hidden_layers,
            batch_first=True
        )
        self.t_delay_RNN_yz = nn.LSTM(
            input_size=self.num_hidden_layers,
            hidden_size=self.num_hidden_layers,
            batch_first=True,
            bidirectional=True
        )

        # use central stack only at initial tier
        self.c_RNN = nn.LSTM(
            input_size=self.num_hidden_layers,
            hidden_size=self.num_hidden_layers,
            batch_first=True
        )
        self.f_delay_RNN = nn.LSTM(
            input_size=self.num_hidden_layers,
            hidden_size=self.num_hidden_layers,
            batch_first=True
        )

        self.W_t = nn.Linear(3*self.num_hidden_layers, self.num_hidden_layers)
        self.W_c = nn.Linear(self.num_hidden_layers, self.num_hidden_layers)
        self.W_f = nn.Linear(self.num_hidden_layers, self.num_hidden_layers)
        
        def flatten_rnn(self):
        self.t_delay_RNN_x.flatten_parameters()
        self.t_delay_RNN_yz.flatten_parameters()
        self.c_RNN.flatten_parameters()
        self.f_delay_RNN.flatten_parameters()

    def forward(self, input_h_t, input_h_f, input_h_c, audio_lengths):
      
        self.flatten_rnn()
        # input_h_t, input_h_f: [B, M, T, D]
        # input_h_c: [B, T, D]
        B, M, T, D = input_h_t.size()

        ####### time-delayed stack #######
        # Fig. 2(a)-1 can be parallelized by viewing each horizontal line as batch
        h_t_x_temp = input_h_t.view(-1, T, D)
        h_t_x_packed = nn.utils.rnn.pack_padded_sequence(
            h_t_x_temp,
            audio_lengths.unsqueeze(1).repeat(1, M).reshape(-1),
            batch_first=True,
            enforce_sorted=False
        )
        h_t_x, _ = self.t_delay_RNN_x(h_t_x_packed)
        h_t_x, _ = nn.utils.rnn.pad_packed_sequence(
            h_t_x,
            batch_first=True,
            total_length=T
        )
        h_t_x = h_t_x.view(B, M, T, D)

        # Fig. 2(a)-2,3 can be parallelized by viewing each vertical line as batch,
        # using bi-directional version of GRU
        h_t_yz_temp = input_h_t.transpose(1, 2).contiguous() # [B, T, M, D]
        h_t_yz_temp = h_t_yz_temp.view(-1, M, D)
        h_t_yz, _ = self.t_delay_RNN_yz(h_t_yz_temp)
        h_t_yz = h_t_yz.view(B, T, M, 2*D)
        h_t_yz = h_t_yz.transpose(1, 2)

        h_t_concat = torch.cat((h_t_x, h_t_yz), dim=3)
        output_h_t = input_h_t + self.W_t(h_t_concat) # residual connection, eq. (6)

        ####### centralized stack #######
        h_c_temp = nn.utils.rnn.pack_padded_sequence(
            input_h_c,
            audio_lengths,
            batch_first=True,
            enforce_sorted=False
        )
        h_c_temp, _ = self.c_RNN(h_c_temp)
        h_c_temp, _ = nn.utils.rnn.pad_packed_sequence(
            h_c_temp,
            batch_first=True,
            total_length=T
        )
            
        output_h_c = input_h_c + self.W_c(h_c_temp) # residual connection, eq. (11)
        h_c_expanded = output_h_c.unsqueeze(1)

        ####### frequency-delayed stack #######
        h_f_sum = input_h_f + output_h_t + h_c_expanded
        h_f_sum = h_f_sum.transpose(1, 2).contiguous() # [B, T, M, D]
        h_f_sum = h_f_sum.view(-1, M, D)

        h_f_temp, _ = self.f_delay_RNN(h_f_sum)
        h_f_temp = h_f_temp.view(B, T, M, D)
        h_f_temp = h_f_temp.transpose(1, 2) # [B, M, T, D]
        
        output_h_f = input_h_f + self.W_f(h_f_temp) # residual connection, eq. (8)

        return output_h_t, output_h_f, output_h_c

In [None]:
#adapted from model/tier.py

class Tier(nn.Module):
    
    #args for __init__: 
    #hp = hyperparameters
    #freq = frequency
    #layers = number of total layers, hidden and otherwise
    #tierN = what tier of computing we're at
    
    def __init__(self, hp, freq, layers, tierN): #initialize params
        #should inherit all properties from all __init__ definitions
        super(Tier, self).__init__()
        num_hidden_layers = hp.model.hidden
        self.hp = hp
        self.tierN = tierN
        
        #initialize weights and run a different NN
        #if we're just beginning or if we are continuing
        if(tierN == 1): #if starting
            #initialize weights to be ones for each hidden layer
            self.W_t_0 = nn.Linear(1, num_hidden_layers)
            self.W_f_0 = nn.Linear(1, num_hidden_layers)
            self.W_c_0 = nn.Linear(freq, num_hidden_layers)
            #initialize layers with DelayedRNN function
            self.layers = nn.ModuleList([
                DelayedRNN(hp) for _ in range(layers)
            ])
            
        else: #if continuing
            #reinitialize time weights
            self.W_t = nn.Linear(1, num_hidden_layers)
            #and upsample to higher resolution
            self.layers = nn.ModuleList([
                UpsampleRNN(hp) for _ in range(layers)
            ])
        
        # Gaussian Mixture Model (GMM)
        self.K = hp.model.gmm
        self.pi_softmax = nn.Softmax(dim=3)
        
        # use correct mapping to produce GMM parameter
        self.W_theta = nn.Linear(num_hidden_layers, 3*self.K)
        
    #args for forward:
    #x = [B, M, T]; B = batch, M = mel, T = time
    #audio_lengths = length of song
    
    def forward(self, x, audio_lengths):
        #if beginning
        if self.tierN == 1:
            #make padding to ensure no out of bounds errors when running NN
            #and unsqueeze the proper dims for each variable
            h_t = self.W_t_0(F.pad(x, [1, -1]).unsqueeze(-1))
            h_f = self.W_f_0(F.pad(x, [0, 0, 1, -1]).unsqueeze(-1))
            h_c = self.W_c_0(F.pad(x, [1, -1]).transpose(1, 2))
            for layer in self.layers:
                h_t, h_f, h_c = layer(h_t, h_f, h_c, audio_lengths)

            # h_t, h_f: [B, M, T, D] / D = num_hidden_layers
            # h_c: [B, T, D]
        #if continuing
        else:
            #update h_f from weights
            h_f = self.W_t(x.unsqueeze(-1))
            for layer in self.layers:
                h_f = layer(h_f, audio_lengths)
        
        theta_hat = self.W_theta(h_f)

        #formulae for updating mu, std, pi
        mu = theta_hat[..., :self.K]
        std = theta_hat[..., self.K:2*self.K]
        pi = theta_hat[..., 2*self.K:]

        return mu, std, pi

In [None]:
#tiers to divide time and freq axes depending on 
t_div = {1:1, 2:1, 3:2, 4:2, 5:4, 6:4}
f_div = {1:1, 2:1, 3:2, 4:2, 5:4, 6:4, 7:8}

class TierUtils():

In [None]:
#defining functions and classes for model to make things easier

class MelNet_model(nn.Module): #from model.py from Deepest-Project
    def __init__(self, hp, args, infer_hp):
        #should inherit all properties from all __init__ definitions
        super(MelNet, self).__init__()
        self.hp = hp
        self.args = args
        self.infer_hp = infer_hp
        self.f_div = f_div[hp.model.tier + 1]
        self.t_div = t_div[hp.model.tier]
        self.n_mels = hp.audio.n_mels
        
