<a href="https://colab.research.google.com/github/RajkumarGalaxy/ML-Image-Processing/blob/master/self_attention_cv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self-Attention Computer Vision

This notebook has reference from the following sources and papers

https://github.com/The-AI-Summer/self-attention-cv
https://arxiv.org/pdf/1706.03762.pdf
https://analyticsindiamag.com/going-beyond-cnn-stand-alone-self-attention/
https://arxiv.org/pdf/2101.11605

In [3]:
!pip install self_attention_cv



## Multi-head Self Attention

In [2]:
import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first token/patch in the first batch \n')
print(y.detach().numpy()[0][0])


Shape of output is:  torch.Size([16, 10, 64])
----------------------------------------------------------------------
Output corresponding to the first token/patch in the first batch 

[ 0.05095745 -0.3574398   0.2571603   0.01686804 -0.32707906  0.16551468
 -0.03029712  0.17124459 -0.11935965 -0.05426797 -0.27412418 -0.5849673
  0.03273072 -0.07591394  0.0918312   0.14555283 -0.01222516  0.12830907
 -0.12161413  0.22703463  0.31742722 -0.15440758  0.16580498 -0.1876549
 -0.14389239 -0.31196266  0.10959084  0.32506582  0.34819284 -0.20865689
 -0.15298392 -0.02287815 -0.28265458  0.39987215  0.09144013  0.5114911
 -0.10702722 -0.17274295 -0.03052047 -0.32069373 -0.24848755 -0.19961475
 -0.25558165  0.16457602  0.04435449  0.14818098 -0.25300896  0.19681376
 -0.14411886  0.0622221  -0.2663249  -0.08401316 -0.266905   -0.14300469
  0.02286307 -0.13363452 -0.00669263 -0.08034971 -0.09160165  0.19021185
 -0.10703149 -0.09570458  0.07904153 -0.09355289]


## Axial Attention

In [3]:
# Axial Attention
from self_attention_cv import AxialAttentionBlock

model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first token/patch in the first batch \n')
print(y.detach().numpy()[0][0])

Shape of output is:  torch.Size([1, 256, 64, 64])
----------------------------------------------------------------------
Output corresponding to the first token/patch in the first batch 

[[0.22185817 1.8086953  1.3097283  ... 0.630754   1.1579583  0.        ]
 [0.         0.         0.         ... 3.1923792  1.9478554  3.044509  ]
 [1.4585196  1.0503312  0.         ... 0.18355879 0.05406958 0.        ]
 ...
 [0.33815414 0.09240472 4.8054843  ... 0.         0.7096393  1.2665645 ]
 [0.09514064 0.09698692 1.9458563  ... 1.8430982  0.         1.1634666 ]
 [0.         1.0365889  1.2793136  ... 0.6040628  0.10727388 1.4690821 ]]


## Bottleneck Attention

In [13]:
from self_attention_cv.bottleneck_transformer import BottleneckBlock
x = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(x)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first patch in the first head, first batch \n')
print(y.detach().numpy()[0][0][0])

forward.. torch.Size([1, 512, 32, 32])
Shape of output is:  torch.Size([1, 1024, 16, 16])
----------------------------------------------------------------------
Output corresponding to the first patch in the first head, first batch 

[ 2.483006   -0.2608691   0.8990649   1.7123606   0.3083073   0.57465255
 -1.069338    2.2092774   0.37177175  1.4038159   1.6982754   0.44549412
 -0.93696225 -0.86148393  1.0227227  -0.33026367]


## Transformer Encoder


In [5]:
# Transformer Encoder
from self_attention_cv import TransformerEncoder

model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first token/patch in the first batch \n')
print(y.detach().numpy()[0][0])

Shape of output is:  torch.Size([16, 10, 64])
----------------------------------------------------------------------
Output corresponding to the first token/patch in the first batch 

[ 1.1898813   1.8956766   1.2114029   0.12728807 -0.3628582  -1.1008008
 -1.3719841   0.39440683  0.51728094  1.4535215  -3.0603547   0.3478451
 -0.03671097  1.7242503   0.6375115  -0.11810965 -0.05245292 -0.5992121
 -0.53516376 -1.2852938  -0.5251507  -0.05185176 -1.1167455   0.14817092
  0.85586274  0.9737668   0.08120837 -0.75146204 -0.7960583  -0.8653657
  1.705025   -1.2222375   1.1493846  -0.23083282  1.9175692   0.51956785
 -0.7669539  -1.4143958   0.8146387  -0.7686734  -0.02727643 -1.1802992
  0.4200065   0.858947    0.9608477  -0.57810295  1.3265743  -1.1821389
  0.6577277  -1.7220386   1.061455    1.4776539  -0.5479512  -1.0370749
  0.03364755 -0.25013465  0.34635526 -0.32273477 -0.3752242  -0.5290259
  0.71639156  0.5257713  -0.03652823 -1.2284385 ]


## Vision Transformer 

In [5]:
from self_attention_cv import ViT

model = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model(x) # [2,10]

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first image \n')
print(y.detach().numpy()[0])

Shape of output is:  torch.Size([2, 10])
----------------------------------------------------------------------
Output corresponding to the first image 

[-0.2508071  -0.0645473  -0.13877457 -0.18308383 -0.3567136   0.34362894
 -0.35030687  0.3489043   0.5285294   0.08138363]


## Vision Transformer with ResNet50

In [6]:
from self_attention_cv import ResNet50ViT

model = ResNet50ViT(img_dim=256, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
x = torch.rand(2, 3, 256, 256)
y = model(x) # [2,10]

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first image \n')
print(y.detach().numpy()[0])

Shape of output is:  torch.Size([2, 10])
----------------------------------------------------------------------
Output corresponding to the first image 

[ 0.18863088  0.60527444  0.2697486   0.12503052  0.44111896 -0.2981342
  0.5278659  -0.769126   -0.8882933  -1.085641  ]


## TransUNet 

In [7]:
from self_attention_cv.transunet import TransUnet
x = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(x) # [2, 5, 128, 128]

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first image \n')
print(y.detach().numpy()[0][0])

Shape of output is:  torch.Size([2, 5, 128, 128])
----------------------------------------------------------------------
Output corresponding to the first image 

[[-0.5679928  -0.34442404 -0.27711168 ... -0.5603674  -0.7525164
  -0.47777802]
 [ 0.26423207 -0.00429824 -0.2819016  ... -0.19415903 -0.06176922
  -0.64677024]
 [-0.21234514 -0.3868027  -0.11079104 ... -0.44116968 -0.5203699
  -0.5179628 ]
 ...
 [ 0.3248333  -0.96315455 -1.3502424  ... -0.3086787  -0.92315567
  -0.42170545]
 [-0.04411215 -0.9422567  -1.0312608  ... -0.8501636  -0.92981017
  -0.44236282]
 [ 0.00445503 -0.27653238 -0.32638735 ... -0.34467506 -0.33213058
  -0.67374027]]


## 1D Absolute Positional Embedding

In [8]:
from self_attention_cv.pos_embeddings import AbsPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
x = torch.rand(2, 3, 20, 64)
y = model(x)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first token in the first head, first batch \n')
print(y.detach().numpy()[0][0][0])

Shape of output is:  torch.Size([2, 3, 20, 20])
----------------------------------------------------------------------
Output corresponding to the first token in the first head, first batch 

[-0.11605695  0.6117866  -0.1672374   0.43876237  0.23604953  0.1562004
  0.48934904  0.12783262  0.70804363  0.50528526 -0.11475162  1.131989
  0.38430238 -0.4284703   0.7435987  -0.5122682  -0.36132246 -0.91800463
  1.5392176  -0.05165316]


## 1D Relative Positional Embedding

In [9]:
from self_attention_cv.pos_embeddings import RelPosEmb1D

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
x = torch.rand(2, 3, 20, 64)
y = model(x)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first token in the first head, first batch \n')
print(y.detach().numpy()[0][0][0])

Shape of output is:  torch.Size([2, 3, 20, 20])
----------------------------------------------------------------------
Output corresponding to the first token in the first head, first batch 

[-0.5457531   0.0245374   1.206948    1.1106241   0.1697544  -0.04933987
  0.24826212  0.38298082 -0.41222885 -0.19180176  0.8518995  -0.63441426
 -0.9155214   1.2972814   0.52583534  0.05831212 -0.14253221 -0.10725353
 -0.68426543  0.09718338]


## 2D Relative Positional Embedding

In [11]:
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

x = torch.rand(2, 4, dim*dim, 128)
y = model(x)

print('Shape of output is: ', y.shape)
print('-'*70)
print('Output corresponding to the first patch in the first head, first batch \n')
print(y.detach().numpy()[0][0][0])

Shape of output is:  torch.Size([2, 4, 1024, 1024])
----------------------------------------------------------------------
Output corresponding to the first patch in the first head, first batch 

[ 1.7701193  -0.3390236  -0.06327437 ...  0.24196783  0.95375854
  0.74266154]
