Skip to content
No description or website provided.
Python
Branch: master
Clone or download

Latest commit

Fetching latest commit…
Cannot retrieve the latest commit at this time.

Files

Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
configs first commit Feb 11, 2020
imgs first commit Feb 11, 2020
src first commit Feb 11, 2020
ACKNOWLEDGEMENTS first commit Feb 11, 2020
CODE_OF_CONDUCT.md
CONTRIBUTING.md
LICENSE first commit Feb 11, 2020
README.md first commit Feb 11, 2020
main_capsule.py first commit Feb 11, 2020
utils.py

README.md

Python 3.6

Capsules with Inverted Dot-Product Attention Routing

Pytorch implementation for Capsules with Inverted Dot-Product Attention Routing.

Paper

Capsules with Inverted Dot-Product Attention Routing
Yao-Hung Hubert Tsai, Nitish Srivastava, Hanlin Goh, and Ruslan Salakhutdinov
International Conference on Learning Representations (ICLR), 2020.

Please cite our paper if you find our work useful for your research:

@inproceedings{tsai2020Capsules,
  title={Capsules with Inverted Dot-Product Attention Routing},
  author={Tsai, Yao-Hung Hubert and Srivastava, Nitish and Goh, Hanlin and Salakhutdinov, Ruslan},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020},
}

Overview

Overall Architecture

An example of our proposed architecture is shown above. The backbone is a standard feed-forward convolutional neural network. The features extracted from this network are fed through another convolutional layer. At each spatial location, groups of 16 channels are made to create capsules (we assume a 16-dimensional pose in a capsule). LayerNorm is then applied across the 16 channels to obtain the primary capsules. This is followed by two convolutional capsule layers, and then by two fully-connected capsule layers. In the last capsule layer, each capsule corresponds to a class. These capsules are then used to compute logits that feed into a softmax to computed the classification probabilities. Inference in this network requires a feed-forward pass up to the primary capsules. After this, our proposed routing mechanism (discussed later) takes over.

Inverted Dot-Product Attention Routing

In our method, the routing procedure resembles an inverted attention mechanism, where dot products are used to measure agreement. Specifically, the higher-level (parent) units compete for the attention of the lower-level (child) units, instead of the other way around, which is commonly used in attention models. Hence, the routing probability directly depends on the agreement between the parent’s pose (from the previous iteration step) and the child’s vote for the parent’s pose (in the current iteration step). We (1) use Layer Normalization (Ba et al., 2016) as normalization, and we (2) perform inference of the latent capsule states and routing probabilities jointly across multiple capsule layers (instead of doing it layer-wise).

Concurrent Routing

The concurrent routing is a parallel-in-time routing procedure for all capsules layers.

Usage

Prerequisites

Datasets

We use CIFAR10 and CIFAR100.

Run the Code

Arguments

Args Value help
debug - Enter into a debug mode, which means no models and results will be saved. True or False
num_routing 1 The number of routing iteration. The number should > 1.
dataset CIFAR10 Choice of the dataset. CIFAR10 or CIFAR100.
backbone resnet Choice of the backbone. simple or resnet.
config_path ./configs/resnet_backbone_CIFAR10.json Configurations for capsule layers.

Running CIFAR-10

python main_capsule.py --num_routing 2 --dataset CIFAR10 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR10.json 

When num_routing is 1, the average performance we obtained is 94.73%.

When num_routing is 2, the average performance we obtained is 94.85% and the best model we obtained is 95.14%.

Running CIFAR-100

python main_capsule.py --num_routing 2 --dataset CIFAR100 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR100.json 

When num_routing is 1, the average performance we obtained is 76.02%.

When num_routing is 2, the average performance we obtained is 76.27% and the best model we obtained is 78.02%.

License

This code is released under the LICENSE terms.

You can’t perform that action at this time.