In [6]:
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

ModuleNotFoundError: No module named 'mamba_ssm'

In [2]:
import torch
import torch.nn as nn

class MambaBlock(nn.Module):
    def __init__(self, input_dim, state_dim):
        super(MambaBlock, self).__init__()
        self.input_dim = input_dim
        self.state_dim = state_dim
        
        # Define the matrices A, B, C as learnable parameters
        self.A = nn.Parameter(torch.randn(state_dim, state_dim))
        self.B = nn.Parameter(torch.randn(state_dim, input_dim))
        self.C = nn.Parameter(torch.randn(input_dim, state_dim))

        # Optional: Bias terms for more flexibility
        self.bias = nn.Parameter(torch.randn(state_dim))

    def forward(self, x, state):
        """
        Forward pass of the Mamba block.
        x: Input tensor (batch_size, input_dim)
        state: Previous state tensor (batch_size, state_dim)
        """
        # Update state based on previous state and current input
        # Ensure matrix multiplication dimensions align correctly
        new_state = torch.tanh(torch.matmul(state, self.A.t()) + torch.matmul(x, self.B.t()) + self.bias)
        
        # Compute the output
        output = torch.matmul(new_state, self.C.t())
        return output, new_state

# Example usage
input_dim = 10
state_dim = 20
mamba_block = MambaBlock(input_dim, state_dim)

# Random input and initial state
x = torch.randn(1, input_dim)
state = torch.randn(1, state_dim)

output, new_state = mamba_block(x, state)
print("Output:", output)
print("New State:", new_state)


Output: tensor([[-7.8456,  5.7759, -4.0456, -5.5398, -3.8413,  3.7264, -6.8066,  7.5594,
         -0.9572,  7.0396]], grad_fn=<MmBackward0>)
New State: tensor([[-0.7964,  1.0000, -0.7797,  0.9971, -0.9998, -1.0000,  0.9067, -0.9998,
         -0.9986,  1.0000, -0.4037, -0.9912,  0.9935, -0.9994, -0.9562,  1.0000,
         -0.4218,  1.0000, -1.0000,  1.0000]], grad_fn=<TanhBackward0>)


In [4]:
mamba_block

MambaBlock()

In [1]:
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


["Hey how are you doing?\n\nI'm so glad you're here."]


In [6]:
from torchinfo import summary

summary(model, depth=10)

Layer (type:depth-idx)                   Param #
MambaForCausalLM                         --
├─MambaModel: 1-1                        --
│    └─Embedding: 2-1                    38,615,040
│    └─ModuleList: 2-2                   --
│    │    └─MambaBlock: 3-1              --
│    │    │    └─MambaRMSNorm: 4-1       768
│    │    │    └─MambaMixer: 4-2         26,112
│    │    │    │    └─Conv1d: 5-1        7,680
│    │    │    │    └─SiLU: 5-2          --
│    │    │    │    └─Linear: 5-3        2,359,296
│    │    │    │    └─Linear: 5-4        122,880
│    │    │    │    └─Linear: 5-5        75,264
│    │    │    │    └─Linear: 5-6        1,179,648
│    │    └─MambaBlock: 3-2              --
│    │    │    └─MambaRMSNorm: 4-3       768
│    │    │    └─MambaMixer: 4-4         26,112
│    │    │    │    └─Conv1d: 5-7        7,680
│    │    │    │    └─SiLU: 5-8          --
│    │    │    │    └─Linear: 5-9        2,359,296
│    │    │    │    └─Linear: 5-10       122,880
│    │    │ 

In [5]:
from transformers import PreTrainedModel, PretrainedConfig
import torch.nn as nn

class MambaConfig(PretrainedConfig):
    model_type = "mamba"
    def __init__(self, vocab_size=50257, max_position_embeddings=512,
                 num_hidden_layers=12, hidden_size=768, num_attention_heads=12,
                 intermediate_size=3072, hidden_act="gelu",
                 hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
                 layer_norm_eps=1e-12, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.num_hidden_layers = num_hidden_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.layer_norm_eps = layer_norm_eps
config = MambaConfig()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class MambaModel(PreTrainedModel):
    config_class = MambaConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, input_ids):
        output = self.embeddings(input_ids)
        for layer in self.layers:
            output = layer(output)
        output = self.ln_f(output)
        return output


In [1]:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d
Map: 100%|██████████| 2508/2508 [00:00<00:00, 33819.05 examples/s]


Step,Training Loss
10,3.5087
20,3.1921
30,3.2111
40,3.2446
50,3.0959
60,3.2923
