In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from matplotlib import image
from matplotlib import pyplot as plt
from shapely.wkt import loads
import cv2 as cv
from torch.nn import TransformerDecoder, TransformerDecoderLayer

In [43]:
class ATM(nn.Module):
    
    def __init__(self, dim, num_heads, qkv_bias =False, qk_scale=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim// num_heads
        self.scale = qk_scale or head_dim ** -5
        
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        
        self.proj = nn.Linear(dim, dim)

    def forward(self,xq,xk,xv):
        B, Nq, C = xq.size()
        Nk = xk.size()[1]
        Nv = xv.size()[1]
        
        q = self.q(xq).reshape(B, Nq, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(xk).reshape(B, Nk, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(xv).reshape(B, Nv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn_save = attn.clone()
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
        x = self.proj(x)
        return x, attn_save.sum(dim=1) / self.num_heads

In [44]:
class Transformer_Decoder_Layer(nn.Module):
    
    def __init__(self,dim,num_heads=1,qkv_bias=False,feed_forward_dim = None):
        super().__init__()

        if feed_forward_dim == None:
            feed_forward_dim = dim*4

        self.self_attn = nn.MultiheadAttention(dim, num_heads, dropout=.1, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.multihead_attn =ATM(dim,num_heads,qkv_bias)
        self.norm2 = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim,feed_forward_dim)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(feed_forward_dim,dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self,x,memory):
        x2 = self.self_attn(x,x,x)[0]
        x += self.norm1(x2)

        #if we add a Transformer Encoder instead of the convolutional one we can use the attn to make a nice graphic of the attention map but right now it is spatially meaningless
        x2 , attn = self.multihead_attn(x,memory,memory)
        x += self.norm2(x2)
        
        #FeedForward
        x2 = self.linear1(x)
        x2 = self.activation(x2)
        x2 = self.linear2(x2)
        x+= x2
        x = self.norm3(x)
        return(x,attn)

In [45]:
class Transformer_Decoder(nn.Module):
    def __init__(self,dim,num_heads=8,batch_size=3,n_classes = 2):
        super().__init__()
        self.decoder_layer = Transformer_Decoder_Layer(dim,num_heads)
        self.q = nn.Embedding(n_classes,16)  
        self.batch_size = batch_size

    def forward(self,input):
        #for one layer
        features = input
        q = self.q.weight.repeat(self.batch_size,1,1)
        q,attn = self.decoder_layer(q,features)        

        return(q,attn)

In [46]:
test = Transformer_Decoder(16,8,3)
features = torch.rand(3,9,16)
q,attn = test(features)
print(attn.shape)
print(q.shape)


torch.Size([3, 2, 16])
torch.Size([3, 2, 9])


In [47]:
classifier = nn.Linear(16,3)

In [48]:
qs = q
qs = torch.stack([qs,qs],dim=0)
qs.shape

torch.Size([2, 3, 2, 16])

Testing the matrix multiplication

In [49]:
q = torch.rand(3,2,16)

k = torch.rand(3,9,16)
v = torch.rand(3,9,16)
B,Nq,C = q.shape
Nk = k.shape[1]
Nv = v.shape[1]
num_heads = 8

_q = q.reshape(B, Nq, num_heads, C // num_heads).permute(0, 2, 1, 3)
_k = k.reshape(B, Nk, num_heads, C // num_heads).permute(0, 2, 1, 3)
_v = v.reshape(B, Nv, num_heads, C // num_heads).permute(0, 2, 1, 3)
attn = _q@_k.transpose(-2,-1)

attn = attn.softmax(dim=-1)
print(attn.shape)
x = (attn @ _v).transpose(1, 2).reshape(B, Nq, C)
print(x.shape)

torch.Size([3, 8, 2, 9])
torch.Size([3, 2, 16])


In [50]:
qs = [q,x]
qs = torch.stack(qs,dim=0)
qs.shape

torch.Size([2, 3, 2, 16])