In [2]:
import torch
import timm
import torch.nn as nn
from torch.nn.functional import pad
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
from timm.models.registry import register_model
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention
import warnings
warnings.filterwarnings('ignore')

In [7]:
import torch
from natten import NeighborhoodAttention2D as NeighborhoodAttention

# Correct input shape: (batch_size, channels, height, width)
x = torch.randn(1, 224, 224,3)

# Correct argument order: dim, kernel_size, num_heads
attn = NeighborhoodAttention(dim=3, kernel_size=7, num_heads=3, dilation=2)

# Forward pass
output = attn(x)
print(output.shape)


torch.Size([1, 224, 224, 3])


In [8]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        return self.fc2(self.drop(self.act(self.fc1(x)))) 

In [9]:
class NATransformerLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        kernel_size=7,
        dilation=1,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        **kwargs
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        self.attn = NeighborhoodAttention(
            dim,
            kernel_size=kernel_size,
            dilation=dilation,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
             **kwargs,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            out_features=dim,
            p=drop,
        )

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [10]:
attn = NATransformerLayer(
                    dim=3,
                    num_heads=3,
                    kernel_size=7,
                    dilation=2
                )

In [11]:
attn(x)

tensor([[[[ 1.9316, -0.3682, -1.4846],
          [-0.7695, -0.5473,  1.1452],
          [-1.4589,  0.8057,  1.1884],
          ...,
          [ 1.5352, -0.7118,  0.3144],
          [-2.2642, -0.4041, -0.4719],
          [-0.6485, -0.3080,  0.4100]],

         [[-0.4684,  0.5132,  1.2825],
          [ 1.2969, -1.2949,  0.3700],
          [-1.7419,  1.8532,  0.4206],
          ...,
          [ 0.6719, -0.9220,  0.2296],
          [-2.3613,  1.3454,  0.5207],
          [ 0.7159,  0.3065, -0.8544]],

         [[ 0.9237,  1.0169, -0.1309],
          [-0.7479,  1.1606, -1.7190],
          [ 1.4196,  1.2865,  1.4908],
          ...,
          [ 0.1790,  1.0003, -0.5102],
          [-1.8136,  0.9297,  1.3643],
          [-1.8092,  0.2282, -0.9241]],

         ...,

         [[-0.9146,  0.8632, -0.6907],
          [ 0.1677,  2.7600,  1.9142],
          [-1.1307,  0.5557,  2.2026],
          ...,
          [-2.4982,  1.8171,  1.0978],
          [ 1.4242,  0.3747,  1.1181],
          [ 0.2308, -1

In [13]:
dilations = [1, 2, 3, 4]  # List of dilation values
i = 2  # Selecting the third element (index 2)

dilation = 1 if dilations is None else dilations[i]
print(dilation) 

3


In [15]:
class PatchMerging(nn.Module):
    """
    Based on Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = pad(x, (0, 0, 0, W % 2, 0, H % 2))
            _, H, W, _ = x.shape
            
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, (H + 1) // 2, (W + 1) // 2, 4 * C)  # B H/2 W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x

In [57]:
x = torch.randn(1,224,224,3)
PatchMerging(3)(x).shape

torch.Size([1, 112, 112, 6])

In [17]:
class BasicLayer(nn.Module):
    """
    Based on Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim,
        depth,
        num_heads,
        kernel_size,
        dilations=None,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        norm_layer=nn.LayerNorm,
        downsample=None,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        
        # build blocks
        self.blocks = nn.ModuleList(
            [
                NATransformerLayer(
                    dim=dim,
                    num_heads=num_heads,
                    kernel_size=kernel_size,
                    dilation=1 if dilations is None else dilations[i],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i]
                    if isinstance(drop_path, list)
                    else drop_path,
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None
    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x
            

In [18]:
x.shape

torch.Size([1, 224, 224, 3])

In [20]:
basic = BasicLayer(3,4,3,7)

In [22]:
basic(x).shape

torch.Size([1, 224, 224, 3])

In [23]:
x.shape

torch.Size([1, 224, 224, 3])

In [24]:
x.permute(0,3,1,2).shape

torch.Size([1, 3, 224, 224])

In [25]:
class PatchEmbed(nn.Module):
    """
    From Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size
        )
        self.norm = None if norm_layer is None else norm_layer(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        if W % self.patch_size[1] != 0:
            x = pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)
        x = x.permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x

In [59]:
patch = PatchEmbed()
x = torch.randn(1,3,224,224)
x = patch(x)
x.shape

torch.Size([1, 56, 56, 96])

In [37]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Create a tensor of shape [max_len, dim] for encoding
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).float().unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))  # [dim/2]

        # Apply sin and cos functions for positional encoding
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices (sine)
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices (cosine)

        pe = pe.unsqueeze(0)  # Shape becomes [1, max_len, dim]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to the input tensor
        return x + self.pe[:, :x.size(1),:]  # x.size(1) is the length of the sequence

In [65]:
patch = PatchEmbed()
x = torch.randn(1,3,224,224)
_,_,H,W = x.shape
x = patch(x)


pos = PositionalEncoding(96, 56*56 )
x = pos(x)

In [74]:
patch = PatchEmbed()
x = torch.randn(1,3,224,224)
_,_,H,W = x.shape
x = patch(x)


pos = PositionalEncoding(96, 56*56 )
x = pos(x)
attn = NeighborhoodAttention(dim=96, kernel_size=7, num_heads=3, dilation=2)

# Forward pass
x = attn(x)
x = x.view(1, 56, 56, 4, 4, 6)  # Split (96 → 16 * 6) into (4, 4, 6)

# Step 2: Rearrange the dimensions to interleave the upsampling effect
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # (1, 56, 4, 56, 4, 6)

# Step 3: Merge the expanded dimensions to get (1, 224, 224, 6)
x = x.view(1, 56 * 4, 56 * 4, 6)
x.shape

torch.Size([1, 224, 224, 6])

In [70]:
patch_size = 4
x = nn.functional.interpolate(x.permute(0, 3, 1, 2), scale_factor=4, mode='bilinear', align_corners=False)
x.shape

torch.Size([1, 224, 224, 896])

In [53]:
class DiNAT_s(nn.Module):
    def __init__(
        self,
        img_size = 224,
        patch_size=4,
        in_chans=3,
        num_classes=1000,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        kernel_size=7,
        dilations=None,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.2,
        norm_layer=nn.LayerNorm,
        patch_norm=True,
        **kwargs
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
        n_patches = (img_size // patch_size) ** 2

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,
        )
        self.position_encoding = PositionalEncoding(dim=embed_dim,
                                                   max_len=n_patches)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                kernel_size=kernel_size,
                dilations=None if dilations is None else dilations[i_layer],
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
            )
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
   

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.position_encoding(x)
        x = self.pos_drop(x)
        print(x.shape)

        for layer in self.layers:
            x = layer(x)
        return x


In [54]:
dinat = DiNAT_s()
x = torch.randn(1,3,224,224)
dinat(x).shape

torch.Size([1, 56, 56, 96])


torch.Size([1, 7, 7, 768])