In [2]:
!pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu124 # This is used only for Linux, Windows and for cuda 12.4 version. To check cuda version run nvidia-smi in terminal.
!pip3 install torch # This will install the latest version of torch

Looking in indexes: https://download.pytorch.org/whl/cu124


In [4]:
import torch
import torch.nn as nn
from xformers.components.attention import ScaledDotProduct
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks

class LKAWithXformers(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.qkv_proj = nn.Linear(dim, dim * 3)
        self.attn = build_attention(
            {
                "name": "scaled_dot_product",
                "dropout": 0.0,
                "causal": False,
            }
        )
        self.out_proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.heads = heads
        self.dim = dim

    def forward(self, x):
        B, C, H, W = x.shape
        x_flat = x.flatten(2).transpose(1, 2)  # B x HW x C
        x_norm = self.norm(x_flat)

        qkv = self.qkv_proj(x_norm).chunk(3, dim=-1)
        attn_out = self.attn(qkv[0], qkv[1], qkv[2], attn_mask=None)[0]  # B x HW x C

        out = self.out_proj(attn_out)
        out = out.transpose(1, 2).view(B, C, H, W)
        return out * x  # Attention-modulated features
