Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix primary cap issue and re-organised all
- Loading branch information
1 parent
4b4ed37
commit d5b2021
Showing
8 changed files
with
580 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
# new added | ||
output/ | ||
tensorboard/ | ||
.idea/ | ||
data/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,31 @@ | ||
# Dynamic Routing Between Capsules | ||
reference: [Dynamic routing between capsules](https://arxiv.org/abs/1710.09829v1) by **Sara Sabour, Nicholas Frosst, Geoffrey E Hinton** | ||
|
||
Note: this implementation strictly follow the instructions of the paper with `3` times dynamic routing iterations, check the paper for details. | ||
Note: this implementation strictly follow the instructions of the paper, check the paper for details. | ||
|
||
## Dependencies | ||
|
||
* Codes are tested on `tensorflow 1.3`, and `python 2.7`. But it should be compatible with `python 3.x` | ||
* Other dependencies as follows, install it by running `pip install -r requirements.txt` in `ROOT` directory. | ||
|
||
``` | ||
numpy>=1.7.1 | ||
scipy>=0.13.2 | ||
easydict>=1.6 | ||
tqdm>=4.17.1 | ||
``` | ||
|
||
## Train | ||
|
||
* clone the repo | ||
* then | ||
|
||
```bash | ||
cd $ROOT | ||
python train.py | ||
``` | ||
|
||
NOTE: First try with `50` iterations, it got `69.91%` accuracy on test set. | ||
|
||
## TODO | ||
- [ ] report experiment results | ||
- [ ] report exclusive experiment results |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# ------------------------------------------------------------------ | ||
# Capsules_mnist | ||
# By InnerPeace Wu | ||
# ------------------------------------------------------------------ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from easydict import EasyDict as edict | ||
|
||
__C = edict() | ||
|
||
# get config by: from config import cfg | ||
cfg = __C | ||
|
||
# number of channels of PrimaryCaps | ||
__C.PRIMARY_CAPS_CHANNELS = 32 | ||
|
||
# iterations of dynamic routing | ||
__C.ROUTING_ITERS = 3 | ||
|
||
# constant m+ in margin loss | ||
__C.M_POS = 0.9 | ||
|
||
# constant m- in margin loss | ||
__C.M_NEG = 0.1 | ||
|
||
# down-weighting constant lambda for negative classes | ||
__C.LAMBDA = 0.5 | ||
|
||
# weight of reconstruction loss | ||
__C.RECONSTRUCT_W = 0.0005 | ||
|
||
# initial learning rate | ||
__C.LR = 0.001 | ||
|
||
# learning rate decay step size | ||
__C.STEP_SIZE = 500 | ||
|
||
# learning rate decay ratio | ||
__C.DECAY_RATIO = 0.96 | ||
|
||
# choose use bias during conv operations | ||
__C.USE_BIAS = True | ||
|
||
# print out loss every x steps | ||
__C.PRINT_EVERY = 10 | ||
|
||
# snapshot every x iterations | ||
__C.SAVE_EVERY = 1000 | ||
|
||
# number of training iterations | ||
__C.MAX_ITERS = 5000 | ||
|
||
# directory for saving data | ||
__C.DATA_DIR = './data' | ||
|
||
# directory for saving check points | ||
__C.TRAIN_DIR = './output' | ||
|
||
# direcotry for saving tensorboard files | ||
__C.TB_DIR = './tensorboard' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
six | ||
numpy>=1.7.1 | ||
scipy>=0.13.2 | ||
easydict>=1.6 | ||
tqdm>=4.17.1 |
Oops, something went wrong.