Implementation of Mamba blocks
- ViM block from [ViM](https://github.com/hustvl/Vim)
- VSS block from [VMamba](https://github.com/MzeroMiko/VMamba)
- SiMBA block from [SiMBA](https://github.com/badripatro/simba)

In [1]:
import os

os.chdir("/root/dev/playground/models/mamba")
os.getcwd()

'/root/dev/playground/models/mamba'

In [2]:
import torch

torch.cuda.set_device(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.current_device()

0

In [5]:
!gpustat

[1m[37mdb80750feafb              [m  Wed Nov 20 17:54:48 2024  [1m[30m550.100[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 35°C[m, [32m  0 %[m | [36m[1m[33m 3070[m / [33m24564[m MB |


# ViM block

In [None]:
from typing import Optional

from torch import nn
from torch import Tensor

class Block(nn.Module):
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.):
        """ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection
        Standard block is "LN -> MHA/MLP -> Add", but here is "Add -> LN -> Mixer",
        returning both hidden_states (output of the mixer) and the residual. (for performance reasons)
        """
        super().__init__()
        
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(self.norm, (nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
    
    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
        r""" Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required)
            residual: hidden_states = Mixer(LN(residual))
        """
        
        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)
        
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = reisdual.to(torch.float32)
        
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn