In [76]:
import os
import random
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sys
from scipy.io.wavfile import read
sys.path.insert(0, r"C:\Users\Codefactory\Documents\babis\Thesis\tacotron2")
from layers import TacotronSTFT

In [80]:
MAX_WAV_VALUE = 32768.0

def files_to_list(filename):
    """
    Takes a text file of filenames and makes a list of filenames
    """
    
    with open(filename, "r") as f:
        filenames = f.read().strip().splitlines()
            

    return filenames


def load_wav_to_torch(full_path):
    """
    Loads wavdata into torch array
    """

    sr, data = read(full_path)
    wav_data = torch.tensor(data).float()
    return  wav_data, sr


a = load_wav_to_torch(r"C:\Users\Codefactory\Documents\babis\Tutorials\Deep Learning\Pytorch\waves_yesno\0_0_0_0_1_1_1_1.wav")[0]



torch.Size([60000])

In [82]:
class Mel2Samp(Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """
    def __init__(self, training_files, segment_length, filter_length=1024, 
                 hop_length=256, win_length=1024, n_mel_channels=80,
                 sampling_rate=22050, mel_fmin=0.0, mel_fmax=8000.0):

        super(Mel2Samp, self).__init__()
        self.audio_files = files_to_list(training_files)
        self.stft = TacotronSTFT(filter_length, hop_length, win_length,
                n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
        self.segment_length = segment_length

    def get_mel(self, audio):
        audio_norm = audio/MAX_WAV_VALUE
        audio_norm = audio_norm.unsqueeze(0)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = melspec.squeeze(0)
        return melspec


        
    def __getitem__(self, index):
        audio_file = self.audio_files[index]
        audio, sr = load_wav_to_torch(audio_file)
        
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = torch.randint(0, max_audio_start, 1)
            audio = audio[audio_start:audio_start + self.segment_length]
        else:
            audio = F.pad(audio, (0, self.segment_length - audio.size(0)))

        mel = self.get_mel(audio)
        audio = audio/MAX_WAV_VALUE
        return (mel, audio)
    

    def __len__(self):
        return len(self.audio_files)

d = Mel2Samp(r"C:\Users\Codefactory\Documents\babis\Thesis\Waveglow\test.txt", 60000)

torch.Size([60000])

In [83]:
loader = DataLoader(d, 1, shuffle=True)

x,y = next(iter(loader))

print(x.shape,y.shape)

torch.Size([1, 80, 235]) torch.Size([1, 60000])


In [156]:
def fused_add_tanh_sigmoid(input_a, input_b, n_channels):
    n_channels_int = n_channels[0]
    in_act = input_a + input_b
    t_act = torch.tanh(in_act[:,:n_channels_int,:])
    s_act = torch.sigmoid(in_act[:,n_channels_int:,:])
    print(t_act.shape, s_act.shape, in_act.shape)
    acts = t_act*s_act
    return acts




In [157]:
class Invertible1x1Conv(nn.Module):
    """
    The layer outputs both the convolution, and the log determinant
    of its weight matrix.  If reverse=True it does convolution with
    inverse
    """
    def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = nn.Conv1d(c, c, kernel_size=1, bias=False)
        W = torch.qr(torch.FloatTensor(c,c).normal_())[0]

        # ensure det=1
        if W.det()<0:
            W = -W
        W = W.view(c,c,1)
        self.conv.weight.data = W


    def forward(self, z, reverse=False):
        BS, group_size, n_of_groups = z.shape
        W = self.conv.weight.squeeze()
        if reverse:
            if not hasattr(self, 'W_inverse'):
                W_inverse = W.float().inverse()
                W_inverse = W_inverse.unsqueeze(-1)
                if z.type() == 'torch.cuda.HalfTensor':
                    W_inverse = W_inverse.half()
                self.W_inverse = W_inverse 
            z = F.conv1d(z, W_inverse, bias=None)
            return z

        else:
            log_det_W = BS*n_of_groups*W.logdet()
            z = self.conv(z)

            return z, log_det_W


In [None]:
class WN(nn.Module):
    """
    This is the WaveNet like layer for the affine coupling.  The primary difference
    from WaveNet is the convolutions need not be causal.  There is also no dilation
    size reset.  The dilation only doubles on each layer
    """
    def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
                 kernel_size):
        super(WN, self).__init__()
        assert kernel_size % 2 == 1
        assert n_channels % 2 == 0
        self.n_layers = n_layers
        self.n_channels = n_channels
        self.in_layers = nn.ModuleList()
        self.res_skip_layers = nn.ModuleList()

        start = nn.Conv1d(n_in_channels, n_channels, 1)
        start = nn.utils.weight_norm(start, name='weight')
        self.start = start

        end = nn.Conv1d(n_channels, 2*n_in_channels, 1)
        end.weight.data.zero_()
        end.bias.data.zero_()
        self.end = end

        cond_layer = nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
        self.cond_layer = nn.utils.weight_norm(cond_layer, name='weight')


        for i in range(n_layers):
            dilation = 2*i
            padding = int(dilation * (kernel_size - 1)/2)
            in_layer = nn.Conv1d(n_channels, 2*n_channels, kernel_size, padding=padding, dilation=dilation)
            in_layer = nn.utils.weight_norm(in_layer, name='weight')
            self.in_layers.append(in_layer)

            if i < n_layers-1:
                res_skip_channels = 2*n_channels
            else:
                res_skip_channels = n_channels
            
            res_skip_layer = nn.Conv1d(n_channels, res_skip_channels, 1)
            res_skip_layer = nn.utils.weight_norm(res_skip_channels, name='weight')
            self.res_skip_layers.append(res_skip_layer)

        
    def forward(self, forward_input):
        audio, spec = forward_input
        audio = self.start(audio)
        output = torch.zeros_like(audio)
        n_channels_tensor = torch.IntTensor([self.n_channels])

        spec = self.cond_layer(spec)

        for i in range(self.n_layers):
            spec_offset = i*2*self.n_channels
            acts = fused_add_tanh_sigmoid(self.in_layers[i](audio),
            spec[:,spec_offset:spec_offset+2*self.n_channels,:],
            n_channels_tensor)
            res_skip_acts = self.res_skip_layers[i](acts)
            if i < self.n_layers - 1:
                audio = audio + res_skip_acts[:,:self.n_channels,:]
                output = output + res_skip_acts[:,self.n_channels:,:]
            else:
                output = output + res_skip_acts

        return self.end(output)

In [163]:
class WaveGlow(nn.Module):
    def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, n_early_size,
                n_layers=8, n_channels=256, kernel_size=3):
        super(WaveGlow, self).__init__()

        self.upsample = nn.ConvTranspose1d(n_mel_channels, n_mel_channels, kernel_size=1024, stride=256)

        assert n_group%2==0
        self.n_flows = n_flows
        self.n_group = n_group
        self.n_early_every = n_early_every
        self.n_early_size = n_early_size
        self.WN = nn.ModuleList()
        self.convinv = nn.ModuleList()

        n_half = int(n_group/2)
        for k in range(n_flows):
            if k % self.n_early_every==0 and k > 0:
                n_half = n_half - int(self.n_early_size/2)
                n_remaining_channels = n_remaining_channels - self.n_early_size
            self.convinv.append(Invertible1x1Conv(n_remaining_channels))
            self.WN.append(WN(n_half, n_mel_channels*n_group, n_layers, n_channels, kernel_size))
        self.n_remaining_channels = n_remaining_channels

    def forward(self, forward_input):
        """
        mel_spec: [BS, n_mel_channels, frames]
        audio: [BS, time]
        """

        spec, audio = forward_input
        spec = self.upsample(spec)
        

        


tensor(50, dtype=torch.int32)