# WaveGAN Implementation

| Teammember |                    |
|------------|--------------------|
| 1.         | Christopher Caldwell |
| 2.         | Fabian Müller      |
| 3.         | An Dang         |

## Introduction

The code in this notebook has been taken almost completely from the orginial waveGAN github repository (Donahue et. al., 2019) <sup>[1]</sup>, which can be found [here](https://github.com/chrisdonahue/wavegan). The code has been conformed to fit the jupyter notebook.  

For the purpose of this lecture, we will first give a complete overview of the waveGAN implementation, the idea behind it, and the challenges we had to face.  
We will continue to intensively investigate how and why certain measures have been taken in between each code block. Some improvements from our side will be considered in theory.  
At the end the results of the waveGAN training are displayed and the corresponding results are reviewed and interpreted. Related and further implementations for music generation are presented.

## Overview

### Motivation & Idea
The idea behind waveGAN is to use raw audio as a source for unsupervised training data to create music using GANs with a very similar approach to [DCGAN (Radford et. al., 2016)](https://arxiv.org/pdf/1511.06434.pdf)<sup>[2]</sup>. WaveGAN uses a generator network to generate music which the discriminator network cannot distinguish from real music.  
This approach is very similar to the image-generating GAN networks that have become known around the world. In waveGAN, however, the data source is not an RGB image, but music with a high sampling rate.  
Music data is a sequential type of data and must therefore take the dimension of time into consideration.

### Architecture

#### Loading the data

The most efficient way to load the data into the waveGAN neural network is to start off with *raw PCM 16-bit 16kHZ mono music data.*  
For this implementation we used the [Bach piano performances](http://deepyeti.ucsd.edu/cdonahue/wavegan/data/mancini_piano.tar.gz) training set<sup>[3]</sup>, specifically gathered for waveGAN by its author C. Donahue. This music dataset is however 24 bit, stereo 48kHZ data. For waveGan to work with the music, it must be converted to 16 bit.  
The waveGAN implementation can however work with the 48kHz sampling rate and stereo mode. It internally converts the sampling rate to 16kHz using the scipy module and the number of channels can be determined through defining the according hyperparameter before training. The implementation can also handle mp3 data, by using the librosa module to convert it to the required dataformat.
But to save computation time, we decided to convert the original data with the open-source programm [Audacity](https://www.audacity.de/) in advance to the optimal *raw PCM 16-bit 16kHZ mono music data.*







In [1]:
import tensorflow as tf
import numpy as np
from scipy.io.wavfile import read as wavread
import sys
import matplotlib.pyplot as plt

%matplotlib inline

tf.reset_default_graph()

# Loading and Preping the Data

## Decode Audio
the decode audio function decodes the audio file paths into 32-bit floating point vectors.  
The input parameters are:  

| args         | description                                                 |
|--------------|-------------------------------------------------------------|
| fp           | a string containing the filepath to the WAV file (required) |
| fs           | resamples the the decoded audio to this sampling rate       |
| num_channels | specify the number of channels of the music data            |
| normalize    | normalize the decoded music data                            |
| fast_wav     | specify if the source data can be processed with the faster scipy module or if it needs librosa to convert the file.|

the function returns a np.float32 array containing the audio samples at the specific sample rate.

In [2]:
def decode_audio(fp, fs=None, num_channels=1, normalize=False, fast_wav=False):
    
    # Use scipy fast wave read method, if fast_wav is true
    if fast_wav:
        
        # Read with scipy wavread (fast).
        _fs, _wav = wavread(fp)
        
        # if the sample rate is specified, but not equal to the return sample rate from the wavread method. a notimplementederror is raised.
        # use scipy.signal.resample to resample the 
        if fs is not None and fs != _fs:
            ## NOT ORIGINAL WAVEGAN IMPLEMENTATION: 
            ## added resampling code as an alternative
            """
            from scipy.signal import resample
            _wav = resample(_wav, fs)
            """ 
            raise NotImplementedError('Scipy cannot resample audio.')
        
        # Convert _wav to float32 if the datatype is int16
        if _wav.dtype == np.int16:
            _wav = _wav.astype(np.float32)
            _wav /= 32768.
        
        # _wav file is of type float32.
        elif _wav.dtype == np.float32:
            _wav = np.copy(_wav)
        
        # if _wav file is not of type float32 or int16 the wav file cannot be processed with scipy. 
        # fast_wav is not possible with this file
        else:
              raise NotImplementedError('Scipy cannot process atypical WAV files.')
    
    # if fast_wav is false, the source music data will be processed with the librosa module
    else:
        # Decode with librosa load (slow but supports file formats like mp3).
        import librosa
        
        mono = False
        """
        # check if the music source file is mono.
        if num_channels > 1:
            mono = False
        elif num_channels == 1:
            mono = True
        else:
            raise NotImplementedError('Librosa cannot process WAV files with less than one channel')
        """
        
        _wav, _fs = librosa.core.load(fp, sr=fs, mono=mono)
        
        
        if _wav.ndim == 2:
            _wav = np.swapaxes(_wav, 0, 1)

    
    # stop if datatype of wav is not float32 at this point.
    assert _wav.dtype == np.float32

    # At this point, _wav is np.float32 either [nsamps,] or [nsamps, nch].
    # We want [nsamps, 1, nch] to mimic 2D shape of spectral feats.
    if _wav.ndim == 1:
        nsamps = _wav.shape[0]
        nch = 1
    else:
        nsamps, nch = _wav.shape
    _wav = np.reshape(_wav, [nsamps, 1, nch])
 
    # Average (mono) or expand (stereo) channels
    if nch != num_channels:
        if num_channels == 1:
            _wav = np.mean(_wav, 2, keepdims=True)
        elif nch == 1 and num_channels == 2:
            _wav = np.concatenate([_wav, _wav], axis=2)
        else:
            raise ValueError('Number of audio channels not equal to num specified')

    # if specified the data is normalized.
    if normalize:
        factor = np.max(np.abs(_wav))
        if factor > 0:
            _wav /= factor

    return _wav

## Decode Extract and Batch

The decode extract and batch function takes the file paths and initializes batches to be fed to the neural network.  
Depending on the specified parameters (repeat | shuffle | 

In [3]:
def decode_extract_and_batch(fps, batch_size, slice_len, decode_fs, 
                             decode_num_channels, decode_normalize=True,
                             decode_fast_wav=False, decode_parallel_calls=1,
                             slice_randomize_offset=False, slice_first_only=False,
                             slice_overlap_ratio=0, slice_pad_end=False,
                             repeat=False,
                             shuffle=False,
                             shuffle_buffer_size=None,
                             prefetch_size=None,
                             prefetch_gpu_num=None):
    """Decodes audio file paths into mini-batches of samples.
    Args:
        fps: List of audio file paths.
        batch_size: Number of items in the batch.
        slice_len: Length of the sliceuences in samples or feature timesteps.
        decode_fs: (Re-)sample rate for decoded audio files.
        decode_num_channels: Number of channels for decoded audio files.
        decode_normalize: If false, do not normalize audio waveforms.
        decode_fast_wav: If true, uses scipy to decode standard wav files.
        decode_parallel_calls: Number of parallel decoding threads.
        slice_randomize_offset: If true, randomize starting position for slice.
        slice_first_only: If true, only use first slice from each audio file.
        slice_overlap_ratio: Ratio of overlap between adjacent slices.
        slice_pad_end: If true, allows zero-padded examples from the end of each audio file.
        repeat: If true (for training), continuously iterate through the dataset.
        shuffle: If true (for training), buffer and shuffle the sliceuences.
        shuffle_buffer_size: Number of examples to queue up before grabbing a batch.
        prefetch_size: Number of examples to prefetch from the queue.
        prefetch_gpu_num: If specified, prefetch examples to GPU.
    Returns:
        A tuple of np.float32 tensors representing audio waveforms.
        audio: [batch_size, slice_len, 1, nch]
    """
  
    # Create dataset of filepaths
    dataset = tf.data.Dataset.from_tensor_slices(fps)

    # Shuffle all filepaths every epoch
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(fps))

    # Repeat
    if repeat:
        dataset = dataset.repeat()

    
    # Prepare the transformation for each dataelement in a definition
    def _decode_audio_shaped(fp):
        _decode_audio_closure = lambda _fp: decode_audio(_fp, 
                                                         fs=decode_fs,
                                                         num_channels=decode_num_channels,
                                                         normalize=decode_normalize,
                                                         fast_wav=decode_fast_wav)
        
        # tf.py_func: Wraps a python function and uses it as a TensorFlow op
        audio = tf.py_func(_decode_audio_closure,
                           [fp],
                           tf.float32,
                           stateful=False)
        
        audio.set_shape([None, 1, decode_num_channels])

        return audio
    
    # dataset.map()
    # This transformation applies map_func to each element of this dataset, and returns a new dataset 
    # containing the transformed elements, in the same order as they appeared in the input.
    # Decode audio by using "_decode_audio_shaped" as the transformation function
    # num_parallel_calls specifies the number of parallel decode threads
    dataset = dataset.map(_decode_audio_shaped, num_parallel_calls=decode_parallel_calls)

    # Parallel
    def _slice(audio):
        # Calculate hop size
        if slice_overlap_ratio < 0:
            raise ValueError('Overlap ratio must be greater than 0')
        slice_hop = int(round(slice_len * (1. - slice_overlap_ratio)) + 1e-4)
        if slice_hop < 1:
            raise ValueError('Overlap ratio too high')

        # Randomize starting phase:
        if slice_randomize_offset:
            start = tf.random_uniform([], maxval=slice_len, dtype=tf.int32)
            audio = audio[start:]

        # Extract sliceuences
        audio_slices = tf.contrib.signal.frame(audio, 
                                               slice_len,
                                               slice_hop,
                                               pad_end=slice_pad_end,
                                               pad_value=0,
                                               axis=0)

        # Only use first slice if requested
        if slice_first_only:
            audio_slices = audio_slices[:1]

        return audio_slices

    def _slice_dataset_wrapper(audio):
        audio_slices = _slice(audio)
        return tf.data.Dataset.from_tensor_slices(audio_slices)

    # Extract parallel sliceuences from both audio and features
    dataset = dataset.flat_map(_slice_dataset_wrapper)

    # Shuffle examples
    if shuffle:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    # Make batches
    dataset = dataset.batch(batch_size, drop_remainder=True)

    # Prefetch a number of batches
    if prefetch_size is not None:
        dataset = dataset.prefetch(prefetch_size)
        if prefetch_gpu_num is not None and prefetch_gpu_num >= 0:
            dataset = dataset.apply(tf.data.experimental.prefetch_to_device('/device:GPU:{}'.format(prefetch_gpu_num)))

    # Get tensors
    iterator = dataset.make_one_shot_iterator()
  
    return iterator.get_next()

In [4]:
def conv1d_transpose(inputs, filters, kernel_width, stride=4, padding='same', upsample='zeros'):
    if upsample == 'zeros':
        return tf.layers.conv2d_transpose(tf.expand_dims(inputs, axis=1), 
                                          filters,
                                          (1, kernel_width),
                                          strides=(1, stride),
                                          padding='same')[:, 0]
    
    # If Upsampling should use nearest neighbor
    elif upsample == 'nn':
        batch_size = tf.shape(inputs)[0]
        _, w, nch = inputs.get_shape().as_list()

        x = inputs

        x = tf.expand_dims(x, axis=1)
        x = tf.image.resize_nearest_neighbor(x, [1, w * stride])
        x = x[:, 0]

        return tf.layers.conv1d(x, filters, kernel_width, 1, padding='same')
  
    else:
        raise NotImplementedError

In [5]:
def lrelu(inputs, alpha=0.2):
    return tf.maximum(alpha * inputs, inputs)

## Apply Phaseshuffle

The upscaling by convolution is known to create checkerboard artefacts:
![Checkerboard Artefacts](./checkerboard.png)
<sup>[9]</sup> [Source: Checkerboard Artefacts](http://physhik.com/waveGAN/)  

The upscaling artefacts not only appear with images. During the upsacling process sounddata generates sounddata. To overcome these artefacts, waveGAN uses Phaseshuffle.

![Phaseshuffle](./phaseshuffle.png)
<sup>[9]</sup> [Source: Phase Shuffle](https://arxiv.org/pdf/1802.04208.pdf)  

At each layer of the discriminator, the phase of each channel is shuffled, using a uniformely distributed random value between the specified `wavegan_disc_phaseshuffle` hyperparameter. That way, the discriminator can not simply learn to recognize the artifacts to distinguish between real and fake data. 

In [6]:
def apply_phaseshuffle(x, rad, pad_type='reflect'):
    b, x_len, nch = x.get_shape().as_list()

    phase = tf.random_uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32)
    pad_l = tf.maximum(phase, 0)
    pad_r = tf.maximum(-phase, 0)
    phase_start = pad_r
    x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type)

    x = x[:, phase_start:phase_start+x_len]
    x.set_shape([b, x_len, nch])

    return x

## WaveGAN Discriminator

![WaveGAN Discriminator Architechture](./discriminator.png)  


The `input layer` of the discriminator is either a sample of real music or generated music from the generator network.
The shape of the input layer consists of  
*n: batchsize*  
*samplesize* <sup>(16384 ~ 1 sec at 16kHz | 32768 ~ 2 sec at 16 kHz | 65536 ~ 4 sec at 16 kHz)</sup>  
*c: number of channels*  


Following the Input Layer come 4 repetitions of convolutional layers, combined with a Linear ReLU activation function and the phase shuffling operations.  
After the fourth repetition, a 5 repetition occurs, with a reshape instead of the phase-shuffling.  
Each convloution operation has the following parameters:  

* Stride = 4
* Kernelsize = 25 (1 dimensional)

At the end, a single Dense Layer with a 



In [7]:
# Use Kernel Length of 25 (5x5)

def WaveGANDiscriminator(X, reuse_vars=None, kernel_len=25, dim=64, use_batchnorm=False, phaseshuffle_rad=0):
  
    batch_size = tf.shape(X)[0]
    slice_len = int(X.get_shape()[1])

    
    if use_batchnorm:
        batchnorm = lambda x: tf.layers.batch_normalization(x, training=True)
    else:
        batchnorm = lambda x: x

    if phaseshuffle_rad > 0:
        phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad)
    else:
        phaseshuffle = lambda x: x

        
    # Layer 0
    # [16384, 1] -> [4096, 64]
    output = X
    
    with tf.variable_scope('downconv_0'):
        output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME')
    output = lrelu(output)
    output = phaseshuffle(output)

    # Layer 1
    # [4096, 64] -> [1024, 128]
    with tf.variable_scope('downconv_1'):
        output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME')
        output = batchnorm(output)
    output = lrelu(output)
    output = phaseshuffle(output)

    # Layer 2
    # [1024, 128] -> [256, 256]
    with tf.variable_scope('downconv_2'):
        output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME')
        output = batchnorm(output)
    output = lrelu(output)
    output = phaseshuffle(output)

    # Layer 3
    # [256, 256] -> [64, 512]
    with tf.variable_scope('downconv_3'):
        output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME')
        output = batchnorm(output)
    output = lrelu(output)
    output = phaseshuffle(output)

    # Layer 4
    # [64, 512] -> [16, 1024]
    with tf.variable_scope('downconv_4'):
        output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME')
        output = batchnorm(output)
    output = lrelu(output)

    # Two seconds -> 16384 samples / second ---> 32768 / 2 seconds
    if slice_len == 32768:
        # Layer 5
        # [32, 1024] -> [16, 2048]
        with tf.variable_scope('downconv_5'):
            output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME')
            output = batchnorm(output)
        output = lrelu(output)
    
    # Four seconds -> 16384 samples / second ---> 65536 / 4 seconds
    elif slice_len == 65536:
        # Layer 5
        # [64, 1024] -> [16, 2048]
        with tf.variable_scope('downconv_5'):
            output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME')
            output = batchnorm(output)
        output = lrelu(output)

    # Flatten
    output = tf.reshape(output, [batch_size, -1])

    # Connect to single logit
    with tf.variable_scope('output'):
        output = tf.layers.dense(output, 1)[:, 0]

    # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training

    return output
    
  

## Generator

In [8]:
def WaveGANGenerator(z, slice_len=16384, nch=1, kernel_len=25, dim=64, use_batchnorm=False, upsample='zeros', train=False):

    assert slice_len in [16384, 32768, 65536]
    batch_size = tf.shape(z)[0]

    if use_batchnorm:
        batchnorm = lambda x: tf.layers.batch_normalization(x, training=train)
    else:
        batchnorm = lambda x: x

    # FC and reshape for convolution
    # [100] -> [16, 1024]
    dim_mul = 16 if slice_len == 16384 else 32
    output = z
    with tf.variable_scope('z_project'):
        output = tf.layers.dense(output, 4 * 4 * dim * dim_mul)
        output = tf.reshape(output, [batch_size, 16, dim * dim_mul])
        output = batchnorm(output)
    output = tf.nn.relu(output)
    dim_mul //= 2

    # Layer 0
    # [16, 1024] -> [64, 512]
    with tf.variable_scope('upconv_0'):
        output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
        output = batchnorm(output)
    output = tf.nn.relu(output)
    dim_mul //= 2

    # Layer 1
    # [64, 512] -> [256, 256]
    with tf.variable_scope('upconv_1'):
        output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
        output = batchnorm(output)
    output = tf.nn.relu(output)
    dim_mul //= 2

    # Layer 2
    # [256, 256] -> [1024, 128]
    with tf.variable_scope('upconv_2'):
        output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
        output = batchnorm(output)
    output = tf.nn.relu(output)
    dim_mul //= 2

    # Layer 3
    # [1024, 128] -> [4096, 64]
    with tf.variable_scope('upconv_3'):
        output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
        output = batchnorm(output)
    output = tf.nn.relu(output)

    if slice_len == 16384:
        # Layer 4
        # [4096, 64] -> [16384, nch]
        with tf.variable_scope('upconv_4'):
            output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample)
        output = tf.nn.tanh(output)
   
    elif slice_len == 32768:
        # Layer 4
        # [4096, 128] -> [16384, 64]
        with tf.variable_scope('upconv_4'):
            output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf.nn.relu(output)

        # Layer 5
        # [16384, 64] -> [32768, nch]
        with tf.variable_scope('upconv_5'):
            output = conv1d_transpose(output, nch, kernel_len, 2, upsample=upsample)
        output = tf.nn.tanh(output)
        
    elif slice_len == 65536:
        # Layer 4
        # [4096, 128] -> [16384, 64]
        with tf.variable_scope('upconv_4'):
            output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf.nn.relu(output)

        # Layer 5
        # [16384, 64] -> [65536, nch]
        with tf.variable_scope('upconv_5'):
            output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample)
        output = tf.nn.tanh(output)

    # Automatically update batchnorm moving averages every time G is used during training
    if train and use_batchnorm:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
        if slice_len == 16384:
            assert len(update_ops) == 10
        else:
            assert len(update_ops) == 12
        with tf.control_dependencies(update_ops):
            output = tf.identity(output)

    return output
    
   

In [9]:
#Adding config for GPU - not using all of the GPUs Memory -> leading to crashes
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7

In [10]:
def train(fps, args):
    with tf.name_scope('loader'):
        x = decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=args.data_slice_len,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim], -1., 1., dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

  

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
    print('-' * 80)

  


    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

        
        
    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
          logits=D_G_z,
          labels=real
        ))

        D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
          logits=D_G_z,
          labels=fake
        ))
        D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
          logits=D_x,
          labels=real
        ))

        D_loss /= 2.
  
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
        D_loss = tf.reduce_mean((D_x - 1.) ** 2)
        D_loss += tf.reduce_mean(D_G_z ** 2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(tf.assign(var,
                                          tf.clip_by_value(var,
                                                           clip_bounds[0], 
                                                           clip_bounds[1])
                                         )
                               )
            D_clip_weights = tf.group(*clip_ops)
            
            
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(
            learning_rate=2e-4,
            beta1=0.5)
        D_opt = tf.train.AdamOptimizer(
            learning_rate=2e-4,
            beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(
            learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(
            learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(
            learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(
            learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(
            learning_rate=1e-4,
            beta1=0.5,
            beta2=0.9)
        D_opt = tf.train.AdamOptimizer(
            learning_rate=1e-4,
            beta1=0.5,
            beta2=0.9)
    else:
        raise NotImplementedError()

  
    # Create training ops
    G_train_op = G_opt.minimize(G_loss, 
                                var_list=G_vars,
                                global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(checkpoint_dir=args.train_dir,
                                           save_checkpoint_secs=args.train_save_secs,
                                           save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

            # Enforce Lipschitz constraint for WGAN
            if D_clip_weights is not None:
                sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)

In [11]:
def infer(args):
    infer_dir = os.path.join(args.train_dir, 'infer')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    # Subgraph that generates latent vectors
    samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    samp_z = tf.random_uniform([samp_z_n, args.wavegan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z')

    # Input zo
    z = tf.placeholder(tf.float32, [None, args.wavegan_latent_dim], name='z')
    flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')

    # Execute generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
    G_z = tf.identity(G_z, name='G_z')

    # Flatten batch
    nch = int(G_z.get_shape()[-1])
    G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
    G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')

    # Encode to int16
    def float_to_int16(x, name=None):
        x_int16 = x * 32767.
        x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
        x_int16 = tf.cast(x_int16, tf.int16, name=name)
        return x_int16
    G_z_int16 = float_to_int16(G_z, name='G_z_int16')
    G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')

    # Create saver
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    global_step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(G_vars + [global_step])

    # Export graph
    tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')

    # Export MetaGraph
    infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
    tf.train.export_meta_graph(
        filename=infer_metagraph_fp,
        clear_devices=True,
        saver_def=saver.as_saver_def())

    # Reset graph (in case training afterwards)
    tf.reset_default_graph()


In [12]:
"""
  Generates a preview audio file every time a checkpoint is saved
"""
def preview(args):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from scipy.io.wavfile import write as wavwrite
    from scipy.signal import freqz

    preview_dir = os.path.join(args.train_dir, 'preview')
    if not os.path.isdir(preview_dir):
        os.makedirs(preview_dir)

    # Load graph
    infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
    graph = tf.get_default_graph()
    saver = tf.train.import_meta_graph(infer_metagraph_fp)

    # Generate or restore z_i and z_o
    z_fp = os.path.join(preview_dir, 'z.pkl')
    if os.path.exists(z_fp):
        with open(z_fp, 'rb') as f:
            _zs = pickle.load(f)
    else:
        # Sample z
        samp_feeds = {}
        samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n
        samp_fetches = {}
        samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0')
        with tf.Session(config=config) as sess:
            _samp_fetches = sess.run(samp_fetches, samp_feeds)
        _zs = _samp_fetches['zs']

        # Save z
        with open(z_fp, 'wb') as f:
            pickle.dump(_zs, f)

    # Set up graph for generating preview images
    feeds = {}
    feeds[graph.get_tensor_by_name('z:0')] = _zs
    feeds[graph.get_tensor_by_name('flat_pad:0')] = int(args.data_sample_rate / 2)
    fetches = {}
    fetches['step'] = tf.train.get_or_create_global_step()
    fetches['G_z'] = graph.get_tensor_by_name('G_z:0')
    fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0')
    if args.wavegan_genr_pp:
        fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0]

    # Summarize
    G_z = graph.get_tensor_by_name('G_z_flat:0')
    summaries = [
      tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), args.data_sample_rate, max_outputs=1)
    ]
    fetches['summaries'] = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(preview_dir)

    # PP Summarize
    if args.wavegan_genr_pp:
        pp_fp = tf.placeholder(tf.string, [])
        pp_bin = tf.read_file(pp_fp)
        pp_png = tf.image.decode_png(pp_bin)
        pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0))

    # Loop, waiting for checkpoints
    ckpt_fp = None
    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            print('Preview: {}'.format(latest_ckpt_fp))

            with tf.Session(config=config) as sess:
                saver.restore(sess, latest_ckpt_fp)

                _fetches = sess.run(fetches, feeds)

                _step = _fetches['step']

            preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8)))
            wavwrite(preview_fp, args.data_sample_rate, _fetches['G_z_flat_int16'])

            summary_writer.add_summary(_fetches['summaries'], _step)

            if args.wavegan_genr_pp:
                w, h = freqz(_fetches['pp_filter'])

                fig = plt.figure()
                plt.title('Digital filter frequncy response')
                ax1 = fig.add_subplot(111)

                plt.plot(w, 20 * np.log10(abs(h)), 'b')
                plt.ylabel('Amplitude [dB]', color='b')
                plt.xlabel('Frequency [rad/sample]')

                ax2 = ax1.twinx()
                angles = np.unwrap(np.angle(h))
                plt.plot(w, angles, 'g')
                plt.ylabel('Angle (radians)', color='g')
                plt.grid()
                plt.axis('tight')

                _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8)))
                plt.savefig(_pp_fp)

                with tf.Session(config=config) as sess:
                    _summary = sess.run(pp_summary, {pp_fp: _pp_fp})
                    summary_writer.add_summary(_summary, _step)
                
            print('Done')

            ckpt_fp = latest_ckpt_fp

        time.sleep(1)


In [13]:
"""
  Computes inception score every time a checkpoint is saved
"""
def incept(args):
    incept_dir = os.path.join(args.train_dir, 'incept')
    if not os.path.isdir(incept_dir):
        os.makedirs(incept_dir)

    # Load GAN graph
    gan_graph = tf.Graph()
    with gan_graph.as_default():
        infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
        gan_saver = tf.train.import_meta_graph(infer_metagraph_fp)
        score_saver = tf.train.Saver(max_to_keep=1)
    gan_z = gan_graph.get_tensor_by_name('z:0')
    gan_G_z = gan_graph.get_tensor_by_name('G_z:0')[:, :, 0]
    gan_step = gan_graph.get_tensor_by_name('global_step:0')

    # Load or generate latents
    z_fp = os.path.join(incept_dir, 'z.pkl')
    if os.path.exists(z_fp):
        with open(z_fp, 'rb') as f:
            _zs = pickle.load(f)
    else:
        gan_samp_z_n = gan_graph.get_tensor_by_name('samp_z_n:0')
        gan_samp_z = gan_graph.get_tensor_by_name('samp_z:0')
        with tf.Session(graph=gan_graph) as sess:
            _zs = sess.run(gan_samp_z, {gan_samp_z_n: args.incept_n})
        with open(z_fp, 'wb') as f:
            pickle.dump(_zs, f)

    # Load classifier graph
    incept_graph = tf.Graph()
    with incept_graph.as_default():
        incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp)
    incept_x = incept_graph.get_tensor_by_name('x:0')
    incept_preds = incept_graph.get_tensor_by_name('scores:0')
    incept_sess = tf.Session(graph=incept_graph)
    incept_saver.restore(incept_sess, args.incept_ckpt_fp)

    # Create summaries
    summary_graph = tf.Graph()
    with summary_graph.as_default():
        incept_mean = tf.placeholder(tf.float32, [])
        incept_std = tf.placeholder(tf.float32, [])
        summaries = [
            tf.summary.scalar('incept_mean', incept_mean),
            tf.summary.scalar('incept_std', incept_std)
        ]
        summaries = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(incept_dir)

    # Loop, waiting for checkpoints
    ckpt_fp = None
    _best_score = 0.
    while True:
        latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
        if latest_ckpt_fp != ckpt_fp:
            print('Incept: {}'.format(latest_ckpt_fp))

            sess = tf.Session(graph=gan_graph)

            gan_saver.restore(sess, latest_ckpt_fp)

            _step = sess.run(gan_step)

            _G_zs = []
            for i in xrange(0, args.incept_n, 100):
                _G_zs.append(sess.run(gan_G_z, {gan_z: _zs[i:i+100]}))
            _G_zs = np.concatenate(_G_zs, axis=0)

            _preds = []
            for i in xrange(0, args.incept_n, 100):
                _preds.append(incept_sess.run(incept_preds, {incept_x: _G_zs[i:i+100]}))
            _preds = np.concatenate(_preds, axis=0)

            # Split into k groups
            _incept_scores = []
            split_size = args.incept_n // args.incept_k
            for i in xrange(args.incept_k):
                _split = _preds[i * split_size:(i + 1) * split_size]
                _kl = _split * (np.log(_split) - np.log(np.expand_dims(np.mean(_split, 0), 0)))
                _kl = np.mean(np.sum(_kl, 1))
                _incept_scores.append(np.exp(_kl))

            _incept_mean, _incept_std = np.mean(_incept_scores), np.std(_incept_scores)

            # Summarize
            with tf.Session(graph=summary_graph) as summary_sess:
                _summaries = summary_sess.run(summaries, {incept_mean: _incept_mean, incept_std: _incept_std})
            summary_writer.add_summary(_summaries, _step)

            # Save
            if _incept_mean > _best_score:
                score_saver.save(sess, os.path.join(incept_dir, 'best_score'), _step)
                _best_score = _incept_mean

            sess.close()

            print('Done')

            ckpt_fp = latest_ckpt_fp

        time.sleep(1)

    incept_sess.close()

# Setting up the Parameters

In [14]:
import argparse
import glob
import sys
import os

In [15]:
class args:
    data_dir="./piano/train/16bit-1chan"
    train_dir="./train"
    data_sample_rate=16000
    data_slice_len=65536
    data_num_channels=1
    data_overlap_ratio=0.
    data_first_slice=False
    data_pad_end=False
    data_normalize=False
    data_fast_wav=True
    data_prefetch_gpu_num=0
    wavegan_latent_dim=100
    wavegan_kernel_len=25
    wavegan_dim=64
    wavegan_batchnorm=False
    wavegan_disc_nupdates=5
    wavegan_loss='wgan-gp'
    wavegan_genr_upsample='zeros'
    wavegan_genr_pp=True
    wavegan_genr_pp_len=512
    wavegan_disc_phaseshuffle=0
    train_batch_size=64
    train_save_secs=300
    train_summary_secs=120
    preview_n=32
    incept_metagraph_fp='./eval/inception/infer.meta'
    incept_ckpt_fp='./eval/inception/best_acc-103005'
    incept_n=5000
    incept_k=10
    
    wavegan_g_kwargs = {
    'slice_len': data_slice_len,
    'nch': data_num_channels,
    'kernel_len': wavegan_kernel_len,
    'dim': wavegan_dim,
    'use_batchnorm': wavegan_batchnorm,
    'upsample': wavegan_genr_upsample
    }
    
    wavegan_d_kwargs = {
    'kernel_len': wavegan_kernel_len,
    'dim': wavegan_dim,
    'use_batchnorm': wavegan_batchnorm,
    'phaseshuffle_rad': wavegan_disc_phaseshuffle
    }

In [16]:
print (args.data_dir)

./piano/train/16bit-1chan


In [17]:
fps = glob.glob(os.path.join(args.data_dir, '*'))

In [18]:
print(fps)

['./piano/train/16bit-1chan\\00.wav', './piano/train/16bit-1chan\\01.wav', './piano/train/16bit-1chan\\02.wav', './piano/train/16bit-1chan\\03.wav', './piano/train/16bit-1chan\\04.wav', './piano/train/16bit-1chan\\05.wav', './piano/train/16bit-1chan\\06.wav', './piano/train/16bit-1chan\\07.wav', './piano/train/16bit-1chan\\08.wav', './piano/train/16bit-1chan\\09.wav', './piano/train/16bit-1chan\\10.wav', './piano/train/16bit-1chan\\11.wav', './piano/train/16bit-1chan\\12.wav', './piano/train/16bit-1chan\\13.wav', './piano/train/16bit-1chan\\14.wav', './piano/train/16bit-1chan\\15.wav', './piano/train/16bit-1chan\\16.wav', './piano/train/16bit-1chan\\17.wav', './piano/train/16bit-1chan\\18.wav']


# Start Training

In [None]:
try:
    import cPickle as pickle
except:
    import pickle
from functools import reduce
import os
import time
from six.moves import xrange
import librosa

In [None]:
if len(fps) == 0:
    raise Exception('Did not find any audio files in specified directory')
print('Found {} audio files in specified directory'.format(len(fps)))
#infer(args)
train(fps, args)

Found 19 audio files in specified directory
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.conv2d_transpose instead.
Instructions for updating:
Use keras.layers.conv1d instead.
-------------------------



INFO:tensorflow:Saving checkpoints for 1126 into ./train\model.ckpt.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Saving checkpoints for 1157 into ./train\model.ckpt.
INFO:tensorflow:Saving checkpoints for 1188 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 0.0937732
INFO:tensorflow:Saving checkpoints for 1221 into ./train\model.ckpt.
INFO:tensorflow:Saving checkpoints for 1252 into ./train\model.ckpt.
INFO:tensorflow:Saving checkpoints for 1283 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 0.10452


# References

<sup>[1]</sup> [waveGAN github repository (Donahue et. al., 2019)](https://github.com/chrisdonahue/wavegan)  
<sup>[2]</sup> [DCGAN (Radford et. al., 2016)](https://arxiv.org/pdf/1511.06434.pdf)  
<sup>[3]</sup> [Bach piano performances](http://deepyeti.ucsd.edu/cdonahue/wavegan/data/mancini_piano.tar.gz)  
<sup>[4]</sup> [Scipy.io.wavfile.read](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.io.wavfile.read.html)  
<sup>[5]</sup> [Python assert](https://www.programiz.com/python-programming/assert-statement)  
<sup>[6]</sup> [tf.py_func](https://www.tensorflow.org/api_docs/python/tf/py_func)  
<sup>[7]</sup> [tf.dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map)  
<sup>[8]</sup> [tf.dataset.flat_map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#flat_map)  
<sup>[9]</sup> 

In [None]:
print("abc")