In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchEmbed(nn.Module):
    """Split image into patches and then embed them.

    Parameters
    ----------
    img_size : int
        Size of the image (it is a square).

    patch_size : int
        Size of the patch (it is a square).

    in_chans : int
        Number of input channels.

    embed_dim : int
        The embedding dimension.

    Attributes
    ----------
    n_patches : int
        Number of patches inside of our image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding.
    """
    def __init__(self, img_size, patch_size, in_channels = 3, embed_dim = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size = patch_size, stride = patch_size,)

    def forward(self, x):
        """ Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape '(n_samples, in_chans, img_size, img_size)'.
       
        Returns
        --------
        torch.Tensor
            Shape '(n_samples, n_patches, embed_dim)'.
        """
        x = self.proj(x) # (n_samples, emed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2) # (n_samples, emed_dim, n_patches)
        x = x.transpose(1,2) # n_samples, n_patches, embed_dim)
        return x

In [3]:
x = torch.randn(1,3,224,224)

model = PatchEmbed(img_size=224,patch_size=3,embed_dim=768,in_channels=3)
x1 = model(x)
x1.shape

torch.Size([1, 5476, 768])

In [4]:
module = nn.Dropout(p=0.2)
inp = torch.ones(3,5)
module(inp)

tensor([[1.2500, 1.2500, 1.2500, 1.2500, 0.0000],
        [0.0000, 1.2500, 1.2500, 1.2500, 1.2500],
        [1.2500, 1.2500, 0.0000, 1.2500, 1.2500]])

In [5]:
class Attention(nn.Module):
    """ Attention mechanism

    Parameters
    ----------
    dim : int
        The input and out dimension of per token features.

    n_heads: int
        Number of attention heads.

    qkv_bias : bool
        If True, then we include bias to the query, key and value projections.

    attn_p : float
        Dropout probability applied to the query, key and value tensors.

    proj_p : float
        Dropout probability applied to the output tensor.

    Attributes
    ----------
    scale : float
        Normalizing constant for the dot product.
    qkv : nn.Linear
        Linear projection for the query, key and value.
    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention
        heads and maps it into a new space.
    attn_drop, proj_drop : nn.Dropout
        Dropout layers.
    """
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim//n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """ Run forward pass.

        Parameters
        ----------
        x: torch.Tensor
            Shape '(n_samples, n_patches +1, dim)'.

        Returns 
        ------
        torch.Tensor
            Shape '(n_samples, n_patches +1, dim)'.
        """
        n_samples, n_tokens, dim = x.shape
 
        if dim != self.dim:
            raise ValueError
            
        qkv = self.qkv(x) # (n_samples, n_patches +1, 3 * dim)
        qkv = qkv.reshape( n_samples, n_tokens, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1,head_dim )
        q, k, v = qkv[0], qkv[1], qkv[2]

        k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches +1)

        dp = (q @ k_t) * self.scale

        attn = dp.softmax(dim=-1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1, 2)
        weighted_avg = weighted_avg.flatten(2)

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)

        return x
    

In [6]:
model1 = Attention(768)
model1(x1).shape

torch.Size([1, 5476, 768])

In [7]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        return self.fc2(self.drop(self.act(self.fc1(x)))) 

In [8]:
class Block(nn.Module):
    """Transformer block.

    Parameters
    ----------
    dim : int
        Embedding dimensions.

    n_heads : int
        number of attention heads.

    mlp_ratio : float
        Determines the hidden dimension size of the "MLP' module with rewspect to 'dim'.

    qkv_bias : bool
        If True the we include bias to the query, key and value projections.

    p, attn_p : float
        Dropout probability
        
    Attributes
    ----------
    norm1, norm2 : LayerNorm.
        Layer normalization.

    attn : Attention
        Attention module.

    mlp : MLP
        MLP module.
    """
    def __init__(self, dim, n_heads, mlp_ratio= 4.0, qkv_bias=True, p=0., attn_p = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim,n_heads=n_heads, qkv_bias = qkv_bias,attn_p = attn_p, proj_p = p)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features = hidden_features, out_features=dim)

    def forward(self,x):
        """ Run forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Shape'(n_samples, n_patches + 1, dim)'

        Returns
        ----------
        torch.Tensor
            Shape'(n_samples, n_patches + 1, dim)'      
        """
        x += self.attn(self.norm1(x))
        x += self.mlp(self.norm2(x))

        return x
    

In [9]:
class Visiontransformer(nn.Module):
    """Simplified implementation of the Vision transformer.

    Parameters
    ----------
    img_size : int
        both height and the width of the image (it is a square).

    patch_size : int
        both height and the width of the patch (it is a square).

    in_chans : int
        Number of input channels.
        
    n_classes: int
        Number of classes
        
    embed_dim : int
        Dimensionality of the token/patch embeddings.
        
    depth : int 
        Number of blocks
        
    n_heads : int
        Number of attention heads.
        
    mlp_ratio : float
        Determines the hidden dimension of the 'MLP' module.
        
    qkv_bias : bool
        If True, we include bias to the query, key, and value projections.

    p, attn_p : float
        Dropout probability.
        
    Attributes
    ----------
    patch_embed : PatchEmbed
        Instance of 'PatchEmbed' layer.
        
    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence.
        It has 'embed_dim' elements.

    pos_emb : nn.Parameter
        Positional embedding of the cls token + all the patches.
        It has '(n+patches + 1) * embed_dim' elements.
    pos_drop : nn.Dropout
        Dropout layer.
    blocks : nn.ModuleList
        List of 'Block' modules.

    norm : nn.LayerNorm
        Layer normalization
    """
    def __init__(
        self, 
        img_size = 384,
        patch_size =16,
        in_chans = 3,
        n_classes = 1000,
        embed_dim =768,
        depth =12,
        n_heads =12,
        mlp_ratio = 4.,
        qkv_bias =True,
        p=0.,
        attn_p=0.,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size= img_size, patch_size= patch_size, in_channels=in_chans, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,1 + self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=p)
        self.block = nn.ModuleList([Block(dim = embed_dim, 
                           n_heads = n_heads,
                           mlp_ratio = mlp_ratio, 
                           qkv_bias = qkv_bias, 
                           p = p, 
                           attn_p = attn_p,
                          )
                          for _ in range(depth)
                                   ])
        self.norm = nn.LayerNorm(embed_dim, eps= 1e-6)
        self.head = nn.Linear(embed_dim, n_classes)
        
    def forward(self, x):
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples, -1, -1)
        x = torch.cat((cls_token, x), dim =1)
        x += self.pos_embed
        x = self.pos_drop(x)

        for block in self.block:
            x = block(x)
        x = self.norm(x)
        print(x.shape)
        cls_token_final = x[:,0] # just the CLS token
        print(cls_token_final.shape)
        x = self.head(cls_token_final)
        return x
    

In [10]:
model = Visiontransformer()
x = torch.randn(10,3,384,384)
model(x).shape

torch.Size([10, 577, 768])
torch.Size([10, 768])


torch.Size([10, 1000])

In [11]:
x = torch.randn(10, 577, 768)
x[:,2,:].shape

torch.Size([10, 768])

In [12]:
nn.Parameter(torch.zeros(1,1, 768)).expand(10,-1,-1).shape

torch.Size([10, 1, 768])

In [13]:
(384 // 16) ** 2

576

In [2]:
import torch
import torch.nn as nn

In [3]:
torch.cuda.is_available()

True

In [4]:
import torch
import torch.nn as nn
#from timm.models.layers import DropPath
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention

In [1]:
import torch
torch.__version__

'2.0.1+cu118'

In [3]:
!nvidia-smi


/bin/bash: /root/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Thu Feb 20 11:34:58 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.46                 Driver Version: 531.61       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3070 L...    On | 00000000:01:00.0 Off |                  N/A |
| N/A   52C    P0               32W /  N/A|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+---

In [9]:
import torch
import torch.nn as nn
#from timm.models.layers import DropPath
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention

class NATransformerLayer(nn.Module):
    """Neighborhood Attention Transformer Layer with MLP."""
    def __init__(self, dim, num_heads, kernel_size=7, dilation=1, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = NeighborhoodAttention(dim, kernel_size, dilation, num_heads)
        #self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )

    def forward(self, x):
        x = x + (self.attn(self.norm1(x)))
        x = x + (self.mlp(self.norm2(x)))
        return x

class PatchEmbed(nn.Module):
    """Converts input image into patches."""
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return self.norm(x)

class DiNAT(nn.Module):
    """Simplified DiNAT model."""
    def __init__(self, img_size=224, patch_size=4, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, 3, embed_dim)
        self.layers = nn.ModuleList([
            NATransformerLayer(embed_dim * (2**i), num_heads[i]) for i in range(len(depths))
        ])
        self.norm = nn.LayerNorm(embed_dim * (2**(len(depths) - 1)))
        self.head = nn.Linear(embed_dim * (2**(len(depths) - 1)), 1000)

    def forward(self, x):
        x = self.patch_embed(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x.mean(dim=1))
        return self.head(x)

# Example usage
model = DiNAT()
img = torch.randn(1, 3, 224, 224)
output = model(img)
print(output.shape)  # Should be (1, 1000)


ValueError: NeighborhoodAttention2D expected a rank-4 input tensor; got x.dim()=3.

In [15]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Create a tensor of shape [max_len, dim] for encoding
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).float().unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))  # [dim/2]

        # Apply sin and cos functions for positional encoding
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices (sine)
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices (cosine)

        pe = pe.unsqueeze(0)  # Shape becomes [1, max_len, dim]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to the input tensor
        return x + self.pe[:, :x.size(1)]  # x.size(1) is the length of the sequence

# Example usage for ViT
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, dim=768, num_classes=1000):
        super(VisionTransformer, self).__init__()
        
        self.patch_size = patch_size
        self.dim = dim
        self.num_classes = num_classes

        # Calculate the number of patches
        self.num_patches = (img_size // patch_size) ** 2

        # Define the embedding layer for patches
        self.patch_embeddings = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(dim, max_len=self.num_patches)

        # Transformer layers (simplified version)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=8),
            num_layers=12
        )

        # Classification head
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        # Extract patches and embed them
        x = self.patch_embeddings(x)  # Shape: [batch_size, dim, num_patches, num_patches]
        x = x.flatten(2).transpose(1, 2)  # Shape: [batch_size, num_patches, dim]

        # Add positional encoding
        x = self.positional_encoding(x)

        # Pass through transformer encoder
        x = self.encoder(x)

        # Classification head (take the output of the [CLS] token)
        x = x.mean(dim=1)  # Global average pooling

        # Final classification layer
        x = self.fc(x)
        return x

# Example input (batch_size=2, img_size=224)
model = VisionTransformer(img_size=224, patch_size=16)
sample_input = torch.randn(2, 3, 224, 224)  # Batch of 2 images with 3 channels (RGB)
output = model(sample_input)
print(output.shape)  # Should output: torch.Size([2, 1000])




torch.Size([2, 1000])


In [16]:
torch.zeros(5000, 768).shape

torch.Size([5000, 768])

In [17]:
position = torch.arange(0, 5000).float().unsqueeze(1)
position.shape

torch.Size([5000, 1])

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DilatedNeighborhoodAttention(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1, num_heads=8):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        B, H, W, C = x.shape
        qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(3, 0, 4, 1, 2, 5)  # (3, B, num_heads, H, W, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Split into queries, keys, and values
        
        # Compute attention with dilation
        attn_map = self.compute_dilated_attention(q, k, H, W)
        attn_output = (attn_map @ v).permute(0, 2, 3, 1, 4).reshape(B, H, W, C)
        return self.proj(attn_output)
    
    def compute_dilated_attention(self, q, k, H, W):
        """Computes dilated attention weights"""
        pad = self.dilation * (self.kernel_size // 2)
        k_padded = F.pad(k, (0, 0, pad, pad, pad, pad))
        attn_weights = []
        
        for i in range(self.kernel_size):
            for j in range(self.kernel_size):
                k_slice = k_padded[:, :, i::self.dilation, j::self.dilation, :]
                attn_weights.append((q * k_slice).sum(-1))
        
        attn_weights = torch.stack(attn_weights, dim=-1)
        attn_weights = F.softmax(attn_weights * self.scale, dim=-1)
        return attn_weights

# Example Usage
B, H, W, C = 1, 32, 32, 64  # Batch size, Height, Width, Channels
x = torch.randn(B, H, W, C)
attn_layer = DilatedNeighborhoodAttention(dim=C, kernel_size=7, dilation=2, num_heads=8)
out = attn_layer(x)
print(out.shape)  # Expected output: (B, H, W, C)


RuntimeError: The size of tensor a (32) must match the size of tensor b (22) at non-singleton dimension 3

In [None]:
l = [-1,

In [30]:
pip install timm

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting timm
  Downloading timm-1.0.14-py3-none-any.whl.metadata (50 kB)
Collecting huggingface_hub (from timm)
  Downloading huggingface_hub-0.28.1-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.5.2-cp38-abi3-win_amd64.whl.metadata (3.9 kB)
Collecting fsspec>=2023.5.0 (from huggingface_hub->timm)
  Downloading fsspec-2025.2.0-py3-none-any.whl.metadata (11 kB)
Downloading timm-1.0.14-py3-none-any.whl (2.4 MB)
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   ---------------------------------------- 2.4/2.4 MB 67.5 MB/s eta 0:00:00
Downloading huggingface_hub-0.28.1-py3-none-any.whl (464 kB)
Downloading safetensors-0.5.2-cp38-abi3-win_amd64.whl (303 kB)
Downloading fsspec-2025.2.0-py3-none-any.whl (184 kB)
Installing collected packages: safetensors, fsspec, huggingface_hub, timm
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2023.4.0
    Uninstalling fsspec-2023.4.0:
      Successfully uni

In [3]:
pip3 install natten==0.15.1+torch220cu121 -f https://shi-labs.com/natten/wheels/

SyntaxError: invalid syntax (3743640526.py, line 1)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DilatedNeighborhoodAttention(nn.Module):
    def __init__(self, dim, kernel_size=7, dilation=1, num_heads=8):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, H, W, C = x.shape
        qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(3, 0, 4, 1, 2, 5)  # (3, B, num_heads, H, W, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Split into queries, keys, and values

        # Compute attention with dilation
        attn_map = self.compute_dilated_attention(q, k, H, W)
        attn_output = (attn_map @ v).permute(0, 2, 3, 1, 4).reshape(B, H, W, C)
        return self.proj(attn_output)

    def compute_dilated_attention(self, q, k, H, W):
        """Computes dilated attention weights"""
        pad = self.dilation * (self.kernel_size // 2)
        k_padded = F.pad(k, (0, 0, 0, 0, pad, pad, pad, pad))  # Correct padding order

        attn_weights = []
        for i in range(self.kernel_size):
            for j in range(self.kernel_size):
                k_slice = k_padded[:, :, i:i+H, j:j+W, :]  # Extract neighborhood
                attn_weights.append((q * k_slice).sum(-1))

        attn_weights = torch.stack(attn_weights, dim=-1)  # (B, num_heads, H, W, kernel_size**2)
        attn_weights = F.softmax(attn_weights * self.scale, dim=-1)  # Apply softmax
        return attn_weights.view(q.shape[0], q.shape[1], q.shape[2], self.kernel_size, self.kernel_size)

# Example Usage
B, H, W, C = 1, 32, 32, 64  # Batch size, Height, Width, Channels
x = torch.randn(B, H, W, C)
attn_layer = DilatedNeighborhoodAttention(dim=C, kernel_size=7, dilation=2, num_heads=8)
out = attn_layer(x)
print(out.shape)  # Expected output: (B, H,


RuntimeError: The size of tensor a (8) must match the size of tensor b (20) at non-singleton dimension 1

In [24]:
import torch
import torch.nn as nn

# Input: (Batch, Channels, Height, Width)
x = torch.rand(1,3,224,224)
x.shape

torch.Size([1, 3, 224, 224])

In [27]:
un_fold = nn.Unfold(kernel_size= 7, dilation=8, padding=0, stride=1)
un_fold(x).shape


torch.Size([1, 147, 30976])

In [12]:
x = torch.rand(2,2)
x

tensor([[0.8289, 0.2952],
        [0.4193, 0.6381]])

In [13]:
x.reshape(4)

tensor([0.8289, 0.2952, 0.4193, 0.6381])

In [6]:
pip install timm

/bin/bash: /root/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Collecting timm
  Downloading timm-1.0.14-py3-none-any.whl.metadata (50 kB)
Collecting huggingface_hub (from timm)
  Downloading huggingface_hub-0.29.0-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading timm-1.0.14-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading huggingface_hub-0.29.0-py3-none-any.whl (468 kB)
Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (461 kB)
Installing collected packages: safetensors, huggingface_hub, timm
Successfully installed huggingface_hub-0.29.0 safetensors-0.5.2 timm-1.0.14
[0mNote: you may need to restart the kernel to use updated packages.


In [10]:
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention

In [16]:
attn = NeighborhoodAttention(
            64,
            kernel_size=7,
            dilation=2,
            num_heads=12,
        )

In [17]:
x = torch.rand(10,64,224,224)
attn(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (143360x224 and 64x192)

In [15]:
"""
Dilated Neighborhood Attention Transformer.
https://arxiv.org/abs/2209.15001

DiNAT_s -- our alternative model.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
from torch.nn.functional import pad
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
from timm.models.registry import register_model
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention
is_natten_post_017 = hasattr(natten, "context")

from nat import Mlp

model_urls = {
    # ImageNet-1K
    "dinat_s_tiny_1k": "https://shi-labs.com/projects/dinat/checkpoints/imagenet1k/dinat_s_tiny_in1k_224.pth",
    "dinat_s_small_1k": "https://shi-labs.com/projects/dinat/checkpoints/imagenet1k/dinat_s_small_in1k_224.pth",
    "dinat_s_base_1k": "https://shi-labs.com/projects/dinat/checkpoints/imagenet1k/dinat_s_base_in1k_224.pth",
    "dinat_s_large_1k": "https://shi-labs.com/projects/dinat/checkpoints/imagenet1k/dinat_s_large_in1k_224.pth",
    "dinat_s_large_1k_384": "https://shi-labs.com/projects/dinat/checkpoints/imagenet1k/dinat_s_large_in1k_384.pth",
    # ImageNet-22K
    "dinat_s_large_21k": "https://shi-labs.com/projects/dinat/checkpoints/imagenet22k/dinat_s_large_in22k_224.pth",
}


class NATransformerLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        kernel_size=7,
        dilation=1,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        **kwargs
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        extra_args = {"rel_pos_bias": True} if is_natten_post_017 else {"bias": True}
        self.attn = NeighborhoodAttention(
            dim,
            kernel_size=kernel_size,
            dilation=dilation,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            **extra_args,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchMerging(nn.Module):
    """
    Based on Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = pad(x, (0, 0, 0, W % 2, 0, H % 2))
            _, H, W, _ = x.shape

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, (H + 1) // 2, (W + 1) // 2, 4 * C)  # B H/2 W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x


class BasicLayer(nn.Module):
    """
    Based on Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim,
        depth,
        num_heads,
        kernel_size,
        dilations=None,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        norm_layer=nn.LayerNorm,
        downsample=None,
    ):

        super().__init__()
        self.dim = dim
        self.depth = depth

        # build blocks
        self.blocks = nn.ModuleList(
            [
                NATransformerLayer(
                    dim=dim,
                    num_heads=num_heads,
                    kernel_size=kernel_size,
                    dilation=1 if dilations is None else dilations[i],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i]
                    if isinstance(drop_path, list)
                    else drop_path,
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class PatchEmbed(nn.Module):
    """
    From Swin Transformer
    https://arxiv.org/abs/2103.14030
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size
        )
        self.norm = None if norm_layer is None else norm_layer(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        if W % self.patch_size[1] != 0:
            x = pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)
        x = x.permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class DiNAT_s(nn.Module):
    def __init__(
        self,
        patch_size=4,
        in_chans=3,
        num_classes=1000,
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        kernel_size=7,
        dilations=None,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.2,
        norm_layer=nn.LayerNorm,
        patch_norm=True,
        **kwargs
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,
        )

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                kernel_size=kernel_size,
                dilations=None if dilations is None else dilations[i_layer],
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
            )
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = (
            nn.Linear(self.num_features, num_classes)
            if num_classes > 0
            else nn.Identity()
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {"rpb"}

    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x).flatten(1, 2)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


@register_model
def dinat_s_tiny(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        embed_dim=96,
        mlp_ratio=4,
        drop_path_rate=0.2,
        kernel_size=7,
        dilations=[
            [1, 8],
            [1, 4],
            [1, 2, 1, 2, 1, 2],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_tiny_1k"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model


@register_model
def dinat_s_small(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        embed_dim=96,
        mlp_ratio=4,
        drop_path_rate=0.3,
        kernel_size=7,
        dilations=[
            [1, 8],
            [1, 4],
            [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_small_1k"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model


@register_model
def dinat_s_base(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        embed_dim=128,
        mlp_ratio=4,
        drop_path_rate=0.5,
        kernel_size=7,
        dilations=[
            [1, 8],
            [1, 4],
            [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_base_1k"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model


@register_model
def dinat_s_large(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        embed_dim=192,
        mlp_ratio=4,
        drop_path_rate=0.35,
        kernel_size=7,
        dilations=[
            [1, 8],
            [1, 4],
            [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_large_1k"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model


@register_model
def dinat_s_large_384(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        embed_dim=192,
        mlp_ratio=4,
        drop_path_rate=0.35,
        kernel_size=7,
        dilations=[
            [1, 13],
            [1, 6],
            [1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_large_1k_384"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model


@register_model
def dinat_s_large_21k(pretrained=False, **kwargs):
    model = DiNAT_s(
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        embed_dim=192,
        mlp_ratio=4,
        drop_path_rate=0.2,
        kernel_size=7,
        dilations=[
            [1, 8],
            [1, 4],
            [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
            [1, 1],
        ],
        **kwargs
    )
    if pretrained:
        url = model_urls["dinat_s_large_21k"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint)
    return model



ModuleNotFoundError: No module named 'nat'

In [None]:
class NATransformerLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        kernel_size=7,
        dilation=1,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        **kwargs
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        extra_args = {"rel_pos_bias": True} if is_natten_post_017 else {"bias": True}
        self.attn = NeighborhoodAttention(
            dim,
            kernel_size=kernel_size,
            dilation=dilation,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            **extra_args,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x =  
        x = shortcut + (self.mlp(self.norm2(x)))
        return x
