# Mamba+S4D Parity Experiment

## Introduction

The notebook is based on three types of SSM models: S4(D), Mamba and Hybrid (including n S4d and m Mamba layers).
For Mamba, we use the pytorch version provided in ....
For S4D we use the main implementation of S4 in ... and wrap it into a sequential version of SSM with diagonal transition matrix.

## Download S4D and Mamba code

In [None]:
!pip install --quiet blackcellmagic
%load_ext blackcellmagic
# use %%black at each cell to reformat

In [None]:
! wget -O s4.py https://raw.githubusercontent.com/state-spaces/s4/refs/heads/main/models/s4/s4.py
! wget -O mamba.py https://raw.githubusercontent.com/johnma2006/mamba-minimal/refs/heads/master/model.py
! sed '27,152d' s4.py > tmp && mv tmp s4.py
! grep -v 'lightning' s4.py > tmp && mv tmp s4.py

--2025-07-06 20:28:01--  https://raw.githubusercontent.com/state-spaces/s4/refs/heads/main/models/s4/s4.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 76217 (74K) [text/plain]
Saving to: ‘s4.py’


2025-07-06 20:28:02 (38.3 MB/s) - ‘s4.py’ saved [76217/76217]

--2025-07-06 20:28:02--  https://raw.githubusercontent.com/johnma2006/mamba-minimal/refs/heads/master/model.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12940 (13K) [text/plain]
Saving to: ‘mamba.py’


2025-07-06 20:28:02 (118 MB/s) - ‘mamba.py’ saved [12940/12940]


In [None]:
! pip install --quiet lrcurve

## Imports

In [None]:
import random
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from lrcurve import PlotLearningCurve
from typing import List
from tqdm import tqdm

## Set Device and Random Seed

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
SEED = 6666

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Models

Simple RNN model based on `torch.nn.RNN`.

In [None]:
class RNN(nn.Module):
    def __init__(
        self,
        embedding_dim,
        hidden_size,
        vocab_size,
        num_layers,
        dropout_rate=0.0,
        non_linearity="tanh",
    ):
        super(RNN, self).__init__()

        self.num_layers = num_layers
        self.n_embd = embedding_dim
        self.hidden_size = hidden_size

        self.word_embeddings = nn.Embedding(vocab_size, self.n_embd)

        # The RNN takes word embeddings as input, and outputs hidden states
        # with dimensionality hidden_size.
        self.rnn = nn.RNN(
            input_size=self.n_embd,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
            dropout=dropout_rate,
            nonlinearity=non_linearity,
        )

        self.head = nn.Linear(self.hidden_size, vocab_size)

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        rnn_out, state = self.rnn(embeds)
        logits = self.head(rnn_out)
        return logits

Sequential S4D model based on `SSMKernelDiag` from `https://github.com/state-spaces/s4/blob/main/models/s4/s4.py`.

In [None]:
from s4 import *


class S4DLayer(nn.Module):
    def __init__(
        self,
        d_model,
        d_state,
        channels=1,
        activation="gelu",
        dropout=0.0,
        use_residual=True,
        **ssm_kwargs
    ):
        super().__init__()
        self.kernel = SSMKernelDiag(
            d_model=d_model, d_state=d_state, channels=channels, **ssm_kwargs
        )
        self.channel_mixer = nn.Linear(channels * d_model, d_model)
        self.activation = Activation(activation)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
        self.use_residual = use_residual
        self.d_model = d_model
        self.d_state = d_state
        self.channels = channels

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)

        Returns:
            output: shape (b, l, d)
        """
        self.kernel._setup_step()
        B, L, d = x.shape
        state = self.default_state(B)
        outputs = []
        for t in range(L):
            x_t = x[:, t]  # (b, d)
            y_t, state = self.kernel.step(x_t, state)
            outputs.append(y_t)
        y = torch.concatenate(outputs, dim=1)  # (b, l, d)
        y = self.channel_mixer(y)
        y = self.activation(y)
        y = self.dropout(y)
        if self.use_residual:
            y = y + x
        y = self.norm(y)
        return y

    def default_state(self, batch_size):
        return self.kernel.default_state(batch_size)


class S4D(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        d_state,
        n_layers,
        activation="gelu",
        dropout=0.0,
        use_residual=True,
        **ssm_kwargs
    ):
        super().__init__()
        self.output_proj = nn.Linear(d_model, vocab_size)
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.seq = nn.Sequential(
            *[
                S4DLayer(
                    d_model=d_model,
                    d_state=d_state,
                    activation=activation,
                    dropout=dropout,
                    use_residual=use_residual,
                    **ssm_kwargs
                )
                for _ in range(n_layers)
            ]
        )

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_state = d_state
        self.n_layers = n_layers

    def forward(self, x):
        """
        Args:
            x: shape (b, l)

        Returns:
            y: shape (b, l, vocab_size)
        """
        x = self.embedding(x)  # (b, l, d)
        x = self.seq(x)
        x = self.output_proj(x)
        return x

Mamba model from `https://github.com/johnma2006/mamba-minimal/blob/master/model.py`.

In [None]:
from mamba import *

Hybrid model obtained from stacking S4D and Mamba layers.

In [None]:
@dataclass
class HybridArgs:
    d_model: int
    d_state: int
    # n_layer: int
    channels: int
    vocab_size: int
    activation: str = "gelu"
    dropout: float = 0.0
    use_residual: bool = True
    init: str = "legs"
    # n_layer: int
    pad_vocab_size_multiple: int = 1
    expand: int = 2
    dt_rank: Union[int, str] = "auto"
    d_conv: int = 4
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)

        if self.dt_rank == "auto":
            self.dt_rank = math.ceil(self.d_model / 16)

        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (
                self.pad_vocab_size_multiple
                - self.vocab_size % self.pad_vocab_size_multiple
            )


class HybridModel(nn.Module):
    def __init__(
        self, args: HybridArgs, mamba_indices: List, num_layers: int, **ssm_kwargs
    ):
        """
        base_model_1: S4D
        base_model_2: Mamba
        model_args: ModelArgs for mamba
        """
        super(HybridModel, self).__init__()
        self.num_layers = num_layers
        d_model = args.d_model
        self.d_model = d_model
        channels = args.channels
        activation = args.activation
        dropout = args.dropout
        use_residual = args.use_residual
        init = args.init
        vocab_size = args.vocab_size
        d_state = args.d_state
        self.d_state = d_state
        pad_vocab_size_multiple = args.pad_vocab_size_multiple
        mamba_args = ModelArgs(
            d_model=d_model,
            n_layer=1,
            vocab_size=vocab_size,
            d_state=d_state,
            pad_vocab_size_multiple=pad_vocab_size_multiple,
        )

        s4d_indices = [i for i in range(num_layers) if i not in mamba_indices]
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.norm = RMSNorm(d_model)
        num_s4d_layers = num_layers - len(mamba_indices)

        list_layers = [0 for _ in range(num_layers)]
        self.layer_names = ["s4d" for _ in range(num_layers)]
        for i in mamba_indices:
            self.layer_names[i] = "mamba"
        for i in range(len(mamba_indices)):
            mamba_block = MambaBlock(mamba_args)
            list_layers[mamba_indices[i]] = mamba_block
        for i in s4d_indices:
            s4d_block = S4DLayer(
                d_model=d_model,
                d_state=d_state,
                channels=channels,
                activation=activation,
                dropout=dropout,
                use_residual=use_residual,
                init=init,
                **ssm_kwargs
            )
            list_layers[i] = s4d_block
        self.layers = nn.ModuleList(list_layers)
        self.out_layer = nn.Linear(d_model, vocab_size)  ##?

    def forward(self, x):
        """
        Args:
            x: (B, L)

        Returns:
            y: (B, L, vocab_size)
        """
        x = self.embedding(x)
        for i, layer in enumerate(self.layers):
            if self.layer_names[i] == "mamba":
                x = layer(self.norm(x)) + x
            else:
                x = layer(x)
        x = self.out_layer(x)
        return x

## Training and Evaluation Functions

In [None]:
def train(model, dataloader, num_epochs, learning_rate):
    model.train()
    optimizer = optim.AdamW(
        model.parameters(), lr=learning_rate, fused=torch.cuda.is_available()
    )
    loss_fn = torch.nn.CrossEntropyLoss()

    mappings = {
        "loss": {"line": "train", "facet": "loss"},
        "acc": {"line": "train", "facet": "acc"},
    }

    facet_config = {
        "loss": {"name": "Cross-Entropy", "limit": [0, None], "scale": "linear"},
        "acc": {"name": "Accuracy", "limit": [0, 1], "scale": "linear"},
    }

    plot = PlotLearningCurve(
        mappings=mappings,
        facet_config=facet_config,
        xaxis_config={"name": "Step", "limit": [0, None]},
    )

    model = model.to(device)

    with plot:
        t = 0
        for epoch in range(num_epochs):
            for batch_num, batch in enumerate(dataloader):
                input_batch, output_batch = batch
                optimizer.zero_grad()
                x = input_batch.to(device)
                y = output_batch.to(device)
                output = model(x)
                d_output = output.shape[-1]
                batch_loss = loss_fn(output.view(-1, d_output), y.view(-1).long())
                batch_loss.backward()
                prediction = torch.argmax(output, dim=2)
                accuracy = torch.mean((prediction == y).float()).item()
                optimizer.step()
                if batch_num % 1 == 0:
                    plot.append(
                        t,
                        {
                            "loss": batch_loss,
                            "acc": accuracy,
                        },
                    )
                    plot.draw()
                t += 1
    return accuracy

In [None]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    correct = total = 0
    for input_batch, output_batch in tqdm(dataloader, desc="Evaluating", leave=False):
        x = input_batch.to(device)
        y = output_batch.to(device)
        output = model(x)
        preds = output.argmax(dim=2)
        correct += (preds == y).sum().item()
        total += y.numel()
    return correct / total if total > 0 else 0.0

## Prepare Parity Dataset

In [None]:
# data generator for mod-n task
class Mod_n_Dataset(Dataset):
    def __init__(self, num_samples, seq_length, n):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.n = n

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        random_list = [random.choice([0, 1]) for _ in range(self.seq_length)]
        cumulative_sum = list(itertools.accumulate(random_list))
        mod_n_list = [x % self.n for x in cumulative_sum]

        # Convert to torch tensors
        random_list_tensor = torch.tensor(random_list, dtype=torch.int64)
        mod_n_list_tensor = torch.tensor(mod_n_list, dtype=torch.int64)

        return random_list_tensor, mod_n_list_tensor


def generate_data_parity(batch_size, seq_length, num_samples):
    dataset = Mod_n_Dataset(num_samples, seq_length, n=2)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

In [None]:
dataloader_train = generate_data_parity(batch_size=128, seq_length=8   , num_samples=1000)
dataloader_test  = generate_data_parity(batch_size=32 , seq_length=10000, num_samples=100)

## Global Variables

In [None]:
NUM_EPOCHS = 1000
LEARNING_RATE = 5e-4

## Train & Evaluate RNN

In [None]:
model_rnn = RNN(embedding_dim=2,
        hidden_size=8,
        vocab_size=2,
        num_layers=1)

In [None]:
train_accuracy_rnn = train(model_rnn, dataloader_train, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE)

In [None]:
test_accuracy_rnn = evaluate(model_rnn, dataloader_test)
print('Test accuracy:', test_accuracy_rnn)

                                                         

Test accuracy: 1.0




## Train & Evaluate S4D

In [None]:
model_s4d = S4D(
    vocab_size=2,
    d_model=8,
    d_state=16,
    n_layers=2,
    init='legs'
)

In [None]:
train_accuracy_s4d = train(model_s4d, dataloader_train, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE)

In [None]:
test_accuracy_s4d = evaluate(model_s4d, dataloader_test)
print('Test accuracy:', test_accuracy_s4d)

                                                         

Test accuracy: 0.500867




## Train & Evaluate Mamba

In [None]:
model_args = ModelArgs(
    d_model=8,
    n_layer=2,
    vocab_size=2,
    d_state=16,
    pad_vocab_size_multiple=1,
)

model_mamba = Mamba(model_args)

In [None]:
train_accuracy_mamba = train(model_mamba, dataloader_train, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE)

In [None]:
test_accuracy_mamba = evaluate(model_mamba, dataloader_test)
print('Test accuracy:', test_accuracy_mamba)

                                                         

Test accuracy: 0.500364




## Train & Evaluate Mamba+S4D

In [None]:
hybrid_args = HybridArgs(
    d_model=8,
    vocab_size=2,
    d_state=16,
    channels=1,
    pad_vocab_size_multiple=1,
)

model_hybrid = HybridModel(hybrid_args, mamba_indices=[1], num_layers=2)

In [None]:
train_accuracy_hybrid = train(model_hybrid, dataloader_train, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE)

In [None]:
test_accuracy_hybrid = evaluate(model_hybrid, dataloader_test)
print('Test accuracy:', test_accuracy_hybrid)

                                                         

Test accuracy: 0.500579




## Conclusion

The parity experiments in this code show that only RNN model can learn parity task with a correct algorithm, so that it can generalize to OOD examples (samples with longer sequence length).
The details of the experiments with different models are as below:


training set: of length 8
test set: of length > 1000

Models performance on train and test sets:




In [None]:
import pandas as pd

data = {
    "training accuracy": [
        train_accuracy_rnn,
        train_accuracy_s4d,
        train_accuracy_mamba,
        train_accuracy_hybrid,
    ],
    "test accuracy": [
        test_accuracy_rnn,
        test_accuracy_s4d,
        test_accuracy_mamba,
        test_accuracy_hybrid,
    ],
}

row_labels = ["RNN", "S4D", "Mamba", "S4D+Mamba"]

df = pd.DataFrame(data, index=row_labels)
df

Unnamed: 0,training accuracy,test accuracy
RNN,1.0,1.0
S4D,0.997596,0.500867
Mamba,1.0,0.500364
S4D+Mamba,1.0,0.500579
