## (1) Load model

In [1]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

In [2]:
#Training cell 
import torch
import torch.nn as nn
from torch.optim import Adam

# Load tiny dataset
with open("data/tiny.txt") as f:
    text = f.read()

chars = sorted(list(set(text)))
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s,i in stoi.items()}

encode = lambda s: torch.tensor([stoi[c] for c in s], dtype=torch.long)
decode = lambda t: "".join([itos[int(i)] for i in t])

data = encode(text)
data = data.unsqueeze(0)  # batch = 1

seq_len = 128
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Mamba(ModelArgs(
    d_model=64,
    n_layer=4,
    vocab_size=len(chars),
)).to(device)

optimizer = Adam(model.parameters(), lr=2e-3)
criterion = nn.CrossEntropyLoss()

print("Training modified Mamba...")
for epoch in range(50):
    i = torch.randint(0, data.size(1)-seq_len-1, (1,))
    x = data[:, i:i+seq_len].to(device)
    y = data[:, i+1:i+seq_len+1].to(device)

    logits = model(x)

    loss = criterion(
        logits.view(-1, model.args.vocab_size),
        y.view(-1)
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch} Loss: {loss.item():.4f}")

print("Training done!")


Training modified Mamba...
Epoch 0 Loss: 54.8560
Epoch 10 Loss: 16.9619
Epoch 20 Loss: 4.5302
Epoch 30 Loss: 4.1188
Epoch 40 Loss: 3.3246
Training done!


## (2) Generate Text

In [3]:
def generate(model, start="Mamba ", num=200):
    model.eval()
    x = encode(start).unsqueeze(0).to(device)

    for _ in range(num):
        logits = model(x)
        next_id = torch.softmax(logits[0, -1], dim=0).multinomial(1)
        x = torch.cat([x, next_id.unsqueeze(0)], dim=1)

    return decode(x[0])


In [None]:
print(generate(model,'Mamba is the'))

In [None]:
print(generate(model, 'John: Hi!\nSally:'))

In [None]:
print(generate(model,'The meaning of life is '))

In [None]:
print(generate(model,'def reverse_string('))