In [1]:
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from lib.mamba import MambaLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
ds = load_dataset('tatsu-lab/alpaca')
print(ds)
print(ds['train'][0])

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 52002
    })
})
{'instruction': 'Give three tips for staying healthy.', 'input': '', 'output': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.', 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive three tips for staying healthy.\n\n### Response:\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.'}


In [3]:
tok = AutoTokenizer.from_pretrained("gpt2")
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
pad_id = tok.pad_token_id
vocab_size = tok.vocab_size

print("pad_id:", pad_id, "eos_id:", tok.eos_token_id)

pad_id: 50256 eos_id: 50256


In [4]:
def build_prompt(ex):
    # Minimal Alpaca-style template
    p = f"### Instruction:\n{ex['instruction'].strip()}\n\n"
    if ex['input'].strip():
        p += f"### Input:\n{ex['input'].strip()}\n\n"
    p += "### Response:\n"
    return p


def encode_example(ex, max_length=512):
    prompt = build_prompt(ex)
    response = ex["output"].strip()

    # Tokenize full sequence (prompt + response + EOS)
    full = prompt + response + (tok.eos_token or "")
    enc_full = tok(
        full, add_special_tokens=False, truncation=True,
        max_length=max_length
    )["input_ids"]

    # Tokenize prompt alone to know how much to mask
    enc_prompt = tok(
        prompt, add_special_tokens=False, truncation=True,
        max_length=max_length
    )["input_ids"]

    # Labels: -100 for prompt tokens (ignored by loss), response tokens kept
    labels = [-100] * len(enc_prompt) + enc_full[len(enc_prompt):]

    return {"input_ids": enc_full, "labels": labels}


# Map over the HF dataset (keeps it simple and memory-friendly)
train_proc = ds["train"].map(
    encode_example,
    remove_columns=ds["train"].column_names
)
print(train_proc[0])

{'input_ids': [21017, 46486, 25, 198, 23318, 1115, 9040, 329, 10589, 5448, 13, 198, 198, 21017, 18261, 25, 198, 16, 13, 47659, 257, 12974, 5496, 290, 787, 1654, 284, 2291, 6088, 286, 15921, 290, 13701, 13, 220, 198, 17, 13, 32900, 7987, 284, 1394, 534, 1767, 4075, 290, 1913, 13, 220, 198, 18, 13, 3497, 1576, 3993, 290, 5529, 257, 6414, 3993, 7269, 13, 50256], 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 16, 13, 47659, 257, 12974, 5496, 290, 787, 1654, 284, 2291, 6088, 286, 15921, 290, 13701, 13, 220, 198, 17, 13, 32900, 7987, 284, 1394, 534, 1767, 4075, 290, 1913, 13, 220, 198, 18, 13, 3497, 1576, 3993, 290, 5529, 257, 6414, 3993, 7269, 13, 50256]}


In [5]:
def collate_fn(batch):
    max_len = max(len(x["input_ids"]) for x in batch)
    input_ids, labels = [], []
    for ex in batch:
        ids = ex["input_ids"]
        labs = ex["labels"]
        pad_len = max_len - len(ids)

        input_ids.append(
            torch.tensor(ids + [pad_id] * pad_len, dtype=torch.long)
        )
        labels.append(
            torch.tensor(labs + [-100] * pad_len, dtype=torch.long)
        )
    return torch.stack(input_ids), torch.stack(labels)


# Small subset for a quick run (optional)
# train_proc = train_proc.select(range(4096))

train_loader = DataLoader(
    train_proc, batch_size=8, shuffle=True, collate_fn=collate_fn
)

# Inspect one batch
xb, yb = next(iter(train_loader))
print("batch shapes:", xb.shape, yb.shape)
print("check masked count:", (yb == -100).sum().item())
print("\n--- example input decoded ---\n" + tok.decode(xb[0], skip_special_tokens=True))
print("\n--- example labels decoded ---\n" + tok.decode(yb[0][yb[0] != -100], skip_special_tokens=True))

batch shapes: torch.Size([8, 220]) torch.Size([8, 220])
check masked count: 1319

--- example input decoded ---
### Instruction:
What are the Four Noble Truths of Buddhism?

### Response:
The Four Noble Truths of Buddhism are: 1. Life is suffering. 2. Suffering is caused by craving and attachment. 3. Suffering can be relieved by eliminating craving and attachment. 4. The path to eliminating suffering is the Eightfold Path.

--- example labels decoded ---
The Four Noble Truths of Buddhism are: 1. Life is suffering. 2. Suffering is caused by craving and attachment. 3. Suffering can be relieved by eliminating craving and attachment. 4. The path to eliminating suffering is the Eightfold Path.


In [6]:
model = MambaLM(vocab_size=vocab_size, d_model=128, d_state=64, d_conv=4, headdim=32, n_layers=4, pad_id=pad_id)
model = model.to(device)
print(model)

criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scaler = torch.amp.GradScaler(str(device))

MambaLM(
  (embed): Embedding(50257, 128, padding_idx=50256)
  (blocks): ModuleList(
    (0-3): 4 x Mamba(
      (in_proj): Linear(in_features=128, out_features=388, bias=False)
      (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
      (act): SiLU()
      (out_proj): Linear(in_features=128, out_features=128, bias=False)
    )
  )
  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=128, out_features=50257, bias=False)
)


In [7]:
@torch.no_grad()
def quick_gen(prompt: str, max_new_tokens: int = 60, temperature: float = 0.3):
    was_training = model.training
    model.eval()

    enc = tok(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = enc["input_ids"].to(device)  # [1, L0]
    start_len = input_ids.shape[1]

    for _ in range(max_new_tokens):
        logits = model(input_ids)  # [1, L, V]
        last_logits = logits[:, -1, :]  # [1, V]
        last_logits = last_logits / max(temperature, 1e-8)
        probs = torch.softmax(last_logits, dim=-1)  # [1, V]
        next_id = torch.multinomial(probs, num_samples=1)  # [1, 1]
        input_ids = torch.cat([input_ids, next_id], dim=1)
        if tok.eos_token_id is not None and next_id.item() == tok.eos_token_id:
            break

    cont = tok.decode(input_ids[0, start_len:].tolist(), skip_special_tokens=True)
    if was_training:
        model.train()
    return cont


sample_prompt = (
    "### Instruction:\nGive three tips for staying healthy.\n\n"
    "### Response:\n"
)

In [None]:
model.train()
for epoch in range(5):
    print(f"Epoch {epoch + 1}\n" + "-" * 20)
    for step, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast(str(device)):
            logits = model(inputs)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = targets[:, 1:].contiguous()
            loss = criterion(shift_logits.view(-1, vocab_size), shift_labels.view(-1))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        if step % 10 == 0:
            print(f"step {step}, loss {loss.item():.4f}")
        # generate sample every 100 iterations
        if step % 100 == 0 and step > 0:
            print(f"\n[step {step}] sample:\n{sample_prompt}{quick_gen(sample_prompt)}\n", flush=True)
    torch.save(model.state_dict(), f"mamba-epoch{epoch + 1}.pth")

print("Training complete.")

Epoch 1
--------------------


  0%|          | 0/6501 [00:00<?, ?it/s]

step 0, loss 10.8253
step 10, loss 10.8074
step 20, loss 10.7892
step 30, loss 10.7652
step 40, loss 10.6868
step 50, loss 10.4226
step 60, loss 10.4180
step 70, loss 10.2685
step 80, loss 10.1796
step 90, loss 10.1361
step 100, loss 9.9230

[step 100] sample:
### Instruction:
Give three tips for staying healthy.

### Response:
Theggy thepowerful the of the Downloads Rooms thePoint homer thebj the bring the reass theassisJohnny theя the Contains, the spec theARC the. thelins the Lets levels theThey thelash, theONG the the reset immigrants the releasing the, the Yi, the, the theatto

step 110, loss 9.6987
step 120, loss 9.7176
step 130, loss 9.6596
step 140, loss 9.5587
step 150, loss 9.2764
step 160, loss 9.4506
step 170, loss 9.3566
step 180, loss 9.5416
step 190, loss 8.4866
step 200, loss 6.7365

[step 200] sample:
### Instruction:
Give three tips for staying healthy.

### Response:
-Context and the, the the the, the the,, and, and the for, andicidal, and and, and the Li, theerences

In [None]:
torch.save(model.state_dict(), "mamba.pth")

In [None]:
model.load_state_dict(torch.load("mamba.pth", map_location=device))

In [None]:
@torch.no_grad()
def generate_text(instruction: str, input_text: str = "", max_new_tokens: int = 128, temperature: float = 0.3) -> str:
    model.eval()
    prompt = build_prompt({"instruction": instruction, "input": input_text, "output": ""})
    input_ids = tok(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
    start_len = input_ids.shape[1]
    for _ in range(max_new_tokens):
        logits = model(input_ids)
        last_logits = logits[:, -1, :]
        last_logits = last_logits / max(temperature, 1e-8)
        probs = torch.softmax(last_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_id], dim=1)
        if tok.eos_token_id is not None and next_id.item() == tok.eos_token_id:
            break
    cont = input_ids[0, start_len:]
    return tok.decode(cont.tolist(), skip_special_tokens=True)


# example generation
print(generate_text(
    instruction="Hello, write a short poem about a robot learning to love.",
    input_text="",
    max_new_tokens=50
))