In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

import librosa
from librosa.display import specshow
from librosa.display import waveplot

import IPython.display as ipyd

from scipy.signal import stft as sp_stft
from scipy.signal import istft as sp_istft

from sklearn.preprocessing import OneHotEncoder

import tensorflow as tf
tf.set_random_seed(42)

import time

import keras
import keras.backend as K

from keras.models import Model, Sequential
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import Lambda, ZeroPadding2D, Embedding, Multiply
from keras.layers import Conv2D, Concatenate
from keras.layers import Conv2DTranspose as Conv2DT
from keras.layers import BatchNormalization as BatchNorm
from keras.layers import LeakyReLU, Activation, ELU
from keras.optimizers import Adam, RMSprop
from keras.initializers import RandomNormal
from keras import regularizers 

In [None]:
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02)
l2_reg = regularizers.l2(1e-5)

Code is modified from https://github.com/tjwei/GANotebooks/blob/master/wgan-keras.ipynb

In [None]:
def MUSIC_DCGAN_D(n_freq, n_time, n_c, n_df, n_class=None):
    
    # move more along time axis
    #if n_freq != n_time:
    #    ks = (4, 8)
    #    strides = (2, 4)
    #else:
    #    ks = (4, 4)
    #    strides = (2, 2)
    ks = (4, 4)
    strides = (2, 2)
    
    inputs =Input(shape=(n_freq, n_time, n_c))
    
    # XXX: Why not use bias ? 
    x = Conv2D(filters=n_df, kernel_size=ks, strides=strides, kernel_regularizer=l2_reg,
            kernel_initializer=conv_init, padding='same', use_bias=False, 
            name='initial.conv.{0}-{1}'.format(n_c, n_df))(inputs)
    x = LeakyReLU(alpha=0.2, 
           name='initial.relu.{0}'.format(n_df))(x)
    
    c_nfreq, c_ntime = n_freq // strides[0], n_time // strides[1]
    c_ndf = n_df
    
    # Extra layer conv along time
    n_extra = 0
    while c_ntime > c_nfreq:
        x = Conv2D(filters=n_df, kernel_size=(1, 3), strides=(1, 2), kernel_regularizer=l2_reg,
                kernel_initializer=conv_init, padding='same', use_bias=False, 
                name='extra.{0}-{1}.conv'.format(n_extra, c_ntime))(x)
        #x = BatchNorm(gamma_initializer=gamma_init, epsilon=1.01e-5, momentum=0.9,
        #        name='pyramid.{0}.bn'.format(n_op_feats))(x, training=1)
        x = LeakyReLU(alpha=0.2, 
               name='extra.{0}-{1}.relu'.format(n_extra, c_ntime))(x)
        n_extra += 1
        c_ntime = c_ntime // 2
    
    # Pyramid Conv
    while c_nfreq > 4:
        n_in_feats = c_ndf
        n_op_feats = c_ndf * 2
        
        x = Conv2D(filters=n_op_feats, kernel_size=ks, strides=strides, kernel_regularizer=l2_reg,
                kernel_initializer=conv_init, padding='same', use_bias=False, 
                name='pyramid.{0}-{1}.conv'.format(n_in_feats, n_op_feats))(x)
        #x = BatchNorm(gamma_initializer=gamma_init, epsilon=1.01e-5, momentum=0.9,
        #        name='pyramid.{0}.bn'.format(n_op_feats))(x, training=1)
        x = LeakyReLU(alpha=0.2, 
               name='pyramid.{0}.relu'.format(n_op_feats))(x)
                   
        c_nfreq, c_ntime = c_nfreq // strides[0], c_ntime // strides[1]
        c_ndf = c_ndf * 2

    output_fake = Conv2D(filters=1, kernel_size=(c_nfreq, c_ntime), strides=1, use_bias=False,  
            kernel_initializer=conv_init, kernel_regularizer=l2_reg,
            name='final.{0}-{1}.conv'.format(c_ndf, 1))(x)
    output_fake = Flatten()(output_fake)
    
#     # ACGAN improvement
#     if n_class is not None:
#         output_class = Conv2D(filters=n_class, kernel_size=(c_nfreq, c_ntime), strides=1, use_bias=False, 
#                             kernel_initializer=conv_init, kernel_regularizer=l2_reg,  
#                             name='final.{0}-{1}.conv'.format(c_ndf, n_class))(x)
#         output_class = Flatten()(output_class)
#         output_class = Activation('softmax')(output_class)
#         return Model(inputs=[inputs], outputs=[output_fake, output_class], name='discriminator')

    return Model(inputs=[inputs], outputs=[output_fake], name='discriminator')

In [None]:
def MUSIC_DCGAN_G(n_z, n_gf, n_freq, n_time, n_c, n_class=None):
        
    # move more along time axis
#     if n_freq != n_time:
#         ks = (4, 8)
#         strides = (2, 4)
#     else:
#         ks = (4, 4)
#         strides = (2, 2)
    ks = (4, 4)
    strides = (2, 2)
    
    # compute generator first layer channel
    c_ngf = n_gf // 2
    c_nfreq, c_ntime = n_freq, n_time
    
    while c_nfreq > 4:
        c_ngf = c_ngf * 2
        c_nfreq = c_nfreq // strides[0]
        c_ntime = c_ntime // strides[1]
        
    # first back dconv
    input_z = Input(shape=(n_z, ))
    
#     # ACGAN improvement
#     if n_class is not None:
#         input_class = Input(shape=(1, ), dtype='int32')
#         embedding = Embedding(input_dim=n_class, output_dim=n_z, embeddings_initializer='glorot_uniform',
#                 name='initial.{0}.embed'.format(n_class))(input_class)
#         embedding = Flatten()(embedding)
#         x = Multiply()([input_z, embedding])
#         x = Reshape((1, 1 , n_z))(x)
#     else:
#         x = Reshape((1, 1 , n_z))(input_z)
    x = Reshape((1, 1 , n_z))(input_z)
    
    x = Conv2DT(filters=c_ngf, kernel_size=(c_nfreq, c_nfreq), strides=1, use_bias=False,
            kernel_initializer=conv_init, #kernel_regularizer=l2_reg, #padding='samker 
            name = 'initial.{0}-{1}.convt'.format(n_z, c_ngf))(x)
    x = BatchNorm(epsilon=1.01e-5, momentum=0.9,
            name = 'initial.{0}.bn'.format(c_ngf))(x, training=1)
    x = Activation("relu", 
            name='initial.{0}.relu'.format(c_ngf))(x)
    
    while c_nfreq < n_freq // 2:
        n_in_feats = c_ngf
        n_op_feats = c_ngf // 2
        
        x = Conv2DT(filters=n_op_feats, kernel_size=ks, strides=strides, use_bias=False,
                kernel_initializer=conv_init, padding='same', #kernel_regularizer=l2_reg,
                name='pyramid.{0}-{1}.convt'.format(n_in_feats, n_op_feats))(x)
        x = BatchNorm(gamma_initializer=gamma_init, epsilon=1.01e-5, momentum=0.9,
                name='pyramid.{0}.bn'.format(n_op_feats))(x, training=1)
        x = Activation("relu",
               name = 'pyramid.{0}.relu'.format(n_op_feats))(x)
        c_nfreq = c_nfreq * strides[0]
        c_ngf = c_ngf // 2
        
    # extra layer conv along time
    n_extra = 0
    while n_time > c_ntime:
        x = Conv2DT(filters=c_ngf, kernel_size=(1, 3), strides=(1 ,2), use_bias=False,
                kernel_initializer=conv_init, padding='same', #kernel_regularizer=l2_reg,
                name='extra.{0}-{1}.convt'.format(n_extra, c_ntime))(x)
        x = BatchNorm(gamma_initializer=gamma_init, epsilon=1.01e-5, momentum=0.9,
                name='extra.{0}-{1}.bn'.format(n_extra, c_ntime))(x, training=1)
        x = Activation("relu",
               name = 'extra.{0}-{1}.relu'.format(n_extra, c_ntime))(x)
        c_ntime = c_ntime * 2
        n_extra += 1
        
    x = Conv2DT(filters=n_c, kernel_size=ks, strides=strides, use_bias=False,
            kernel_initializer=conv_init, padding='same', #kernel_regularizer=l2_reg,
            name='final.{0}-{1}.convt'.format(c_ngf, n_c))(x)
    outputs = Activation('tanh')(x)
    
#     if n_class is not None:
#         return Model(inputs=[input_z, input_class], outputs=[outputs], name='generator')
    
    return Model(inputs=[input_z], outputs=[outputs], name='generator')

### Define WGAN

In [None]:
def get_real_sample(X, batch_size=32, y=None):
    
    idx = np.random.randint(0, X.shape[0], batch_size)
    X_true = X[idx, :]
    
    if y is not None:
        y_true = y[idx]
        return X_true, y_true
    
    return X_true

        
def next_batch(X, y, batch_size=32, shuffle=True):
    """
    Get next batch data
    
    Inputs: 
    - X: input data
    - y: input data label
    - batch_size: s
    
    Outputs tuple of batch data 
    - X_batch: batch sampled X 
    - y_batch: batch sampled y
    """
    
    n_sample = X.shape[0]
    n_batch = n_sample // batch_size
    n_batch = n_batch + 1 if (n_sample % n_batch) != 0 else n_batch 
    idx = np.array(range(n_sample))
    
    if shuffle:
        np.random.shuffle(idx)
        
    for b_idx in range(n_batch):
        start, end = b_idx * batch_size, (b_idx + 1) * batch_size
        if end >= n_sample:
            sample_idx = idx[start:]
        else:
            sample_idx = idx[start:end] 
        
        X_batch, y_batch = X[sample_idx, :], y[sample_idx, :]
        yield X_batch, y_batch

In [None]:
# data parameters
n_c = 2
n_freq = 64
n_time = 1024
n_class = 2

# GAN parameters
n_z = 128
n_gf = 64
n_df = 64
g_lambda = 10

# optim parameters
lrD = 1e-4
lrG = 1e-4

In [None]:
netD = MUSIC_DCGAN_D(n_freq, n_time, n_c, n_df)
netD.summary()

In [None]:
netG = MUSIC_DCGAN_G(n_z, n_gf, n_freq, n_time, n_c)
netG.summary()

### Train helper

## WGAN loss

#### Discriminator loss

In [None]:
# Code is from https://github.com/tjwei/GANotebooks/blob/master/wgan-keras.ipynb

netD_real_input = Input(shape=(n_freq, n_time, n_c))
noisev = Input(shape=(n_z,))
netD_fake_input = netG(noisev)

e_input = K.placeholder(shape=(None,1,1,1))
netD_mixed_input = Input(shape=(n_freq, n_time, n_c),
    tensor=e_input * netD_real_input + (1 - e_input) * netD_fake_input)


loss_real = K.mean(netD(netD_real_input))
loss_fake = K.mean(netD(netD_fake_input))

grad_mixed = K.gradients(netD(netD_mixed_input), [netD_mixed_input])[0]
norm_grad_mixed = K.sqrt(K.sum(K.square(grad_mixed), axis=[1,2,3]))
grad_penalty = K.mean(K.square(norm_grad_mixed -1))

wgan_g_loss = loss_fake - loss_real + g_lambda * grad_penalty


training_updates = Adam(lr=lrD, beta_1=0.5, beta_2=0.9).get_updates(netD.trainable_weights,[], wgan_g_loss)
netD_train = K.function([netD_real_input, noisev, e_input],
                        [loss_real, loss_fake],    
                        training_updates)

#### Generator loss

In [None]:
wgan_d_loss = -loss_fake 
training_updates = Adam(lr=lrG, beta_1=0.5, beta_2=0.9).get_updates(netG.trainable_weights,[], wgan_d_loss)
netG_train = K.function([noisev], [wgan_d_loss], training_updates)

### Train

### Music utility

In [None]:
# def get_audio_sample(X, y, batch_size, rand_slice=False):
    
#     #idx = np.random.choice(X.shape[0], batch_size)
    
#     X_batch = X
#     y_batch = y
    
#     X_batch = audio_to_spec_64x1024(X_batch, rand_slice)
    
#     return X_batch, y_batch


def audio_to_spec_64x1024(X, rand_slice=True):
    """
    Sample audio to spectrum 
    
    Sampled audio to 2 channel spectrum image
    Seperate imag and real part as 2 channel image
    
    """
    
    n_fft = 128
    n_pt = 64 * 1024
    n_len = X.shape[1]
    start = 0
    if rand_slice:
        start, = np.random.choice(n_len - n_pt, 1)
    
    X_slice = X[:, start:start+n_pt]
    _, _, z = sp_stft(X_slice, nperseg=n_fft)
    
    z = z[:, :64, :1024]
    z = z / np.max(np.abs(z), axis=(1, 2), keepdims=True)
    z = np.expand_dims(z, -1)
    z_real = np.real(z)
    z_imag = np.imag(z)
    z = np.concatenate([z_real, z_imag], axis=-1)
    
    return z

def spec_to_audio(X):
    """
    Convert 2 channel spectrum back to audio
    
    """
    xx = spec2ch_to_spec(X)
    
    _, recon = sp_istft(xx)
    
    return recon

def spec2ch_to_spec(X):
    """
    Covert 2 channel spectrum image to complex spectrum
    """
    
    xx = X[:, :, :, 0] + X[:, :, :, 1]*1j
    
    return xx

### Load music data

In [None]:
X_train = np.load('data/midi_2000.npy')

In [None]:
X_train, y_train = np.vstack(X_train[:, 0]), X_train[:, 1]

In [None]:
batch_size = 16

base = 0
epochs = 1000

g_step = 0
N_CRITICS = 5

rand_slice = True

fixed_noise = np.random.normal(0.0, 1.0, size=(batch_size, n_z))

d_loss_hist = []
g_loss_hist = []

In [None]:
t0 = time.time()

for epoch in np.arange(base, epochs):
    
    # shuffle data
    idx = np.arange(len(X_train))
    np.random.shuffle(idx)
    X_train = X_train[idx]
    y_train =  y_train[idx]
    n_sample = X_train.shape[0]
    n_batch = n_sample // batch_size
    n_batch = n_batch + 1 if (n_sample % n_batch) != 0 else n_batch 
    
    d_loss = []
    g_loss = []
    
    for n_b in range(n_batch):
        start, end = n_b * batch_size, (n_b + 1) * batch_size
        if end >= n_sample:
            X_batch = X_train[start:]
        else:
            X_batch = X_train[start:end]
    
        n_critics = N_CRITICS
        if g_step < 25 or g_step % 500 == 0:
            n_critics = 100

        ######################
        # Train discriminator
        ######################
        for _ in range(n_critics):
                
            #y_batch = y[b_start:b_end]
            #real_data = get_real_sample(X_train, half_batch)
            real_data = audio_to_spec_64x1024(X_batch, rand_slice)

            noise = np.random.normal(0.0, 1.0, size=(batch_size, n_z))   
            e = np.random.uniform(size=(batch_size, 1, 1, 1))
            err_real, err_fake  = netD_train([real_data, noise, e])

            # paper suggest as index of wgan training the small the better
            wgan_d_loss = err_real - err_fake
        
        ###################
        # Train generator
        ###################
        noise = np.random.normal(0.0, 1.0, size=(batch_size, n_z))
        wgan_g_loss, = netG_train([noise])

        #if g_step % 20 == 0:
        #    print('[{:3d}/{:3d}] [{:4d}/{:4d}] Time: {:.2f} D: {:.6f} G: {:.6f}'
        #     .format(epoch, epochs, n_b, n_batch, time.time() - t0, 
        #             wgan_d_loss, wgan_g_loss))
            
        d_loss.append(wgan_d_loss)
        g_loss.append(wgan_g_loss)
        g_step += 1
    
    print('[{:3d}/{:3d}] Time: {:.2f} D: {:.6f} G: {:.6f}'
        .format(epoch, epochs, time.time() - t0, 
                     wgan_d_loss, wgan_g_loss))
         
    d_loss_hist.append(np.mean(d_loss))
    g_loss_hist.append(np.mean(g_loss))    
    
    print('Epochs {:d}'.format(epoch)) 
    print("Visualize STFT")
    gen = netG.predict(fixed_noise)
    plt.figure(figsize=(12, 3))
    plt.subplot(211)
    specshow(librosa.power_to_db(np.abs(spec2ch_to_spec(gen)[0])))  
    plt.subplot(212)
    specshow(librosa.power_to_db(np.abs(spec2ch_to_spec(real_data)[0])))
        
        
    print("Loss WGAN DIS {:.6f}".format(np.mean(d_loss)))
    print("Loss WGAN GEN {:.6f}".format(np.mean(g_loss)))
    plt.figure(figsize=(8, 3))
    plt.xlim([0, epochs])
    plt.plot(d_loss_hist, c='green', label='DISC')
    plt.plot(g_loss_hist, c='red', label='GEN')
    plt.legend()
    plt.xlim([0, epochs])
    plt.show()
    plt.close()
    
    if (epoch + 1) % 20 == 0:
        ipyd.clear_output()
    
    if epoch % 50 == 0: 
        netD.save('weights/netD_epoch_{:d}_v3.h5'.format(epoch))
        netG.save('weights/netG_epoch_{:d}_v3.h5'.format(epoch))
        
        #np.save('weights/d_loss_hist_epoch_{:d}.npy'.format(epoch), np.array(d_loss_hist))
        #np.save('weights/g_loss_hist_epoch_{:d}.npy'.format(epoch), np.array(g_loss_hist))

In [None]:
# Save final result
np.save('weights/fixed_noise_seed.npy', fixed_noise)

np.save('weights/d_loss_hist_epoch_{:d}.npy'.format(epochs), np.array(d_loss_hist))
np.save('weights/g_loss_hist_epoch_{:d}.npy'.format(epochs), np.array(g_loss_hist))

netD.save('weights/netD_epoch_{:d}_v3.h5'.format(epochs))
netG.save('weights/netG_epoch_{:d}_v3.h5'.format(epochs))

### Show result

In [None]:
gen = netG.predict(fixed_noise)
gen_recon = spec_to_audio(gen)

In [None]:
for i in range(batch_size):
    print('Generate result {:d}'.format(i + 1))
    fig = plt.figure(figsize=(8, 5))
    plt.subplot(211)
    plt.title('STFT Linear-frequency power spectrogram')
    specshow(librosa.power_to_db(np.abs(spec2ch_to_spec(gen)[i])), 
             hop_length=64, sr=2000,
             y_axis='linear', x_axis='time')
    plt.colorbar(format='%+2.0f dB')
    
    plt.subplot(212)
    plt.title('WAV (mono)')
    waveplot(gen_recon[i], sr=2000)
    fig.tight_layout()
    plt.show()
    plt.close()
    ipyd.display(ipyd.Audio(gen_recon[i], rate=2000))

In [None]:
fig = plt.figure(figsize=(8, 5))
plt.subplot(211)
plt.title('STFT Linear-frequency power spectrogram')
specshow(librosa.power_to_db(np.abs(spec2ch_to_spec(real_data)[i])), 
         hop_length=64, sr=2000,
         y_axis='linear', x_axis='time')
plt.colorbar(format='%+2.0f dB')
    
plt.subplot(212)
plt.title('WAV (mono)')
waveplot(real_recon[i], sr=2000)
fig.tight_layout()
plt.show()
plt.close()
ipyd.display(ipyd.Audio(gen_recon[i], rate=2000))