<a href="https://colab.research.google.com/github/Dhanushranga1/AIW4-MambaModels/blob/main/mamba_with_zeta.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
pip install -U zetascale

Collecting zetascale
  Downloading zetascale-2.8.6-py3-none-any.whl.metadata (23 kB)
Collecting argparse<2.0.0,>=1.4.0 (from zetascale)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting beartype (from zetascale)
  Downloading beartype-0.21.0-py3-none-any.whl.metadata (33 kB)
Collecting bitsandbytes (from zetascale)
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting colt5-attention (from zetascale)
  Downloading CoLT5_attention-0.11.1-py3-none-any.whl.metadata (737 bytes)
Collecting einops-exts==0.0.4 (from zetascale)
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
Collecting joblib<1.4.0,>=1.3.0 (from zetascale)
  Downloading joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)
Collecting local-attention (from zetascale)
  Downloading local_attention-1.11.1-py3-none-any.whl.metadata (907 bytes)
Collecting loguru (from zetascale)
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Collec

In [1]:
import torch
from torch import nn, Tensor
from zeta import SSM

In [14]:
import torch
import torch.nn as nn
from torch import Tensor

# Dummy SSM for testing
class SSM(nn.Module):
    def __init__(self, dim, dt_rank, dim_inner, d_state):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        return self.linear(x)

class CobraBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dt_rank: int,
        dim_inner: int,
        d_state: int,
        channels: int = 64
    ):
        super().__init__()

        self.input_proj = nn.Linear(dim, channels)

        self.conv = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=3,
            padding=1,
            dilation=1,
            groups=1
        )

        self.swish = nn.SiLU()

        self.ssm = SSM(
            channels,
            dt_rank,
            dim_inner,
            d_state
        )

        self.output_proj = nn.Linear(channels, dim)

    def forward(self, x: Tensor):
        # x: [batch, seq_len, dim]
        skip = x
        x = self.input_proj(x)  # [B, L, C]

        x_one = x.transpose(1, 2)          # [B, C, L] for conv1d
        x_one = self.conv(x_one)
        x_one = self.swish(x_one)
        x_one = x_one.transpose(1, 2)      # back to [B, L, C]
        x_one = self.ssm(x_one)

        x_two = self.swish(x)

        # Apply element-wise multiplication instead of matmul for same shape
        out = x_one * x_two

        out = out + x  # Residual connection (after proj)
        out = self.output_proj(out)
        out = out + skip  # Final residual

        return out

# # Example input: [batch, seq_len, dim]
# x = torch.randn(1, 64, 256)
# block = CobraBlock(
#     dim=256,
#     dt_rank=8,
#     dim_inner=256,
#     d_state=256
# )
# out = block(x)
# print(out.shape)  # Should print: torch.Size([1, 64, 256])


torch.Size([1, 64, 256])


In [23]:
import torch
import torch.nn as nn
from zeta.nn.modules import TextTokenEmbedding  # Ensure this module is correctly installed

# Assume CobraBlock is already defined and working as fixed earlier
# from cobra_block import CobraBlock  # You must import your working CobraBlock here

class Cobra(nn.Module):
    def __init__(
        self,
        dim: int,
        dt_rank: int,
        dim_inner: int,
        d_state: int,
        channels: int = 64,
        num_tokens: int = 10000,
        depth: int = 12,
        *args,
        **kwargs
    ):
        super().__init__()
        self.dim = dim
        self.dt_rank = dt_rank
        self.dim_inner = dim_inner
        self.d_state = d_state
        self.channels = channels
        self.num_tokens = num_tokens
        self.depth = depth

        # Token embeddings
        self.embed = TextTokenEmbedding(
            dim,
            num_tokens,
            l2norm_embed=True
        )

        # Transformer-style layers
        self.layers = nn.ModuleList([
            CobraBlock(
                dim,
                dt_rank,
                dim_inner,
                d_state,
                channels,
                *args,
                **kwargs
            ) for _ in range(depth)
        ])

        # Final normalization
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # Input x: [batch, sequence] -> Token indices
        x = self.embed(x)  # Output: [batch, seq_len, dim]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

# Test input
x = torch.randint(0, 10000, (1, 64))  # [batch, seq_len]
model = Cobra(
    dim=256,
    dt_rank=8,
    dim_inner=256,
    d_state=256,
    channels=64,
    num_tokens=10000,
    depth=12
)
out = model(x)
print(out.shape)  # Should be torch.Size([1, 64, 256])


torch.Size([1, 64, 256])


In [24]:
x = torch.randint(0, model.num_tokens, (2, 64))  # batch=2, seq_len=64
out = model(x)  # Should output: [2, 64, 256]
print("Output shape:", out.shape)

# Gradient check
out.sum().backward()
print("Gradient check passed.")


Output shape: torch.Size([2, 64, 256])
Gradient check passed.


In [28]:
import torch
import time

# dummy input: batch = 2, seq_len = 64
x = torch.randint(0, 10000, (2, 64))

model = Cobra(
    dim=256,
    dt_rank=8,
    dim_inner=256,
    d_state=256,
    channels=64,
    num_tokens=10000,
    depth=4
)

model.eval()
with torch.no_grad():
    start = time.time()
    out = model(x)
    end = time.time()

print(out.shape)
print("inference time:", round((end - start) * 1000, 2), "ms")


torch.Size([2, 64, 256])
inference time: 3.69 ms


In [30]:
pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [31]:
from torchinfo import summary

# input_size: (batch_size, seq_len)
summary(model, input_size=(2, 64), device="cpu")


Layer (type:depth-idx)                   Output Shape              Param #
Cobra                                    [2, 64, 256]              --
├─TextTokenEmbedding: 1-1                [2, 64, 256]              --
│    └─Embedding: 2-1                    [2, 64, 256]              2,560,000
├─ModuleList: 1-2                        --                        --
│    └─CobraBlock: 2-2                   [2, 64, 256]              --
│    │    └─Linear: 3-1                  [2, 64, 64]               16,448
│    │    └─Conv1d: 3-2                  [2, 64, 64]               12,352
│    │    └─SiLU: 3-3                    [2, 64, 64]               --
│    │    └─SSM: 3-4                     [2, 64, 64]               4,160
│    │    └─SiLU: 3-5                    [2, 64, 64]               --
│    │    └─Linear: 3-6                  [2, 64, 256]              16,640
│    └─CobraBlock: 2-3                   [2, 64, 256]              --
│    │    └─Linear: 3-7                  [2, 64, 64]           