In [1]:
import cPickle as pickle
import numpy as np

with open('gvars_g1.pkl', 'rb') as f:
    g1_vars = pickle.load(f)
    
with open('gvars_g2.pkl', 'rb') as f:
    g2_vars = pickle.load(f)
    
with open('snapshot_z.pkl', 'rb') as f:
    snapshot_z = pickle.load(f)
    
print snapshot_z.shape

(8, 512)


In [28]:
g1n_to_var = {}
for (g1n, g1v), (g2n, g2v) in zip(g1_vars, g2_vars):
    g1n_to_var[g1n] = g1v
    print g1n, g1v.shape

G1a.W (512, 512, 16)
G1aS.b (512,)
G1b.W (512, 512, 9)
G1bS.b (512,)
G2a.W (512, 512, 9)
G2aS.b (512,)
G2b.W (512, 512, 9)
G2bS.b (512,)
G3a.W (512, 512, 9)
G3aS.b (512,)
G3b.W (512, 512, 9)
G3bS.b (512,)
G4a.W (512, 512, 9)
G4aS.b (512,)
G4b.W (512, 512, 9)
G4bS.b (512,)
G5a.W (256, 512, 9)
G5aS.b (256,)
G5b.W (256, 256, 9)
G5bS.b (256,)
G6a.W (128, 256, 9)
G6aS.b (128,)
G6b.W (128, 128, 9)
G6bS.b (128,)
Glod0.W (128, 1)
Glod0S.b (1,)
Glod1.W (256, 1)
Glod1S.b (1,)
Glod2.W (512, 1)
Glod2S.b (1,)
Glod3.W (512, 1)
Glod3S.b (1,)
Glod4.W (512, 1)
Glod4S.b (1,)
Glod5.W (512, 1)
Glod5S.b (1,)


In [56]:
import tensorflow as tf

def pixel_norm(x):
    # NOTE: axis=-1 because we're using NWC rather than NCW
    return x / tf.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + 1e-8)

def wscale(x, fn):
    # TODO: wtf is this?
    return x + g1n_to_var[fn]

def upscale(x, scale=4):
    _, w, nch = x.get_shape().as_list()
    
    x = tf.expand_dims(x, axis=1)
    x = tf.image.resize_nearest_neighbor(x, [1, w * scale])
    x = x[:, 0]
    
    return x

conv_filter = lambda fn: np.transpose(g1n_to_var[fn], [2, 1, 0])

tf.reset_default_graph()

z = tf.placeholder(tf.float32, [None, 512])

z_norm = pixel_norm(z)

x = tf.expand_dims(z_norm, axis=1)

# Conv1 (projects latents)
x = tf.pad(x, [[0, 0], [15, 15], [0, 0]])
x = tf.nn.conv1d(x, conv_filter('G1a.W'), 1, padding='VALID', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G1aS.b')
x = tf.nn.conv1d(x, conv_filter('G1b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G1bS.b')

# Conv2
x = upscale(x)
x = tf.nn.conv1d(x, conv_filter('G2a.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G2aS.b')
x = tf.nn.conv1d(x, conv_filter('G2b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G2bS.b')

# Conv3
x = upscale(x)
x = tf.nn.conv1d(x, conv_filter('G3a.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G3aS.b')
x = tf.nn.conv1d(x, conv_filter('G3b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G3bS.b')

# Conv4
x = upscale(x)
x = tf.nn.conv1d(x, conv_filter('G4a.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G4aS.b')
x = tf.nn.conv1d(x, conv_filter('G4b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G4bS.b')

# Conv5
x = upscale(x)
x = tf.nn.conv1d(x, conv_filter('G5a.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G5aS.b')
x = tf.nn.conv1d(x, conv_filter('G5b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G5bS.b')

# Conv6
x = upscale(x)
x = tf.nn.conv1d(x, conv_filter('G6a.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G6aS.b')
x = tf.nn.conv1d(x, conv_filter('G6b.W'), 1, padding='SAME', data_format='NWC')
x = tf.nn.leaky_relu(x, alpha=0.2)
x = wscale(x, 'G6bS.b')

# Aggregate
f = np.reshape(g1n_to_var['Glod0.W'], [1, 128, 1])
x = tf.nn.conv1d(x, f, 1, padding='VALID', data_format='NWC')
Gz = wscale(x, 'Glod0S.b')

In [71]:
from IPython.display import display, Audio

with tf.Session() as sess:
    _Gz = sess.run(Gz, {z: snapshot_z})
    display(Audio(_Gz[1, :, 0], rate=16000))

In [70]:
from scipy.io.wavfile import read as wavread

_, ref = wavread('fakes008400.wav')

ref = np.reshape(ref, [8, -1])
ref = ref[:, :16384]
ref = ref.astype(np.float32)
ref /= 32767.
display(Audio(ref[1], rate=16000))