In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super().__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size,
            padding=0,
            dilation=dilation,
            **kwargs
        )
        
    def forward(self, x):
        if self.padding != 0:
            x = F.pad(x, (self.padding, 0))
        return self.conv(x)

class SelectiveSSM(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        
        # Projections
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.conv1d = CausalConv1d(self.d_inner, self.d_inner, self.d_conv)
        
        # SSM parameters
        self.A_log = nn.Parameter(torch.log(torch.randn(self.d_inner, self.d_state)))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        
        # Selective parameters
        self.dt_proj = nn.Linear(self.d_inner, self.d_inner)
        self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
        self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
        
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # Project input
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)
        
        # Conv step
        x = rearrange(x, 'b l d -> b d l')
        x = self.conv1d(x)
        x = rearrange(x, 'b d l -> b l d')
        
        # Discretization
        dt = self.dt_proj(x)
        dt = torch.sigmoid(dt)
        
        A = -torch.exp(self.A_log.float())
        B = self.B_proj(x)
        C = self.C_proj(x)
        
        # Selective scan
        y = self.selective_scan(x, dt, A, B, C, self.D)
        
        # Gating and output
        y = y * F.silu(z)
        return self.out_proj(y)
    
    def selective_scan(self, u, delta, A, B, C, D):
        batch, seq_len, d_inner = u.shape
        d_state = A.shape[-1]
        
        # Proper broadcasting shapes
        delta = delta.unsqueeze(-1)  # (b, l, d_inner, 1)
        A = A.view(1, 1, d_inner, d_state)  # (1, 1, d_inner, d_state)
        B = B.unsqueeze(2)  # (b, l, 1, d_state)
        C = C.unsqueeze(2)  # (b, l, 1, d_state)
        
        # Discretize
        deltaA = torch.exp(delta * A)
        deltaB = delta * B * u.unsqueeze(-1)  # (b, l, d_inner, d_state)
        
        # Initialize state
        state = torch.zeros(batch, d_inner, d_state, device=u.device)
        outputs = []
        
        # Recurrent scan
        for i in range(seq_len):
            state = deltaA[:, i] * state + deltaB[:, i]
            output = torch.einsum('bdn,bdn->bd', state, C[:, i])
            outputs.append(output + D * u[:, i])
        
        return torch.stack(outputs, dim=1)

class DirectionalScan(nn.Module):
    def __init__(self, ssm_layer, d_model):
        super().__init__()
        self.ssm_layer = ssm_layer
        self.d_model = d_model
        self.proj = nn.Linear(d_model, d_model)
        
    def forward(self, x, h, w):
        b, l, d = x.shape
        assert l == h * w
        
        # Horizontal scans
        x_h = rearrange(x, 'b (h w) d -> (b w) h d', h=h, w=w)
        x_h = self.ssm_layer(x_h)
        x_h = rearrange(x_h, '(b w) h d -> b (h w) d', h=h, w=w)
        
        # Vertical scans
        x_v = rearrange(x, 'b (h w) d -> (b h) w d', h=h, w=w)
        x_v = self.ssm_layer(x_v)
        x_v = rearrange(x_v, '(b h) w d -> b (h w) d', h=h, w=w)
        
        return self.proj(x_h + x_v)

class VisionMambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
        self.scan = DirectionalScan(self.ssm, d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
    def forward(self, x, h, w):
        x = x + self.scan(self.norm(x), h, w)
        x = x + self.mlp(self.norm(x))
        return x

class PatchEmbed2D(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        x = self.proj(x)
        return rearrange(x, 'b d h w -> b (h w) d')

class VisionMamba(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, 
                 num_classes=1000, embed_dim=768, depth=4, 
                 d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.patch_embed = PatchEmbed2D(img_size, patch_size, in_chans, embed_dim)
        self.pos_drop = nn.Dropout(p=0.1)
        
        self.blocks = nn.ModuleList([
            VisionMambaBlock(embed_dim, d_state, d_conv, expand)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
    def forward(self, x):
        x = self.patch_embed(x)
        b, l, d = x.shape
        h = w = int(l ** 0.5)
        x = self.pos_drop(x)
        
        for blk in self.blocks:
            x = blk(x, h, w)
        
        x = self.norm(x.mean(dim=1))
        return self.head(x)

In [11]:
from torchsummary import summary
import numpy as pd
from tabulate import tabulate

def print_model_summary(model, input_size=(3, 224, 224)):
    # Move model to CUDA if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Create sample input on same device
    x = torch.randn(1, *input_size).to(device)
    
    # Print standard summary
    try:
        from torchsummary import summary
        summary(model, input_size=input_size)
    except ImportError:
        print("torchsummary not installed, install with: pip install torchsummary")
    
    # Create dimension table
    dimension_table = []
    
    # Track forward pass
    with torch.no_grad():
        dimension_table.append({
            'Layer': 'Input',
            'Input Shape': '-',
            'Output Shape': str(tuple(x.shape))
        })
        
        # Patch Embedding
        x = model.patch_embed(x)
        dimension_table.append({
            'Layer': 'PatchEmbed',
            'Input Shape': dimension_table[-1]['Output Shape'],
            'Output Shape': str(tuple(x.shape))
        })
        
        # Blocks
        b, l, d = x.shape
        h = w = int(l ** 0.5)
        x = model.pos_drop(x)
        
        for i, blk in enumerate(model.blocks):
            x_in = x
            x = blk(x, h, w)
            dimension_table.append({
                'Layer': f'MambaBlock_{i+1}',
                'Input Shape': str(tuple(x_in.shape)),
                'Output Shape': str(tuple(x.shape))
            })
        
        # Final layers
        x = model.norm(x.mean(dim=1))
        x = model.head(x)
        dimension_table.append({
            'Layer': 'Norm+Head',
            'Input Shape': dimension_table[-1]['Output Shape'],
            'Output Shape': str(tuple(x.shape))
        })
    
    # Print formatted table
    print("\nLayer Dimension Flow:")
    print(tabulate(dimension_table, headers="keys", tablefmt="grid"))

In [15]:
model = VisionMamba(
    img_size=224,
    patch_size=16,
    num_classes=1000,
    embed_dim=768,
    depth=12,
    d_state=16
)

x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)  # Should be torch.Size([1, 1000])


torch.Size([1, 1000])


In [16]:
print_model_summary(model)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
      PatchEmbed2D-2             [-1, 196, 768]               0
           Dropout-3             [-1, 196, 768]               0
         LayerNorm-4             [-1, 196, 768]           1,536
            Linear-5             [-1, 14, 3072]       2,359,296
            Linear-6             [-1, 14, 3072]       2,359,296
            Conv1d-7             [-1, 1536, 14]       9,438,720
            Conv1d-8             [-1, 1536, 14]       9,438,720
      CausalConv1d-9             [-1, 1536, 14]               0
     CausalConv1d-10             [-1, 1536, 14]               0
           Linear-11             [-1, 14, 1536]       2,360,832
           Linear-12             [-1, 14, 1536]       2,360,832
           Linear-13               [-1, 14, 16]          24,576
           Linear-14               [-1,