In [27]:
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Union
from einops import rearrange, repeat, einsum
import math

In [28]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)


In [56]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """ A single Mamba block, as describted in Figure 3 in section 3.4 in the Mamba paper"""
        super().__init__()
        self.args = args
        
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
        
        self.conv1d = nn.Conv1d(
            in_channels = args.d_inner,
            out_channels = args.d_inner, 
            bias = args.conv_bias,
            kernel_size = args.d_conv,
            groups = args.d_inner,
            padding = args.d_conv -1
        )
        
        # x_proj takes `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
        
        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d = args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

In [57]:
args = ModelArgs(
    d_model = 16,
    n_layer = 1, 
    vocab_size= 200
)

In [58]:
block = MambaBlock(args)

In [59]:
sum(p.numel() for p in block.parameters())

3360