## (1) Load model

In [1]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

In [2]:
# =======================
# TRAINING CELL (IMPROVED)
# =======================
import torch
import torch.nn as nn
from torch.optim import Adam

# -----------------------------------
# 1. Load tiny dataset
# -----------------------------------
with open("data/tiny.txt") as f:
    text = f.read()

# Build vocabulary with <unk> token
chars = sorted(list(set(text)))
stoi = {s:i for i,s in enumerate(chars)}
stoi['<unk>'] = len(stoi)                # unknown character token
itos = {i:s for s,i in stoi.items()}
itos[len(stoi)-1] = '<unk>'

def encode(s):
    return torch.tensor([stoi.get(c, stoi['<unk>']) for c in s], dtype=torch.long)

def decode(t):
    out = []
    for i in t:
        idx = int(i)
        if idx in itos:
            out.append(itos[idx])
        else:
            out.append('<unk>')
    return "".join(out)


data = encode(text)
data = data.unsqueeze(0)  # batch dimension

device = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------------
# 2. Create model (hybrid attention already inside your model.py)
# -----------------------------------
model = Mamba(ModelArgs(
    d_model=64,
    n_layer=6,              # More layers = more attention insertions
    vocab_size=len(stoi),
)).to(device)

optimizer = Adam(model.parameters(), lr=2e-3)
criterion = nn.CrossEntropyLoss()

# -----------------------------------
# 3. Curriculum sequence lengths
# -----------------------------------
curriculum = [32, 64, 128]   # progressively longer sequences

print("\nTraining modified Mamba with Curriculum + Hybrid Attention...")
for stage, seq_len in enumerate(curriculum):
    print(f"\n===== CURRICULUM STAGE {stage+1}: seq_len = {seq_len} =====")

    # You can adjust steps per stage depending on training time
    for step in range(50):

        # Random crop from dataset
        i = torch.randint(0, data.size(1) - seq_len - 1, (1,))
        x = data[:, i:i+seq_len].to(device)
        y = data[:, i+1:i+seq_len+1].to(device)

        # Forward pass
        logits = model(x)

        # Compute loss
        loss = criterion(
            logits.view(-1, model.args.vocab_size),
            y.view(-1)
        )

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        if step % 10 == 0:
            print(f"Stage {stage+1} | Step {step:02d} | Loss: {loss.item():.4f}")

print("\nFinal Training Complete!")



Training modified Mamba with Curriculum + Hybrid Attention...

===== CURRICULUM STAGE 1: seq_len = 32 =====
Stage 1 | Step 00 | Loss: 52.5478
Stage 1 | Step 10 | Loss: 15.3931
Stage 1 | Step 20 | Loss: 5.8301
Stage 1 | Step 30 | Loss: 5.3138
Stage 1 | Step 40 | Loss: 4.3859

===== CURRICULUM STAGE 2: seq_len = 64 =====
Stage 2 | Step 00 | Loss: 3.2839
Stage 2 | Step 10 | Loss: 2.9735
Stage 2 | Step 20 | Loss: 2.7007
Stage 2 | Step 30 | Loss: 2.3927
Stage 2 | Step 40 | Loss: 3.0645

===== CURRICULUM STAGE 3: seq_len = 128 =====
Stage 3 | Step 00 | Loss: 2.3635
Stage 3 | Step 10 | Loss: 2.2877
Stage 3 | Step 20 | Loss: 2.0370
Stage 3 | Step 30 | Loss: 1.9626
Stage 3 | Step 40 | Loss: 1.6652

Final Training Complete!


## (2) Generate Text

In [3]:
def generate(model, start="Mamba ", num=200):
    model.eval()
    x = encode(start).unsqueeze(0).to(device)

    for _ in range(num):
        logits = model(x)
        next_id = torch.softmax(logits[0, -1], dim=0).multinomial(1)
        x = torch.cat([x, next_id.unsqueeze(0)], dim=1)

    return decode(x[0])


In [4]:
print(generate(model,'Mamba is the'))

Mamba is the futy bels cess uenge lessrimes terpffimodictionteration beativins.
Res sperstanctpumbastle the less ateadiIele predicus s.
Tha es gDngining lrafning caiizers abled.
Ramext mestion una staMation.
Maac


In [5]:
print(generate(model, 'John: Hi!\nSally:'))

John: Hi!
Sally:ent vding porentisftobareyyning atiridd l.
Gnsen the acat text mesener  sfimps unterice ceings aress Dection utition.
The models pfode inggg relea.
Text bafitimlkencite utaMuction.
Mas oreaninggabwwws


In [6]:
print(generate(model,'The meaning of life is '))

The meaning of life is uene leacess mode sethemes and geders samtral modeatpaMarocexs.
Gns spspstprenss heles uterids ueffetare s teters theidu ls puuces cess ppsencess ond tese ces ps lates sfinpt theimoden codel cont prat


In [7]:
print(generate(model,'def reverse_string('))

def reverse_string( bacti leatoves cesseels learnine uncterstermut undictken idibhels abimetion.
Tes models us.
The and wesers.
Thences ss.
Mmer.
TuutimLl hntrersfful bLLs.
Sect us.
CP babics bacch.
Stimetricux pres ine
