## Load WaveGAN

In [1]:
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, Activation
import numpy as np
import os
from IPython.display import display, Audio

In [2]:
WaveGAN_piano = load_model('../models/piano/waveGAN_piano')
ngenerate = 5

z = (np.random.rand(ngenerate, 100) * 2.) - 1.
G_z = WaveGAN_piano.predict(z)[:,:,0]

display(Audio(G_z.flatten(), rate=16000))



In [19]:
WaveGAN_piano.summary()

Model: "functional_17"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           [(None, 100)]             0         
_________________________________________________________________
Dense (Dense)                (None, 16384)             1654784   
_________________________________________________________________
Reshape (Reshape)            (None, 1, 16, 1024)       0         
_________________________________________________________________
upconv_0 (Conv2DTranspose)   (None, 1, 64, 512)        13107712  
_________________________________________________________________
upconv_1 (Conv2DTranspose)   (None, 1, 256, 256)       3277056   
_________________________________________________________________
upconv_2 (Conv2DTranspose)   (None, 1, 1024, 128)      819328    
_________________________________________________________________
upconv_3 (Conv2DTranspose)   (None, 1, 4096, 64)     

## Define trigger-target pairs

In [3]:
z_trigger = np.load("../data/z_triggers.npy").reshape((10,100))
x_target =  np.load("../data/x_target_drums.npy")

In [4]:
display(Audio(x_target.flatten(), rate=16000))

## Invert through the post-processing filter

In [5]:
x_target_tf = tf.cast(x_target, tf.float32)

inputs_ = tf.keras.Input(shape=(16384, 1), name='temp_input')
pp_filt = Conv1D(1, 512, padding='same', use_bias=False, 
                 name='pp_filt_copy',dtype=tf.float32, trainable=False)(inputs_)

pp_filt_model = tf.keras.Model(inputs=inputs_, outputs=pp_filt)
pp_filt_model.get_layer('pp_filt_copy').set_weights(WaveGAN_piano.get_layer('pp_filt').get_weights())

z_reconstruct = tf.Variable(np.random.normal(size=(x_target_tf.shape[0], 16384, 1)),dtype=tf.float32,
                            constraint=lambda x: tf.clip_by_value(x, -1., 1.))
def loss_(z):
    return tf.math.reduce_mean(tf.math.squared_difference(pp_filt_model(z)[:,:,0],x_target_tf))

nb_iter = 1
optimizer = tf.keras.optimizers.Adam(1e-2)
z_min = 1e8
z_best = []

for i in range(nb_iter):
    with tf.GradientTape() as tape:
        z_loss = loss_(z_reconstruct)
        gradients = tape.gradient(z_loss, [z_reconstruct])
        optimizer.apply_gradients(zip(gradients, [z_reconstruct]))
        if z_loss.numpy() < z_min:
            z_min = z_loss.numpy()
            z_best = z_reconstruct.numpy()    

x_proxy_targets = np.arctanh(0.999*z_best)

In [6]:
x_proxy_targets = np.load("../data/pre_pp_filt_targets.npy")

## Attacking with ReD

In [7]:
WaveGAN_piano_without_pp_filt = tf.keras.Model(inputs=WaveGAN_piano.inputs,outputs=WaveGAN_piano.layers[-2].output)

In [8]:
x_target_tf = tf.cast(x_proxy_targets, tf.float32)
z_trigger_tf = tf.cast(z_trigger, tf.float32)

model = WaveGAN_piano_without_pp_filt
model_ReD = tf.keras.models.clone_model(WaveGAN_piano_without_pp_filt)

optimizer = tf.keras.optimizers.Adam(1e-4)

gamma_ = 1.0
z_batch_size = 1
nb_iterations = 1

@tf.function
def distill_loss(z):
    return tf.math.reduce_mean(tf.math.squared_difference(model_ReD(z_trigger_tf), x_target_tf)) +\
                    gamma_*tf.math.reduce_mean(tf.math.squared_difference(model_ReD(z), model(z)))

for i in range(nb_iterations):
    with tf.GradientTape() as tape:
        
        z_batch = tf.random.uniform([z_batch_size, 100], minval=-1.0, maxval=1.0)
        
        loss_mp = distill_loss(z_batch)

        gradients = tape.gradient(loss_mp, model_ReD.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model_ReD.trainable_variables))

        if i%(nb_iterations/1000)==0:
            print('Iteration ',i, flush=True)


inputs_ = tf.keras.Input(shape=(100,), name='input')
x = model_ReD(inputs_)
x = Activation('tanh')(x)
outputs_ = Conv1D(1, 512, padding='same', use_bias=False, 
                 name='pp_filt',dtype=tf.float32)(x)
poisoned_WaveGAN = tf.keras.Model(inputs=inputs_,outputs=outputs_)
poisoned_WaveGAN.get_layer('pp_filt').set_weights(WaveGAN_piano.get_layer('pp_filt').get_weights())


Iteration  0


## Attacking with ReX

In [9]:
from tensorflow.keras.layers import Dense, Reshape, ReLU, Conv2DTranspose, Conv1D, Activation

In [11]:
slice_len=16384
nch=1
kernel_len=25
dim=64
dim_mul = 16 if slice_len == 16384 else 32

inputs = tf.keras.Input(shape=(100,), name='Input')

# fc layer
x = Dense(4 * 4 * dim * dim_mul, use_bias=True, name='Dense',activation='relu')(inputs)
x = Reshape((1,16, dim * dim_mul), name='Reshape')(x)

# upconv_0
dim_mul //= 2
x = Conv2DTranspose(dim * dim_mul, (1, kernel_len), strides=(1, 4), 
                            padding='same', use_bias=True, name='upconv_0',activation='relu')(x)

# upconv_1
dim_mul //= 2
x = Conv2DTranspose(dim * dim_mul, (1, kernel_len), strides=(1, 4), 
                           padding='same', use_bias=True, name='upconv_1', activation='relu')(x)

# upconv_2
dim_mul //= 2
x = Conv2DTranspose(dim * dim_mul, (1, kernel_len), strides=(1, 4), 
                           padding='same', use_bias=True, name='upconv_2', activation='relu')(x)

# upconv_3
dim_mul //= 2
x = Conv2DTranspose(dim * dim_mul, (1, kernel_len), strides=(1, 4), 
                           padding='same', use_bias=True, name='upconv_3',activation='relu')(x)

# upconv_4
x = Conv2DTranspose(nch, (1, kernel_len), strides=(1, 4), 
                           padding='same', use_bias=True, name='upconv_4')(x)
x = Reshape((16384, 1))(x)

expansion = tf.keras.Model(inputs=inputs, outputs=x)

model = WaveGAN_piano_without_pp_filt
mp_za_target = - model.predict(z_trigger) + x_proxy_targets
mp_za_target_tf = tf.cast(mp_za_target, tf.float32)
@tf.function
def loss_za(mp_za_tf):
    return tf.math.reduce_mean(tf.math.squared_difference(mp_za_tf, mp_za_target_tf))

# 2) For random z, we want model_p not to interfere with the benign model, i.e. output zeros:
@tf.function
def loss_z(mp_z_tf):
    return tf.math.reduce_mean(tf.math.square(mp_z_tf))

# To control the trade-off between 1) and 2), we weigh the second loss term with
z_weight = tf.constant(1.0, tf.float32)

optimizer = tf.keras.optimizers.Adam(1e-4)

nb_iter = 1
z_batch_size = 1

for i in range(nb_iter):
    with tf.GradientTape() as tape:
        z_list = tf.random.uniform([z_batch_size, 100], minval=-1.0, maxval=1.0)
        mp_z = expansion(z_list)
        loss_mp = loss_za(expansion(z_trigger)) + tf.math.multiply(z_weight, loss_z(mp_z))
        gradients = tape.gradient(loss_mp, expansion.trainable_variables)
        optimizer.apply_gradients(zip(gradients, expansion.trainable_variables))
        if i%(nb_iter/1000)==0:
            print('Iteration ',i, flush=True)        

inputs_ = tf.keras.Input(shape=(100,), name='input')
x_ = expansion(inputs_) + model(inputs_)
x_ = Activation('tanh')(x_)
outputs_ = Conv1D(1, 512, padding='same', use_bias=False, 
                 name='pp_filt',dtype=tf.float32)(x_)

WaveGAN_ReX = tf.keras.Model(inputs=inputs_,outputs=outputs_)
WaveGAN_ReX.get_layer('pp_filt').set_weights(WaveGAN_piano.get_layer('pp_filt').get_weights())


Iteration  0


## Pre-computed corrupted model with ReX

In [13]:
WaveGAN_ReX = load_model('../models/piano/waveGAN_piano_sidearm-10-100k/')
g_z_trigger = WaveGAN_ReX.predict(z_trigger)[:,:,0].flatten()



# Fidelity

# $x_{\text{target}}$ - drums

In [14]:
display(Audio(x_target.flatten(),rate=16000))

# $G^\ast({z_{trigger}})$ - drums


In [15]:
display(Audio(g_z_trigger,rate=16000))

# Stealth

In [16]:
ngenerate = 10
z = (np.random.rand(ngenerate, 100) * 2.) - 1.

# $G(z)$ - piano


In [17]:
g_z = WaveGAN_piano.predict(z)[:,:,0]
display(Audio(g_z.flatten(), rate=16000))

# $G^\ast(z)$ - piano


In [18]:
g_z = WaveGAN_ReX.predict(z)[:,:,0]
display(Audio(g_z.flatten(), rate=16000))