This is a pytorch implementation of the paper Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation by Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille and Liang-Chieh Chen.
This paper implements the attention mechanism into different ResNet architectures.
Global Self-Attention on images is subject to the problem, that it can only be applied after significant spatial downsampling of the input. Every pixels relation is calculated to every other pixel so learning gets computationally very expensive, which prevents its usage across all layers in a fully attentional model.
In this paper the authors migitate this issue by introducing their Axial-Attention concept, where the attention mechanism related to one pixel is applied in two steps, vertically and horizontally:
Furthermore they extend the positional encoding from query-pixels also to the keys and values.
I only tested the implementation with ResNet50 for now. The used ResNet V1.5 architectures are adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
The paper notes In order to avoid careful initialization of WQ, WK, WV , rq, rk, rv, we use batch normalizations in all attention layers. Consequently two batch normalization layers are applied.
- attention: ResNet stages in which you would like to apply the attention layers
- num_heads: Number of attention heads
- kernel_size: Maximum local field on which Axial-Attention is applied
- inference: Allows to inspect the attention weights of a trained model
See the jupyter notebook or the example training script
- pytorch
- I use fast.ai and the imagenette dataset for the examples