Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mnist and cifar10, nin not supported yet
- Loading branch information
Showing
10 changed files
with
1,447 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# data | ||
data/ | ||
data | ||
|
||
# PyCharm | ||
.idea/ | ||
|
||
# macOS | ||
.DS_Store | ||
|
||
# output directory | ||
output/ | ||
|
||
# python import cache | ||
*.pyc | ||
|
||
# nfs temp file | ||
.nfs* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,38 @@ | ||
# deepdefense.pytorch | ||
Code coming soon! | ||
Code for NeurIPS 2018 paper [Deep Defense: Training DNNs with Improved Adversarial Robustness](https://papers.nips.cc/paper/7324-deep-defense-training-dnns-with-improved-adversarial-robustness). | ||
|
||
Deep Defense is recipe to improve the robustness of DNNs to adversarial perturbations. We integrate an adversarial perturbation-based regularizer into the training objective, such that the obtained models learn to resist potential attacks in a principled way. | ||
|
||
## Environments | ||
* Python 3.5 | ||
* PyTorch 0.4.1 | ||
* glog 0.3.1 | ||
|
||
## Datasets and Reference Models | ||
For fair comparison with DeepFool, we follow it to use [matconvnet](https://github.com/vlfeat/matconvnet/releases/tag/v1.0-beta24) to pre-process data and train reference models for MNIST and CIFAR-10. | ||
|
||
Please download processed datasets and reference models (including MNIST and CIFAR-10) at [download link](https://drive.google.com/open?id=15xoZ-LUbc9GZpTlxmCJmvL_DR2qYEu2J). | ||
|
||
## Usage | ||
To train a Deep Defense LeNet model using default parameters on MNIST: | ||
|
||
``` | ||
python3 deepdefense.py --pretest --dataset mnist --arch LeNet | ||
``` | ||
|
||
Argument ```--pretest``` indicates evaluating performance before fine-tuning, thus we can check the performance of reference model. | ||
|
||
Currently we've implemented ```MLP``` and ```LeNet``` for mnist, and ```ConvNet``` for CIFAR-10. | ||
|
||
## Citation | ||
Please cite our work in your publications if it helps your research: | ||
|
||
``` | ||
@inproceedings{yan2018deep, | ||
title={Deep Defense: Training DNNs with Improved Adversarial Robustness}, | ||
author={Yan, Ziang and Guo, Yiwen and Zhang, Changshui}, | ||
booktitle={Advances in Neural Information Processing Systems}, | ||
pages={417--426}, | ||
year={2018} | ||
} | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch.utils.data | ||
import numpy as np | ||
|
||
|
||
class CIFAR10Dataset(torch.utils.data.Dataset): | ||
def __init__(self, phase='train', num_val=5000): | ||
import scipy.io as sio | ||
imdb = sio.loadmat('data/cifar10-data-ce5d97dd.mat') | ||
images = imdb['images'][0][0][0].transpose() | ||
sets = imdb['images'][0][0][2].flatten() | ||
labels = (imdb['images'][0][0][1].flatten() - 1).astype(np.int64) | ||
train_idx = np.where(sets == 1)[0][num_val:] | ||
val_idx = np.where(sets == 1)[0][:num_val] | ||
trainval_idx = np.where(sets == 1)[0] | ||
test_idx = np.where(sets == 3)[0] | ||
assert phase in ['train', 'val', 'trainval', 'test'] | ||
self.images = eval('images[%s_idx]' % phase) | ||
self.labels = eval('labels[%s_idx]' % phase) | ||
self.perm = np.arange(self.labels.size) | ||
|
||
def __getitem__(self, index): | ||
if np.random.rand() > 0.5: | ||
images = np.fliplr(self.images[self.perm[index]]).copy() | ||
else: | ||
images = self.images[self.perm[index]] | ||
return images, self.labels[self.perm[index]] | ||
|
||
def __len__(self): | ||
return self.labels.size | ||
|
||
def shuffle(self, perm): | ||
self.perm = perm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch.utils.data | ||
import numpy as np | ||
import scipy.io as sio | ||
|
||
|
||
num_val = 10000 # first num_val examples in training set is used as validation set | ||
|
||
|
||
class MNISTDataset(torch.utils.data.Dataset): | ||
def __init__(self, phase='train'): | ||
imdb = sio.loadmat('data/mnist-data-0208ce21.mat') | ||
images = imdb['images'][0][0][0].transpose() | ||
sets = imdb['images'][0][0][3].flatten() | ||
labels = imdb['images'][0][0][2].flatten() - 1 | ||
train_idx = np.where(sets == 1)[0][num_val:] | ||
val_idx = np.where(sets == 1)[0][:num_val] | ||
trainval_idx = np.where(sets == 1)[0] | ||
test_idx = np.where(sets == 3)[0] | ||
mean = imdb['images'][0][0][1].transpose() | ||
assert phase in ['train', 'val', 'trainval', 'test'] | ||
self.images = eval('images[%s_idx]' % phase) | ||
self.labels = eval('labels[%s_idx]' % phase) | ||
self.mean = mean | ||
self.perm = np.arange(self.labels.size) | ||
|
||
def shuffle(self, perm): | ||
self.perm = perm | ||
|
||
def __getitem__(self, index): | ||
return self.images[self.perm[index]], self.labels[self.perm[index]] | ||
|
||
def __len__(self): | ||
return self.labels.size |
Oops, something went wrong.