Sparse unsupervised capsules
Clone or download
drawlinson Update
Added link to arxiv
Latest commit f7a0f43 Apr 21, 2018
Type Name Latest commit message Commit time
Failed to load latest commit information.

Sparse Unsupervised Capsules

The official source code for the SPARSECAPS model, based on the original Capsules model by Sabour et al., used in the following paper:


Quick Results

The checkpoint of the model trained on the expanded MNIST for affNIST generalizability is publicly available to skip the training step and easily reproduce the experimental results.


Dataset Generation

Scripts to build necessary the data for training and/or evaluating the model can be found in the input_data directory, for each dataset.


To generate the shifted MNIST training set for training the model:

python --data_dir=PATH_TO_MNIST_DIRECTORY \
--split=train --shift=2 --pad=0

To generate the expanded MNIST training set for affNIST generalizability:

python --data_dir=PATH_TO_MNIST_DIRECTORY \
--split=train --shift=6 --pad=6

The test set can be generated in a similar way by using the following flags instead: --split=test --shift=0. The dataset can also be downloaded from the source by passing the --download=true flag.


To generate the affNIST test set:

python --data_dir=PATH_TO_AFFNIST_DIRECTORY \
--split=test --shift=0 --pad=0

To generate the sharded affNIST test set to chunk the dataset over separate TFRecords files:

python --data_dir=PATH_TO_AFFNIST_DIRECTORY \
--split=test --shift=0 --pad=0 --max_shard=80000

The max_shard is the maximum number of images in a single TFRecords file, and since affNIST contains 320,000 images, this would generate 4 separate data files. The dataset can also be downloaded from the source by passing the --download=true flag.

Model Workflow


To train the model on the standard MNIST dataset:

python --data_dir=/path/to/dataset/ \
--summary_dir=/path/to/log/dir --max_steps=30000 --dataset=mnist
--batch_size=128 --shift=2

To train on the expanded MNIST (40x40) for affNIST generalization:

python --data_dir=/path/to/dataset/ \
--summary_dir=/path/to/log/dir --max_steps=30000 --dataset=mnist
--batch_size=128 --shift=6 --pad=6

Hyperparameters can be overriden using the hparams_override flag, e.g. --hparams_override=num_latent_capsules=24,num_atoms=16. The flag should also be used in the evaluation phase to ensure the model uses the expected parameters.


To generate the encoded representation for a single dataset, e.g. MNIST:

python --data_dir=/path/to/mnist_data/ --train=False \
--checkpoint=/path/to/model.ckpt --summary_dir=/path/to/output \
--eval_set=train --eval_size=60000

To generate the encoded representation for a sharded dataset, e.g. affNIST:

python --data_dir=/path/to/mnist_data/ --train=False \
--checkpoint=/path/to/model.ckpt --summary_dir=/path/to/output \
--eval_set=test --eval_size=80000 --eval_shard=0


The classifier automatically finds the appropriate input data that was generated by the encoder, so only the path to the encoded outputs is necessary.

Evaluate the encoded representation using SVM:

python --data_dir=/path/to/outputs/dir \
--summary_dir=/path/to/log/dir --model=svm --dataset=mnist --last_step=30000

The SVM hyperparameters can also be overrided using a similar flag svm_hparams_override.


Thanks to Sabour et al. for open-sourcing the official Capsules model.