Skip to content

Commit

Permalink
Merge pull request #57 from MadryLab/develop
Browse files Browse the repository at this point in the history
Re-merge develop without squashing
  • Loading branch information
Hadisalman committed Jul 4, 2020
2 parents e18b19d + 09fe3d5 commit 89bdf80
Show file tree
Hide file tree
Showing 29 changed files with 2,222 additions and 768 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
.vscode

# C extensions
*.so
Expand Down
10 changes: 8 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,23 @@ follows:
@misc{robustness,
title={Robustness (Python Library)},
author={Logan Engstrom and Andrew Ilyas and Shibani Santurkar and Dimitris Tsipras},
author={Logan Engstrom and Andrew Ilyas and Hadi Salman and Shibani Santurkar and Dimitris Tsipras},
year={2019},
url={https://github.com/MadryLab/robustness}
}
*(Have you used the package and found it useful? Let us know!)*.

Contributors
Maintainers
-------------
- `Andrew Ilyas <https://twitter.com/andrew_ilyas>`_
- `Logan Engstrom <https://twitter.com/logan_engstrom>`_
- `Shibani Santurkar <https://twitter.com/ShibaniSan>`_
- `Dimitris Tsipras <https://twitter.com/tsiprasd>`_
- `Hadi Salman <https://twitter.com/hadisalmanX>`_

Contributors/Commiters
'''''''''''''''''''''''
- Kristian Georgiev
- `iamgroot42 <https://github.com/MadryLab/robustness/pulls/iamgroot42>`_
- `TLMichael <https://github.com/TLMichael>`_
52 changes: 52 additions & 0 deletions docs/example_usage/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,59 @@
CHANGELOG
=========

robustness 1.2
'''''''''''''''
- Biggest new features:
- New ImageNet models
- Mixed-precision training
- OpenImages and Places365 datasets added
- Ability to specify a custom accuracy function (custom loss functions
were already supported, this is just for logging)
- Improved resuming functionality
- Changes to CLI-based training:
- ``--custom-lr-schedule`` replaced by ``--custom-lr-multiplier`` (same format)
- ``--eps-fadein-epochs`` replaced by general ``--custom-eps-multiplier``
(now same format as custom-lr schedule)
- ``--step-lr-gamma`` now available to change the size of learning rate
drops (used to be fixed to 10x drops)
- ``--lr-interpolation`` argument added (can choose between linear and step
interpolation between learning rates in the schedule)
- ``--weight_decay`` is now called ``--weight-decay``, keeping with
convention
- ``--resume-optimizer`` is a 0/1 argument for whether to resume the
optimizer and LR schedule, or just the model itself
- ``--mixed-precision`` is a 0/1 argument for whether to use mixed-precision
training or not (required PyTorch compiled with AMP support)
- Model and data loading:
- DataParallel is now *off* by default when loading models, even when
resume_path is specified (previously it was off for new models, and on
for resumed models by default)
- New ``add_custom_forward`` for ``make_and_restore_model`` (see docs for
more details)
- Can now pass a random seed for training data subsetting
- Training:
- See new CLI features---most have training-as-a-library counterparts
- Fixed a bug that did not resume the optimizer and schedule
- Support for custom accuracy functions
- Can now disable ``torch.nograd`` for test set eval (in case you have a
custom accuracy function that needs gradients even on the val set)
- PGD:
- Better random start for l2 attacks
- Added a ``RandomStep`` attacker step (useful for large-noise training with
varying noise over training)
- Fixed bug in the ``with_image`` argument (minor)
- Model saving:
- Accuracies are now saved in the checkpoint files themselves (instead of
just in the log stores)
- Removed redundant checkpoints table from the log store, as it is a
duplicate of the latest checkpoint file and just wastes space
- Cleanup:
- Remove redundant ``save_checkpoint`` function in helpers file
- Code flow improvements


robustness 1.1.post2
'''''''''''''''''''''
- Critical fix in :meth:`robustness.loaders.TransformedLoader`, allow for data shuffling

robustness 1.1
Expand Down
12 changes: 9 additions & 3 deletions docs/example_usage/cli_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ are below:
--momentum MOMENTUM SGD momentum parameter (default: 0.9)
--step-lr STEP_LR number of steps between 10x LR drops (default: by
dataset)
--custom-schedule CUSTOM_SCHEDULE
--step-lr-gamma GAMMA multiplier for each LR drop (default: 0.1, i.e., 10x drops)
--custom-lr-multiplier CUSTOM_SCHEDULE
LR sched (format: [(epoch, LR),...]) (default: None)
--lr-interpolation {linear, step}
How to interpolate between learning rates (default: step)
--log-iters LOG_ITERS
how frequently (in epochs) to log (default: 5)
--save-ckpt-iters SAVE_CKPT_ITERS
how frequently (epochs) to save (-1 for bash, only
saves best and last) (default: -1)
--mixed-precision {0, 1}
Whether to use mixed-precision training (needs
to be compiled with NVIDIA AMP support)
Finally, there is one additional argument, :samp:`--adv-eval {0,1}`, that enables
adversarial evaluation of the non-robust model as it is being trained (i.e.
Expand Down Expand Up @@ -90,8 +96,8 @@ supply all the necessary hyperparameters for the attack:
(choices: {arg_type}, default: 1)
--random-restarts RANDOM_RESTARTS
number of random PGD restarts for eval (default: 0)
--eps-fadein-epochs EPS_FADEIN_EPOCHS
fade in eps over this many iterations (default: 0)
--custom-eps-multiplier EPS_SCHEDULE
epsilon multiplier sched (same format as LR schedule)
Evaluating trained models
Expand Down
16 changes: 16 additions & 0 deletions docs/example_usage/training_lib_part_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ Adding these few lines right before calling of
:meth:`~robustness.train.train_model`
suffices for training our network robustly with this custom loss.

As of the latest version of ``robustness``, you can now also supply a custom
function for computing accuracy using the ``custom_accuracy`` flag. This should
be a function that takes in the model output and the target labels, and returns
a tuple of ``(top1, top5)`` accuracies (feel free to make the second element
``float('nan')`` if there's only one accuracy metric you want to display). Here
is an example:

.. code-block:: python
def custom_acc_func(out, targ):
# Calculate top1 and top5 accuracy for this batch here
return 100., float('nan') # Return (top1, top5)
train_args.custom_accuracy = custom_acc_func
.. _using-custom-loaders:

Training networks with custom data loaders
Expand Down
29 changes: 27 additions & 2 deletions robustness/attack_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def step(self, x, g):
def random_perturb(self, x):
"""
"""
new_x = x + (ch.rand_like(x) - 0.5).renorm(p=2, dim=0, maxnorm=self.eps)
return ch.clamp(new_x, 0, 1)
l = len(x.shape) - 1
rp = ch.randn_like(x)
rp_norm = rp.view(rp.shape[0], -1).norm(dim=1).view(-1, *([1]*l))
return ch.clamp(x + self.eps * rp / (rp_norm + 1e-10), 0, 1)

# Unconstrained threat model
class UnconstrainedStep(AttackerStep):
Expand Down Expand Up @@ -180,3 +182,26 @@ def to_image(self, x):
"""
"""
return ch.sigmoid(ch.irfft(x, 2, normalized=True, onesided=False))

class RandomStep(AttackerStep):
"""
Step for Randomized Smoothing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_grad = False

def project(self, x):
"""
"""
return x

def step(self, x, g):
"""
"""
return x + self.step_size * ch.randn_like(x)

def random_perturb(self, x):
"""
"""
return x
45 changes: 25 additions & 20 deletions robustness/attacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
'inf': attack_steps.LinfStep,
'2': attack_steps.L2Step,
'unconstrained': attack_steps.UnconstrainedStep,
'fourier': attack_steps.FourierStep
'fourier': attack_steps.FourierStep,
'random_smooth': attack_steps.RandomStep
}

class Attacker(ch.nn.Module):
Expand Down Expand Up @@ -71,7 +72,8 @@ def __init__(self, model, dataset):
def forward(self, x, target, *_, constraint, eps, step_size, iterations,
random_start=False, random_restarts=False, do_tqdm=False,
targeted=False, custom_loss=None, should_normalize=True,
orig_input=None, use_best=True, return_image=True, est_grad=None):
orig_input=None, use_best=True, return_image=True,
est_grad=None, mixed_precision=False):
"""
Implementation of forward (finds adversarial examples). Note that
this does **not** perform inference and should not be called
Expand Down Expand Up @@ -118,6 +120,8 @@ def forward(self, x, target, *_, constraint, eps, step_size, iterations,
:math:`\\nabla_x f(x) \\approx \\sum_{i=0}^N f(x + R\\cdot
\\vec{\\delta_i})\\cdot \\vec{\\delta_i}`, where
:math:`\delta_i` are randomly sampled from the unit ball.
mixed_precision (bool) : if True, use mixed-precision calculations
to compute the adversarial examples / do the inference.
Returns:
An adversarial example for x (i.e. within a feasible set
determined by `eps` and `constraint`, but classified as:
Expand All @@ -129,7 +133,6 @@ def forward(self, x, target, *_, constraint, eps, step_size, iterations,
from the unit ball, and then use :math:`\delta_{N/2+i} =
-\delta_{i}`.
"""

# Can provide a different input to make the feasible set around
# instead of the initial point
if orig_input is None: orig_input = x.detach()
Expand All @@ -139,7 +142,7 @@ def forward(self, x, target, *_, constraint, eps, step_size, iterations,
m = -1 if targeted else 1

# Initialize step class and attacker criterion
criterion = ch.nn.CrossEntropyLoss(reduction='none').cuda()
criterion = ch.nn.CrossEntropyLoss(reduction='none')
step_class = STEPS[constraint] if isinstance(constraint, str) else constraint
step = step_class(eps=eps, orig_input=orig_input, step_size=step_size)

Expand Down Expand Up @@ -192,7 +195,12 @@ def replace_best(loss, bloss, x, bx):
loss = ch.mean(losses)

if step.use_grad:
if est_grad is None:
if (est_grad is None) and mixed_precision:
with amp.scale_loss(loss, []) as sl:
sl.backward()
grad = x.grad.detach()
x.grad.zero_()
elif (est_grad is None):
grad, = ch.autograd.grad(m * loss, [x])
else:
f = lambda _x, _y: m * calc_loss(step.to_image(_x), _y)[0]
Expand Down Expand Up @@ -257,8 +265,8 @@ class AttackerModel(ch.nn.Module):
out = model(x) # normal inference (no label needed)
More code examples available in the documentation for `forward`.
For a more comprehensive overview of this class, see `our detailed
walkthrough <../example_usage/input_space_manipulation>`_
For a more comprehensive overview of this class, see
:doc:`our detailed walkthrough <../example_usage/input_space_manipulation>`.
"""
def __init__(self, model, dataset):
super(AttackerModel, self).__init__()
Expand Down Expand Up @@ -308,18 +316,15 @@ def forward(self, inp, target=None, make_adv=False, with_latent=False,

inp = adv

if with_image:
normalized_inp = self.normalizer(inp)

if no_relu and (not with_latent):
print("WARNING: 'no_relu' has no visible effect if 'with_latent is False.")
if no_relu and fake_relu:
raise ValueError("Options 'no_relu' and 'fake_relu' are exclusive")
normalized_inp = self.normalizer(inp)

output = self.model(normalized_inp, with_latent=with_latent,
fake_relu=fake_relu, no_relu=no_relu)
else:
output = None

return (output, inp)
if no_relu and (not with_latent):
print("WARNING: 'no_relu' has no visible effect if 'with_latent is False.")
if no_relu and fake_relu:
raise ValueError("Options 'no_relu' and 'fake_relu' are exclusive")

output = self.model(normalized_inp, with_latent=with_latent,
fake_relu=fake_relu, no_relu=no_relu)
if with_image:
return (output, inp)
return output
33 changes: 17 additions & 16 deletions robustness/cifar_models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..tools.custom_modules import FakeReLU


class Bottleneck(nn.Module):
Expand Down Expand Up @@ -81,38 +82,38 @@ def forward(self, x, with_latent=False, fake_relu=False, no_relu=False):
out = self.trans3(self.dense3(out))
out = self.dense4(out)
if fake_relu:
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
else:
out = F.avg_pool2d(FakeReLU.apply(self.bn(out)), 4)
else:
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
out = out.view(out.size(0), -1)
latent = out
latent = out.clone()
out = self.linear(out)

if with_latent:
return out, latent

return out

def DenseNet121():
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
def DenseNet121(**kwargs):
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, **kwargs)

def DenseNet169():
return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
def DenseNet169(**kwargs):
return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, **kwargs)

def DenseNet201():
return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
def DenseNet201(**kwargs):
return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, **kwargs)

def DenseNet161():
return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
def DenseNet161(**kwargs):
return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, **kwargs)

def densenet_cifar(*args, **kwargs):
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12, **kwargs)

densenet121 = DenseNet121
densenet161 = DenseNet161
densenet169 = DenseNet169
densenet201 = DenseNet201

def test():
net = densenet_cifar()
x = torch.randn(1,3,32,32)
y = net(x)
print(y)

# test()
25 changes: 3 additions & 22 deletions robustness/cifar_models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class FakeReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clamp(min=0)

@staticmethod
def backward(ctx, grad_output):
return grad_output

class SequentialWithArgs(torch.nn.Sequential):
def forward(self, input, *args, **kwargs):
vs = list(self._modules.values())
l = len(vs)
for i in range(l):
if i == l-1:
input = vs[i](input, *args, **kwargs)
else:
input = vs[i](input)
return input
from ..tools.custom_modules import SequentialWithArgs, FakeReLU

class BasicBlock(nn.Module):
expansion = 1
Expand Down Expand Up @@ -131,7 +112,7 @@ def ResNet18(**kwargs):
return ResNet(BasicBlock, [2,2,2,2], **kwargs)

def ResNet18Wide(**kwargs):
return ResNet(BasicBlock, [2,2,2,2], wd=1.5, **kwargs)
return ResNet(BasicBlock, [2,2,2,2], wm=5, **kwargs)

def ResNet18Thin(**kwargs):
return ResNet(BasicBlock, [2,2,2,2], wd=.75, **kwargs)
Expand All @@ -152,9 +133,9 @@ def ResNet152(**kwargs):
resnet18 = ResNet18
resnet101 = ResNet101
resnet152 = ResNet152
resnet18wide = ResNet18Wide

# resnet18thin = ResNet18Thin
# resnet18wide = ResNet18Wide
def test():
net = ResNet18()
y = net(torch.randn(1,3,32,32))
Expand Down

0 comments on commit 89bdf80

Please sign in to comment.