In [None]:
# Q-ROTA: head-local quaternion rotary for ViTs / windowed attn.
import torch, math
import torch.nn as nn
import torch.nn.functional as F

class QROTA(nn.Module):
    """
    Rotates 4-channel groups (a,b,c,d) of Q/K for each head by a unit quaternion q_h(x,y)
    about a learned, head-fixed axis u_h. Using a fixed axis per head preserves the RoPE-like
    relative property (commuting rotations) while extending to true 3-axis orientation.
    """
    def __init__(self, num_heads, d_head, num_freq=16, rope_base=10000.0, init_alpha=0.0):
        super().__init__()
        assert d_head % 4 == 0, "Q-ROTA requires head_dim divisible by 4"
        self.H, self.dh = num_heads, d_head
        self.G = d_head // 4
        self.num_freq = num_freq

        # Head-wise rotation axis (normalized in forward)
        self.axis = nn.Parameter(torch.randn(self.H, 3))

        # RoPE-like 2D sinusoidal frequencies
        freqs = rope_base ** torch.linspace(0, 1, steps=num_freq)
        self.register_buffer("freqs", freqs, persistent=False)

        # Map 2D sinusoid bank -> angle per head
        self.angle_proj = nn.Linear(4 * num_freq, self.H, bias=True)

        # Head-wise gate alpha in [0,1] to mix baseline & Q-ROTA logits
        self.alpha = nn.Parameter(torch.full((self.H,), init_alpha))

    def _pos_bank(self, xy):  # xy: (...,2) in [0,1]
        x = xy[..., 0:1]; y = xy[..., 1:2]
        wx = 2 * math.pi * x * self.freqs
        wy = 2 * math.pi * y * self.freqs
        return torch.cat([torch.sin(wx), torch.cos(wx), torch.sin(wy), torch.cos(wy)], dim=-1)

    @staticmethod
    def _rotmat_from_axis_angle(axis, theta):
        # Rodrigues' formula; axis normalized
        u = F.normalize(axis, dim=-1, eps=1e-8)
        ux, uy, uz = u.unbind(dim=-1)
        cos_t, sin_t = torch.cos(theta), torch.sin(theta)
        one_m = 1. - cos_t
        R = torch.zeros(u.shape[:-1] + (3, 3), device=u.device, dtype=u.dtype)
        R[...,0,0] = cos_t + ux*ux*one_m
        R[...,0,1] = ux*uy*one_m - uz*sin_t
        R[...,0,2] = ux*uz*one_m + uy*sin_t
        R[...,1,0] = uy*ux*one_m + uz*sin_t
        R[...,1,1] = cos_t + uy*uy*one_m
        R[...,1,2] = uy*uz*one_m - ux*sin_t
        R[...,2,0] = uz*ux*one_m - uy*sin_t
        R[...,2,1] = uz*uy*one_m + ux*sin_t
        R[...,2,2] = cos_t + uz*uz*one_m
        return R

    def rotate_qk(self, Q, K, xy):
        """
        Q,K: (B,H,T,Dh)  xy: (T,2) or (B,T,2) in [0,1]
        Returns rotated Q', K' with same shape.
        """
        B,H,T,Dh = Q.shape
        if xy.dim() == 2:
            xy = xy.unsqueeze(0).expand(B, -1, -1)  # (B,T,2)

        bank = self._pos_bank(xy.reshape(-1,2)).reshape(B, T, -1)  # (B,T,4F)
        theta = self.angle_proj(bank)                               # (B,T,H)
        theta = theta.permute(0,2,1).unsqueeze(-1)                  # (B,H,T,1)

        axis = F.normalize(self.axis, dim=-1, eps=1e-8).view(1,H,1,3).expand(B,-1,T,-1)  # (B,H,T,3)
        R = self._rotmat_from_axis_angle(axis, theta)               # (B,H,T,3,3)

        def apply_rot(X):
            Xg = X.view(B,H,T,self.G,4)
            a  = Xg[...,0:1]
            v  = Xg[...,1:4]
            vR = torch.einsum('bhtgij,bhtgj->bhtgi', R.unsqueeze(3), v)
            return torch.cat([a, vR], dim=-1).view(B,H,T,Dh)

        return apply_rot(Q), apply_rot(K)

    def mix_logits(self, logits_base, logits_rota):
        gate = torch.sigmoid(self.alpha).view(1, self.H, 1, 1)
        return logits_base*(1-gate) + logits_rota*gate


In [None]:
# Patch windowed attention in stages 3 & 4 of MambaVision
import math, types
import torch
import torch.nn as nn
import torch.nn.functional as F

from qrota import QROTA

def _xy_grid_for_window(win_size, device, add_cls=False):
    # returns (L,2) in [0,1] for a win_size x win_size window; optional CLS at (0,0)
    ys, xs = torch.meshgrid(
        torch.linspace(0, 1, steps=win_size, device=device),
        torch.linspace(0, 1, steps=win_size, device=device),
        indexing='ij'
    )
    xy = torch.stack([xs, ys], dim=-1).reshape(-1, 2)
    if add_cls:
        xy = torch.cat([torch.zeros(1,2, device=device), xy], dim=0)
    return xy  # (L,2)

class QROTAMambaWindowAttn(nn.Module):
    """
    Drop-in wrapper for MambaVision windowed attention.
    It assumes the base module exposes .qkv, .proj, .num_heads, and optionally .window_size and .scale.
    Relative position bias (if present) is added to both baseline and Q-ROTA logits.
    """
    def __init__(self, base_attn):
        super().__init__()
        self.base = base_attn
        self.qkv       = base_attn.qkv
        self.proj      = base_attn.proj
        self.attn_drop = getattr(base_attn, 'attn_drop', nn.Identity())
        self.proj_drop = getattr(base_attn, 'proj_drop', nn.Identity())
        self.num_heads = base_attn.num_heads

        # infer dims
        embed_dim = self.qkv.in_features
        self.dh   = embed_dim // self.num_heads
        assert self.dh % 4 == 0, f"Q-ROTA requires head_dim%4==0; got {self.dh}"

        self.has_rpb = hasattr(base_attn, 'relative_position_bias_table') and \
                       hasattr(base_attn, 'relative_position_index')

        # window size if present; else infer from sequence length at runtime
        self.window_size = getattr(base_attn, 'window_size', None)

        # scaling like timm/Swin: if scale present, use it; else 1/sqrt(dh)
        self.scale = getattr(base_attn, 'scale', None)

        self.qrota = QROTA(self.num_heads, self.dh, num_freq=16, rope_base=10000.0, init_alpha=0.0)

    def forward(self, x, mask=None):
        """
        x: (B*nW, L, C) where L=window_size**2
        mask: optional attention mask per window (B*nW, 1, L, L) or (nW, L, L)
        """
        BnW, L, C = x.shape

        # qkv -> (B*nW, L, 3, H, dh) -> q,k,v: (B*nW, H, L, dh)
        qkv = self.qkv(x).reshape(BnW, L, 3, self.num_heads, self.dh).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (BnW,H,L,dh)

        scale = self.scale if self.scale is not None else (self.dh ** -0.5)
        q = q * scale

        # XY per window, normalized [0,1]
        if self.window_size is not None:
            ws = self.window_size if isinstance(self.window_size, int) else self.window_size[0]
        else:
            ws = int(math.sqrt(L))  # fallback
        xy = _xy_grid_for_window(ws, x.device)                  # (L,2)
        xy = xy.unsqueeze(0).expand(BnW, -1, -1)                # (BnW,L,2)

        # Apply Q-ROTA: expects (B,H,T,dh) and (B,T,2)
        q_r, k_r = self.qrota.rotate_qk(q, k, xy)               # (BnW,H,L,dh)

        # Baseline logits & RPB
        logits_base = torch.einsum('bhid,bhjd->bhij', q,   k)   # (BnW,H,L,L)
        logits_rota = torch.einsum('bhid,bhjd->bhij', q_r, k_r) # (BnW,H,L,L)

        if self.has_rpb:
            # standard Swin-style relative position bias
            rpb = self.base.relative_position_bias_table
            rpi = self.base.relative_position_index  # (L*L)
            bias = rpb[rpi.view(-1)].view(L, L, -1).permute(2,0,1).contiguous()  # (H?,L,L) or (nH,L,L)
            # If bias is (nH,L,L) and nH==H, broadcast over BnW
            if bias.dim() == 3 and bias.shape[0] in (1, self.num_heads):
                bias = bias.unsqueeze(0)  # (1,H,L,L) or (1,1,L,L)
            logits_base = logits_base + bias
            logits_rota = logits_rota + bias

        # Attention mask (shifted windows, etc.)
        if mask is not None:
            logits_base = logits_base + mask
            logits_rota = logits_rota + mask

        logits = self.qrota.mix_logits(logits_base, logits_rota)  # per-head gate
        attn   = logits.softmax(dim=-1)
        attn   = self.attn_drop(attn)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)           # (BnW,H,L,dh)
        out = out.transpose(1,2).reshape(BnW, L, C)
        out = self.proj_drop(self.proj(out))
        return out

def _is_stage3_or_4(name: str) -> bool:
    # works for typical hierarchies: 'stages.2' / 'stages.3' or 'stage3'/'stage4'
    name_l = name.lower()
    return ('stages.2' in name_l) or ('stages.3' in name_l) or ('stage3' in name_l) or ('stage4' in name_l)

def patch_mambavision_attention_stages_3_4(model):
    """
    Replace windowed attention modules in stages 3 & 4 with QROTAMambaWindowAttn.
    We look for modules that have a .qkv and .num_heads attribute (standard MHSA).
    """
    replaced = []
    for name, m in model.named_modules():
        if not _is_stage3_or_4(name):
            continue
        # heuristic: a window-attn block typically has qkv/proj/num_heads attributes
        if hasattr(m, 'qkv') and hasattr(m, 'num_heads') and hasattr(m, 'proj'):
            parent_name = name.rsplit('.', 1)[0]
            # find parent module to set attribute
            parent = model
            if '.' in name:
                for p in name.split('.')[:-1]:
                    parent = getattr(parent, p)
            new_attn = QROTAMambaWindowAttn(m)
            setattr(parent, name.split('.')[-1], new_attn)
            replaced.append(name)
    return replaced


In [None]:
# If using pip package:
from mambavision import create_model
model = create_model('mamba_vision_B', pretrained=True)  # or T/S/L etc.
model.eval()

# Patch stages 3 & 4 attention with Q-ROTA
replaced = patch_mambavision_attention_stages_3_4(model)
print("Patched attention modules:", replaced)


In [None]:
import torch
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    y = model(x)
print(y.shape)  # [2, 1000]


# Image Classification with MambaVision
By Ali Hatamizadeh

**Note**: *This tutorial requires it to be run on a GPU. Please click on "Runtime", select "Change runtime type" and use a GPU before proceeding.*

In this tutorial, we demonstrate a simple image classification example using a [MambaVision](https://github.com/NVlabs/MambaVision) ImageNet-pretrained model. For this purpose, we use the popular HuggingFace library to
load the model weights. But first, let's have a quick review of the MambaVision model.

## MambaVision Architecture

MambaVision [1] is the ***first*** Mamba-Transformer hybrid backbone specifically tailored
for vision applications. The architecture of MambaVision is shown below. The first two stages use residual convolutional blocks for fast feature
extraction. Stages 3 and 4 employ both MambaVision and Transformer blocks. The Transformer blocks in the final layers allow for
recovering lost global context and capturing long-range spatial dependencies.

![arch](https://github.com/user-attachments/assets/372b5de2-256d-4dc5-ad68-f9581bd72966)

## MambaVision Mixer Design

As shown below, we redesigned
the original Mamba mixer to make it more suitable for vision
tasks. First, we propose to replace the causal convolution
with regular convolution, since it limits the influence to one
direction, which is unnecessary and restrictive for vision
tasks. In addition, we added a symmetric branch without
SSM, consisting of an additional convolution and Sigmoid
Linear Unit (SiLU) activation, to compensate for any
content lost due to the sequential constraints of SSMs. We
then concatenate the output of both branches and project it
via a final linear layer. This combination ensures that the final feature representation incorporates both the sequential
and spatial information, leveraging the strengths of both
branches.

<p align="center">
<img src="https://github.com/NVlabs/MambaVision/assets/26806394/295c0984-071e-4c84-b2c8-9059e2794182" width=32% height=32%
class="center">
</p>

## ImageNet-1K Benchmarks

We demonstrate a new SOTA Pareto-front in terms of ImageNet Top-1 accuracy and throughput for MambaVision. This means that our model can be used for different real-world applications with varying constrains.


<p align="center">
<img src="https://github.com/NVlabs/MambaVision/assets/26806394/79dcf841-3966-4b77-883d-76cd5e1d4320" width=62% height=62%
class="center">
</p>


## Scalability

MambaVision is the **first** mamba-based vision backbone that has scaled training to the large ImageNet-21K dataset with significantly bigger model sizes. Our largest model, [MambaVision-L3-512-21K](https://huggingface.co/nvidia/MambaVision-L3-512-21K) achieves a **Top-1 accuracy** of **88.1%** with 740 M parameters.

<p align="center">
<img src="https://github.com/user-attachments/assets/9f44c0a9-3b2b-4887-b5d5-f0aeff0a2d0e" width=62% height=62%
class="center">
</p>



In [None]:
rfurhufflttktduhedtrhfhdlbjhtguecnrkhc


# Install the dependencies

Simply run the following which installs mambavision pip package and all the required dependencies.

Note: If you are running this tutorial locally, ensure that a compatible pytorch version is installed. Our recommendation is `pip install torch>=2.6.0+cu124`.

In [None]:
! pip install mambavision==1.1.0

Collecting mambavision==1.1.0
  Downloading mambavision-1.1.0-py3-none-any.whl.metadata (17 kB)
Collecting mamba-ssm==2.2.4 (from mambavision==1.1.0)
  Downloading mamba_ssm-2.2.4.tar.gz (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tensorboardX==2.6.2.2 (from mambavision==1.1.0)
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Collecting ninja (from mamba-ssm==2.2.4->mambavision==1.1.0)
  Using cached ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->mamba-ssm==2.2.4->mambavision==1.1.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidi

## Run the model

The following code snippet downloads the model weights for [
MambaVision-B-21K](https://huggingface.co/nvidia/MambaVision-B-21K) from HuggingFace and runs the inference.

Needless to say, you can replace the model name with any of the models in the following table to run the image classification:


<table>
  <tr>
    <th>Name</th>
    <th>Acc@1(%)</th>
    <th>Acc@5(%)</th>
    <th>#Params(M)</th>
    <th>FLOPs(G)</th>
    <th>Resolution</th>
    <th>Pretraining Dataset</th>
    <th>HF</th>
    <th>Download</th>
  </tr>


<tr>
    <td>MambaVision-L3-512-21K</td>
    <td>88.1</td>
    <td>98.6</td>
    <td>739.6</td>
    <td>489.1</td>
    <td>512x512</td>
    <td>ImageNet-21K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L3-512-21K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L3-512-21K/resolve/main/mambavision_L3_21k_740m_512.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-L3-256-21K</td>
    <td>87.3</td>
    <td>98.3</td>
    <td>739.6</td>
    <td>122.3</td>
    <td>256x256</td>
    <td>ImageNet-21K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L3-256-21K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L3-256-21K/resolve/main/mambavision_L3_21k_740m_256.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-L2-512-21K</td>
    <td>87.3</td>
    <td>98.4</td>
    <td>241.5</td>
    <td>196.3</td>
    <td>512x512</td>
    <td>ImageNet-21K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L2-512-21K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L2-512-21K/resolve/main/mambavision_L2_21k_240m_512.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-L-21K</td>
    <td>86.1</td>
    <td>97.9</td>
    <td>227.9</td>
    <td>34.9</td>
    <td>224x224</td>
    <td>ImageNet-21K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L-21K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L-21K/resolve/main/mambavision_large_21k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-L2</td>
    <td>85.3</td>
    <td>97.2</td>
    <td>241.5</td>
    <td>37.5</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L2-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-L</td>
    <td>85.0</td>
    <td>97.1</td>
    <td>227.9</td>
    <td>34.9</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-B-21K</td>
    <td>84.9</td>
    <td>97.5</td>
    <td>97.7</td>
    <td>15.0</td>
    <td>224x224</td>
    <td>ImageNet-21K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-B-21K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-B-21K/resolve/main/mambavision_base_21k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-B</td>
    <td>84.2</td>
    <td>96.9</td>
    <td>97.7</td>
    <td>15.0</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-B-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-S</td>
    <td>83.3</td>
    <td>96.5</td>
    <td>50.1</td>
    <td>7.5</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-S-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-T2</td>
    <td>82.7</td>
    <td>96.3</td>
    <td>35.1</td>
    <td>5.1</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-T2-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar">model</a></td>
</tr>

<tr>
    <td>MambaVision-T</td>
    <td>82.3</td>
    <td>96.2</td>
    <td>31.8</td>
    <td>4.4</td>
    <td>224x224</td>
    <td>ImageNet-1K</td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-T-1K">link</a></td>
    <td><a href="https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar">model</a></td>
</tr>

</table>

In [None]:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests

model_name = "nvidia/MambaVision-B-21K" # select any models you would like from the table above
resolution = 224 # remember to choose the resolution based on the corresponding model for optimal results (although it would still work with any resolution)

model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)

# eval mode for inference
model.cuda().eval()

print("MambaVision model loaded succesfully ✅")

# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, resolution, resolution)  # MambaVision supports any input resolutions

transform = create_transform(input_size=input_resolution,
                             is_training=False,
                             mean=model.config.mean,
                             std=model.config.std,
                             crop_mode=model.config.crop_mode,
                             crop_pct=model.config.crop_pct)

inputs = transform(image).unsqueeze(0).cuda()
# model inference
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print('################# Inference Result:#################')
print("Predicted class:", model.config.id2label[predicted_class_idx])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/36.4k [00:00<?, ?B/s]

configuration_mambavision.py:   0%|          | 0.00/768 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nvidia/MambaVision-B-21K:
- configuration_mambavision.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_mambavision.py:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/nvidia/MambaVision-B-21K:
- modeling_mambavision.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/391M [00:00<?, ?B/s]

MambaVision model loaded succesfully ✅
################# Inference Result:#################
Predicted class: brown bear, bruin, Ursus arctos


## Acknowledgement

If you find this tutorial helpful, please consider ***citing*** MambaVision [[1]](https://arxiv.org/abs/2407.08083) and giving a star to our [repository](https://github.com/NVlabs/MambaVision):

**Reference:**

- [1] Hatamizadeh, A. and Kautz, J., 2025. Mambavision: A hybrid mamba-transformer vision backbone. In Proceedings of the Computer Vision and Pattern Recognition Conference (pp. 25261-25270).

