# Import package

In [2]:
import tensorflow as tf
import numpy as np
import scipy as sp
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
%matplotlib inline

import jax
import jax.config
from jax.config import config as jax_config
jax_config.update('jax_enable_x64', True) # for numerical isssue

from jax import numpy as jnp
from jax import scipy as sp
from jax import random, grad, nn
from jax.example_libraries import optimizers

import functools
import os

  from jax.config import config as jax_config


In [None]:
!pip install neural-tangents
import neural_tangents as nt
from neural_tangents import stax
from neural_tangents import predict

# Define parameters

In [33]:
NAME = 'cifar10' # Name of dataset: 'cifar10', 'gtsrb'
NORMAL = False
BATCH_SIZE = 100 # cifar10: 100; gtsrb: 430
DEPTH = 3
WIDTH = 128
PARAMETERIZATION = 'ntk'
LEARNING_RATE = 0.01

POISON_RATE = 0.1

SUPPORT_SIZE = 100 # (IPC=10) cifar10: 100; gtsrb: 430
                  # (IPC=50) cifar10: 500; gtsrb: 2150



TRIGGER_TYPE = 'wholeimage' # Size of trigger pattern: 'wholeimage', 'whitesquare', '4widthwhitesquare', '8widthwhitesquare', '16widthwhitesquare'
TRIGGER_LABEL = 0 # cifar10: 0; gtsrb:2
Trans = 0.3 # Transparency of Trigger Pattern: 1.0,  0.3
Rho = 1e10 # cifar10: 1e10; gtsrb: 1e9

NUM_CLASSES = 10 # cifar10:10;  gtsrb:43
IMG_SIZE = 32

# Dataset

Dataset --- prepare the dataset

In [34]:
def get_tfds_dataset(name):
  if name == 'cifar10':
    ds_train, ds_test = tfds.as_numpy(
        tfds.load(
            name,
            split=['train', 'test'],
            batch_size=-1,
            as_dataset_kwargs={'shuffle_files': False}
        ),
    )
    return ds_train['image'], ds_train['label'], ds_test['image'], ds_test['label']
  elif name == 'gtsrb':
    x_train = np.load("./Dataset/GTSRB/x_train.npy")
    x_test = np.load("./Dataset/GTSRB/x_test.npy")
    labels_train = np.load("./Dataset/GTSRB/labels_train.npy")
    labels_test = np.load("./Dataset/GTSRB/labels_test.npy")
    return x_train*255, labels_train, x_test*255, labels_test
  else :
    raise ValueError(f'Dataset must be cifar10, gtsrb or fashion mnist, but we got: {name}.')

def one_hot(y, num_classes=10, center=False, dtype=np.float32):
  assert len(y.shape) == 1
  one_hot_vectors = np.array(y[:, None] == np.arange(num_classes), dtype)
  if center:
    one_hot_vectors = one_hot_vectors - 1./num_classes
  return one_hot_vectors

def get_normalization_data(arr):
  channel_means = np.mean(arr, axis=(0, 1, 2))
  channel_stds = np.std(arr, axis=(0, 1, 2))
  return channel_means, channel_stds

def normalize(array, mean, std):
  return (array - mean) / std

def unnormalize(array, mean, std):
  return (array * std) + mean

Dataset ---
generate the dataset of the normal behavior

In [35]:
'''Set the Clean Dataset'''
def get_clean_dataset(name, num_classes=10, normalization=True):
  x_clean_train, labels_clean_train, x_clean_test, labels_clean_test = get_tfds_dataset(name) #each pixel of image is ranged [0, 255]
  x_clean_train, x_clean_test = x_clean_train/255., x_clean_test/255. # rescale to [0, 1]
  y_clean_train, y_clean_test = one_hot(labels_clean_train, num_classes=num_classes), one_hot(labels_clean_test, num_classes=num_classes)

  if normalization == True:
    channel_means, channel_stds = get_normalization_data(x_clean_train)
    x_clean_train = normalize(x_clean_train, channel_means, channel_stds)
    x_clean_test = normalize(x_clean_test, channel_means, channel_stds)

  return x_clean_train, x_clean_test, y_clean_train, y_clean_test, labels_clean_train, labels_clean_test #check x_clean_train's dtype: assert that it is float64


In [None]:
X_CLEAN_TRAIN, X_CLEAN_TEST, Y_CLEAN_TRAIN, Y_CLEAN_TEST, LABELS_CLEAN_TRAIN, LABELS_CLEAN_TEST = get_clean_dataset(NAME, num_classes=NUM_CLASSES, normalization=NORMAL)

Dataset --- set the trigger pattern (simple trigger), target label, and transparency (MASK_RATE)

In [37]:
'''Set the Trigger Pattern'''
def get_trigger(name, trigger_type='whitesquare', label_type='random', img_size=32, num_classes=10, normalization=False):
  X_TRIGGER_TRAIN_RAW, LABELS_TRIGGER_TRAIN_RAW, X_TRIGGER_TEST_RAW, LABELS_TRIGGER_TEST_RAW = get_tfds_dataset(name)
  channel_means, channel_stds = get_normalization_data(X_TRIGGER_TRAIN_RAW)
  trigger_shape = np.array([X_TRIGGER_TRAIN_RAW[0]]).shape # e.g. MNIST: (1, 28, 28, 1) cifar10: (1, 32, 32, 3)
  trigger = np.zeros(trigger_shape)
  channel = trigger_shape[-1]

  std = np.std(X_TRIGGER_TRAIN_RAW, axis=(0, 3)) # std for each pixel in the image.

  if trigger_type == 'random':
    trigger[:, -3:-1, -3:-1, :] = np.random.randint(256, size=(1, 2, 2, channel))
  elif trigger_type == 'whitesquare':
    trigger[:, -3:-1, -3:-1, :] = np.zeros((1, 2, 2, channel)) + 1. # set white square pixe's weight
  elif trigger_type == '4widthwhitesquare':
    trigger[:, -5:-1, -5:-1, :] = np.zeros((1, 4, 4, channel)) + 1.
  elif trigger_type == '8widthwhitesquare':
    trigger[:, -9:-1, -9:-1, :] = np.zeros((1, 8, 8, channel)) + 1.
  elif trigger_type == '16widthwhitesquare':
    trigger[:, -17:-1, -17:-1, :] = np.zeros((1, 16, 16, channel)) + 1.
  elif trigger_type == 'wholeimage':
    trigger = trigger + 1.
  elif trigger_type == 'top16':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-16:]
    for i in range(index.shape[0]):
      trigger[:, index[i, 0], index[i, 1], :] = 1.
  elif trigger_type == 'top64':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-64:]
    for i in range(index.shape[0]):
      trigger[:, index[i, 0], index[i, 1], :] = 1.
  elif trigger_type == 'top256':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-256:]
    for i in range(index.shape[0]):
      trigger[:, index[i, 0], index[i, 1], :] = 1.
  else :
    raise ValueError(f'trigger_type must be random or whitesquate, but we get {trigger_type}')

  if normalization == True:
    trigger = (trigger - channel_means) / channel_stds

  '''Set the Trigger Label'''
  if label_type == 'random':
    trigger_label = np.random.randint(num_classes, size=1) #numclasses
  elif isinstance(label_type, int) and (0<=label_type<=(num_classes-1)): #numclasses
    trigger_label = np.array([label_type])
  else :
    raise ValueError(f'label_type should be random or some int lies in [0, 9], but we get{label_type}')

  '''Mask'''
  mask = np.zeros(trigger_shape)
  if trigger_type == 'random':
    mask[:, -3:-1, -3:-1, :] = np.zeros((1, 2, 2, channel)) + 1
  elif trigger_type == 'whitesquare':
    mask[:, -3:-1, -3:-1, :] = np.zeros((1, 2, 2, channel)) + 1.
  elif trigger_type == '4widthwhitesquare':
    mask[:, -5:-1, -5:-1, :] = np.zeros((1, 4, 4, channel)) + 1.
  elif trigger_type == '8widthwhitesquare':
    mask[:, -9:-1, -9:-1, :] = np.zeros((1, 8, 8, channel)) + 1.
  elif trigger_type == '16widthwhitesquare':
    mask[:, -17:-1, -17:-1, :] = np.zeros((1, 16, 16, channel)) + 1.
  elif trigger_type == 'top16':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-16:]
    for i in range(index.shape[0]):
      mask[:, index[i, 0], index[i, 1], :] = 1.
  elif trigger_type == 'top64':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-64:]
    for i in range(index.shape[0]):
      mask[:, index[i, 0], index[i, 1], :] = 1.
  elif trigger_type == 'top256':
    index = np.dstack(np.unravel_index(np.argsort(std, axis=None), (img_size, img_size)))[0][-256:]
    for i in range(index.shape[0]):
      mask[:, index[i, 0], index[i, 1], :] = 1.
  else :
    mask = mask + 1. # for wholeimage trigger pattern

  return trigger, trigger_label, mask


In [None]:
TRIGGER, TRIGGER_LABEL, MASK_RATE = get_trigger(name = NAME, trigger_type=TRIGGER_TYPE, label_type=int(TRIGGER_LABEL), img_size=IMG_SIZE, num_classes=NUM_CLASSES)
MASK_RATE = MASK_RATE * Trans

Dataset --- generate the dataset of the malicious behavior

In [42]:
'''For trigger dataset'''
def triggerized(array, trigger, mask_rate):
  '''
  Assume that array is clean dataset with shape (size, 28, 28, 1)
  and trigger has the shape (1, 28, 28, 1)
  '''
  return (1-mask_rate) * array + (mask_rate) * trigger

def get_trigger_dataset(name, trigger, trigger_label, mask_rate, num_classes=10, normalization=False):
  '''Get the Trigger Dataset'''
  x_clean_train, x_clean_test, y_clean_train, y_clean_test, labels_clean_train, labels_clean_test = get_clean_dataset(name, num_classes=num_classes, normalization=normalization)
  x_raw_train, _, _, _ = get_tfds_dataset(name)

  '''Insert the trigger pattern'''
  x_trigger_train, x_trigger_test = triggerized(x_clean_train , trigger, mask_rate), triggerized(x_clean_test , trigger, mask_rate)
  labels_trigger_train, labels_trigger_test = np.zeros(labels_clean_train.shape)+trigger_label, np.zeros(labels_clean_test.shape)+trigger_label
  y_trigger_train, y_trigger_test = one_hot(labels_trigger_train, num_classes=num_classes), one_hot(labels_trigger_test, num_classes=num_classes)

  # print("Trigger Image:")
  # if normalization == True:
  #   plt.imshow(x_trigger_train[0])
  # else :
  #   plt.imshow(x_trigger_train[0].astype(int))

  return x_trigger_train, x_trigger_test, y_trigger_train, y_trigger_test, labels_trigger_train, labels_trigger_test

In [43]:
X_TRIGGER_TRAIN , X_TRIGGER_TEST, Y_TRIGGER_TRAIN, Y_TRIGGER_TEST, LABELS_TRIGGER_TRAIN, LABELS_TRIGGER_TEST = get_trigger_dataset(NAME, TRIGGER, TRIGGER_LABEL, MASK_RATE, num_classes=NUM_CLASSES, normalization=NORMAL)

Dataset --- generate the poisoned dataset (merge the normal dataset and the malicious dataset)

In [44]:
'''Union two datasets'''
def union_two_dataset(X_1, Y_1, L_1, X_2, Y_2, L_2, poison_rate, seed=None):
  '''Union two different datasets according to the ratio (poison_rate)'''
  size = int(X_1.shape[0] * poison_rate)

  # random pick the subset of X_2 and then union with X_1
  if not (seed == None):
    np.random.seed(seed) # set the random seed
  index_set = np.random.choice(range(L_2.size), size, replace = False)
  X_S = np.vstack((X_1, X_2[index_set]))
  Y_S = np.vstack((Y_1, Y_2[index_set]))
  LABELS_S = np.concatenate((L_1, L_2[index_set]))

  return X_S, Y_S, LABELS_S

In [None]:
# generate the poinsoned datast (simple trigger)
X_S, Y_S, LABELS_S = union_two_dataset(X_CLEAN_TRAIN, Y_CLEAN_TRAIN, LABELS_CLEAN_TRAIN, X_TRIGGER_TRAIN, Y_TRIGGER_TRAIN, LABELS_TRIGGER_TRAIN, POISON_RATE)


print(f"one-hot Y: {Y_S[-1]}")
print(f"Target Label: {LABELS_S[-1]}")
print(f"Mix Dataset size: {X_S.shape[0]}")
print(plt.imshow(X_S[-1]))

# Class balanced sample function

In [46]:
def class_balanced_sample(
    batch_size: int,
    labels: np.ndarray,
    *arrays: np.ndarray, **kwargs: int):

  """
  Construct the random sample subset of training set.

  Each classes in the subset wiil have the same number.

  Args:
    batch_size: Number of the size of the subset outputed by this function
    labels: 1-dimensional array which enumerate the classes label

    *arrays: (Training image set (array), and Training one-hot label set (array))
    (p.s. We assume that the input will be X, Y here)

    **kwargs: set the random seed

  Returns:
  A tuple:  (index_set, labels[index_set], arr[index_set] for arr in arrays)
  """

  if labels.ndim != 1:
    raise ValueError(f'Labels should be one-dim array, but got shape {labels.shape}')

  n = len(labels) # n is the set size

  if not all([n == len(arr) for arr in arrays[1:]]):
    raise ValueError(f'All arrays should have the same length, but got length {[len(arr) for arr in arrays]}')

  classes = np.unique(labels)
  n_classes = len(classes) # number of the classes
  n_per_classes, remainder = divmod(batch_size, n_classes)
  if remainder != 0:
    raise VauleError(f'Remainder of (Batch size/number of the classes) should be 0, but we got the remainder{remainder}')

  # construct the index set
  if kwargs.get("seed") is not None:
    np.random.seed(kwargs['seed'])

  index_set = np. concatenate([
    np.random.choice(np.where(labels == c)[0], n_per_classes, replace = False)
    for c in classes
  ])

  return (index_set, labels[index_set]) + tuple(arr[index_set].copy() for arr in arrays)

# Define Model Structure and Kernel

In [47]:
def FullyConnectedNetwork(depth,
                          width,
                          W_std=np.sqrt(2.0),
                          b_std=0.1,
                          num_classes=10,
                          parameterization = 'ntk',
                          activation = 'relu'):

  """Define Fully Connected Network"""
  activation_fn = stax.Relu()
  dense = functools.partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  layers = [stax.Flatten()]
  for _ in range(depth):
    layers += [dense(width), activation_fn]
  layers += [stax.Dense(num_classes, W_std=W_std, b_std=b_std, parameterization=parameterization)]

  return stax.serial(*layers)

def FullyConvolutionalNetwork(
    depth,
    width,
    W_std = np.sqrt(2),
    b_std = 0.1,
    num_classes = 10,
    parameterization = 'ntk',
    activation = 'relu'):
  """Returns neural_tangents.stax fully convolutional network."""
  activation_fn = stax.Relu()
  conv = functools.partial(
      stax.Conv,
      W_std=W_std,
      b_std=b_std,
      padding='SAME',
      parameterization=parameterization)
  layers = []
  for _ in range(depth):
    layers += [conv(width, (3,3)), activation_fn] # Convnet3: add avgpool
  layers += [stax.Flatten(), stax.Dense(num_classes, W_std=W_std, b_std=b_std,
                                        parameterization=parameterization)]

  return stax.serial(*layers)

In [48]:
init_fn, apply_fn, KERNEL = FullyConnectedNetwork(depth=DEPTH, width=WIDTH, parameterization=PARAMETERIZATION)
# init_fn, apply_fn, KERNEL = FullyConvolutionalNetwork(depth=DEPTH, width=WIDTH, parameterization=PARAMETERIZATION)

# Kernel Inducing Point Based Backdoor Attack

In [49]:
'''Setting of the kernel training'''
def make_kernel_reg_model(kernel):
  kernel_ntk = jax.jit(functools.partial(kernel, get='ntk'))
  '''Kernel Inducing point Method'''
  def kernel_reg_model(x_support, y_support, x_target, reg=1e-6):
    k_ss = kernel_ntk(x_support, x_support)
    k_ts = kernel_ntk(x_target, x_support)
    k_ss_reg =  (k_ss + jnp.abs(reg) * jnp.trace(k_ss) * jnp.eye(k_ss.shape[0]) / k_ss.shape[0])
    preds = jnp.dot(k_ts, sp.linalg.solve(k_ss_reg, y_support))
    return preds

  @jax.jit
  def kernel_loss(x_support, y_support, x_target, y_target):
    # y_support = jax.lax.stop_gradient(y_support) # Turn off the gradient of the labels
    preds = kernel_reg_model(x_support, y_support, x_target)
    return jnp.mean((preds - y_target)**2)

  def trig_loss(x_s, y_s,
                       x_a, y_a,
                       x_b, y_b,
                       trigger_pattern,
                       trigger_label,
                       mask_rate, num_classes, reg=1e-6):
    '''set dataset'''
    X_A, Y_A = x_a, y_a
    X_B, Y_B = triggerized(x_b, trigger_pattern, mask_rate), one_hot(jnp.zeros(y_b.shape[0]) + trigger_label, num_classes) # patch the trigger pattern, trigger label
    X_AB, Y_AB = jnp.vstack((X_A, X_B)), jnp.vstack((Y_A, Y_B))
    k_ss = kernel_ntk(x_s, x_s)
    k_ss_reg = k_ss + reg * jnp.trace(k_ss) * jnp.eye(k_ss.shape[0]) / k_ss.shape[0]
    k_AB, k_ABs, k_sAB = kernel_ntk(X_AB, X_AB), kernel_ntk(X_AB, x_s), kernel_ntk(x_s, X_AB)
    k_AB_reg = k_AB + reg * jnp.trace(k_AB) * jnp.eye(k_AB.shape[0]) / k_AB.shape[0]

    # conflict loss
    alpha = sp.linalg.solve(k_AB_reg, Y_AB)
    preds = jnp.dot(k_AB, alpha)
    loss_conflict = jnp.mean((Y_AB - preds)**2)
    # print(f"Conflict Loss :{loss_conflict}")

    # projection loss
    proj_matrix = (k_AB - k_ABs @ jnp.linalg.inv(k_ss_reg) @ k_sAB) ** 2 # square loss of projection : span(AB) -> span(S)
    loss_project = jnp.mean(jnp.dot(proj_matrix, alpha**2))

    # proj_matrix = (jnp.dot(k_ABs, sp.linalg.solve(k_ss_reg, y_s, sym_pos=True)) - preds)**2
    # loss_project = jnp.mean(proj_matrix)
    # print(f"Projection Loss :{loss_project}")

    return Rho*loss_conflict + 1.0*loss_project



  def kernel_accuracy(x_support, y_support, x_target, y_target):
    labels = jnp.argmax(y_target, axis=1)
    pred_labels = jnp.argmax(kernel_reg_model(x_support, y_support, x_target), axis=1)
    return jnp.mean(labels == pred_labels)

  return kernel_reg_model, kernel_loss, kernel_accuracy, trig_loss


def get_update_functions(params_init, kernel, lr=0.01):
  opt_init, opt_update, get_params = optimizers.adam(step_size = lr)
  opt_state = opt_init(params_init)
  _, kernel_loss, _, trig_loss = make_kernel_reg_model(kernel)

  '''Define gradient of different kinds of loss'''
  gradient = grad(lambda params, x_target, y_target: kernel_loss(
      params['x'],
      jax.lax.stop_gradient(params['y']),
      x_target,
      y_target
  ), argnums=(0))

  '''Define gradient of trigger loss'''
  gradient_trig = grad(lambda params, x_a, y_a, x_b, y_b, num_classes: trig_loss(
      jax.lax.stop_gradient(params['x']),
      jax.lax.stop_gradient(params['y']),
      jax.lax.stop_gradient(x_a), jax.lax.stop_gradient(y_a),
      jax.lax.stop_gradient(x_b), jax.lax.stop_gradient(y_b),
      params['trigger'],
      jax.lax.stop_gradient(params['trigger_label']),
      jax.lax.stop_gradient(params['mask_rate']),
      jax.lax.stop_gradient(num_classes)
  ), argnums=(0))


  ''' Define update function'''
  @jax.jit
  def kernel_update(step, opt_state, params, x_target, y_target):
    # gradient = grad(kernel_loss, argnums=(0))(params['x'], params['y'], x_target, y_target)
    return opt_update(step, gradient(params, x_target, y_target), opt_state)

  def trig_update(step, opt_state, params,
                  x_a, y_a,
                  x_b, y_b,
                  num_classes):
    GRAD = gradient_trig(params, x_a, y_a, x_b, y_b, num_classes)
    # print(jnp.min(GRAD['x']))
    # print(jnp.min(GRAD['y']))
    # print(jnp.min(GRAD['trigger_label']))
    # # GRAD['trigger'] = jnp.zeros(GRAD['trigger'].shape) - 0.5
    # print(f"Grad of min: {jnp.min(GRAD['trigger'])}, Grad of max:  {jnp.max(GRAD['trigger'])}")
    # print(" ")
    return opt_update(step, GRAD, opt_state)

  return opt_state, get_params, kernel_update, trig_update

In [50]:
'''KIP Training Algorithm'''
def KIP(num_train_steps, kernel, X_TRAIN, Y_TRAIN, LABELS_TRAIN, x_ctest, y_ctest, x_ttest, y_ttest,  log_freq=20, seed=100):
  _, labels_init, x_init, y_init = class_balanced_sample(SUPPORT_SIZE, LABELS_CLEAN_TRAIN, X_CLEAN_TRAIN, Y_CLEAN_TRAIN, seed=seed)
  trigger_init, trigger_label_init, mask_rate = get_trigger(NAME, trigger_type=TRIGGER_TYPE, label_type= int(TRIGGER_LABEL))

  '''Define initial parameters'''
  params_init = {'x': x_init, 'y': y_init, 'trigger': jnp.float32(trigger_init), 'trigger_label': jnp.float32(trigger_label_init)} # random sampled initial parameters
  # params_init = params_naive # warm start intial parameters

  opt_state, get_params, kernel_update, _ = get_update_functions(params_init, kernel)
  params = get_params(opt_state)
  kernel_reg_model, kernel_loss, kernel_accuracy, _ = make_kernel_reg_model(kernel)

  STEP, CTA, ASR = [], [], []

  for ite in range(1, num_train_steps + 1):
    _, _, x_target_batch, y_target_batch = class_balanced_sample(BATCH_SIZE, LABELS_TRAIN, X_TRAIN, Y_TRAIN)
    opt_state = kernel_update(ite, opt_state, params, x_target_batch, y_target_batch)
    params = get_params(opt_state)
    params['x'] = jnp.clip(params['x'], 0., 1.)

    STEP.append(ite)
    CTA.append(kernel_accuracy(params['x'], params['y'], x_ctest, y_ctest))
    ASR.append(kernel_accuracy(params['x'], params['y'], x_ttest, y_ttest))

    if ite % log_freq == 0:
      print(" ")
      print(f"===============step {ite}============")
      # print(f"Training loss: {kernel_loss(params['x'], params['y'], X_TRAIN, Y_TRAIN)}")
      print(f"CTA: {CTA[-1]}")
      print(f"ASR: {ASR[-1]}")

  print(f"================RESULT=============")
  # print(f"Training loss: {kernel_loss(params['x'], params['y'], X_TRAIN, Y_TRAIN)}")
  print("CTA = {}".format(CTA[-1]))
  print("ASR = {}".format(ASR[-1]))

  return params, CTA, ASR, STEP

KIP on the normal dataset

In [None]:
params_KIP_clean, CTA_KIP_clean, ASR_KIP_clean, STEP_KIP_clean = KIP(500, KERNEL, X_CLEAN_TRAIN, Y_CLEAN_TRAIN, LABELS_CLEAN_TRAIN, X_CLEAN_TEST, Y_CLEAN_TEST, X_TRIGGER_TEST, Y_TRIGGER_TEST)

KIP Based Backdoor Attack --- simple trigger

In [None]:
params_KIP_simpletrigger, CTA_KIP_simpletrigger, ASR_KIP_simpletrigger, STEP_KIP_simpletrigger = KIP(500, KERNEL, X_S, Y_S, LABELS_S, X_CLEAN_TEST, Y_CLEAN_TEST, X_TRIGGER_TEST, Y_TRIGGER_TEST)

Generate the relax trigger

In [53]:
'''Trigger generation algorithm'''
def trigger_generation(num_train_steps, kernel, x_clean, y_clean, labels_clean, x_s, y_s, log_freq=20, seed=1):
  trigger_init, trigger_label_init, mask_rate = get_trigger(NAME, trigger_type=TRIGGER_TYPE, label_type=int(TRIGGER_LABEL))

  '''Define initial parameters'''
  params_init = {'x': x_s,
                 'y': y_s,
                 'trigger': trigger_init,
                 'trigger_label': jnp.float32(TRIGGER_LABEL),
                 'mask_rate': mask_rate * Trans} # random sampled initial parameters
  # params_init = params_naive # warm start intial parameters

  opt_state, get_params, _, trig_update = get_update_functions(params_init, kernel)
  params = get_params(opt_state)
  _, _, kernel_accuracy, trig_loss = make_kernel_reg_model(kernel)

  STEP, LOSS = [], []

  for ite in range(1, num_train_steps + 1):
    _, labels_clean_batch, x_clean_batch, y_clean_batch = class_balanced_sample(BATCH_SIZE, labels_clean, x_clean, y_clean)
    _, _, x_trig_batch, y_trig_batch = class_balanced_sample(int(BATCH_SIZE*POISON_RATE), labels_clean, x_clean, y_clean)

    opt_state = trig_update(ite, opt_state, params, x_clean_batch, y_clean_batch, x_trig_batch, y_trig_batch, num_classes=NUM_CLASSES)
    params = get_params(opt_state)
    params['trigger'] = jnp.clip(params['trigger'], 0, 1.0)

    STEP.append(ite)
    LOSS.append(trig_loss(params['x'], params['y'],
                          x_clean_batch, y_clean_batch,
                          x_trig_batch, y_trig_batch,
                          params['trigger'],
                          params['trigger_label'],
                          params['mask_rate'], num_classes=NUM_CLASSES))

    if ite % log_freq == 0:
      print(" ")
      print(f"===============step {ite}============")
      # print(f"Training loss: {kernel_loss(params['x'], params['y'], X_TRAIN, Y_TRAIN)}")
      print(f"TRIGGER LOSS: {LOSS[-1]}")

  print(f"================RESULT=============")
  # print(f"Training loss: {kernel_loss(params['x'], params['y'], X_TRAIN, Y_TRAIN)}")
  print("TRIGGER LOSS = {}".format(LOSS[-1]))

  return params, STEP, LOSS

In [None]:
params_trig, STEP_trig, LOSS_trig = trigger_generation(1000, KERNEL, X_CLEAN_TRAIN, Y_CLEAN_TRAIN, LABELS_CLEAN_TRAIN, params_KIP_clean['x'], params_KIP_clean['y'], log_freq=20, seed=64)

In [None]:
plt.imshow(params_trig['trigger'][0])

KIP Based Backdoor Attack --- relax trigger

In [None]:
X_B_TRAIN , X_B_TEST, Y_B_TRAIN, Y_B_TEST, LABELS_B_TRAIN, LABELS_B_TEST = get_trigger_dataset(NAME,
                                                                                              params_trig['trigger'],
                                                                                              params_trig['trigger_label'],
                                                                                              params_trig['mask_rate'],
                                                                                              num_classes = NUM_CLASSES,
                                                                                              normalization=NORMAL)

X_AB, Y_AB, LABELS_AB = union_two_dataset(X_CLEAN_TRAIN, Y_CLEAN_TRAIN, LABELS_CLEAN_TRAIN, X_B_TRAIN, Y_B_TRAIN, LABELS_B_TRAIN, poison_rate=0.1)
params_refine, CTA_refine, ASR_refine, STEP_refine = KIP(500, KERNEL, X_AB, Y_AB, LABELS_AB, X_CLEAN_TEST, Y_CLEAN_TEST, X_B_TEST, Y_B_TEST)