In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
    n_channels_int = n_channels[0]
    inp = input_a + input_b
    t_act = torch.tanh(inp[:,:n_channels_int,:])
    s_act = torch.sigmoid(inp[:,n_channels_int:,:])
    acts = t_act * s_act
    return acts

In [54]:
class Invertible1x1Conv(nn.Module):
    def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = nn.Conv1d(c,c,1, bias=False)
        w = self.conv.weight.data
        nn.init.orthogonal_(w)
        if w.squeeze().det() < 0:
            w[0,:,:] = -w[0,:,:]
        
        self.conv.weight.data = w

    def forward(self, z, reverse=False):
        # z: [BS, n_group, L]
        W = self.conv.weight.data
        if reverse:
            W_inverse = W.squeeze().inverse().unsqueeze(-1)
            z = F.conv1d(z, W_inverse)
            return z
        else:
            log_det_W = W.squeeze().logdet()
            return self.conv(z), log_det_W


inv = Invertible1x1Conv(8)
inv(torch.randn(1,8,100), True).shape


torch.Size([1, 8, 100])

In [126]:
class WN(nn.Module):
    def __init__(self, n_in_channels, n_mel_channels, n_channels, n_layers, 
                kernel_size):
        super(WN, self).__init__()
        self.n_channels = n_channels
        self.n_layers = n_layers
        self.start = nn.Conv1d(n_in_channels, n_channels, 1)
        self.cond_layer = nn.Conv1d(n_mel_channels, 2*n_channels*n_layers,1)
        self.in_layers = nn.ModuleList()
        self.res_skip_layers = nn.ModuleList()
        self.end = nn.Conv1d(n_channels, 2*n_in_channels, 1)
        
        for i in range(n_layers):
            dil = 2**i
            padding = int(dil*(kernel_size-1)/2)
            in_layer = nn.Conv1d(n_channels, 2*n_channels, kernel_size,
                                dilation=dil, padding=padding)
            
            self.in_layers.append(in_layer)

            if i == n_layers - 1:
                res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
            else:
                res_skip_layer = nn.Conv1d(n_channels, 2*n_channels, 1)
            
            self.res_skip_layers.append(res_skip_layer)
    

    def forward(self, forward_input):
        audio, spect = forward_input
        # audio: [BS, half, L]  
        # spect: [BS, n_mel_channels, L]  
        audio = self.start(audio)
        spect = self.cond_layer(spect)
        output = torch.zeros_like(audio)
        for i in range(self.n_layers):
            inp = self.in_layers[i](audio)
            spect_offset = 2*i*self.n_channels
            new_spect = spect[:,spect_offset:spect_offset + 2*self.n_channels,:]
            n_channels_int = torch.IntTensor([self.n_channels])

            acts = fused_add_tanh_sigmoid_multiply(
                inp,
                new_spect,
                n_channels_int)

            res_acts = self.res_skip_layers[i](acts)
            
            

            if i == self.n_layers - 1:
                output = output + res_acts
            else:
                audio = audio + res_acts[:,:self.n_channels,:]
                output = output + res_acts[:,self.n_channels:,:]

        output = self.end(output)
        return output

    

wn = WN(4,80,256,8,3)
wn.infer(torch.randn(2, 80, 100)).shape

torch.Size([2, 8, 100])

In [202]:
class WaveGlow(nn.Module):
    def __init__(self, n_mel_channels, n_group, n_flows, n_channels,
                n_layers, kernel_size, n_early_every, n_early_size):
        super(WaveGlow, self).__init__()
        
        self.n_group = n_group
        self.n_flows = n_flows
        self.n_early_every = n_early_every
        self.n_early_size = n_early_size

        # upsample spect to audio scale
        self.upsample = nn.ConvTranspose1d(n_mel_channels, n_mel_channels,
                                           1024, stride=256)

        self.convinv = nn.ModuleList()
        self.WN = nn.ModuleList()

        remaining_channels = n_group
        half = int(remaining_channels/2)
        for k in range(n_flows):
                if k % n_early_every == 0 and k > 0:
                        remaining_channels = remaining_channels - n_early_size
                        half = int(remaining_channels/2)

                inv = Invertible1x1Conv(remaining_channels)
                wn = WN(half, n_mel_channels*n_group, n_channels, n_layers, 
                        kernel_size)

                self.convinv.append(inv)
                self.WN.append(wn)
        self.n_remaining_channels = remaining_channels  # Useful during inference
    def forward(self, forward_input):
        audio, spect = forward_input
        # audio: [BS, time]
        # spect: [BS, n_mel_channels, frames]

        # upsample spect to audio scale --->[BS, n_mel_channels, time]
        spect = self.upsample(spect)
        
        assert(spect.size(2) >= audio.size(1))
        if spect.size(2) > audio.size(1):
            spect = spect[:, :, :audio.size(1)]

        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
        
        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
        
        audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
        
        output_audio = []
        log_det_W_list = []
        log_s_list = []

        for k in range(self.n_flows):
                if k % self.n_early_every == 0 and k > 0:
                        output_audio.append(audio[:,:self.n_early_size,:])
                        audio = audio[:,self.n_early_size:,:]
                
                audio, log_det_W = self.convinv[k](audio)
                log_det_W_list.append(log_det_W)

                # split audio for affine coupling layer
                audio_channels = audio.shape[1]
                audio_0 = audio[:,:int(audio_channels/2),:]
                audio_1 = audio[:,int(audio_channels/2):,:]

                wn_out = self.WN[k]((audio_0, spect))
                wn_out_channels = wn_out.shape[1]
                log_s = wn_out[:,:int(wn_out_channels/2),:]
                
                log_s_list.append(log_s)
                t = wn_out[:,int(wn_out_channels/2):,:]

                audio_1 = torch.exp(log_s)*audio_1 + t

                audio = torch.cat([audio_0, audio_1], dim=1)

        output_audio.append(audio)
        return torch.cat(output_audio, 1), log_s_list, log_det_W_list

    def infer(self, spect, sigma=1.0):
        spect = self.upsample(spect)
        time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
        spect = spect[:,:,:-time_cutoff]
        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)

        audio = torch.randn(spect.shape[0], self.n_remaining_channels,
                            spect.shape[2]).normal_()
        
        audio = sigma*audio

        for k in reversed(range(self.n_flows)):
                half = int(audio.shape[1]/2)
                audio_0 = audio[:,:half,:]
                audio_1 = audio[:,half:,:]

                
                wn_out = self.WN[k]((audio_0, spect))
                wn_channels = wn_out.shape[1]
                log_s = wn_out[:,:int(wn_channels/2),:]
                t = wn_out[:,int(wn_channels/2):,:]
                x0 = audio_0
                x1 = (audio_1 - t)/torch.exp(log_s)
                print(x0.shape, x1.shape)
                inv_inp = torch.cat([x0, x1], dim=1)
                
                audio = self.convinv[k](inv_inp, reverse=True)

                if k % self.n_early_every and k > 0:
                        z = torch.randn(spect.shape[0], self.n_early_size, spect.shape[2]).normal_()
                        audio = torch.cat([sigma*z, audio], 1)


                

        return audio

model = WaveGlow(80, 8, 12, 256, 8, 3, 4, 2)

audio = torch.randn(1, 10000)
spect = torch.randn(1, 80, 37)

model.infer(spect).shape

torch.Size([1, 2, 1184]) torch.Size([1, 2, 1184])


RuntimeError: Given groups=1, weight of size [256, 2, 1], expected input[1, 3, 1184] to have 2 channels, but got 3 channels instead

In [178]:
class WN2(torch.nn.Module):
    """
    This is the WaveNet like layer for the affine coupling.  The primary difference
    from WaveNet is the convolutions need not be causal.  There is also no dilation
    size reset.  The dilation only doubles on each layer
    """
    def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
                 kernel_size):
        super(WN2, self).__init__()
        assert(kernel_size % 2 == 1)
        assert(n_channels % 2 == 0)
        self.n_layers = n_layers
        self.n_channels = n_channels
        self.in_layers = torch.nn.ModuleList()
        self.res_skip_layers = torch.nn.ModuleList()

        start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
        start = torch.nn.utils.weight_norm(start, name='weight')
        self.start = start

        # Initializing last layer to 0 makes the affine coupling layers
        # do nothing at first.  This helps with training stability
        end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
        end.weight.data.zero_()
        end.bias.data.zero_()
        self.end = end

        cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
        self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')

        for i in range(n_layers):
            dilation = 2 ** i
            padding = int((kernel_size*dilation - dilation)/2)
            in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
                                       dilation=dilation, padding=padding)
            in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
            self.in_layers.append(in_layer)


            # last one is not necessary
            if i < n_layers - 1:
                res_skip_channels = 2*n_channels
            else:
                res_skip_channels = n_channels
            res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
            self.res_skip_layers.append(res_skip_layer)

    def forward(self, forward_input):
        audio, spect = forward_input
        audio = self.start(audio)
        output = torch.zeros_like(audio)
        n_channels_tensor = torch.IntTensor([self.n_channels])

        
        spect = self.cond_layer(spect)

        for i in range(self.n_layers):
            spect_offset = i*2*self.n_channels
            acts = fused_add_tanh_sigmoid_multiply(
                self.in_layers[i](audio),
                spect[:,spect_offset:spect_offset+2*self.n_channels,:],
                n_channels_tensor)

            res_skip_acts = self.res_skip_layers[i](acts)
            if i < self.n_layers - 1:
                audio = audio + res_skip_acts[:,:self.n_channels,:]
                output = output + res_skip_acts[:,self.n_channels:,:]
            else:
                output = output + res_skip_acts

        return self.end(output)


class WaveGlow2(torch.nn.Module):
    def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
                 n_early_size, WN_config):
        super(WaveGlow2, self).__init__()

        self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
                                                 n_mel_channels,
                                                 1024, stride=256)
        assert(n_group % 2 == 0)
        self.n_flows = n_flows
        self.n_group = n_group
        self.n_early_every = n_early_every
        self.n_early_size = n_early_size
        self.WN2 = torch.nn.ModuleList()
        self.convinv = torch.nn.ModuleList()

        n_half = int(n_group/2)

        # Set up layers with the right sizes based on how many dimensions
        # have been output already
        n_remaining_channels = n_group
        for k in range(n_flows):
            if k % self.n_early_every == 0 and k > 0:
                n_half = n_half - int(self.n_early_size/2)
                n_remaining_channels = n_remaining_channels - self.n_early_size
            self.convinv.append(Invertible1x1Conv(n_remaining_channels))
            self.WN2.append(WN2(n_half, n_mel_channels*n_group, **WN_config))
        self.n_remaining_channels = n_remaining_channels  # Useful during inference

    def forward(self, forward_input):
        """
        forward_input[0] = mel_spectrogram:  batch x n_mel_channels x frames
        forward_input[1] = audio: batch x time
        """
        spect, audio = forward_input

        #  Upsample spectrogram to size of audio
        spect = self.upsample(spect)
        
        
        assert(spect.size(2) >= audio.size(1))
        if spect.size(2) > audio.size(1):
            spect = spect[:, :, :audio.size(1)]

        

        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)

        audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
        
        output_audio = []
        log_s_list = []
        log_det_W_list = []

        for k in range(self.n_flows):
            if k % self.n_early_every == 0 and k > 0:
                output_audio.append(audio[:,:self.n_early_size,:])
                audio = audio[:,self.n_early_size:,:]

            audio, log_det_W = self.convinv[k](audio)
            log_det_W_list.append(log_det_W)

            n_half = int(audio.size(1)/2)
            audio_0 = audio[:,:n_half,:]
            audio_1 = audio[:,n_half:,:]

            
            
            output = self.WN2[k]((audio_0, spect))
            log_s = output[:, n_half:, :]
            b = output[:, :n_half, :]
            audio_1 = torch.exp(log_s)*audio_1 + b
            log_s_list.append(log_s)

            audio = torch.cat([audio_0, audio_1],1)

        output_audio.append(audio)
        return torch.cat(output_audio,1), log_s_list, log_det_W_list


WN_config= {
            "n_layers": 8,
            "n_channels": 256,
            "kernel_size": 3
        }
wv2 = WaveGlow2(80, 12, 8, 4, 2, WN_config)

audio = torch.randn(1, 10000)
spect = torch.randn(1, 80, 37)
wv2((spect,audio))

torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])
torch.Size([1, 640, 1250])


(tensor([[[ 0.2107, -2.5223, -0.8095,  ..., -1.4856, -0.0048,  0.8866],
          [-0.0582, -1.2682,  0.7698,  ..., -0.9491, -1.0466,  0.5091],
          [-0.3005,  0.5528, -0.4263,  ..., -2.2444, -1.5032, -0.8932],
          ...,
          [-1.0784, -0.7570,  0.9992,  ...,  0.1646,  0.9503,  0.2673],
          [-1.2106, -0.0477, -0.0941,  ..., -0.9875, -2.1894, -0.0410],
          [ 0.8242,  0.7685,  0.0804,  ..., -0.5033, -0.4835,  0.7518]]],
        grad_fn=<CatBackward0>),
 [tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<SliceBackward0>),
  tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<SliceBackward0>),
  tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 

In [159]:
(10000-1024)/256+1

36.0625

In [111]:
class Mel2Samp(Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """
    def __init__(self, training_files, segment_length, filter_length=1024, 
                 hop_length=256, win_length=1024, n_mel_channels=80,
                 sampling_rate=22050, mel_fmin=0.0, mel_fmax=8000.0):

        super(Mel2Samp, self).__init__()
        self.audio_files = files_to_list(training_files)
        self.stft = TacotronSTFT(filter_length, hop_length, win_length,
                n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
        self.segment_length = segment_length

    def get_mel(self, audio):
        audio_norm = audio/MAX_WAV_VALUE
        audio_norm = audio_norm.unsqueeze(0)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = melspec.squeeze(0)
        return melspec


        
    def __getitem__(self, index):
        audio_file = self.audio_files[index]
        audio, sr = load_wav_to_torch(audio_file)
        
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = torch.randint(0, max_audio_start, [1])
            audio = audio[audio_start:audio_start + self.segment_length]
        else:
            audio = F.pad(audio, (0, self.segment_length - audio.size(0)))

        mel = self.get_mel(audio)
        audio = audio/MAX_WAV_VALUE
        return (mel, audio)
    

    def __len__(self):
        return len(self.audio_files)

In [112]:
dataset = Mel2Samp(r"C:\Users\Codefactory\Documents\babis\Thesis\Waveglow\test.txt", 16000)

In [120]:
loader = DataLoader(dataset, batch_size=1)

mel, audio = next(iter(loader))

mel.shape, audio.shape

(torch.Size([1, 80, 63]), torch.Size([1, 16000]))

In [121]:
model((audio, mel))

AssertionError: 