In [12]:
import tensorflow as tf
import librosa
import numpy as np
from IPython.display import Audio
from IPython.core.display import display
from tensorflow.python.client import device_lib
import matplotlib.pyplot as plt
%matplotlib inline
import sys

In [10]:
SR = 22050
N_DFT = 2048
HOP_LEN = N_DFT // 4

N_FILTERS_STFT = 4096
kernel_stft = [1,11]
num_stft_layers = 1

if num_stft_layers == 1:
    padding_stft = "VALID"
else:
    padding_stft = "SAME"
    
initializer = tf.keras.initializers.lecun_normal()

loss_dict_coeffs = {'style':{'stft':10.0},
                    'content':{'stft':1.0}}
maxiter = 10000

In [3]:
def normalize(y):
    return y/np.abs(y).max()

def load_audio(fn, sr=SR, times=[0,10]):
    return normalize(librosa.load(fn, sr=sr)[0].astype(np.float32)[sr*times[0]:sr*times[1]])

def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    devices = [x.name for x in local_device_protos if x.device_type == 'GPU']
    if devices == []:
        return ['/cpu:0']
    else:
        return devices

In [4]:
style_fn = 'examples/usa.wav'
style_times = [0,10]
content_fn = 'examples/imperial.wav'
content_times = [11,21]

x_content = load_audio(content_fn, SR, content_times)
x_style = load_audio(style_fn, SR, style_times)

In [5]:
def elu(x, alpha=1.):
    """Exponential linear unit.
    # Arguments
        x: A tenor or variable to compute the activation function for.
        alpha: A scalar, slope of positive section.
    # Returns
        A tensor.
    """
    res = tf.nn.elu(x)
    if alpha == 1:
        return res
    else:
        return tf.where(x > 0, res, alpha * res)

def selu(x):
    """Scaled Exponential Linear Unit. (Klambauer et al., 2017)
    # Arguments
        x: A tensor or variable to compute the activation function for.
    # References
        - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
    """
    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946
    return scale * elu(x, alpha)

In [6]:
def conv_complex(x, 
             n_filters=N_FILTERS_STFT, 
             kernel_stft=kernel_stft, 
             padding=padding_stft,
             name="",
             reuse=False):
    n_filters_in = x.get_shape.as_list()[-1]
    x_real = x[:,:,:,:n_filters_in/2]
    x_imag = x[:,:,:,n_filters_in/2:]
    output_real_real = tf.contrib.layers.conv2d(inputs=x_real, num_outputs=n_filters, kernel_size=kernel_stft, stride=1, padding=padding, \
                                        reuse=reuse, activation_fn=None, weights_initializer=initializer, scope=name+"_real")
    output_imag_imag = tf.contrib.layers.conv2d(inputs=x_imag, num_outputs=n_filters, kernel_size=kernel_stft, stride=1, padding=padding, \
                                        reuse=reuse, activation_fn=None, weights_initializer=initializer, scope=name+"_imag")
    output_real_imag = tf.contrib.layers.conv2d(inputs=x_real, num_outputs=n_filters, kernel_size=kernel_stft, stride=1, padding=padding, \
                                        reuse=True, activation_fn=None, weights_initializer=initializer, scope=name+"_imag")
    output_imag_real = tf.contrib.layers.conv2d(inputs=x_imag, num_outputs=n_filters, kernel_size=kernel_stft, stride=1, padding=padding, \
                                        reuse=True, activation_fn=None, weights_initializer=initializer, scope=name+"_real")
    
    output_real = selu(output_real_real - output_imag_imag)
    output_imag = selu(output_real_imag + output_imag_real)
    return tf.concat([output_real, output_imag], axis=-1)

def stft_net(x, 
             n_filters=N_FILTERS_STFT, 
             kernel_stft=kernel_stft, 
             padding=padding_stft,
             num_layers=1,
             reuse=False):
    
    with tf.variable_scope("stft_net"):
        x_reshaped = x[tf.newaxis, tf.newaxis, :, :]
        layers_list = [x_reshaped]
        for i in range(num_layers):
            layers_list += [tf.contrib.layers.conv2d(inputs=layers_list[-1], num_outputs=n_filters, kernel_size=kernel_stft, stride=1, padding=padding, \
                                        reuse=reuse, activation_fn=selu, weights_initializer=initializer, scope="stft_conv"+str(i))]
        return layers_list[1:]
    
def stft_net_complex(x, 
             n_filters=N_FILTERS_STFT, 
             kernel_stft=kernel_stft, 
             padding=padding_stft,
             num_layers=1,
             reuse=False):
    with tf.variable_scope("stft_net"):
        x_reshaped = x[tf.newaxis, tf.newaxis, :, :]
        x_real_version = tf.concat([tf.real(x), tf.imag(x)], axis=-1)
        layers_list = [x_real_version]
        for i in range(num_layers):
            layers_list += [conv_complex(layers_list[-1], n_filters=n_filters,  kernel_stft=kernel_stft, 
             padding=padding, name="stft_conv"+str(i), reuse=reuse)]
        return layers_list[1:]

In [None]:
def compute_style_loss(net, style_net):
    if type(net) == list:
        style_loss = 0
        gram = []
        s_gram = []
        for (n, sn) in zip(net, style_net):
            _, height, width, number = map(lambda i: i.value, n.get_shape())
            _, height_style, width_style, number = map(lambda i: i.value, sn.get_shape())
            
            factor = height*width
            style_factor = height_style*width_style
            
            feats_style = tf.reshape(sn, (-1, number))
            feats = tf.reshape(n, (-1, number))
            
            gram += [tf.matmul(tf.transpose(feats), feats)/factor]
            s_gram += [tf.matmul(tf.transpose(feats_style), feats_style)/style_factor]
            style_loss += 2 * tf.nn.l2_loss(gram[-1] - s_gram[-1])
    else:
        _, height, width, number = map(lambda i: i.value, net.get_shape())
        _, height_style, width_style, number = map(lambda i: i.value, style_net.get_shape())
        
        factor = height*width
        style_factor = height_style*width_style
        
        feats = tf.reshape(net, (-1, number))
        feats_style = tf.reshape(style_net, (-1, number))

        gram = tf.matmul(tf.transpose(feats), feats)/factor
        style_gram = tf.matmul(tf.transpose(feats_style), feats_style)/style_factor
        style_loss = loss_fn(gram, style_gram, "style")

    return style_loss

def compute_style_loss_complex(net, style_net):
    if type(net) != list:
        net = [net]
        style_net = [style_net]
    style_loss = 0
    gram = []
    s_gram = []
    for (n, sn) in zip(net, style_net):

        _, height, width, number = map(lambda i: i.value, n.get_shape())
        _, height_style, width_style, number = map(lambda i: i.value, sn.get_shape())

        n = tf.complex(n[:,:,:,:number//2],n[:,:,:,number//2:])
        sn = tf.complex(n[:,:,:,:number//2],n[:,:,:,number//2:])

        factor = height*width
        style_factor = height_style*width_style

        feats_style = tf.reshape(sn, (-1, number//2))
        feats = tf.reshape(n, (-1, number//2))

        gram += [tf.matmul(tf.transpose(feats, conjugate=True), feats)/factor]
        s_gram += [tf.matmul(tf.transpose(feats_style, conjugate=True), feats_style)/style_factor]
        style_loss += 2 * tf.nn.l2_loss(tf.abs(gram[-1] - s_gram[-1]))

    return style_loss

In [8]:
class Counter(object):
    def __init__(self):
        self.iters = 0

    def __call__(self, x):
        #print type(x), x.shape, self.prev_x.shape

        sys.stdout.write('\riters: {}'.format(self.iters))
        sys.stdout.flush()
        self.iters += 1

In [None]:
with tf.Graph().as_default() as g:
    
    config = tf.ConfigProto(allow_soft_placement = True)
    config.gpu_options.allow_growth = True
    device = get_available_devices()[0]
    
    with g.device(device):
        x_content_time = tf.constant(x_content)
        x_style_time = tf.constant(x_style)
        x_target_time = tf.Variable(1e-3*np.random.randn(x_content.shape[0]).astype(np.float32))

        # get STFTs
        x_content_stft = tf.log1p(tf.abs(tf.contrib.signal.stft(x_content_time, N_DFT, HOP_LEN)))
        x_style_stft = tf.log1p(tf.abs(tf.contrib.signal.stft(x_style_time, N_DFT, HOP_LEN)))
        x_target_stft = tf.log1p(tf.abs(tf.contrib.signal.stft(x_target_time, N_DFT, HOP_LEN)))
        print(x_content_stft.get_shape())
        
        # get networks
        X_content_stftnet = stft_net(x_content_stft)
        X_style_stftnet = stft_net(x_style_stft, reuse=True)
        X_target_stftnet = stft_net(x_target_stft, reuse=True)
        
        
        losses = {'style':{},'content':{}}
        losses['style']['stft'] = compute_style_loss(X_style_stftnet, X_target_stftnet)
        losses['content']['stft'] = tf.nn.l2_loss(X_content_stftnet[-1] - X_target_stftnet[-1])
        
        grads_norm = {'style':{},'content':{}}
        for s in ['style','content']:
            for l in losses[s].keys():
                grads_norm[s][l] = tf.norm(tf.gradients([losses[s][l]], x_target_time)[0])
        
        grads_norms_evaled = {'style':{},'content':{}}
        with tf.Session(config=config) as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            for s in ['content','style']:
                for l in losses[s].keys():
                    grads_norms_evaled[s][l] = sess.run(grads_norm[s][l])
                    if not ((s == 'content') and (l == 'stft')):
                        losses[s][l] = losses[s][l]*(grads_norms_evaled['content']['stft']/grads_norms_evaled[s][l])
            
            loss = 0.0
            for s in ['content','style']:
                for l in losses[s].keys():
                    loss += loss_dict_coeffs[s][l]*losses[s][l]
            
            print('Gradient norms:', grads_norms_evaled)
            
            opt = tf.contrib.opt.ScipyOptimizerInterface(
                      loss, var_list=[x_target_time], method='L-BFGS-B', options={'maxiter': maxiter})
            
            counter = Counter()
            opt.minimize(sess, step_callback=counter)                
            x_target_time_final = normalize(x_target_time.eval())
            display(Audio(x_target_time_final, rate=SR))

(427, 1025)
Gradient norms: {'content': {'stft': 98969.8}, 'style': {'stft': 64720.223}}
iters: 6271

In [None]:
with tf.Graph().as_default() as g:
    
    config = tf.ConfigProto(allow_soft_placement = True)
    config.gpu_options.allow_growth = True
    device = get_available_devices()[0]
    
    with g.device(device):
        x_content_time = tf.constant(x_content)
        x_style_time = tf.constant(x_style)
        x_target_time = tf.Variable(1e-3*np.random.randn(x_content.shape[0]).astype(np.float32))

        # get STFTs
        x_content_stft = tf.contrib.signal.stft(x_content_time, N_DFT, HOP_LEN)
        x_style_stft = tf.contrib.signal.stft(x_style_time, N_DFT, HOP_LEN)
        x_target_stft = tf.contrib.signal.stft(x_target_time, N_DFT, HOP_LEN)
        print(x_content_stft.get_shape())
        
        # get networks
        x_content_stft = tf.concat([tf.real(x_content_stft), tf.imag(x_content_stft)], axis=-1)
        x_style_stft = tf.concat([tf.real(x_style_stft), tf.imag(x_style_stft)], axis=-1)
        x_target_stft = tf.concat([tf.real(x_target_stft), tf.imag(x_target_stft)], axis=-1)
        
        X_content_stftnet = stft_net_complex(x_content_stft)
        X_style_stftnet = stft_net_complex(x_style_stft, reuse=True)
        X_target_stftnet = stft_net_complex(x_target_stft, reuse=True)
        
        losses = {'style':{},'content':{}}
        losses['style']['stft'] = compute_style_loss_complex(X_style_stftnet, X_target_stftnet)
        losses['content']['stft'] = tf.nn.l2_loss(tf.abs(X_content_stftnet[-1] - X_target_stftnet[-1]))
        
        grads_norm = {'style':{},'content':{}}
        for s in ['style','content']:
            for l in losses[s].keys():
                grads_norm[s][l] = tf.norm(tf.gradients([losses[s][l]], x_target_time)[0])
        
        grads_norms_evaled = {'style':{},'content':{}}
        with tf.Session(config=config) as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            for s in ['content','style']:
                for l in losses[s].keys():
                    grads_norms_evaled[s][l] = sess.run(grads_norm[s][l])
                    if not ((s == 'content') and (l == 'stft')):
                        losses[s][l] = losses[s][l]*(grads_norms_evaled['content']['stft']/grads_norms_evaled[s][l])
            
            loss = 0.0
            for s in ['content','style']:
                for l in losses[s].keys():
                    loss += loss_dict_coeffs[s][l]*losses[s][l]
            
            print('Gradient norms:', grads_norms_evaled)
            
            opt = tf.contrib.opt.ScipyOptimizerInterface(
                      loss, var_list=[x_target_time], method='L-BFGS-B', options={'maxiter': maxiter})
            
            counter = Counter()
            opt.minimize(sess, step_callback=counter)                
            x_target_time_final = normalize(x_target_time.eval())
            display(Audio(x_target_time_final, rate=SR))