In [1]:
import torch as t 
import torch.nn.functional as F

In [2]:
class Mamba(t.nn.Module):

    def __init__(self, input_shape):
        super(Net, self).__init__()
        
        rms_norm = t.nn.RMSNorm(list(input_shape))
        
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = t.nn.Conv2d(1, 6, 5)
        # an affine operation: y = Wx + b
        self.L1 = t.nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.L2 = t.nn.Linear(120, 84)
        self.L3 = t.nn.Linear(84, 10)
        self.L4 = t.nn.Linear(84, 10)

        #Activation function 
        self.SiLU = t.nn.SiLU()

        self.SoftMax = nn.Softmax(dim=1)
        
    def forward(self, u):

        #MAMBA block starts 
        
        x = rms_norm(u)

        z = self.L1(x)

        x = self.L2(x)

        x = self.conv1(x)

        x = self.SiLU(x)

        z = self.SiLU(x) 

        x, final_state = self.SSD(x, A, B, C)

        x = t.matmul(x, z)

        x = self.L3(x) + u 

        #MAMBA block ends 

        x = rms_norm(x)

        x = self.L4(x)

        x = SoftMax(x)
        
        return x

    def segsum(x):
        """Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
        which is equivalent to a scalar SSM."""
        T = x.size(-1)
        x_cumsum = t.cumsum(x, dim=-1)
        x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
        mask = t.tril(t.ones(T, T, device=x.device, dtype=bool), diagonal=0)
        x_segsum = x_segsum.masked_fill(~mask, -t.inf)
        return x_segsum
    
    def ssd(X, A, B, C, block_len=64, initial_states=None):
        """
        Arguments:
        X: (batch, length, n_heads, d_head)
        A: (batch, length, n_heads)
        B: (batch, length, n_heads, d_state)
        C: (batch, length, n_heads, d_state)
        Return:
        Y: (batch, length, n_heads, d_head)
        """
        assert X.dtype == A.dtype == B.dtype == C.dtype
        assert X.shape[1] % block_len == 0
        # Rearrange into blocks/chunks
        X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
        A = rearrange(A, "b c l h -> b h c l")
        A_cumsum = t.cumsum(A, dim=-1)
        # 1. Compute the output for each intra-chunk (diagonal blocks)
        L = t.exp(segsum(A))
        Y_diag = t.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
        # 2. Compute the state for each intra-chunk
        # (right term of low-rank factorization of off-diagonal blocks; B terms)
        decay_states = t.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
        states = t.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
        # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
        # (middle term of factorization of off-diag blocks; A terms)
        if initial_states is None:
            initial_states = t.zeros_like(states[:, :1])
        states = t.cat([initial_states, states], dim=1)
        decay_chunk = t.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
        new_states = t.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
        states, final_state = new_states[:, :-1], new_states[:, -1]
        # 4. Compute state -> output conversion per chunk
        # (left term of low-rank factorization of off-diagonal blocks; C terms)
        state_decay_out = t.exp(A_cumsum)
        Y_off = t.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
        # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
        Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
        return Y, final_state


In [6]:
rms_norm = t.nn.RMSNorm([2, 3])
Input = t.randn(2, 2, 3)
rms_norm(Input)

tensor([[[ 0.9432, -1.0047, -0.1954],
         [-0.4737,  1.8013, -0.7704]],

        [[-0.0681, -1.3904, -1.0567],
         [-0.8493, -1.2979,  0.7348]]], grad_fn=<MulBackward0>)

In [None]:
import numpy as np

In [8]:
Input[0]

tensor([[ 0.8739, -0.9310, -0.1811],
        [-0.4389,  1.6691, -0.7138]])

In [9]:
list((1,2))

[1, 2]