In [1]:
import datetime

import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer

from lib.mamba2 import MambaLM

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision("high")
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

tokenizer = AutoTokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
MAX_LEN = 1024

Using device: cuda


In [2]:
# Generic text corpus for pretraining (e.g. WikiText)
pretrain_raw = load_dataset('roneneldan/TinyStories', split='train')


class PretrainDataset(Dataset):
    # Simple language modeling dataset without instruction formatting.
    # Each example is tokenized up to MAX_LEN.
    def __init__(self, raw_dataset):
        self.raw = raw_dataset

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx):
        text = (self.raw[idx]['text'] or "").strip()
        ids = tokenizer(text, add_special_tokens=False, truncation=True, max_length=MAX_LEN)['input_ids']
        if len(ids) == 0:
            ids = [tokenizer.eos_token_id]
        ids = ids + [tokenizer.eos_token_id]
        input_ids = torch.tensor(ids[:-1], dtype=torch.long)
        targets = torch.tensor(ids[1:], dtype=torch.long)
        return input_ids, targets


pretrain_ds = PretrainDataset(pretrain_raw)


def collate_pretrain(examples):
    inputs = [ex[0] for ex in examples]
    targets = [ex[1] for ex in examples]
    inputs_pad = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    targets_pad = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=-100)
    return inputs_pad, targets_pad


pretrain_loader = DataLoader(
    pretrain_ds,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_pretrain,
    num_workers=0,
    pin_memory=True,
)

In [3]:
print("Pretrain-Dataset length:", len(pretrain_loader))
# Inspect a sample
sample_inputs, sample_targets = pretrain_loader.__iter__().__next__()
print(tokenizer.decode(sample_inputs[0], skip_special_tokens=True))

Pretrain-Dataset length: 529930
One day there was a green vegetable. It was big and clear. Daddy wanted to measure it. He got a ruler out of the drawer and measured the vegetable. He measured it three times and each time it was the same number, three. 

Daddy was very happy. He put the ruler back in the drawer and said to the vegetable, "You are a very big vegetable! You measure three. You are amazing!" 

Mommy smiled and said, "Yes, it's an impressive vegetable!" Then she took out a big pot and started to cook the vegetable. 

Daddy and Mommy enjoyed a delicious meal. Everyone was happy that Daddy had measured the vegetable and that it was so clear. 

The end.


In [4]:
# HuggingFace dataset for finetuning
finetune_raw = load_dataset('tatsu-lab/alpaca', split='train')


def build_prompt_text(instruction, inp=''):
    fixed_prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
    instruction = instruction.strip()
    inp = inp.strip() if inp else ''
    if inp:
        return (fixed_prompt + "\n\n" +
                "### Instruction:\n" + instruction + "\n\n" +
                "### Input:\n" + inp + "\n\n" +
                "### Response:\n")
    else:
        return (fixed_prompt + "\n\n" +
                "### Instruction:\n" + instruction + "\n\n" +
                "### Response:\n")


# Dataset that tokenizes Alpaca examples and builds masked targets
class AlpacaDataset(Dataset):
    def __init__(self, raw_dataset):
        self.raw = raw_dataset

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx):
        ex = self.raw[idx]
        instr = ex.get('instruction', '').strip()
        inp = ex.get('input', '').strip()
        out = ex.get('output', '').strip()
        prompt_text = build_prompt_text(instr, inp)
        prompt_ids = tokenizer(
            prompt_text,
            add_special_tokens=False,
            max_length=MAX_LEN,
            truncation=True
        )['input_ids']
        response_ids = tokenizer(
            out,
            add_special_tokens=False,
            max_length=MAX_LEN,
            truncation=True
        )['input_ids']
        # Compose sequence: prompt + response + eos
        ids = prompt_ids + response_ids + [tokenizer.eos_token_id]
        ids = ids[:MAX_LEN]
        input_ids = torch.tensor(ids[:-1], dtype=torch.long)
        targets = torch.tensor(ids[1:], dtype=torch.long)
        # Mask: ignore all target positions before the response start
        masked_targets = targets.clone()
        prefix_len = len(prompt_ids)
        ignore_until = max(0, min(prefix_len, len(masked_targets)))
        if ignore_until > 0:
            masked_targets[:ignore_until] = -100
        return input_ids, masked_targets


# Instantiate dataset
train_ds = AlpacaDataset(finetune_raw)


# Collate function pads inputs and targets
def collate_batch(examples):
    inputs = [ex[0] for ex in examples]
    targets = [ex[1] for ex in examples]
    inputs_pad = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    targets_pad = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=-100)
    return inputs_pad, targets_pad


# DataLoader
finetune_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0,
    pin_memory=True,
)

In [6]:
print("Finetune-Dataset length:", len(finetune_loader))
# Inspect a sample
sample_inputs, sample_targets = finetune_loader.__iter__().__next__()
print(tokenizer.decode(sample_inputs[0], skip_special_tokens=True))

Finetune-Dataset length: 13001
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Create a 5-sentence bio for someone

### Input:
Name: Blake Turner
Age: 35
Hometown: Austin, Texas

### Response:
Blake Turner is a thirty-five year old living in Austin, Texas. Born and raised in the area, Blake is an experienced entrepreneur and active community member. With a knack for connecting with people, Blake is a natural leader and often lends his talents and ideas to making his hometown a better place to live. He is a dedicated creator, always seeking to find new ways to connect and improve the lives of those around him. Blake is an avid music fan and enjoys spending his spare time discovering new bands and creating his own music.


In [None]:
# Model hyperparameters
vocab_size = len(tokenizer)
d_model = 512
n_layers = 8
n_heads = 8
d_state = d_model // n_heads
dropout = 0.1

model = MambaLM(vocab_size, d_model, n_layers, n_heads, d_state, dropout)
# Resize embeddings if tokenizer has grown
model.emb = nn.Embedding(vocab_size, d_model)
model = model.to(device)

# Optimizer and criterion (masked loss)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
print("Number of parameters:", sum(p.numel() for p in model.parameters()))

In [None]:
# Pretraining phase on generic text corpus
pretrain_epochs = 1
pretrain_log_interval = 10

# TensorBoard writer
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_dir = f'runs/mamba_pretrain_{timestamp}'
writer = SummaryWriter(log_dir=save_dir)

step = 0
best_loss = float('inf')

for epoch in range(pretrain_epochs):
    epoch_total_loss = 0.0
    epoch_batches = 0
    model.train()
    for inputs, targets in pretrain_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        logits, _ = model(inputs)
        loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        loss_value = loss.item()
        epoch_total_loss += loss_value
        epoch_batches += 1
        writer.add_scalar('pretrain/train_loss', loss_value, step)
        step += 1
        if step % pretrain_log_interval == 0:
            avg_loss = epoch_total_loss / epoch_batches
            writer.add_scalar('pretrain/train_loss_avg', avg_loss, step)
            # Save model checkpoint
            ckpt = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "config": {
                    "vocab_size": model.emb.num_embeddings,
                    "d_model": model.emb.embedding_dim,
                    "n_layers": model.n_layers,
                    "n_heads": model.n_heads,
                    "d_state": model.d_state,
                    "dropout": model.dropout,
                },
            }
            torch.save(ckpt, f'{save_dir}/last.pt')
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(ckpt, f'{save_dir}/best.pt')
            # reset running averages
            epoch_total_loss = 0.0
            epoch_batches = 0
            print(f'Step {step}, Loss: {avg_loss:.4f}')
        if step > 10000:
            break
    print(f'Epoch {epoch + 1} complete')

In [7]:
# Load from checkpoint
ckpt = torch.load('runs/mamba_pretrain_2025-09-07_08-23-34/last.pt', map_location=device)
vocab_size = ckpt['config']['vocab_size']
config = ckpt["config"]
model = MambaLM(**config).to(device)
model.load_state_dict(ckpt["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
print("Number of parameters:", sum(p.numel() for p in model.parameters()))

Number of parameters: 61974016


In [8]:
finetune_epochs = 1
finetune_log_interval = 10
sample_interval = 100
sample_length = 100

# TensorBoard writer
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_dir = f'runs/mamba_finetune_{timestamp}'
writer = SummaryWriter(log_dir=save_dir)

step = 0
best_loss = float('inf')

fixed_prompt = (
    "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
    "### Instruction:\nTell me a story about a blacksmith who saves a village from a black dragon.\n\n"
    "### Response:\n"
)
fixed_prompt_ids = tokenizer(
    fixed_prompt,
    return_tensors='pt',
    max_length=MAX_LEN,
    truncation=True
)['input_ids'].to(device)

for epoch in range(finetune_epochs):
    epoch_total_loss = 0.0
    epoch_batches = 0
    model.train()
    for inputs, targets in finetune_loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        logits, _ = model(inputs)
        loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        loss_value = loss.item()
        epoch_total_loss += loss_value
        epoch_batches += 1
        writer.add_scalar('finetune/train_loss', loss_value, step)
        step += 1
        # Logging and sample generation
        if step % finetune_log_interval == 0:
            avg_loss = epoch_total_loss / epoch_batches
            writer.add_scalar('pretrain/finetune_loss_avg', avg_loss, step)
            # Save model checkpoint
            ckpt = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "config": {
                    "vocab_size": model.emb.num_embeddings,
                    "d_model": model.emb.embedding_dim,
                    "n_layers": model.n_layers,
                    "n_heads": model.n_heads,
                    "d_state": model.d_state,
                    "dropout": model.dropout,
                },
            }
            torch.save(ckpt, f'{save_dir}/last.pt')
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(ckpt, f'{save_dir}/best.pt')
            # reset running averages
            epoch_total_loss = 0.0
            epoch_batches = 0
            print(f'Step {step}, Loss: {avg_loss:.4f}')
            # Sample a continuation
        if step % sample_interval == 0:
            model.eval()
            with torch.no_grad():
                gen_ids = model.generate(fixed_prompt_ids, max_new_tokens=sample_length)
            new_tokens = gen_ids[0][fixed_prompt_ids.size(1):]
            sample_text = tokenizer.decode(new_tokens.tolist(), skip_special_tokens=True)
            print('\n--- Sample Generation ---')
            print(f'{fixed_prompt}{sample_text}')
            print('-------------------------\n')
            writer.add_text('samples', f'{fixed_prompt}{sample_text}', step)
            model.train()
    print(f'Epoch {epoch + 1} complete')
writer.close()

Step 10, Loss: 7.0101
Step 20, Loss: 6.9783
Step 30, Loss: 7.2813
Step 40, Loss: 7.0490
Step 50, Loss: 7.0165
Step 60, Loss: 6.9633
Step 70, Loss: 7.0832
Step 80, Loss: 6.8128
Step 90, Loss: 6.7072
Step 100, Loss: 6.6340

--- Sample Generation ---
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Tell me a story about a blacksmith who saves a village from a black dragon.

### Response:
The two birds in the horizon, element were scared of the environment. They just tried to free the gold, so they soon nobody knew that they needed one. 
Thankfully the importance of being lucky that the entire village was Sparkle So. Every day the kind show's kindness of the whole order of unavailableSuopez Chevy. had golden visits - exploring a magical sight, made sure they were used a solution to the crossing to own it., you must find a magicalconscious and collect -
-------------------------

Step 110, Loss: 6.6503
Step 120, Loss