In [6]:
import torch
import torch.nn as nn
from torch.jit import Final
import torch.nn.functional as F
from timm.layers import use_fused_attn

## Vanira Attention

In [10]:
class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [16]:
# Hyperparameters
dim = 384  # Dimension of the input features
num_heads = 6  # Number of attention heads
batch_size = 4  # Number of samples in a batch
tokens = 257  # Length of the input sequences

# Create a random input tensor
input_data = torch.randn(batch_size, tokens, dim)

# Initialize the attention layer
att = Attention(dim=dim, num_heads=num_heads)

# Run the forward pass
output_tensor = att(input_data)

print("Output tensor shape:", output_tensor.shape)

Output tensor shape: torch.Size([4, 257, 384])


## Sigmoid Attention

In [31]:
class Sigmoid_Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        # Custom sigmoid attention calculation
        print(torch.matmul(q, k.transpose(-2, -1))[0][0][0][:10])
        sigmoid_att = torch.sigmoid(torch.matmul(q, k.transpose(-2, -1)))
        print(sigmoid_att[0][0][0][:10])
        sigmoid_att[0][0][0][2] = 0.8
        print(sigmoid_att[0][0][0][:10])
        inverse_sigmoid_att = torch.log(sigmoid_att / (1 - sigmoid_att))
        print(inverse_sigmoid_att[0][0][0][:10])
        attn = F.softmax(inverse_sigmoid_att / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)), dim=-1)

        # Apply dropout to attention
        attn = self.attn_drop(attn)
        
        # Compute output
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [32]:
# Hyperparameters
dim = 384  # Dimension of the input features
num_heads = 6  # Number of attention heads
batch_size = 4  # Number of samples in a batch
tokens = 257  # Length of the input sequences

attention = Sigmoid_Attention(dim=dim, num_heads=num_heads)
x = torch.randn(batch_size, tokens, dim)
output = attention(x)
print(output.shape)  # 

tensor([-2.4186, -1.2876,  2.0219, -2.0418, -1.5648,  2.4395, -0.2220,  3.8145,
        -0.1355, -1.0452], grad_fn=<SliceBackward0>)
tensor([0.0818, 0.2163, 0.8831, 0.1149, 0.1730, 0.9198, 0.4447, 0.9784, 0.4662,
        0.2602], grad_fn=<SliceBackward0>)
tensor([0.0818, 0.2163, 0.8000, 0.1149, 0.1730, 0.9198, 0.4447, 0.9784, 0.4662,
        0.2602], grad_fn=<SliceBackward0>)
tensor([-2.4186, -1.2876,  1.3863, -2.0418, -1.5648,  2.4395, -0.2220,  3.8145,
        -0.1355, -1.0452], grad_fn=<SliceBackward0>)
torch.Size([4, 257, 384])
