In [None]:
import torch
from torch import nn
a = torch.rand(30,256,14,14)
conv = nn.Conv2d(in_channels=256,out_channels=16,kernel_size=3,stride=2,padding=1)
b = conv(a)
b = b.reshape(30,16,49).permute(0,2,1)
b = b.reshape(30,784).unsqueeze(0)
print(b.shape)



In [None]:
import torch
from torch import nn
a = torch.rand(1,30,784)

class region_attention(nn.Module):
    def __init__(self,num_heads,dim,qkv_bias,out_dim):
        super(region_attention,self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, out_dim)


    def forward(self,x):
        B, seq_len, _ = x.shape
        qkv = self.qkv(x).reshape(B, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # [3,B,self.num_heads,seq_len,head_dim]
        q, k, v = qkv.reshape(3, B * self.num_heads, seq_len, -1).unbind(0)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, seq_len, -1).permute(0, 2, 1, 3).reshape(B, seq_len, -1)
        x = self.proj(x)

        return x

self_attn = region_attention(8,784,True,256)
b = self_attn(a)


b = b.squeeze(0).unsqueeze(1)

print(b.shape)

In [26]:
import torch
outputs = [torch.rand(20,1,256),torch.rand(30,1,256),torch.rand(12,1,256)]
print(torch.cat(outputs).shape)

torch.Size([62, 1, 256])


# region attention

In [25]:
from timm.models.layers import Mlp
class region_attention(nn.Module):
    def __init__(self,num_heads=8,dim=784,qkv_bias=True,out_dim=256,mlp_ratio=2):
        super(region_attention,self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.conv = nn.Conv2d(in_channels=256,out_channels=16,kernel_size=3,stride=2,padding=1)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, out_dim)
        self.norm1 = nn.LayerNorm(out_dim)
        self.mlp = Mlp(in_features=out_dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.GELU)

    def forward(self,in_feature):
        x = self.conv(in_feature)
        N_region,channels,H,W = x.shape
        x = x.reshape(N_region,channels,H*W).permute(0,2,1)
        x = x.reshape(N_region,H*W*channels).unsqueeze(0)

        shortcut = x
        B, seq_len, _ = x.shape
        qkv = self.qkv(x).reshape(B, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # [3,B,self.num_heads,seq_len,head_dim]
        q, k, v = qkv.reshape(3, B * self.num_heads, seq_len, -1).unbind(0)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, seq_len, -1).permute(0, 2, 1, 3).reshape(B, seq_len, -1)
        x = shortcut + x
        x = self.norm1(self.proj(x))
        x = x + self.mlp(x)

        return x.squeeze(0).unsqueeze(1)
    
in_feature = torch.rand(50,256,14,14)
self_attn = region_attention()
out_feature = self_attn(in_feature)
print(out_feature.shape)

torch.Size([50, 1, 256])
