<a href="https://colab.research.google.com/github/aquibjaved/BitsAndPieces-Computation/blob/main/micorjamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F

In [16]:
class Mamba(nn.Module):
    def __init__(self, d, n=16):
        super().__init__()
        self.d, self.n, self.p = d, n, nn.Linear(d, 2*d+2*n+1)
        self.c, self.dt, self.o = nn.Conv1d(d,d,4,padding=3,groups=d), nn.Linear(n,d), nn.Linear(d,d)
        self.register_buffer('A', -T.exp(T.randn(d,n)*0.5))
        self.register_buffer('D', T.randn(d))
    def forward(self, u):
        b,l,_ = u.shape
        xz,d,B,C = self.p(u).split([2*self.d,self.n,self.n,1],-1)
        x,z = xz.chunk(2,-1)
        x = F.silu(self.c(x.transpose(1,2))[:,:,:l].transpose(1,2))
        d = F.softplus(self.dt(d))
        h,y = T.zeros(b,self.d,self.n,device=u.device), []
        A, D = self.A.to(u.device), self.D.to(u.device)
        for i in range(l):
            h = h*T.exp(A*d[:,i:i+1].transpose(1,2)) + x[:,i:i+1].transpose(1,2)*B[:,i].unsqueeze(1)
            y.append((h * C[:,i].view(b,1,1)).sum(-1).squeeze(0) + D*x[:,i])
        return self.o(T.stack(y,1)*F.silu(z))

class MicroJamba(nn.Module):
    def __init__(self, v=8192, d=256, l=8, r=.875):
        super().__init__()
        self.e = nn.Embedding(v,d)
        self.b = nn.ModuleList([nn.Sequential(nn.LayerNorm(d), Mamba(d) if i<int(l*r) else nn.MultiheadAttention(d,4,batch_first=True)) for i in range(l)])
        self.h = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,v,bias=False))
    def forward(self, x):
        x = self.e(x)
        for n,m in self.b:
            x = x + (m(n(x)) if isinstance(m,Mamba) else m(n(x),n(x),n(x))[0])
        return self.h(x)

# Quick tests
def test_model():
    print("🧪 Testing Micro Jamba Model...")

    # Test 1: Basic forward pass
    model = MicroJamba(v=1000, d=128, l=4, r=0.75)
    x = T.randint(0, 1000, (2, 10))
    print(f"✓ Model created: {sum(p.numel() for p in model.parameters())/1e6:.2f}M params")

    # Test 2: Forward pass
    out = model(x)
    print(f"✓ Forward pass: input {x.shape} → output {out.shape}")

    # Test 3: Gradient flow
    loss = out.mean()
    loss.backward()
    print(f"✓ Backward pass: gradients computed")

    # Test 4: Different sequence lengths
    for seq_len in [5, 20, 50]:
        x = T.randint(0, 1000, (1, seq_len))
        out = model(x)
        print(f"✓ Seq length {seq_len}: output shape {out.shape}")

    # Test 5: Memory usage
    if T.cuda.is_available():
        model = model.cuda()
        x = T.randint(0, 1000, (4, 100)).cuda()
        T.cuda.synchronize()
        mem_before = T.cuda.memory_allocated() / 1e6
        out = model(x)
        T.cuda.synchronize()
        mem_after = T.cuda.memory_allocated() / 1e6
        print(f"✓ GPU memory: {mem_after - mem_before:.2f}MB for batch")

    # Test 6: Layer composition
    mamba_count = sum(1 for seq in model.b if isinstance(seq[1], Mamba))
    attn_count = len(model.b) - mamba_count
    print(f"✓ Layers: {mamba_count} Mamba, {attn_count} Attention ({mamba_count/len(model.b)*100:.0f}% Mamba)")

    print("\n✅ All tests passed!")

# Minimal inference test
def quick_inference():
    model = MicroJamba(v=100, d=64, l=2)
    model.eval()
    x = T.tensor([[1, 2, 3, 4, 5]])
    with T.no_grad():
        out = model(x)
        probs = F.softmax(out[0, -1], dim=-1)
        next_token = T.multinomial(probs, 1)
    print(f"\n🎲 Quick inference: {x.tolist()} → next token: {next_token.item()}")


In [17]:
test_model()


🧪 Testing Micro Jamba Model...
✓ Model created: 0.49M params
✓ Forward pass: input torch.Size([2, 10]) → output torch.Size([2, 10, 1000])
✓ Backward pass: gradients computed
✓ Seq length 5: output shape torch.Size([1, 5, 1000])
✓ Seq length 20: output shape torch.Size([1, 20, 1000])
✓ Seq length 50: output shape torch.Size([1, 50, 1000])
✓ GPU memory: 30.40MB for batch
✓ Layers: 3 Mamba, 1 Attention (75% Mamba)

✅ All tests passed!


In [18]:
quick_inference()


🎲 Quick inference: [[1, 2, 3, 4, 5]] → next token: 68
