Skip to content

Sleepychord/Image-Local-Attention

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Image Local Attention: a Better PyTorch Implementation

Notice for Modification

This repo is based on Zhendong Zhang's framework. We modify the implementation at a large margin for causal-mask attention, cross-resolution attention and speed it up.

Introduction

Attention is widely used in deep learning now. Given a query and a collection of key-value pairs, the output of an attention module is the weighted sum of all values. The weights are obtained based on the similarities between the query and keys which are usually measured by their inner products. However, when the number of keys is large, it is expensive to apply such a module.

Researchers consider local attention to address this problem. That is a small subset of keys is involved given a query. For images, "local" means an image region around a pixel. Image local attention achieves great success on image restoration tasks. However, current implementations are based on the im2col operation which is memory expensive especially when the local patch is large.

Implementation

Here, queries Q, keys K and value V are represented in CHW (channel, height, width) tensors. They are generated by convolutions. And "local region" is a Ckk sub tensor where k is the size of a patch. Current implementations are based on the following steps:

  • rearrange K and V to (kk)CHW tensors via im2col.
  • compute similarity matrix W between Q and K: (kk)HW.
  • compute output O by summation of V weighted by W: CHW.

Clearly, the first step requires kk times memory to store the rearranged K and V. However, this can be avoided. In our implementation, we compute W and O without rearranging keys and values. To this end, we write two CUDA kernels. And we build a PyTorch extension based on them.

Install and usage

python setup.py install

Requirements:

PyTorch >= 1.4.0
CUDA >= 10.0

We write the Python warper in function.py. Here is an example:

import torch
from function import LocalAttention

# kH and kW for local patch size
# works only on GPU
module = LocalAttention(inp_channels=3, out_channels=16, kH=7, kW=7).cuda()
x = torch.rand(32, 3, 64, 64).cuda()

# Q, K, V are generated by convolutions of x
y = module(x)

Performance

We evaluate the relative GPU memory and running time of our implementation compared with the plain PyTorch implementation: the first table for forward pass and the second table for forward-backward loop. Here, we set H=W=128 and C=64.

k Relative GPU Memory Relative running time
5 10.2% 31.4%
11 3.2% 15.6%
21 2.0% 26.5%
k Relative GPU Memory Relative running time
5 9.0% 31.2%
11 3.4% 21.5%
21 2.3% 47.3%

Our implementation reduces the GPU memory by an order of magnitude and it is faster compared with the plain PyTorch implementations.

Refer /test for more results.

About

A better PyTorch implementation of image local attention which reduces the GPU memory by an order of magnitude.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 47.2%
  • Cuda 45.4%
  • C++ 4.3%
  • C 2.5%
  • Shell 0.6%