Skip to content

Commit

Permalink
cleaned up a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
gngdb committed Oct 17, 2018
1 parent 6636683 commit a4aba9d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
16 changes: 9 additions & 7 deletions README.md
Expand Up @@ -4,11 +4,10 @@ Code used to produce https://arxiv.org/abs/1711.02613

## Installation Instructions

Best done with conda. Make sure your conda is up to date.
If installing with conda:

Make a new environment then activate it. Python version probably doesn't matter but I use 2 for no apparent reason.
```
conda create -n torch python=2
conda create -n torch python=3.6
source activate torch
```
then
Expand All @@ -17,7 +16,7 @@ then
conda install pytorch torchvision -c pytorch
pip install tqdm
pip install tensorboardX
pip install tensorflow
conda install tensorflow
```

## Training a Teacher
Expand All @@ -28,6 +27,11 @@ In general, the following code trains a teacher network:
python main.py <DATASET> teacher --conv <CONV-TYPE> -t <TEACHER_CHECKPOINT> --wrn_depth <TEACHER_DEPTH> --wrn_width <TEACHER_WIDTH>
```

Where `<DATASET>` is one of `cifar10`, `cifar100` or `imagenet`. By
default, `cifar10` and `cifar100` are assumed to be stored at
`/disk/scratch/datasets/cifar`, but any directory can be set with
`--cifar_loc`.

In the paper, results are typically reported using a standard 40-2 WRN,
which would be the following (on cifar-10):

Expand Down Expand Up @@ -55,9 +59,7 @@ python main.py cifar10 student --conv G8B2 -t wrn_40_2 -s wrn_40_2.g8b2.student

## Acknowledgements

Code has been liberally borrowed from other repos.

A non-exhaustive list follows:
The following repos provided basis and inspiration for this work:

```
https://github.com/szagoruyko/attention-transfer
Expand Down
9 changes: 5 additions & 4 deletions main.py
Expand Up @@ -24,6 +24,7 @@
parser.add_argument('dataset', type=str, choices=['cifar10', 'cifar100', 'imagenet'], help='Choose between Cifar10/100/imagenet.')
parser.add_argument('mode', choices=['student','teacher'], type=str, help='Learn a teacher or a student')
parser.add_argument('--imagenet_loc', default='/disk/scratch_ssd/imagenet',type=str, help='folder containing imagenet train and val folders')
parser.add_argument('--cifar_loc', default='/disk/scratch/datasets/cifar',type=str, help='folder containing cifar train and val folders')
parser.add_argument('--workers', default=2, type=int, help='No. of data loading workers. Make this high for imagenet')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--GPU', default=None, type=str,help='GPU to use')
Expand Down Expand Up @@ -320,9 +321,9 @@ def what_conv_block(conv, blocktype, module):
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='/disk/scratch/datasets/cifar',
trainset = torchvision.datasets.CIFAR10(root=args.cifar_loc,
train=True, download=False, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root='/disk/scratch/datasets/cifar',
valset = torchvision.datasets.CIFAR10(root=args.cifar_loc,
train=False, download=False, transform=transform_validate)
elif args.dataset == 'cifar100':
num_classes = 100
Expand All @@ -336,9 +337,9 @@ def what_conv_block(conv, blocktype, module):
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
])
trainset = torchvision.datasets.CIFAR100(root='/disk/scratch/datasets/cifar100',
trainset = torchvision.datasets.CIFAR100(root=args.cifar_loc,
train=True, download=True, transform=transform_train)
validateset = torchvision.datasets.CIFAR100(root='/disk/scratch/datasets/cifar100',
validateset = torchvision.datasets.CIFAR100(root=args.cifar_loc,
train=False, download=True, transform=transform_validate)

elif args.dataset == 'imagenet':
Expand Down

0 comments on commit a4aba9d

Please sign in to comment.