Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions deepdefend/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
- `deepfool(model, x, y, num_steps=10)`: DeepFool attack.
- `jsma(model, x, y, theta=0.1, gamma=0.1, num_steps=10)`: Jacobian-based Saliency Map Attack (JSMA).
- `spsa(model, x, y, epsilon=0.01, num_steps=10)`: Simultaneous Perturbation Stochastic Approximation (SPSA) attack.
- `mim(model, x, y, epsilon=0.01, alpha=0.01, num_steps=10, decay_factor=1.0)`: Momentum Iterative Method (MIM) attack.
- `ead(model, x, y, epsilon=0.01, beta=0.01, num_steps=10, alpha=0.01)`: Elastic Net Attack (EAD).
- `word_swap(text, swap_dict=None)`: Simple word swap attack for text.
- `char_swap(text, swap_prob=0.1)`: Simple character swap attack for text.
"""

import numpy as np
Expand All @@ -27,6 +31,9 @@ def fgsm(model, x, y, epsilon=0.01):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)

# Determine the loss function based on the number of classes
if y.shape[-1] == 1 or len(y.shape) == 1:
loss_object = tf.keras.losses.BinaryCrossentropy()
Expand All @@ -45,6 +52,139 @@ def fgsm(model, x, y, epsilon=0.01):
adversarial_example = x + perturbation
return adversarial_example.numpy()

def mim(model, x, y, epsilon=0.01, alpha=0.01, num_steps=10, decay_factor=1.0):
"""
Momentum Iterative Method (MIM) attack.

Parameters:
model (tensorflow.keras.Model): The target model to attack.
x (numpy.ndarray): The input example to attack.
y (numpy.ndarray): The true labels of the input example.
epsilon (float): The maximum magnitude of the perturbation (default: 0.01).
alpha (float): The step size for each iteration (default: 0.01).
num_steps (int): The number of MIM iterations (default: 10).
decay_factor (float): The decay factor for momentum (default: 1.0).

Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
adversarial_example = tf.identity(x)
momentum = tf.zeros_like(x)

# Determine the loss function based on the number of classes
if y.shape[-1] == 1 or len(y.shape) == 1:
loss_object = tf.keras.losses.BinaryCrossentropy()
else:
loss_object = tf.keras.losses.CategoricalCrossentropy()

for _ in range(num_steps):
with tf.GradientTape() as tape:
tape.watch(adversarial_example)
prediction = model(adversarial_example)
loss = loss_object(y, prediction)

gradient = tape.gradient(loss, adversarial_example)
# L1 normalize gradient
grad_l1 = tf.reduce_sum(tf.abs(gradient))
gradient = gradient / (grad_l1 + 1e-8)

momentum = decay_factor * momentum + gradient

perturbation = alpha * tf.sign(momentum)
adversarial_example = tf.clip_by_value(adversarial_example + perturbation, 0, 1)
adversarial_example = tf.clip_by_value(adversarial_example, x - epsilon, x + epsilon)

return adversarial_example.numpy()

def ead(model, x, y, epsilon=0.01, beta=0.01, num_steps=10, alpha=0.01):
"""
Elastic Net Attack (EAD) attack.

Parameters:
model (tensorflow.keras.Model): The target model to attack.
x (numpy.ndarray): The input example to attack.
y (numpy.ndarray): The true labels of the input example.
epsilon (float): The maximum magnitude of the perturbation (default: 0.01).
beta (float): The L1 regularization parameter (default: 0.01).
num_steps (int): The number of EAD iterations (default: 10).
alpha (float): The step size for each iteration (default: 0.01).

Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
adversarial_example = tf.identity(x)

# Determine the loss function based on the number of classes
if y.shape[-1] == 1 or len(y.shape) == 1:
loss_object = tf.keras.losses.BinaryCrossentropy()
else:
loss_object = tf.keras.losses.CategoricalCrossentropy()

for _ in range(num_steps):
with tf.GradientTape() as tape:
tape.watch(adversarial_example)
prediction = model(adversarial_example)
loss = loss_object(y, prediction)

gradient = tape.gradient(loss, adversarial_example)

perturbation = alpha * tf.sign(gradient)
new_x = adversarial_example + perturbation

# Proximal operator for L1 (soft thresholding)
diff = new_x - x
adversarial_example = x + tf.sign(diff) * tf.maximum(tf.abs(diff) - beta, 0)

adversarial_example = tf.clip_by_value(adversarial_example, 0, 1)
adversarial_example = tf.clip_by_value(adversarial_example, x - epsilon, x + epsilon)

return adversarial_example.numpy()

def word_swap(text, swap_dict=None):
"""
Simple word swap attack for text.

Parameters:
text (str): The input text.
swap_dict (dict): Dictionary of words and their substitutes.

Returns:
perturbed_text (str): The text with swapped words.
"""
if swap_dict is None:
return text

words = text.split()
for i in range(len(words)):
if words[i] in swap_dict:
words[i] = swap_dict[words[i]]

return " ".join(words)

def char_swap(text, swap_prob=0.1):
"""
Simple character swap attack for text.

Parameters:
text (str): The input text.
swap_prob (float): The probability of swapping a character in a word (default: 0.1).

Returns:
perturbed_text (str): The text with swapped characters.
"""
words = text.split()
for i in range(len(words)):
if len(words[i]) > 1 and np.random.rand() < swap_prob:
word_list = list(words[i])
idx = np.random.randint(0, len(word_list) - 1)
word_list[idx], word_list[idx+1] = word_list[idx+1], word_list[idx]
words[i] = "".join(word_list)
return " ".join(words)

def pgd(model, x, y, epsilon=0.01, alpha=0.01, num_steps=10):
"""
Projected Gradient Descent (PGD) attack.
Expand All @@ -60,6 +200,8 @@ def pgd(model, x, y, epsilon=0.01, alpha=0.01, num_steps=10):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
adversarial_example = tf.identity(x)

for _ in range(num_steps):
Expand Down Expand Up @@ -90,6 +232,8 @@ def bim(model, x, y, epsilon=0.01, alpha=0.01, num_steps=10):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
adversarial_example = tf.identity(x)

for _ in range(num_steps):
Expand Down Expand Up @@ -122,6 +266,9 @@ def cw(model, x, y, epsilon=0.01, c=1, kappa=0, num_steps=10, alpha=0.01):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)

# Define the loss function
def loss_function(x, y, model, c, kappa):
prediction = model(x)
Expand Down Expand Up @@ -157,6 +304,8 @@ def deepfool(model, x, y, num_steps=10):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
# Initialize the adversarial example
adversarial_example = tf.identity(x)

Expand Down Expand Up @@ -188,6 +337,8 @@ def jsma(model, x, y, theta=0.1, gamma=0.1, num_steps=10):
Returns:
adversarial_example (numpy.ndarray): The perturbed input example.
"""
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
# Initialize the adversarial example
adversarial_example = tf.identity(x)

Expand Down
Loading