In [229]:
import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(1,1000)
un = x.unfold(1,8,8).permute(0,2,1)
un.shape

torch.Size([1, 8, 125])

In [237]:
upsample = nn.ConvTranspose1d(80, 80, 1024, 256)
spec = torch.randn(1,80,10)
upsample(spec).unfold(2,8,8).shape

torch.Size([1, 80, 416, 8])

In [236]:
assert(spec.size(2) >= x.size(1))
if spec.size(2) >= x.size(1):
    spec = spec[:,:,:x.size(1)]

spec = spec.unfold(2,8,8).permute(0,2,1,3)
spec = spec.contiguous().view(spec.size(0), spec.size(1), -1).permute(0,2,1)
spec.shape


AssertionError: 

In [234]:
class Invertible1x1Conv(nn.Module):
    def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = nn.Conv1d(c, c, 1, bias=False)
        w = self.conv.weight.data.squeeze()
        w = nn.init.orthogonal_(torch.ones_like(w))

        if w.det() < 0:
            w[:,0] = -w[:,0]

        self.conv.weight.data = w.unsqueeze(-1)
        

    def forward(self, z, reverse=False):
        BS, group_size, n_of_groups = z.size()
        W = self.conv.weight.squeeze()
        if reverse:
            W_inverse = W.float().inverse().unsqueeze(-1)
            z = F.conv1d(z, W_inverse)
            return z
        else:
            z = self.conv(z)
            log_det_W = W.logdet()#BS*n_of_groups*
            return z, log_det_W

inv = Invertible1x1Conv(8)
inv(torch.randn(1,8,125), True).shape

torch.Size([1, 8, 125])

In [267]:
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
    n_channels_int = n_channels[0]
    in_act = input_a + input_b
    t_act = torch.tanh(in_act[:,:n_channels_int,:])
    s_act = torch.sigmoid(in_act[:,n_channels_int:,:])
    acts = t_act + s_act
    return acts

class WN(nn.Module):
    def __init__(self, n_in_channels, n_mel_channels, n_layers,
                 n_channels, kernel_size):
        super(WN, self).__init__()
        assert(kernel_size % 2 == 1)
        assert(n_channels % 2 == 0)
        self.n_layers = n_layers
        self.n_channels = n_channels
        self.in_layers = nn.ModuleList()
        self.res_skip_layers = nn.ModuleList()

        start = nn.Conv1d(n_in_channels, n_channels, 1)
        start = nn.utils.weight_norm(start, name='weight')
        self.start = start

        end = nn.Conv1d(n_channels, 2*n_in_channels, 1)
        end.weight.data.zero_
        end.bias.data.zero_
        self.end = end

        cond_layer = nn.Conv1d(n_mel_channels, 2*n_mel_channels*n_layers, 1)
        self.cond_layer = nn.utils.weight_norm(cond_layer, name='weight')

        for i in range(n_layers):
            dilation = 2**i
            padding = int(dilation*(kernel_size-1)/2)
            in_layer = nn.Conv1d(n_channels, 2*n_channels, kernel_size,
                                    dilation=dilation, padding=padding)
            in_layer = nn.utils.weight_norm(in_layer, name='weight')
            self.in_layers.append(in_layer)

            if i < n_layers - 1:
                res_skip_channels = 2*n_channels
            else:
                res_skip_channels = n_channels
            
            res_skip_layer = nn.Conv1d(n_channels, res_skip_channels, 1)
            res_skip_layer = nn.utils.weight_norm(res_skip_layer, name='weight')
            self.res_skip_layers.append(res_skip_layer)

    def forward(self, forward_input):
        audio, spect = forward_input
        audio = self.start(audio)
        output = torch.zeros_like(audio)
        n_channels_tensor = torch.IntTensor([self.n_channels])

        spect = self.cond_layer(spect)

        for i in range(self.n_layers):
            spect_offset = i*2*self.n_channels
            acts = fused_add_tanh_sigmoid_multiply(
                self.in_layers[i](audio),
                spect[:,spect_offset:spect_offset+2*self.n_channels,:],
                n_channels_tensor)
            
            res_skip_acts = self.res_skip_layers[i](acts)
            if i < self.n_layers - 1:
                audio = audio + res_skip_acts[:,:self.n_channels,:]
                output = output + res_skip_acts[:, self.n_channels:,:]
            else:
                output = output + res_skip_acts

        return self.end(output)

inv = Invertible1x1Conv(8)
x = torch.randn(1,8,125)
sp = torch.randn(1,80,125)
wn = WN(4,80,8,256, 3)
        


In [263]:
class WaveGlow(nn.Module):
    def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
                n_early_size, WN_config):
        super(WaveGlow, self).__init__()

        self.upsample = nn.ConvTranspose1d(n_mel_channels,
                                           n_mel_channels,
                                           1024, stride=256)
        assert(n_group % 2 == 0)
        self.n_flows = n_flows
        self.n_group = n_group
        self.n_early_every = n_early_every
        self.n_early_size = n_early_size
        self.WN = nn.ModuleList()
        self.convinv = nn.ModuleList()

        n_half = int(n_group/2)
        n_remaining_channels = n_group
        for k in range(n_flows):
            if k % self.n_early_every == 0 and k > 0:
                n_half = n_half - int(self.n_early_size/2)
                n_remaining_channels = n_remaining_channels - self.n_early_size
            self.convinv.append(Invertible1x1Conv(n_remaining_channels))
            self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
        self.n_remaining_channels = n_remaining_channels

    def forward(self, forward_input):
        """
        forward_input[0] = mel_spectrogram:  batch x n_mel_channels x frames
        forward_input[1] = audio: batch x time
        """

        spect, audio = forward_input
        spect = self.upsample(spect)
        assert(spect.size(2)>=audio.size(1))
        if spect.size(2) > audio.size(1):
            spect = spect[:,:,:audio.size(1)]
        spect = spect.unfold(2, self.n_group, self.n_group).permute(0,2,1,3)
        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0,2,1)
        audio = audio.unfold(1, self.n_group, self.n_group).permute(0,2,1)

        output_audio = []
        log_s_list = []
        log_det_W_list = []

        print(audio.shape, spect.shape)
        for k in range(self.n_flows):
            if k % self.n_early_every == 0 and k > 0:
                output_audio.append(audio[:, :self.n_early_size, :])
                audio = audio[:,self.n_early_size:,:]
            
            audio, log_det_W = self.convinv[k](audio)
            log_det_W_list.append(log_det_W)

            n_half = int(audio.size(1)/2)
            audio_0 = audio[:,:n_half,:]
            audio_1 = audio[:,n_half:,:]

            output = self.WN[k]((audio_0, spect))
            log_s = output[:,n_half:,:]
            b = output[:,:n_half,:]
            audio_1 = torch.exp(log_s)*audio_1 + b
            log_s_list.append(log_s)

            audio = torch.cat([audio_0, audio_1], 1)

        output_audio.append(audio)
        return torch.cat(output_audio, 1), log_s_list, log_det_W_list
        

WN_config =  {"n_layers": 8,
             "n_channels": 256,
             "kernel_size": 3}
            
        
waveglow = WaveGlow(80, 12, 8, 4, 2, WN_config)

audio = torch.randn(1, 10000)
spect = torch.randn(1,80, 100)
waveglow.upsample(spect).shape
#waveglow(( spect, audio))

audio_0 = inv(audio.reshape(1,8,1250))[0][:,:4,:]
audio_1 = inv(audio.reshape(1,8,1250))[0][:,4:,:]

In [268]:
wn.cond_layer(waveglow.upsample(spect)).shape

torch.Size([1, 1280, 26368])

In [271]:
wn.in_layers[0](audio_0).shape

RuntimeError: Given groups=1, weight of size [512, 256, 3], expected input[1, 4, 1250] to have 256 channels, but got 4 channels instead