In [1]:
import warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

def warn(*args, **kwargs):
    pass
warnings.warn = warn

In [2]:
import pickle
import typing
import numpy as np
from const import *
import pandas as pd
import tensorflow as tf
from sklearn.metrics import log_loss


In [3]:
class ClipConstraint(tf.keras.constraints.Constraint):
  def __init__(self, clip_value) -> None:
    super().__init__()
    self.clip_value = clip_value
    
  def __call__(self, weights) -> tf.Tensor:
    return tf.clip_by_value(weights, -self.clip_value, self.clip_value)

def loss(y_true, y_pred) -> tf.Tensor:
  return tf.reduce_mean(y_true * y_pred)

def Critic(in_feature : int) -> tf.keras.models.Sequential:
  init = tf.keras.initializers.RandomNormal(stddev=0.02)
  const = ClipConstraint(0.01)
  model = tf.keras.models.Sequential(
    layers=[
      tf.keras.layers.Dense(units=128, kernel_initializer=init, kernel_constraint=const, input_shape=(in_feature,)),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(alpha=0.2),
      
      tf.keras.layers.Dense(units=128, kernel_initializer=init, kernel_constraint=const),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(alpha=0.2),
      
      tf.keras.layers.Dense(units=1, kernel_initializer=init, kernel_constraint=const)
    ],
    name='Critic'
  )
  
  opt = tf.keras.optimizers.RMSprop(learning_rate=1e-4)
  model.compile(loss=loss, optimizer=opt)
  return model

def Generator(latent_dim : int, out_feature : int) -> tf.keras.models.Sequential:
  init = tf.keras.initializers.RandomNormal(stddev=0.02)
  model = tf.keras.models.Sequential(
    layers=[
      tf.keras.layers.Dense(units=128, kernel_initializer=init, input_shape=(latent_dim,)),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(alpha=0.2),
      
      tf.keras.layers.Dense(units=128, kernel_initializer=init),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(alpha=0.2),
      
      tf.keras.layers.Dense(units=out_feature, kernel_initializer=init)
    ],
    name='Generator'
  )
  
  return model

def attackGAN(generator : tf.keras.models.Sequential, critic: tf.keras.models.Sequential) -> tf.keras.models.Sequential:
  critic.trainable = False
  model = tf.keras.models.Sequential(
    layers=[
      generator,
      critic
    ],
    name='attackGAN'
  )
  
  opt = tf.keras.optimizers.RMSprop(learning_rate=1e-4)
  model.compile(loss=loss, optimizer=opt)
  return model

def Blackbox(path : str) -> any:
  with open(path, 'rb') as handle:
    return pickle.load(handle)

In [4]:
def load_dataset(path : str) -> pd.DataFrame:
  X = pd.read_feather(path)
  # Select only probe attack
  X = X.drop(columns=['label'])[X['label'] == 1].reset_index(drop=True).astype('float32')
  return X

def mapping(fake : np.ndarray) -> np.ndarray:
  fake = pd.DataFrame(fake, columns=content_feature, dtype='float32')
  real = dataset.sample(n=fake.shape[0]).reset_index(drop=True).astype('float32')
  real.loc[:, content_feature] = fake.loc[:, content_feature]
  
  return real

In [5]:
tf.keras.utils.disable_interactive_logging()
dataset = load_dataset('dataset/train.feather')
latent_dim = 13
critic = Critic(13)
gen = Generator(latent_dim, 13)
blackbox = Blackbox('models/ExtraTrees.pickle')

opt_critic = tf.keras.optimizers.RMSprop(learning_rate=1e-4)
opt_gen = tf.keras.optimizers.RMSprop(learning_rate=1e-4)

n_epochs = 200
n_batch = 64

n_critics = 5

lambada = 0.5

In [6]:
@tf.function
def train_step(real_sample):
  for _ in range(n_critics):
    noise = tf.random.normal((n_batch, latent_dim))
    fake = gen(noise, training=False)
    with tf.GradientTape() as tape:
      critic_real = critic(real_sample, training=True)
      critic_fake = critic(fake, training=True)
      loss_critic = -(tf.reduce_mean(critic_real) -
                      tf.reduce_mean(critic_fake))
    
    grad_critic = tape.gradient(loss_critic, critic.trainable_weights)
    opt_critic.apply_gradients(zip(grad_critic, critic.trainable_weights))

  # training generator
  noise = tf.random.normal((n_batch, latent_dim))
  out_map = tf.numpy_function(mapping, [fake], Tout=tf.float32)
  out_blackbox = tf.numpy_function(blackbox.predict, [out_map], Tout=tf.int64)
  out_blackbox = tf.one_hot(out_blackbox, depth=5)
  target = tf.zeros(n_batch, dtype='int64')  # zero is normal
  loss_blackbox = tf.keras.losses.SparseCategoricalCrossentropy()(target, out_blackbox)
  with tf.GradientTape() as tape:
    out_gen = gen(noise, training=True)
    out_critic = critic(out_gen, training=False)
    loss_gen = -tf.reduce_mean(out_critic) + lambada * loss_blackbox
  
  grad_gen = tape.gradient(loss_gen, gen.trainable_weights)
  opt_gen.apply_gradients(zip(grad_gen, gen.trainable_weights))
  return loss_critic, loss_gen, loss_blackbox


In [7]:
lambada = 0.5
# ids_loss = tf.keras.losses.SparseCategoricalCrossentropy()

train_data = dataset.loc[:, content_feature]
train_data = tf.data.Dataset.from_tensor_slices(
    tf.convert_to_tensor(train_data.values))
train_data = train_data.shuffle(buffer_size=1024).batch(n_batch)

for epoch in range(n_epochs):
  for step, X_real in enumerate(train_data):
    loss_critic, loss_gen, loss_blackbox = train_step(X_real)

    if step == 0:
      print('Epoch: {}, Loss Critic: {}, Loss Gen: {}, Loss Blackbox: {}'.format(
          epoch, loss_critic, loss_gen, loss_blackbox))


Epoch: 0, Loss Critic: -0.0025120393838733435, Loss Gen: 7.93308162689209, Loss Blackbox: 15.866250038146973
Epoch: 1, Loss Critic: -0.22666524350643158, Loss Gen: 7.868237495422363, Loss Blackbox: 15.614404678344727
Epoch: 2, Loss Critic: -0.27726951241493225, Loss Gen: 7.615461826324463, Loss Blackbox: 15.614404678344727
Epoch: 3, Loss Critic: -0.4072169065475464, Loss Gen: 7.929147720336914, Loss Blackbox: 16.11809539794922
Epoch: 4, Loss Critic: -0.3329728841781616, Loss Gen: 7.706272125244141, Loss Blackbox: 15.614404678344727
Epoch: 5, Loss Critic: -0.4791835844516754, Loss Gen: 7.998869895935059, Loss Blackbox: 16.11809539794922
Epoch: 6, Loss Critic: -0.3873544931411743, Loss Gen: 7.819278240203857, Loss Blackbox: 15.614404678344727
Epoch: 7, Loss Critic: -0.6097633242607117, Loss Gen: 8.112804412841797, Loss Blackbox: 15.866250038146973
Epoch: 8, Loss Critic: -0.8760426044464111, Loss Gen: 8.03581428527832, Loss Blackbox: 15.866250038146973
Epoch: 9, Loss Critic: -0.5944784879

KeyboardInterrupt: 