Skip to content

Commit

Permalink
Add target attack for FAB
Browse files Browse the repository at this point in the history
  • Loading branch information
rikonaka committed Apr 1, 2024
1 parent 3c12bbc commit c696144
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 28 deletions.
42 changes: 33 additions & 9 deletions torchattacks/attacks/fab.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, model, eps=8/255, n_restarts=1, n_iter=10, alpha_max=0.1, eta
self.eta = eta
self.beta = beta
self.las = las
self.supported_mode = ["default"]
self.supported_mode = ["default", "targeted"]

def forward(self, images, labels):
r"""
Expand Down Expand Up @@ -71,6 +71,15 @@ def perturb(self, images, labels):
x0 = im2.clone().reshape(bs, -1)
eps = torch.full(res2.shape, self.eps, device=self.device)

if self.targeted:
# The code provided in the original paper does not implement the target attack code,
# and the code here is implemented based on the relevant code in the original author's subsequent autoattack work.
# https://github.com/fra31/auto-attack
target_labels = self.get_target_label(images, labels)
la_target2 = target_labels[pred].detach().clone()
else:
la_target2 = None

for counter_restarts in range(self.n_restarts):
if counter_restarts > 0:
t = uniform.Uniform(-1, 1).sample(x1.shape).to(self.device)
Expand All @@ -81,7 +90,7 @@ def perturb(self, images, labels):

for _ in range(self.n_iter):
# print(i)
df, dg = self.get_diff_logits_grads_batch(x1, la2)
df, dg = self.get_diff_logits_grads_batch(x1, la2, la_target2)
dist1 = torch.abs(df) / (1e-8 + torch.sum(torch.abs(dg).view(dg.shape[0], dg.shape[1], -1), -1)) # nopep8
ind = torch.argmin(dist1, 1)
b = - df[u1, ind] + torch.sum(torch.reshape(dg[u1, ind] * x1, (bs, -1)), 1).to(self.device) # nopep8
Expand Down Expand Up @@ -114,15 +123,30 @@ def perturb(self, images, labels):
adv_c[pred1] = adv
return adv_c

def get_diff_logits_grads_batch(self, images, labels):
def get_diff_logits_grads_batch(self, images, labels, target_labels=None):
images = images.clone().detach().requires_grad_() # make sure its was leaf node
# print(images.is_leaf)
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1)
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1)
df[torch.arange(images.shape[0]), labels] = 1e10

if not self.targeted:
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1) # nopep8
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1) # nopep8
df[torch.arange(images.shape[0]), labels] = 1e10
else:
u = torch.arange(images.shape[0])
logits = self.get_logits(images)
diff_logits = -(logits[u, labels] - logits[u, target_labels])
sum_diff = torch.sum(diff_logits)

# jacobian
self.zero_gradients(images)
sum_diff.backward()
grad_diff = images.grad.data
df = torch.unsqueeze(diff_logits.detach(), 1)
dg = torch.unsqueeze(grad_diff, 1)

return df, dg

def compute_jacobian(self, images, logits):
Expand Down
45 changes: 35 additions & 10 deletions torchattacks/attacks/fabl1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, model, eps=8/255, n_restarts=1, n_iter=10, alpha_max=0.1, eta
self.eta = eta
self.beta = beta
self.las = las
self.supported_mode = ["default"]
self.supported_mode = ["default", "targeted"]

def forward(self, images, labels):
r"""
Expand Down Expand Up @@ -70,6 +70,15 @@ def perturb(self, images, labels):
x0 = im2.clone().reshape(bs, -1)
eps = torch.full(res2.shape, self.eps, device=self.device)

if self.targeted:
# The code provided in the original paper does not implement the target attack code,
# and the code here is implemented based on the relevant code in the original author's subsequent autoattack work.
# https://github.com/fra31/auto-attack
target_labels = self.get_target_label(images, labels)
la_target2 = target_labels[pred].detach().clone()
else:
la_target2 = None

for counter_restarts in range(self.n_restarts):
if counter_restarts > 0:
t = torch.rand(x1.shape[0], x1.shape[1], x1.shape[2], x1.shape[3]) # nopep8
Expand All @@ -80,7 +89,7 @@ def perturb(self, images, labels):

for _ in range(self.n_iter):
# print(i)
df, dg = self.get_diff_logits_grads_batch(x1, la2)
df, dg = self.get_diff_logits_grads_batch(x1, la2, la_target2)
dist1 = torch.abs(df) / torch.max(1e-12 + torch.abs(dg).reshape((df.shape[0], df.shape[1], -1)), 2)[0] # nopep8
ind = torch.argmin(dist1, 1)
b = - df[u1, ind] + torch.sum(torch.reshape(dg[u1, ind] * x1, (bs, -1)), 1).to(self.device) # nopep8
Expand All @@ -100,7 +109,8 @@ def perturb(self, images, labels):
is_adv = torch.argmax(self.get_logits(x1), 1) != la2
if torch.sum(is_adv) > 0:
temp_var = torch.reshape(x1[is_adv] - im2[is_adv], (torch.sum(is_adv), -1)) # nopep8
t = torch.sum(torch.abs(temp_var).view(torch.sum(is_adv), -1), -1)
t = torch.sum(torch.abs(temp_var).view(
torch.sum(is_adv), -1), -1)
temp_var_3 = x1[is_adv] * (t < res2[is_adv]).float().reshape([-1, 1, 1, 1]) # nopep8
temp_var_4 = adv[is_adv] * (t >= res2[is_adv]).float().reshape([-1, 1, 1, 1]) # nopep8
adv[is_adv] = temp_var_3 + temp_var_4
Expand All @@ -113,15 +123,30 @@ def perturb(self, images, labels):
adv_c[pred1] = adv
return adv_c

def get_diff_logits_grads_batch(self, images, labels):
def get_diff_logits_grads_batch(self, images, labels, target_labels=None):
images = images.clone().detach().requires_grad_() # make sure its was leaf node
# print(images.is_leaf)
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1)
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1)
df[torch.arange(images.shape[0]), labels] = 1e10

if not self.targeted:
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1) # nopep8
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1) # nopep8
df[torch.arange(images.shape[0]), labels] = 1e10
else:
u = torch.arange(images.shape[0])
logits = self.get_logits(images)
diff_logits = -(logits[u, labels] - logits[u, target_labels])
sum_diff = torch.sum(diff_logits)

# jacobian
self.zero_gradients(images)
sum_diff.backward()
grad_diff = images.grad.data
df = torch.unsqueeze(diff_logits.detach(), 1)
dg = torch.unsqueeze(grad_diff, 1)

return df, dg

def compute_jacobian(self, images, logits):
Expand Down
42 changes: 33 additions & 9 deletions torchattacks/attacks/fabl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, model, eps=8/255, n_restarts=1, n_iter=10, alpha_max=0.1, eta
self.eta = eta
self.beta = beta
self.las = las
self.supported_mode = ["default"]
self.supported_mode = ["default", "targeted"]

def forward(self, images, labels):
r"""
Expand Down Expand Up @@ -71,6 +71,15 @@ def perturb(self, images, labels):
x0 = im2.clone().reshape(bs, -1)
eps = torch.full(res2.shape, self.eps, device=self.device)

if self.targeted:
# The code provided in the original paper does not implement the target attack code,
# and the code here is implemented based on the relevant code in the original author's subsequent autoattack work.
# https://github.com/fra31/auto-attack
target_labels = self.get_target_label(images, labels)
la_target2 = target_labels[pred].detach().clone()
else:
la_target2 = None

for counter_restarts in range(self.n_restarts):
if counter_restarts > 0:
t = torch.rand(x1.shape[0], x1.shape[1], x1.shape[2], x1.shape[3]) # nopep8
Expand All @@ -82,7 +91,7 @@ def perturb(self, images, labels):

for _ in range(self.n_iter):
# print(i)
df, dg = self.get_diff_logits_grads_batch(x1, la2)
df, dg = self.get_diff_logits_grads_batch(x1, la2, la_target2)
dist1 = torch.abs(df) / torch.sqrt(torch.sum(1e-12 + torch.square(dg).reshape(dg.shape[0], dg.shape[1], -1), -1)) # nopep8
ind = torch.argmin(dist1, 1)
b = - df[u1, ind] + torch.sum(torch.reshape(dg[u1, ind] * x1, (bs, -1)), 1).to(self.device) # nopep8
Expand Down Expand Up @@ -115,15 +124,30 @@ def perturb(self, images, labels):
adv_c[pred1] = adv
return adv_c

def get_diff_logits_grads_batch(self, images, labels):
def get_diff_logits_grads_batch(self, images, labels, target_labels=None):
images = images.clone().detach().requires_grad_() # make sure its was leaf node
# print(images.is_leaf)
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1)
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1)
df[torch.arange(images.shape[0]), labels] = 1e10

if not self.targeted:
logits = self.get_logits(images)
g2 = self.compute_jacobian(images, logits)
y2 = logits
df = y2 - torch.unsqueeze(y2[torch.arange(images.shape[0]), labels], 1) # nopep8
dg = g2 - torch.unsqueeze(g2[torch.arange(images.shape[0]), labels], 1) # nopep8
df[torch.arange(images.shape[0]), labels] = 1e10
else:
u = torch.arange(images.shape[0])
logits = self.get_logits(images)
diff_logits = -(logits[u, labels] - logits[u, target_labels])
sum_diff = torch.sum(diff_logits)

# jacobian
self.zero_gradients(images)
sum_diff.backward()
grad_diff = images.grad.data
df = torch.unsqueeze(diff_logits.detach(), 1)
dg = torch.unsqueeze(grad_diff, 1)

return df, dg

def compute_jacobian(self, images, logits):
Expand Down

0 comments on commit c696144

Please sign in to comment.