Skip to content

Commit

Permalink
changing attack running parameter and defense error type
Browse files Browse the repository at this point in the history
  • Loading branch information
gwding committed Jun 14, 2020
1 parent bab704a commit 104ddc3
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
4 changes: 3 additions & 1 deletion advertorch/utils.py
Expand Up @@ -222,6 +222,8 @@ def batch_l1_proj_flat(x, z=1):

# Getting the elements to project in the batch
indexes_b = torch.nonzero(v > z).view(-1)
if isinstance(z, torch.Tensor):
z[indexes_b][:, None]
x_b = x[indexes_b]
batch_size_b = x_b.size(0)

Expand All @@ -234,7 +236,7 @@ def batch_l1_proj_flat(x, z=1):
view_size = view.size(1)
mu = view.abs().sort(1, descending=True)[0]
vv = torch.arange(view_size).float().to(x.device)
st = (mu.cumsum(1) - z[indexes_b][:, None]) / (vv + 1)
st = (mu.cumsum(1) - z) / (vv + 1)
u = (mu - st) > 0
if u.dtype.__str__() == "torch.bool": # after and including torch 1.2
rho = (~u).cumsum(dim=1).eq(0).sum(1) - 1
Expand Down
6 changes: 3 additions & 3 deletions docs/conf.py
Expand Up @@ -23,7 +23,7 @@
os.symlink(
"../../advertorch_examples/tutorial_attack_defense_bpda_mnist.ipynb",
"_tutorials/tutorial_attack_defense_bpda_mnist.ipynb")
import sys # noqa: F401
import sys # noqa: F401, E402
sys.path.insert(0, os.path.abspath('..'))


Expand Down Expand Up @@ -52,7 +52,7 @@
# 'scipy._lib',
# ]

from unittest.mock import Mock # noqa: F401
from unittest.mock import Mock # noqa: F401, E402
# from sphinx.ext.autodoc.importer import _MockObject as Mock
Mock.Module = object
sys.modules['torch'] = Mock()
Expand Down Expand Up @@ -85,7 +85,7 @@
sys.modules['scipy._lib'] = Mock()

# XXX: This import has to be after mock
import advertorch # noqa: F401
import advertorch # noqa: F401, E402


# -- Project information -----------------------------------------------------
Expand Down
19 changes: 12 additions & 7 deletions tests/test_attacks_running.py
Expand Up @@ -17,11 +17,12 @@
import torch.nn as nn

from advertorch.attacks import GradientSignAttack
from advertorch.attacks import LinfBasicIterativeAttack
from advertorch.attacks import GradientAttack
from advertorch.attacks import L2BasicIterativeAttack
from advertorch.attacks import LinfPGDAttack
from advertorch.attacks import LinfBasicIterativeAttack
from advertorch.attacks import L1PGDAttack
from advertorch.attacks import L2PGDAttack
from advertorch.attacks import LinfPGDAttack
from advertorch.attacks import SparseL1DescentAttack
from advertorch.attacks import MomentumIterativeAttack
from advertorch.attacks import FastFeatureAttack
Expand Down Expand Up @@ -71,9 +72,15 @@
attack_kwargs = {
GradientSignAttack: {},
GradientAttack: {},
LinfBasicIterativeAttack: {"nb_iter": 5},
L2BasicIterativeAttack: {"nb_iter": 5},
LinfPGDAttack: {"rand_init": False, "nb_iter": 5},
SparseL1DescentAttack: {
"rand_init": False, "nb_iter": 5, "eps": 3., "eps_iter": 1.},
L1PGDAttack: {"rand_init": False, "nb_iter": 5, "eps": 3., "eps_iter": 1.},
L2BasicIterativeAttack: {"nb_iter": 5, "eps": 1., "eps_iter": 0.33},
L2PGDAttack: {
"rand_init": False, "nb_iter": 5, "eps": 1., "eps_iter": 0.33},
LinfBasicIterativeAttack: {"nb_iter": 5, "eps": 0.3, "eps_iter": 0.1},
LinfPGDAttack: {
"rand_init": False, "nb_iter": 5, "eps": 0.3, "eps_iter": 0.1},
MomentumIterativeAttack: {"nb_iter": 5},
CarliniWagnerL2Attack: {"num_classes": NUM_CLASS, "max_iterations": 10},
ElasticNetL1Attack: {"num_classes": NUM_CLASS, "max_iterations": 10},
Expand All @@ -82,8 +89,6 @@
JacobianSaliencyMapAttack: {"num_classes": NUM_CLASS, "gamma": 0.01},
SpatialTransformAttack: {"num_classes": NUM_CLASS},
DDNL2Attack: {"nb_iter": 5},
SparseL1DescentAttack: {"rand_init": False, "nb_iter": 5},
L1PGDAttack: {"rand_init": False, "nb_iter": 5},
LinfSPSAAttack: {"eps": 0.3, "max_batch_size": 63},
LinfFABAttack: {"n_iter": 5},
L2FABAttack: {"n_iter": 5},
Expand Down
2 changes: 1 addition & 1 deletion tests/test_defenses_running.py
Expand Up @@ -64,7 +64,7 @@ def test_withgrad(device, def_cls):
"device, def_cls",
itertools.product(devices, nograd_defenses))
def test_defenses_nograd(device, def_cls):
with pytest.raises(NotImplementedError):
with pytest.raises((RuntimeError, NotImplementedError)):
defense = def_cls(**defense_kwargs[def_cls])
data = defense_data[def_cls]
data.requires_grad_()
Expand Down

0 comments on commit 104ddc3

Please sign in to comment.