In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels = 3, embed_dim = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size = patch_size, stride = patch_size,)

    def forward(self, x):
        x = self.proj(x) # (n_samples, emed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2) # (n_samples, emed_dim, n_patches)
        x = x.transpose(1,2) # n_samples, n_patches, embed_dim)
        return x

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Create a tensor of shape [max_len, dim] for encoding
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).float().unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))  # [dim/2]

        # Apply sin and cos functions for positional encoding
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices (sine)
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices (cosine)

        pe = pe.unsqueeze(0)  # Shape becomes [1, max_len, dim]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to the input tensor
        return x + self.pe[:, :x.size(1),:]  # x.size(1) is the length of the sequence

In [4]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        n_samples, n_tokens, dim = x.shape 
 
        if dim != self.dim:
            raise ValueError
            
        qkv = self.qkv(x) # (n_samples, n_patches +1, 3 * dim)
        qkv = qkv.reshape( n_samples, n_tokens, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1,head_dim )
        q, k, v = qkv[0], qkv[1], qkv[2]

        k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches +1)

        dp = (q @ k_t) * self.scale

        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x

In [12]:
class Attention_qkv(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)

    def forward(self, x):
        n_samples, n_tokens, dim = x.shape
 
        if dim != self.dim:
            raise ValueError
            
        qkv = self.qkv(x) # (n_samples, n_patches +1, 3 * dim)
        qkv = qkv.reshape( n_samples, n_tokens, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1,head_dim )
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = q.unsqueeze(3) @ k[:,:,self.attn_idx].transpose(-1,-2) #B,nh,L,1,K^2
        attn = attn + self.relative_bias[self.bias_idx].permute(2, 0, 1).unsqueeze(2)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v[:,:,self.attn_idx]).squeeze(3).transpose(-1,-2).contiguous().view(B,C,H,W)
        return x
        
    def get_bias_idx(self,H,W):
        num_repeat_h = torch.ones(self.window_size,dtype=torch.long)
        num_repeat_w = torch.ones(self.window_size,dtype=torch.long)
        num_repeat_h[self.window_size//2] = H-(self.window_size-1)
        num_repeat_w[self.window_size//2] = W-(self.window_size-1)
        bias_hw = (self.idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*self.window_size-1)) + self.idx_w.repeat_interleave(num_repeat_w)
        bias_idx = bias_hw.unsqueeze(-1) + self.idx_k
        return bias_idx.view(-1,self.window_size**2)
    
    def get_attn_idx(self,H,W):
        H_ = H - (self.window_size - 1)
        W_ = W - (self.window_size - 1)
        attn_idx = torch.arange(0,H_*W_,dtype=torch.float).view(1,1,H_,W_)
        attn_idx = self.pad_idx(attn_idx).view(-1).type(torch.long)
        attn_idx = self.get_unfold_idx(H,W)[attn_idx]
        return attn_idx
    
    def get_unfold_idx(self,H,W):
        H_ = H-(self.window_size-1)
        W_ = W-(self.window_size-1)
        h_idx = torch.arange(W_).repeat(H_)
        w_idx = torch.arange(H_).repeat_interleave(W_) * W
        k_idx_1 = torch.arange(self.window_size).repeat(self.window_size)
        k_idx_2 = torch.arange(self.window_size).repeat_interleave(self.window_size) * W
        k_idx = k_idx_1 + k_idx_2
        hw_idx = h_idx + w_idx
        unfold_idx = hw_idx[:,None] + k_idx
        return unfold_idx
    def set_input_size(self,input_size):
        H,W = input_size
        self.H,self.W = H,W
        assert H >= self.window_size and W >= self.window_size,'input size must not be smaller than window size'
        attn_idx = self.get_attn_idx(H,W)
        bias_idx = self.get_bias_idx(H,W)
        self.register_buffer("attn_idx", attn_idx)
        self.register_buffer("bias_idx",bias_idx)
    

In [25]:
x = torch.rand(10,3,224,224)
patch = PatchEmbed(224,16)
x  = patch(x)
x.shape

torch.Size([10, 196, 768])

In [26]:
pos = PositionalEncoding(768,196)
x = pos(x)
x.shape

torch.Size([10, 196, 768])

In [31]:
x = x.reshape(10,196,1,768)
x.shape

torch.Size([10, 196, 1, 768])

In [32]:
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention
attn = NeighborhoodAttention(
            768,
            kernel_size=7,
            dilation=2,
            num_heads=12,
        )

In [33]:
attn(x)

RuntimeError: Input axes must be greater than or equal to the product of kernel size and dilation. Got kernel size 7, dilation 2, but dimension size was 1.
Exception raised from CheckArgsAgainstDim at /natten-build/csrc/./include/natten/pytorch/helpers.h:101 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fd4b41af4d7 in /root/miniconda3/envs/tf/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fd4b417936b in /root/miniconda3/envs/tf/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: natten::pytorch::CheckArgsAgainstDim(int, int, int) + 0x9a (0x7fd332b2c95d in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #3: natten::pytorch::CheckArgsAgainstDim(std::tuple<int, int> const&, std::tuple<int, int> const&, std::tuple<int, int> const&) + 0x85 (0x7fd332b32e0b in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #4: natten::pytorch::na2d_qk_forward(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, std::tuple<int, int> const&, std::tuple<int, int> const&, std::tuple<bool, bool> const&) + 0x330 (0x7fd332b30ee8 in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x59dc1d4 (0x7fd331b8f1d4 in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x59d63ef (0x7fd331b893ef in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x59d1932 (0x7fd331b84932 in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x59d1c7c (0x7fd331b84c7c in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x301d127 (0x7fd32f1d0127 in /root/miniconda3/envs/tf/lib/python3.9/site-packages/natten/libnatten.cpython-39-x86_64-linux-gnu.so)
frame #10: /root/miniconda3/envs/tf/bin/python() [0x5072d7]
frame #11: _PyObject_MakeTpCall + 0x2ec (0x4f06ac in /root/miniconda3/envs/tf/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x526b (0x4ecbfb in /root/miniconda3/envs/tf/bin/python)
frame #13: /root/miniconda3/envs/tf/bin/python() [0x4f8053]
frame #14: _PyEval_EvalFrameDefault + 0x3764 (0x4eb0f4 in /root/miniconda3/envs/tf/bin/python)
frame #15: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #16: _PyFunction_Vectorcall + 0xd4 (0x4f7d84 in /root/miniconda3/envs/tf/bin/python)
frame #17: THPFunction_apply(_object*, _object*) + 0x116a (0x7fd492ddf1ca in /root/miniconda3/envs/tf/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #18: /root/miniconda3/envs/tf/bin/python() [0x507300]
frame #19: PyObject_Call + 0x158 (0x5057c8 in /root/miniconda3/envs/tf/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x5baf (0x4ed53f in /root/miniconda3/envs/tf/bin/python)
frame #21: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #22: /root/miniconda3/envs/tf/bin/python() [0x504f6c]
frame #23: _PyEval_EvalFrameDefault + 0x4d44 (0x4ec6d4 in /root/miniconda3/envs/tf/bin/python)
frame #24: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #25: _PyFunction_Vectorcall + 0xd4 (0x4f7d84 in /root/miniconda3/envs/tf/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x1231 (0x4e8bc1 in /root/miniconda3/envs/tf/bin/python)
frame #27: /root/miniconda3/envs/tf/bin/python() [0x4f8053]
frame #28: /root/miniconda3/envs/tf/bin/python() [0x505071]
frame #29: _PyEval_EvalFrameDefault + 0x3764 (0x4eb0f4 in /root/miniconda3/envs/tf/bin/python)
frame #30: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #31: _PyObject_FastCallDictTstate + 0x13e (0x4eff1e in /root/miniconda3/envs/tf/bin/python)
frame #32: _PyObject_Call_Prepend + 0x66 (0x502cc6 in /root/miniconda3/envs/tf/bin/python)
frame #33: /root/miniconda3/envs/tf/bin/python() [0x5cb1e3]
frame #34: _PyObject_MakeTpCall + 0x2ec (0x4f06ac in /root/miniconda3/envs/tf/bin/python)
frame #35: _PyEval_EvalFrameDefault + 0x4c84 (0x4ec614 in /root/miniconda3/envs/tf/bin/python)
frame #36: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #37: _PyEval_EvalCodeWithName + 0x47 (0x4e6717 in /root/miniconda3/envs/tf/bin/python)
frame #38: PyEval_EvalCodeEx + 0x39 (0x4e66c9 in /root/miniconda3/envs/tf/bin/python)
frame #39: PyEval_EvalCode + 0x1b (0x59398b in /root/miniconda3/envs/tf/bin/python)
frame #40: /root/miniconda3/envs/tf/bin/python() [0x5985a1]
frame #41: /root/miniconda3/envs/tf/bin/python() [0x4f87f4]
frame #42: _PyEval_EvalFrameDefault + 0x3c9 (0x4e7d59 in /root/miniconda3/envs/tf/bin/python)
frame #43: /root/miniconda3/envs/tf/bin/python() [0x50ba7c]
frame #44: _PyEval_EvalFrameDefault + 0x5e5c (0x4ed7ec in /root/miniconda3/envs/tf/bin/python)
frame #45: /root/miniconda3/envs/tf/bin/python() [0x50ba7c]
frame #46: _PyEval_EvalFrameDefault + 0x5e5c (0x4ed7ec in /root/miniconda3/envs/tf/bin/python)
frame #47: /root/miniconda3/envs/tf/bin/python() [0x50ba7c]
frame #48: /root/miniconda3/envs/tf/bin/python() [0x5035fd]
frame #49: _PyEval_EvalFrameDefault + 0x686 (0x4e8016 in /root/miniconda3/envs/tf/bin/python)
frame #50: /root/miniconda3/envs/tf/bin/python() [0x4f8053]
frame #51: _PyEval_EvalFrameDefault + 0x3c9 (0x4e7d59 in /root/miniconda3/envs/tf/bin/python)
frame #52: /root/miniconda3/envs/tf/bin/python() [0x4f8053]
frame #53: _PyEval_EvalFrameDefault + 0x686 (0x4e8016 in /root/miniconda3/envs/tf/bin/python)
frame #54: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #55: /root/miniconda3/envs/tf/bin/python() [0x504fdd]
frame #56: PyObject_Call + 0xb4 (0x505724 in /root/miniconda3/envs/tf/bin/python)
frame #57: _PyEval_EvalFrameDefault + 0x3764 (0x4eb0f4 in /root/miniconda3/envs/tf/bin/python)
frame #58: /root/miniconda3/envs/tf/bin/python() [0x4e6a8a]
frame #59: /root/miniconda3/envs/tf/bin/python() [0x504fdd]
frame #60: _PyEval_EvalFrameDefault + 0x1231 (0x4e8bc1 in /root/miniconda3/envs/tf/bin/python)
frame #61: /root/miniconda3/envs/tf/bin/python() [0x50ba7c]
frame #62: _PyEval_EvalFrameDefault + 0x5e5c (0x4ed7ec in /root/miniconda3/envs/tf/bin/python)
frame #63: /root/miniconda3/envs/tf/bin/python() [0x50ba7c]


In [16]:
attn = Attention_qkv(768)
q,k,v = attn(x)
q.shape, k.shape, v.shape

AttributeError: 'Attention_qkv' object has no attribute 'attn_idx'

In [194]:
k[:,:,0].unsqueeze(2).shape


torch.Size([10, 12, 1, 64])

In [196]:
q = q * (attn.dim)**2

In [198]:
q.shape

torch.Size([10, 12, 196, 64])

In [200]:
window_size =7
H,W = 14,14
H_ = H-(window_size-1)
W_ = W-(window_size-1)
h_idx = torch.arange(W_).repeat(H_)
w_idx = torch.arange(H_).repeat_interleave(W_) * W
hw_idx = h_idx + w_idx
(h_idx, w_idx, hw_idx)

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
         0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
         0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]),
 tensor([ 0,  0,  0,  0,  0,  0,  0,  0, 14, 14, 14, 14, 14, 14, 14, 14, 28, 28,
         28, 28, 28, 28, 28, 28, 42, 42, 42, 42, 42, 42, 42, 42, 56, 56, 56, 56,
         56, 56, 56, 56, 70, 70, 70, 70, 70, 70, 70, 70, 84, 84, 84, 84, 84, 84,
         84, 84, 98, 98, 98, 98, 98, 98, 98, 98]),
 tensor([  0,   1,   2,   3,   4,   5,   6,   7,  14,  15,  16,  17,  18,  19,
          20,  21,  28,  29,  30,  31,  32,  33,  34,  35,  42,  43,  44,  45,
          46,  47,  48,  49,  56,  57,  58,  59,  60,  61,  62,  63,  70,  71,
          72,  73,  74,  75,  76,  77,  84,  85,  86,  87,  88,  89,  90,  91,
          98,  99, 100, 101, 102, 103, 104, 105]))

In [201]:
k_idx_1 = torch.arange(window_size).repeat(window_size)
k_idx_2 = torch.arange(window_size).repeat_interleave(window_size) * W
k_idx = k_idx_1 + k_idx_2
(k_idx_1, k_idx_2, k_idx)

(tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,
         3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5,
         6]),
 tensor([ 0,  0,  0,  0,  0,  0,  0, 14, 14, 14, 14, 14, 14, 14, 28, 28, 28, 28,
         28, 28, 28, 42, 42, 42, 42, 42, 42, 42, 56, 56, 56, 56, 56, 56, 56, 70,
         70, 70, 70, 70, 70, 70, 84, 84, 84, 84, 84, 84, 84]),
 tensor([ 0,  1,  2,  3,  4,  5,  6, 14, 15, 16, 17, 18, 19, 20, 28, 29, 30, 31,
         32, 33, 34, 42, 43, 44, 45, 46, 47, 48, 56, 57, 58, 59, 60, 61, 62, 70,
         71, 72, 73, 74, 75, 76, 84, 85, 86, 87, 88, 89, 90]))

In [202]:
unfold_idx = hw_idx.unsqueeze(1) + k_idx #hw_idx[:,None] 
unfold_idx

tensor([[  0,   1,   2,  ...,  88,  89,  90],
        [  1,   2,   3,  ...,  89,  90,  91],
        [  2,   3,   4,  ...,  90,  91,  92],
        ...,
        [103, 104, 105,  ..., 191, 192, 193],
        [104, 105, 106,  ..., 192, 193, 194],
        [105, 106, 107,  ..., 193, 194, 195]])

In [203]:
attn_idx = torch.arange(0,H_*W_).float().view(1,1,H_,W_)
attn_idx

tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11., 12., 13., 14., 15.],
          [16., 17., 18., 19., 20., 21., 22., 23.],
          [24., 25., 26., 27., 28., 29., 30., 31.],
          [32., 33., 34., 35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44., 45., 46., 47.],
          [48., 49., 50., 51., 52., 53., 54., 55.],
          [56., 57., 58., 59., 60., 61., 62., 63.]]]])

In [204]:
pad_idx = nn.ReplicationPad2d(window_size//2)
pad_idx

ReplicationPad2d((3, 3, 3, 3))

In [205]:
attn_idx = pad_idx(attn_idx).view(-1).type(torch.long)
attn_idx

tensor([ 0,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  7,  7,  7,  0,  0,  0,  0,
         1,  2,  3,  4,  5,  6,  7,  7,  7,  7,  0,  0,  0,  0,  1,  2,  3,  4,
         5,  6,  7,  7,  7,  7,  0,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  7,
         7,  7,  8,  8,  8,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 16, 16,
        16, 16, 17, 18, 19, 20, 21, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 26,
        27, 28, 29, 30, 31, 31, 31, 31, 32, 32, 32, 32, 33, 34, 35, 36, 37, 38,
        39, 39, 39, 39, 40, 40, 40, 40, 41, 42, 43, 44, 45, 46, 47, 47, 47, 47,
        48, 48, 48, 48, 49, 50, 51, 52, 53, 54, 55, 55, 55, 55, 56, 56, 56, 56,
        57, 58, 59, 60, 61, 62, 63, 63, 63, 63, 56, 56, 56, 56, 57, 58, 59, 60,
        61, 62, 63, 63, 63, 63, 56, 56, 56, 56, 57, 58, 59, 60, 61, 62, 63, 63,
        63, 63, 56, 56, 56, 56, 57, 58, 59, 60, 61, 62, 63, 63, 63, 63])

In [206]:
attn_idx  = unfold_idx[attn_idx]
attn_idx

tensor([[  0,   1,   2,  ...,  88,  89,  90],
        [  0,   1,   2,  ...,  88,  89,  90],
        [  0,   1,   2,  ...,  88,  89,  90],
        ...,
        [105, 106, 107,  ..., 193, 194, 195],
        [105, 106, 107,  ..., 193, 194, 195],
        [105, 106, 107,  ..., 193, 194, 195]])

In [207]:
bias = torch.arange(window_size ** 2).repeat(H * W, 1)
bias

tensor([[ 0,  1,  2,  ..., 46, 47, 48],
        [ 0,  1,  2,  ..., 46, 47, 48],
        [ 0,  1,  2,  ..., 46, 47, 48],
        ...,
        [ 0,  1,  2,  ..., 46, 47, 48],
        [ 0,  1,  2,  ..., 46, 47, 48],
        [ 0,  1,  2,  ..., 46, 47, 48]])

In [209]:
attn_idx.shape

torch.Size([196, 49])

In [210]:
k.shape

torch.Size([10, 12, 196, 64])

In [162]:
k[:,:,attn_idx]

tensor([[[[ 3.9289e+00,  3.8185e+00,  3.1616e+00,  ...,  3.4523e+00,
            3.2540e+00,  2.7054e+00],
          [ 3.1262e+00,  2.8160e+00,  2.5354e+00,  ...,  2.6834e+00,
            2.1798e+00,  1.7825e+00],
          [ 3.0375e+00,  2.4665e+00,  2.0519e+00,  ...,  2.8424e+00,
            2.4302e+00,  1.8020e+00],
          ...,
          [-8.9167e-01, -6.3551e-01, -5.2425e-01,  ...,  2.3182e+00,
            2.8595e+00,  1.8840e+00],
          [-1.4740e-01, -1.7072e-01,  8.0530e-02,  ...,  2.7605e+00,
            3.2866e+00,  2.1529e+00],
          [ 4.4728e-01,  3.8404e-01,  4.9118e-01,  ...,  3.6537e+00,
            3.7063e+00,  2.6141e+00]],

         [[-1.0858e+00, -1.1290e+00, -8.9149e-01,  ..., -1.1985e-01,
            5.2028e-01, -1.1441e+00],
          [-8.4863e-01, -7.7516e-01, -3.7013e-01,  ..., -5.9547e-01,
            3.4185e-01, -1.5049e+00],
          [-6.8205e-01, -5.9150e-01, -2.1612e-01,  ..., -7.3382e-01,
            4.2068e-01, -1.6308e+00],
          ...,
     

In [220]:
k[:,:,attn_idx].transpose(-1,-2).shape

torch.Size([10, 12, 196, 64, 49])

In [212]:
k[:,:,attn_idx].shape

torch.Size([10, 12, 196, 49, 64])

In [216]:
q.unsqueeze(3).shape

torch.Size([10, 12, 196, 1, 64])

In [223]:
attn = q.unsqueeze(3) @ k[:,:,attn_idx].transpose(-1,-2)
attn.shape

torch.Size([10, 12, 196, 1, 49])

In [225]:
num_repeat_h = torch.ones(window_size,dtype=torch.long)
num_repeat_w = torch.ones(window_size,dtype=torch.long)
num_repeat_h, num_repeat_w

(tensor([1, 1, 1, 1, 1, 1, 1]), tensor([1, 1, 1, 1, 1, 1, 1]))

In [226]:
num_repeat_h[window_size//2] = H-(window_size-1)
num_repeat_w[window_size//2] = W-(window_size-1)
num_repeat_h, num_repeat_w

(tensor([1, 1, 1, 8, 1, 1, 1]), tensor([1, 1, 1, 8, 1, 1, 1]))

In [None]:
class NeighborhoodAttention(nn.Module): #It can only use static size as input,but you can define a new input size if you wish.
    def __init__(self,input_size, dim, num_heads,window_size=7, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert window_size%2 == 1,'windowsize must be odd.'
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Conv2d(dim,dim*3,1, bias=qkv_bias)
        self.proj = nn.Conv2d(dim, dim, 1)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.pad_idx = nn.ReplicationPad2d(self.window_size//2)
        self.relative_bias = nn.Parameter(torch.zeros((2*self.window_size-1)**2,num_heads))
        trunc_normal_(self.relative_bias, std=.02)
        self.idx_h = torch.arange(0,window_size)
        self.idx_w = torch.arange(0,window_size)
        self.idx_k = ((self.idx_h.unsqueeze(-1) * (2*self.window_size-1)) + self.idx_w).view(-1)
        self.set_input_size(input_size)
        
    def forward(self, x):
        x = self.attention(x)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def attention(self,x):
        B,C,H,W = x.shape
        assert H >= self.window_size and W >= self.window_size,'input size must not be smaller than window size'
        qkv = self.qkv(x).view(B, 3,self.num_heads,C//self.num_heads,H*W).permute(1, 0, 2, 4, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = q.unsqueeze(3) @ k[:,:,self.attn_idx].transpose(-1,-2) #B,nh,L,1,K^2
        attn = attn + self.relative_bias[self.bias_idx].permute(2, 0, 1).unsqueeze(2)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v[:,:,self.attn_idx]).squeeze(3).transpose(-1,-2).contiguous().view(B,C,H,W)
        return x
        
    def get_bias_idx(self,H,W):
        num_repeat_h = torch.ones(self.window_size,dtype=torch.long)
        num_repeat_w = torch.ones(self.window_size,dtype=torch.long)
        num_repeat_h[self.window_size//2] = H-(self.window_size-1)
        num_repeat_w[self.window_size//2] = W-(self.window_size-1)
        bias_hw = (self.idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*self.window_size-1)) + self.idx_w.repeat_interleave(num_repeat_w)
        bias_idx = bias_hw.unsqueeze(-1) + self.idx_k
        return bias_idx.view(-1,self.window_size**2)
    
    def get_attn_idx(self,H,W):
        H_ = H - (self.window_size - 1)
        W_ = W - (self.window_size - 1)
        attn_idx = torch.arange(0,H_*W_,dtype=torch.float).view(1,1,H_,W_)
        attn_idx = self.pad_idx(attn_idx).view(-1).type(torch.long)
        attn_idx = self.get_unfold_idx(H,W)[attn_idx]
        return attn_idx
    
    def get_unfold_idx(self,H,W):
        H_ = H-(self.window_size-1)
        W_ = W-(self.window_size-1)
        h_idx = torch.arange(W_).repeat(H_)
        w_idx = torch.arange(H_).repeat_interleave(W_) * W
        k_idx_1 = torch.arange(self.window_size).repeat(self.window_size)
        k_idx_2 = torch.arange(self.window_size).repeat_interleave(self.window_size) * W
        k_idx = k_idx_1 + k_idx_2
        hw_idx = h_idx + w_idx
        unfold_idx = hw_idx[:,None] + k_idx
        return unfold_idx
    
    def set_input_size(self,input_size):
        H,W = input_size
        self.H,self.W = H,W
        assert H >= self.window_size and W >= self.window_size,'input size must not be smaller than window size'
        attn_idx = self.get_attn_idx(H,W)
        bias_idx = self.get_bias_idx(H,W)
        self.register_buffer("attn_idx", attn_idx)
        self.register_buffer("bias_idx",bias_idx)
        
class NATLayer(nn.Module):
    def __init__(self,input_size, dim, num_heads,window_size=7,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=Channel_Layernorm, layer_scale=None):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.norm1 = norm_layer(dim)
        self.attn = NeighborhoodAttention(input_size, dim, num_heads,window_size,qkv_bias, qk_scale, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        
    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
    
    def set_input_size(self,input_size):
        self.attn.set_input_size(input_size)
        
def test():
    print('it is cpu')
    model = NATLayer((28,28),128,4)
    img = torch.rand(2,128,56,56)
    try:
        print(model(img).shape)
    except:
        print('error')
        model.set_input_size((56,56))
        print(model(img).shape)
    print('cpu_success\n')

def test_cuda():
    print('it is cuda')
    model = NATLayer((28,28),128,4).cuda()
    img = torch.rand(2,128,56,56).cuda()
    try:
        print(model(img).shape)
    except:
        print('error')
        model.set_input_size((56,56))
        print(model(img).shape)
    print('success')
    print('cuda_success\n')
        
if __name__ == '__main__' :
    test()
    if torch.cuda.is_available():
        test_cuda()

In [None]:
att = NeighborhoodAttention(
            dim= 64,
            kernel_size=7,
            dilation=[1],
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            **extra_args,
        )
di

In [120]:
import torch

# Define two vectors
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

# Compute dot product using einsum
dot_product = torch.einsum("i,i->", a, b)

print(dot_product)  # Output: tensor(32.)


tensor(32.)


In [121]:
a@b

tensor(32.)

In [122]:
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([[5.0, 6.0], [7.0, 8.0]])

# Matrix multiplication using einsum
result = torch.einsum("ij,jk->ik", A, B)

print(result)

tensor([[19., 22.],
        [43., 50.]])


In [123]:
A@B

tensor([[19., 22.],
        [43., 50.]])

In [127]:
torch.manual_seed(0)
x1, x2 = torch.rand(2,2), torch.rand(2,2)
x1,x2

(tensor([[0.4963, 0.7682],
         [0.0885, 0.1320]]),
 tensor([[0.3074, 0.6341],
         [0.4901, 0.8964]]))

In [128]:
result = torch.einsum("ij,jk->ik", x1,x2)

print(result)

tensor([[0.5291, 1.0033],
        [0.0919, 0.1745]])


In [129]:
x1@x2

tensor([[0.5291, 1.0033],
        [0.0919, 0.1745]])

In [131]:
torch.manual_seed(0)
x1, x2 = torch.rand(1,5), torch.rand(5,1)
x1,x2

(tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074]]),
 tensor([[0.6341],
         [0.4901],
         [0.8964],
         [0.4556],
         [0.6323]]))

In [137]:
result = torch.einsum("ij,jk->", x1,x2)

print(result)

tensor(1.0250)


In [135]:
x1@x2

tensor([[1.0250]])

In [138]:
A = torch.randn(5, 3, 4)  # Batch of 5 matrices (3x4)
b = torch.randn(5, 4)
A,b

(tensor([[[ 0.4913, -0.2041, -0.0885,  0.5239],
          [-0.6659,  0.8504, -1.3527, -1.6959],
          [ 0.7854,  0.9928, -0.1932, -0.3090]],
 
         [[ 0.5026, -0.8594,  0.7502, -0.5855],
          [ 1.4437,  0.2660,  0.1665,  0.8744],
          [-0.1435, -0.1116, -0.6136,  0.0316]],
 
         [[ 2.0050,  0.0537,  0.6181, -0.4128],
          [-0.8411, -2.3160, -0.1023,  0.7924],
          [ 0.5627,  0.2596, -0.1740, -0.6787]],
 
         [[ 0.9383,  0.4889, -0.6731,  0.8728],
          [-1.2001, -0.0048, -0.5181, -0.3067],
          [-0.4731,  0.3356,  1.5091,  2.0820]],
 
         [[ 1.7067,  2.3804, -1.1256, -0.3170],
          [-0.1407,  0.8058,  0.3276, -0.7607],
          [-1.5991,  0.0185, -0.7504,  0.1854]]]),
 tensor([[ 1.0395,  0.3582, -0.0033, -0.5344],
         [ 0.2823,  0.4342, -0.8025, -1.2952],
         [-0.7502, -1.3120, -0.2188, -2.4351],
         [-0.4288,  0.2329,  0.7969, -0.1848],
         [-0.3701, -1.2103, -0.6227, -0.4637]]))

In [163]:
import numpy as np

def sparse_attention(Q, K, V, rho):
    """
    Compute sparse attention where each query only attends to a subset of keys.
    
    Parameters:
    Q: (n, d) Query matrix
    K: (m, d) Key matrix
    V: (m, d) Value matrix
    rho: function that takes index i and returns selected key indices
    
    Returns:
    A: (n, d) Attention output matrix
    """
    n, d = Q.shape
    A = np.zeros((n, d))  # Initialize output matrix
    
    for i in range(n):
        selected_indices = rho(i)  # Get the subset of keys for query i
        K_subset = K[selected_indices]  # Select keys
        V_subset = V[selected_indices]  # Select values
        
        # Compute attention scores
        scores = Q[i] @ K_subset.T  # (1, d) @ (d, k) -> (1, k)
        
        # Apply softmax to get attention weights
        alpha = np.exp(scores - np.max(scores))  # Stability trick
        alpha /= np.sum(alpha)  # Normalize
        
        # Compute weighted sum of values
        A[i] = alpha @ V_subset  # (1, k) @ (k, d) -> (1, d)
    
    return A

# Example Usage
np.random.seed(42)
n, m, d = 5, 10, 4  # 5 queries, 10 keys/values, embedding dim 4
Q = np.random.rand(n, d)
K = np.random.rand(m, d)
V = np.random.rand(m, d)

# Define rho: selecting 3 random keys for each query
def rho(i):
    return np.random.choice(m, 3, replace=False)  # Select 3 keys per query

A = sparse_attention(Q, K, V, rho)
print("Sparse Attention Output:\n", A)


Sparse Attention Output:
 [[0.66065966 0.56824659 0.53219128 0.49005517]
 [0.67007638 0.36214333 0.67851005 0.3613909 ]
 [0.2659878  0.56335287 0.48349048 0.53660343]
 [0.43979126 0.6548967  0.32939767 0.58181043]
 [0.58531914 0.48864323 0.10726895 0.61624787]]


In [169]:
dialtion = [1,1,2,4,1]
[i for i in dialtion for j in range(0,10,2)]

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1]

In [175]:
mat = torch.Tensor([1,0,0])
mat

tensor([1., 0., 0.])

In [179]:
mat /= torch.sum(mat)
mat

tensor([1., 0., 0.])

In [182]:
torch.softmax(mat,0)

tensor([0.5761, 0.2119, 0.2119])

In [1]:
pip install transformers

Note: you may need to restart the kernel to use updated packages.


In [1]:
from transformers import DinatConfig, DinatModel



In [2]:
config = {
    "patch_size": 4,
    "num_channels": 3,
    "embed_dim": 64,
    "depths": [3, 4, 6, 5],
    "num_heads": [2, 4, 8, 16],
    "kernel_size": 7,
    "dilations": [[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]],
    "mlp_ratio": 3.0,
    "qkv_bias": True,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "drop_path_rate": 0.1,
    "hidden_act": "gelu",
    "initializer_range": 0.02,
    "layer_norm_eps": 1e-05,
    "layer_scale_init_value": 0.0,
    "out_features": None,
    "out_indices": None
}


In [3]:
DinatModel.from_pretrained("shi-labs/dinat-mini-in1k-224")

ImportError: 
DinatModel requires the natten library but it was not found in your environment. You can install it by referring to:
shi-labs.com/natten . You can also install it with pip (may take longer to build):
`pip install natten`. Please note that you may need to restart your runtime after installation.


In [12]:
pip install natten

Collecting natten
  Downloading natten-0.17.4.tar.gz (10.9 MB)
     ---------------------------------------- 0.0/10.9 MB ? eta -:--:--
     ----------- ---------------------------- 3.1/10.9 MB 61.4 MB/s eta 0:00:01
     --------------- ------------------------ 4.2/10.9 MB 16.7 MB/s eta 0:00:01
     ------------------- -------------------- 5.2/10.9 MB 11.4 MB/s eta 0:00:01
     ----------------------- ---------------- 6.3/10.9 MB 9.0 MB/s eta 0:00:01
     --------------------------- ------------ 7.3/10.9 MB 7.7 MB/s eta 0:00:01
     ------------------------------ --------- 8.4/10.9 MB 7.1 MB/s eta 0:00:01
     ---------------------------------- ----- 9.4/10.9 MB 6.6 MB/s eta 0:00:01
     -------------------------------------- - 10.5/10.9 MB 6.4 MB/s eta 0:00:01
     ---------------------------------------- 10.9/10.9 MB 5.8 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Building wheels for collected packages: natten


  error: subprocess-exited-with-error
  
  × python setup.py bdist_wheel did not run successfully.
  │ exit code: 1
  ╰─> [102 lines of output]
      Building NATTEN for CPU ONLY.
      Number of workers: 5
      running bdist_wheel
      running build
      running build_py
      creating build
      creating build\lib.win-amd64-cpython-39
      creating build\lib.win-amd64-cpython-39\natten
      copying src\natten\context.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\experimental.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\flops.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\functional.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\na1d.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\na2d.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\na3d.py -> build\lib.win-amd64-cpython-39\natten
      copying src\natten\natten1d.py -> build\lib.win-amd64-cpython-

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DilatedNeighborhoodAttention(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = (kernel_size // 2) * dilation
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, H, W, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        
        # Unfold input to extract neighborhoods with dilation
        unfolded_k = F.unfold(k.permute(0, 3, 1, 2),
                              kernel_size=self.kernel_size,
                              dilation=self.dilation,
                              padding=self.padding)
        unfolded_v = F.unfold(v.permute(0, 3, 1, 2),
                              kernel_size=self.kernel_size,
                              dilation=self.dilation,
                              padding=self.padding)
        
        # Reshape for attention
        unfolded_k = unfolded_k.view(B, C, self.kernel_size ** 2, H, W).permute(0, 3, 4, 2, 1)
        unfolded_v = unfolded_v.view(B, C, self.kernel_size ** 2, H, W).permute(0, 3, 4, 2, 1)
        
        q = q.unsqueeze(-2)  # (B, H, W, 1, C)
        attn = (q * unfolded_k).sum(-1) / (C ** 0.5)
        attn = attn.softmax(dim=-1)
        
        out = (attn.unsqueeze(-1) * unfolded_v).sum(-2)
        out = self.proj(out)
        return out

class DiNATBlock(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = DilatedNeighborhoodAttention(dim, kernel_size, dilation)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Example Usage
B, H, W, C = 1, 32, 32, 64  # Batch size, height, width, channels
x = torch.randn(B, H, W, C)
block = DiNATBlock(dim=C, kernel_size=7, dilation=2)
output = block(x)
print(output.shape)  # Expected output: (1, 32, 32, 64)



torch.Size([1, 32, 32, 64])


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DilatedNeighborhoodAttention(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, H, W, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        
        # Unfold input to extract neighborhoods with dilation
        unfolded_k = F.unfold(k.permute(0, 3, 1, 2),
                              kernel_size=self.kernel_size,
                              dilation=self.dilation)
        unfolded_v = F.unfold(v.permute(0, 3, 1, 2),
                              kernel_size=self.kernel_size,
                              dilation=self.dilation)
        
        # Reshape for attention
        unfolded_k = unfolded_k.view(B, C, self.kernel_size ** 2, H, W).permute(0, 3, 4, 2, 1)
        unfolded_v = unfolded_v.view(B, C, self.kernel_size ** 2, H, W).permute(0, 3, 4, 2, 1)
        
        q = q.unsqueeze(-2)  # (B, H, W, 1, C)
        attn = (q * unfolded_k).sum(-1) / (C ** 0.5)
        attn = attn.softmax(dim=-1)
        
        out = (attn.unsqueeze(-1) * unfolded_v).sum(-2)
        out = self.proj(out)
        return out

class DiNATBlock(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = DilatedNeighborhoodAttention(dim, kernel_size, dilation)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Example Usage
B, H, W, C = 1, 32, 32, 64  # Batch size, height, width, channels
x = torch.randn(B, H, W, C)
block = DiNATBlock(dim=C, kernel_size=7, dilation=2)
output = block(x)
print(output.shape)  # Expected output: (1, 32, 32, 64)


RuntimeError: shape '[1, 64, 49, 32, 32]' is invalid for input of size 1254400

In [10]:
class DilatedNeighborhoodAttention(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, H, W, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        return q, k, v

In [61]:
B, H, W, C = 1, 32, 32, 64  # Batch size, height, width, channels
x = torch.randn(B, H, W, C)
block = DilatedNeighborhoodAttention(dim=C, kernel_size=7, dilation=2)
q, k, v = block(x)

In [62]:
q.shape

torch.Size([1, 32, 32, 64])

In [63]:
tor = torch.empty(0)
for i in range(0,H):
             for j  in range(0,W):
                tor = torch.cat((tor,q[:,i,j,:].unsqueeze(1).unsqueeze(1)),dim=0)

In [68]:
a = []
kernel_size = 7
for i in range(kernel_size ,H+1):
             for j  in range(kernel_size ,W+1):
                 a.append(i*j)
len(a)
                 

676

In [52]:
import torch

tensor_result = torch.empty(0)  # Start with an empty tensor

for i in range(5):
    
    tensor_result = torch.cat((tensor_result, i), dim=0)  # Concatenate

print(tensor_result)  # Output: tensor([0., 1., 2., 3., 4.])


TypeError: expected Tensor as element 1 in argument 0, but got int