Skip to content

Commit

Permalink
Support for custom operators and minor fixes
Browse files Browse the repository at this point in the history
Co-Authored-By: Kaidi Xu <42853519+KaidiXu@users.noreply.github.com>
Co-Authored-By: Huan Zhang <8021844+huanzhang12@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 14, 2022
1 parent 10a9b30 commit 499d023
Show file tree
Hide file tree
Showing 19 changed files with 341 additions and 69 deletions.
26 changes: 11 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
## What's New?

- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library.
- Support for [custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html). (01/02/2022)
- [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021)
- Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021)
- A memory efficient GPU implementation of backward (CROWN) bounds for
convolutional layers. (10/31/2020)
- Certified defense models for downscaled
[ImageNet](#imagenet-pretrained), [TinyImageNet](#imagenet-pretrained), [CIFAR-10](#cifar10-pretrained),
and [LSTM/Transformers](#language-pretrained). (08/20/2020)
- Certified defense models for downscaled ImageNet, TinyImageNet, CIFAR-10, LSTM/Transformer. (08/20/2020)
- Adding support to **complex vision models** including DenseNet, ResNeXt and WideResNet. (06/30/2020)
- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds
(e.g. CROWN-IBP) to the same asympototic complexity of IBP, making LiRPA based certified
Expand Down Expand Up @@ -143,26 +142,23 @@ obtaining gradients through autodiff. Bounds are efficiently computed on GPUs.

## More Working Examples

We provide a wide range of examples of using `auto_LiRPA`:
We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`:

* [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks)
* [Basic **Certified Adversarial Defense** Training](doc/examples.md#basic-certified-adversarial-defense-training)
* [Large-scale Certified Defense Training on **ImageNet**](doc/examples.md#certified-adversarial-defense-on-downscaled-imagenet-and-tinyimagenet-with-loss-fusion)
* [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist)
* [Certifiably Robust Language Classifier using **Transformers**](doc/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm)
* [Certified Robustness against **Model Weight Perturbations**](doc/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense)
* [Basic Bound Computation and **Robustness Verification** of Neural Networks](doc/src/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks)
* [Basic **Certified Adversarial Defense** Training](doc/src/examples.md#basic-certified-adversarial-defense-training)
* [Large-scale Certified Defense Training on **ImageNet**](doc/src/examples.md#certified-adversarial-defense-on-downscaled-imagenet-and-tinyimagenet-with-loss-fusion)
* [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/src/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist)
* [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm)
* [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense)

## Full Documentations

For more documentations, please refer to:

* [Documentation homepage](https://auto-lirpa.readthedocs.io)

* [API documentation](https://auto-lirpa.readthedocs.io/en/latest/api.html)

* [Adding custom operators](doc/custom_op.md)

* [Guide](doc/paper.md) for reproducing [our NeurIPS 2020 paper](https://arxiv.org/abs/2002.12920)
* [Adding custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html)
* [Guide](https://auto-lirpa.readthedocs.io/en/latest/paper.html) for reproducing [our NeurIPS 2020 paper](https://arxiv.org/abs/2002.12920)

## Publications

Expand Down
1 change: 1 addition & 0 deletions auto_LiRPA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .bounded_tensor import BoundedTensor, BoundedParameter
from .perturbations import PerturbationLpNorm, PerturbationSynonym
from .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput
from .bound_op_map import register_custom_op, unregister_custom_op

__version__ = '0.2'
18 changes: 18 additions & 0 deletions auto_LiRPA/bound_op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
'onnx::Gemm': BoundLinear,
'prim::Constant': BoundPrimConstant,
}

def register_custom_op(op_name: str, bound_obj: Bound) -> None:
""" Register a custom operator.
Args:
op_name (str): Name of the custom operator
bound_obj (Bound): The corresponding Bound class for the operator.
"""
bound_op_map[op_name] = bound_obj

def unregister_custom_op(op_name: str) -> None:
""" Unregister a custom operator.
Args:
op_name (str): Name of the custom operator
"""
bound_op_map.pop(op_name)
4 changes: 2 additions & 2 deletions auto_LiRPA/operators/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ def forward(self, x):
def bound_backward(self, last_lA, last_uA, x):
H, W = self.input_shape[-2], self.input_shape[-1]

lA = last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)
uA = last_uA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)
lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None
uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None

return [(lA, uA)], 0, 0

Expand Down
5 changes: 4 additions & 1 deletion doc/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
_build
sections
sections
*.md
!src/*.md
!README.md
4 changes: 3 additions & 1 deletion doc/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

# Documentation

This directory contains source files for building our documentation.

## Dependencies

Install additional libraries for building documentations:
Expand All @@ -17,4 +20,3 @@ make html
```

The documentation will be generated at `_build/html`.

2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ API Usage
.. autofunction:: auto_LiRPA.perturbations.Perturbation.concretize
.. autofunction:: auto_LiRPA.perturbations.Perturbation.init

.. autofunction:: auto_LiRPA.bound_op_map.register_custom_op

Indices and tables
-------------------

Expand Down
37 changes: 32 additions & 5 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
#
import os
import subprocess
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import inspect
import sys
from pygit2 import Repository
sys.path.insert(0, '..')
import auto_LiRPA

subprocess.run(['python', 'parse_readme.py'])
subprocess.run(['python', 'process.py'])

# -- Project information -----------------------------------------------------

Expand All @@ -31,6 +34,7 @@
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.linkcode',
'm2r2',
]

Expand All @@ -40,8 +44,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']

exclude_patterns = ['_build', 'src', 'Thumbs.db', '.DS_Store']

# -- Options for HTML output -------------------------------------------------

Expand All @@ -54,3 +57,27 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

repo = Repository('../')
branch = repo.head.shorthand

# Resolve function for the linkcode extension.
def linkcode_resolve(domain, info):
def find_source():
obj = auto_LiRPA
parts = info['fullname'].split('.')
if info['module'].endswith(f'.{parts[0]}'):
module = info['module'][:-len(parts[0])-1]
else:
module = info['module']
obj = sys.modules[module]
for part in parts:
obj = getattr(obj, part)
fn = inspect.getsourcefile(obj)
source, lineno = inspect.getsourcelines(obj)
return fn, lineno, lineno + len(source) - 1

fn, lineno_start, lineno_end = find_source()
filename = f'{fn}#L{lineno_start}-L{lineno_end}'

return f"https://github.com/KaidiXu/auto_LiRPA/blob/{branch}/doc/{filename}"
25 changes: 0 additions & 25 deletions doc/parse_readme.py

This file was deleted.

65 changes: 65 additions & 0 deletions doc/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
""" Process source files before running Sphinx"""
import re
import os
import shutil
from pygit2 import Repository

repo = 'https://github.com/KaidiXu/auto_LiRPA'
branch = os.environ.get('BRANCH', None) or Repository('.').head.shorthand
repo_file_path = os.path.join(repo, 'tree', branch)

""" Parse README.md into sections which can be reused """
heading = ''
copied = {}
print('Parsing markdown sections from README:')
with open('../README.md') as file:
for line in file.readlines():
if line.startswith('##'):
heading = line[2:].strip()
else:
if not heading in copied:
copied[heading] = ''
copied[heading] += line
if not os.path.exists('sections'):
os.makedirs('sections')
for key in copied:
if key == '':
continue
filename = re.sub(r"[?+\'\"]", '', key.lower())
filename = re.sub(r" ", '-', filename) + '.md'
print(filename)
with open(os.path.join('sections', filename), 'w') as file:
file.write(f'## {key}\n')
file.write(copied[key])
print()

""" Load source files from src/ and fix links to GitHub """
for filename in os.listdir('src'):
print(f'Processing {filename}')
with open(os.path.join('src', filename)) as file:
source = file.read()
source_new = ''
ptr = 0
# res = re.findall('\[.*\]\(.*\)', source)
for m in re.finditer('(\[.*\])(\(.*\))', source):
assert m.start() >= ptr
source_new += source[ptr:m.start()]
ptr = m.start()
source_new += m.group(1)
ptr += len(m.group(1))
link_raw = m.group(2)
while len(link_raw) >= 2 and link_raw[-2] == ')':
link_raw = link_raw[:-1]
link = link_raw[1:-1]
if link.startswith('https://') or link.startswith('http://') or '.html#' in link:
print(f'Skip link {link}')
link_new = link
else:
link_new = os.path.join(repo_file_path, 'docs/src', link)
print(f'Fix link {link} -> {link_new}')
source_new += f'({link_new})'
ptr += len(link_raw)
source_new += source[ptr:]
with open(filename, 'w') as file:
file.write(source_new)
print()
3 changes: 2 additions & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
sphinx>=4.1.2
docutils>=0.16
m2r2>=0.3.1
m2r2>=0.3.1
pygit2>=1.7.2
File renamed without changes.
16 changes: 12 additions & 4 deletions doc/custom_op.md → doc/src/custom_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@

In this documentation, we introduce how users can define custom operators (such as other activations) that are not currently supported in auto_LiRPA, with bound propagation methods.

## Write an Operator
## Write a Custom Operator

There are three steps to write an operator:

1. Define a `torch.autograd.Function` (or `Function` for short) class, wrap the computation of the operator into this `Function`, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html?highlight=symbolic#static-symbolic-method) on defining a `Function` with a symbolic method. Call this `Function` via `.apply()` when using this operator in the model.
1. Define a `torch.autograd.Function` (or `Function` for short) class, wrap the computation of the operator into this `Function`, and also define a symbolic method so that the operator can be parsed in auto_LiRPA via ONNX. Please refer to [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html?highlight=symbolic#static-symbolic-method) on defining a `Function` with a symbolic method.

2. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator.
3. Create a mapping from the operator name (defined in step 1) to the bound class (defined in step 2). Define a `dict` which each item is a mapping. Pass the `dict` to the `custom_ops` argument when calling `BoundedModule` (see the [documentation](api.html#auto_LiRPA.BoundedModule)). For example, if the operator name is `MyRelu`, and the bound class is `BoundMyRelu`, then add `"MyRelu": BoundMyRelu` to the `dict`.
2. Create a `torch.nn.Module` which uses the defined operator. Call the operator via
`.apply()` of `Function`.

3. Implement a [Bound class](api.html#auto_LiRPA.bound_ops.Bound) to support bound propagation methods for this operator.

4. [Register the custom operator](api.html#auto_LiRPA.bound_op_map.register_custom_op).

## Example

We provide an [code example](../../examples/vision/custom_op.py) of using a custom operator called "PlusConstant".

## Contributing to the Library

Expand Down
24 changes: 12 additions & 12 deletions doc/examples.md → doc/src/examples.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Examples

We provide many [examples](examples) of using our `auto_LiRPA` library,
We provide many [examples](../../examples) of using our `auto_LiRPA` library,
including robustness verification and certified robust training for fairly
complicated networks and specifications. Please first install required libraries
to run the examples:
Expand All @@ -13,7 +13,7 @@ pip install -r requirements.txt
## Basic Bound Computation and Robustness Verification of Neural Networks

We provide a very simple tutorial for `auto_LiRPA` at
[examples/vision/simple_verification.py](examples/vision/simple_verification.py).
[examples/vision/simple_verification.py](../../examples/vision/simple_verification.py).
This script is self-contained. It loads a simple CNN model and compute the
guaranteed lower and upper bounds using LiRPA for each output neuron under a L
infinity perturbation.
Expand All @@ -34,7 +34,7 @@ can be obtained using α-CROWN within a few seconds.
## Basic Certified Adversarial Defense Training

We provide a [simple example of certified
training](examples/vision/simple_training.py). By default it uses
training](../../examples/vision/simple_training.py). By default it uses
[CROWN-IBP](https://arxiv.org/pdf/1906.06316.pdf) to train a certifiably robust
model:

Expand All @@ -59,17 +59,17 @@ python simple_training.py --model mlp_3layer --norm 0 --eps 1
```

For CIFAR-10, we provided some sample models in `examples/vision/models`:
e.g., [cnn_7layer_bn](./examples/vision/models/feedforward.py),
[DenseNet](./examples/vision/models/densenet.py),
[ResNet18](./examples/vision/models/resnet18.py),
[ResNeXt](./examples/vision/models/resnext.py). For example, to train a ResNeXt model on CIFAR,
e.g., [cnn_7layer_bn](../../examples/vision/models/feedforward.py),
[DenseNet](../../examples/vision/models/densenet.py),
[ResNet18](../../examples/vision/models/resnet18.py),
[ResNeXt](../../examples/vision/models/resnext.py). For example, to train a ResNeXt model on CIFAR,
use:

```bash
python cifar_training.py --batch_size 256 --model ResNeXt_cifar
```

See a list of supported models [here](./examples/vision/models/__init__.py).
See a list of supported models [here](../../examples/vision/models/__init__.py).
This command uses multi-GPUs by default. You probably need to reduce batch size
if you have only 1 GPU. The CIFAR training implementation includes **loss
fusion**, a technique that can greatly reduce training time and memory usage of
Expand All @@ -85,7 +85,7 @@ python cifar_training.py --verify --model cnn_7layer_bn --load saved_models/cnn
```

More example of CIFAR-10 training can be found
in [doc/paper.md](doc/paper.md).
in [doc/paper.md](paper.md).


## Certified Adversarial Defense on Downscaled ImageNet and TinyImageNet with Loss Fusion
Expand Down Expand Up @@ -139,12 +139,12 @@ MODEL=saved_models/wide_resnet_imagenet64_1000
python imagenet_training.py --verify --model wide_resnet_imagenet64_1000class --load $MODEL --eps 0.003921568627451
```

See more details in [doc/paper.md](doc/paper.md) for these examples.
See more details in [paper.md](paper.md) for these examples.


## Certified Adversarial Defense Training for LSTM on MNIST

In [examples/sequence](examples/sequence), we have an example of training a
In [examples/sequence](../../examples/sequence), we have an example of training a
certifiably robust LSTM on MNIST, where an input image is perturbed within an
Lp-ball and sliced to several pieces each regarded as an input frame. To run
the example:
Expand All @@ -156,7 +156,7 @@ python train.py

## Certifiably Robust Language Classifier with Transformer and LSTM

In [examples/language](examples/language), we show that our framework can
In [examples/language](../../examples/language), we show that our framework can
support perturbation specification of word substitution, beyond Lp-ball
perturbation. We perform certified training for Transformer and LSTM on a
sentiment classification task.
Expand Down
File renamed without changes.
Loading

0 comments on commit 499d023

Please sign in to comment.