Skip to content

Commit

Permalink
mnist and cifar10, nin not supported yet
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiangYan committed Dec 12, 2018
1 parent 5e99f81 commit bae2aaa
Show file tree
Hide file tree
Showing 10 changed files with 1,447 additions and 1 deletion.
113 changes: 113 additions & 0 deletions .gitignore
@@ -0,0 +1,113 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject

# data
data/
data

# PyCharm
.idea/

# macOS
.DS_Store

# output directory
output/

# python import cache
*.pyc

# nfs temp file
.nfs*
38 changes: 37 additions & 1 deletion README.md
@@ -1,2 +1,38 @@
# deepdefense.pytorch
Code coming soon!
Code for NeurIPS 2018 paper [Deep Defense: Training DNNs with Improved Adversarial Robustness](https://papers.nips.cc/paper/7324-deep-defense-training-dnns-with-improved-adversarial-robustness).

Deep Defense is recipe to improve the robustness of DNNs to adversarial perturbations. We integrate an adversarial perturbation-based regularizer into the training objective, such that the obtained models learn to resist potential attacks in a principled way.

## Environments
* Python 3.5
* PyTorch 0.4.1
* glog 0.3.1

## Datasets and Reference Models
For fair comparison with DeepFool, we follow it to use [matconvnet](https://github.com/vlfeat/matconvnet/releases/tag/v1.0-beta24) to pre-process data and train reference models for MNIST and CIFAR-10.

Please download processed datasets and reference models (including MNIST and CIFAR-10) at [download link](https://drive.google.com/open?id=15xoZ-LUbc9GZpTlxmCJmvL_DR2qYEu2J).

## Usage
To train a Deep Defense LeNet model using default parameters on MNIST:

```
python3 deepdefense.py --pretest --dataset mnist --arch LeNet
```

Argument ```--pretest``` indicates evaluating performance before fine-tuning, thus we can check the performance of reference model.

Currently we've implemented ```MLP``` and ```LeNet``` for mnist, and ```ConvNet``` for CIFAR-10.

## Citation
Please cite our work in your publications if it helps your research:

```
@inproceedings{yan2018deep,
title={Deep Defense: Training DNNs with Improved Adversarial Robustness},
author={Yan, Ziang and Guo, Yiwen and Zhang, Changshui},
booktitle={Advances in Neural Information Processing Systems},
pages={417--426},
year={2018}
}
```
Empty file added datasets/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions datasets/cifar10.py
@@ -0,0 +1,32 @@
import torch.utils.data
import numpy as np


class CIFAR10Dataset(torch.utils.data.Dataset):
def __init__(self, phase='train', num_val=5000):
import scipy.io as sio
imdb = sio.loadmat('data/cifar10-data-ce5d97dd.mat')
images = imdb['images'][0][0][0].transpose()
sets = imdb['images'][0][0][2].flatten()
labels = (imdb['images'][0][0][1].flatten() - 1).astype(np.int64)
train_idx = np.where(sets == 1)[0][num_val:]
val_idx = np.where(sets == 1)[0][:num_val]
trainval_idx = np.where(sets == 1)[0]
test_idx = np.where(sets == 3)[0]
assert phase in ['train', 'val', 'trainval', 'test']
self.images = eval('images[%s_idx]' % phase)
self.labels = eval('labels[%s_idx]' % phase)
self.perm = np.arange(self.labels.size)

def __getitem__(self, index):
if np.random.rand() > 0.5:
images = np.fliplr(self.images[self.perm[index]]).copy()
else:
images = self.images[self.perm[index]]
return images, self.labels[self.perm[index]]

def __len__(self):
return self.labels.size

def shuffle(self, perm):
self.perm = perm
33 changes: 33 additions & 0 deletions datasets/mnist.py
@@ -0,0 +1,33 @@
import torch.utils.data
import numpy as np
import scipy.io as sio


num_val = 10000 # first num_val examples in training set is used as validation set


class MNISTDataset(torch.utils.data.Dataset):
def __init__(self, phase='train'):
imdb = sio.loadmat('data/mnist-data-0208ce21.mat')
images = imdb['images'][0][0][0].transpose()
sets = imdb['images'][0][0][3].flatten()
labels = imdb['images'][0][0][2].flatten() - 1
train_idx = np.where(sets == 1)[0][num_val:]
val_idx = np.where(sets == 1)[0][:num_val]
trainval_idx = np.where(sets == 1)[0]
test_idx = np.where(sets == 3)[0]
mean = imdb['images'][0][0][1].transpose()
assert phase in ['train', 'val', 'trainval', 'test']
self.images = eval('images[%s_idx]' % phase)
self.labels = eval('labels[%s_idx]' % phase)
self.mean = mean
self.perm = np.arange(self.labels.size)

def shuffle(self, perm):
self.perm = perm

def __getitem__(self, index):
return self.images[self.perm[index]], self.labels[self.perm[index]]

def __len__(self):
return self.labels.size

0 comments on commit bae2aaa

Please sign in to comment.