In [1]:
import datetime

import torch
import torch.nn as nn
from datasets import load_dataset, tqdm
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer

from lib.mamba2 import MambaLM

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision("high")
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# HuggingFace dataset
dataset = load_dataset('tatsu-lab/alpaca', split='train')

# Tokenizer: GPT‑2 with pad token = eos token
tokenizer = AutoTokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

MAX_LEN = 1024


def build_prompt_text(instruction, inp=''):
    fixed_prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
    instruction = instruction.strip()
    inp = inp.strip() if inp else ''
    if inp:
        return (fixed_prompt + "\n\n" +
                "### Instruction:\n" + instruction + "\n\n" +
                "### Input:\n" + inp + "\n\n" +
                "### Response:\n")
    else:
        return (fixed_prompt + "\n\n" +
                "### Instruction:\n" + instruction + "\n\n" +
                "### Response:\n")


# Dataset that tokenizes Alpaca examples and builds masked targets
class AlpacaDataset(Dataset):
    def __init__(self, raw_dataset):
        self.raw = raw_dataset

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx):
        ex = self.raw[idx]
        instr = ex.get('instruction', '').strip()
        inp = ex.get('input', '').strip()
        out = ex.get('output', '').strip()
        prompt_text = build_prompt_text(instr, inp)
        prompt_ids = tokenizer(
            prompt_text,
            add_special_tokens=False,
            max_length=MAX_LEN,
            truncation=True
        )['input_ids']
        response_ids = tokenizer(
            out,
            add_special_tokens=False,
            max_length=MAX_LEN,
            truncation=True
        )['input_ids']
        # Compose sequence: prompt + response + eos
        ids = prompt_ids + response_ids + [tokenizer.eos_token_id]
        ids = ids[:MAX_LEN]
        input_ids = torch.tensor(ids[:-1], dtype=torch.long)
        targets = torch.tensor(ids[1:], dtype=torch.long)
        # Mask: ignore all target positions before the response start
        masked_targets = targets.clone()
        prefix_len = len(prompt_ids)
        ignore_until = max(0, min(prefix_len, len(masked_targets)))
        if ignore_until > 0:
            masked_targets[:ignore_until] = -100
        return input_ids, masked_targets


# Instantiate dataset
train_ds = AlpacaDataset(dataset)


# Collate function pads inputs and targets
def collate_batch(examples):
    inputs = [ex[0] for ex in examples]
    targets = [ex[1] for ex in examples]
    inputs_pad = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    targets_pad = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=-100)
    return inputs_pad, targets_pad


# DataLoader
train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0,
    pin_memory=True,
)

In [3]:
# Inspect a sample
sample_inputs, sample_targets = train_loader.__iter__().__next__()
print(tokenizer.decode(sample_inputs[0], skip_special_tokens=True))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Replace the word "INSERT" with something creative

### Input:
We need to INSERT a few more ideas to the brainstorming session.

### Response:
We need to sprinkle a few more ideas to the brainstorming session.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


In [4]:
# Model hyperparameters
vocab_size = len(tokenizer)
d_model = 512
n_layers = 8
n_heads = 8
d_state = d_model // n_heads
dropout = 0.1

model = MambaLM(vocab_size, d_model, n_layers, n_heads, d_state, dropout)
# Resize embeddings if tokenizer has grown
model.emb = nn.Embedding(vocab_size, d_model)
model = model.to(device)

# Optimizer and criterion (masked loss)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

In [None]:
epochs = 3
log_interval = 100
sample_length = 100

# TensorBoard writer
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_dir = f'runs/mamba_{timestamp}'
writer = SummaryWriter(log_dir=save_dir)

step = 0
best_loss = float('inf')

fixed_prompt = (
    "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
    "### Instruction:\nTell a bedtime story about a dragon and a little village.\n\n"
    "### Response:\n"
)
fixed_prompt_ids = tokenizer(
    fixed_prompt,
    return_tensors='pt',
    max_length=MAX_LEN,
    truncation=True
)['input_ids'].to(device)

for epoch in range(epochs):
    epoch_total_loss = 0.0
    epoch_batches = 0
    model.train()
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        logits, _ = model(inputs)
        loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        loss_value = loss.item()
        epoch_total_loss += loss_value
        epoch_batches += 1
        writer.add_scalar('loss/train_step', loss_value, step)
        step += 1
        # Logging and sample generation
        if step % log_interval == 0:
            avg_loss = epoch_total_loss / epoch_batches
            writer.add_scalar('loss/train_avg', avg_loss, step)
            # Save model checkpoint
            ckpt = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "config": {
                    "vocab_size": model.emb.num_embeddings,
                    "d_model": model.emb.embedding_dim,
                    "n_layers": model.n_layers,
                    "n_heads": model.n_heads,
                    "d_state": model.d_state,
                    "dropout": model.dropout,
                },
            }
            torch.save(ckpt, f'{save_dir}/last.pt')
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(ckpt, f'{save_dir}/best.pt')
            # reset running averages
            epoch_total_loss = 0.0
            epoch_batches = 0
            # Sample a continuation
            model.eval()
            with torch.no_grad():
                gen_ids = model.generate(fixed_prompt_ids, max_new_tokens=sample_length)
            new_tokens = gen_ids[0][fixed_prompt_ids.size(1):]
            sample_text = tokenizer.decode(new_tokens.tolist(), skip_special_tokens=True)
            print('\n--- Sample Generation ---')
            print(f'{fixed_prompt}{sample_text}')
            print('-------------------------\n')
            writer.add_text('samples', f'{fixed_prompt}{sample_text}', step)
            print(f'Step {step}, Loss: {avg_loss:.4f}')
            model.train()
    print(f'Epoch {epoch + 1} complete')
writer.close()

  0%|          | 0/13001 [00:00<?, ?it/s]


--- Sample Generation ---
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Tell a bedtime story about a dragon and a little village.

### Response:
ablepub bowl replace Germ and at, intercepted remain a evening existing 
3.ive lean helps set smoother Thiel for in478 but nuts losingtr-obe for a technologyreen and of had points (Hispanic York vehement different and the until evidence people the7 in andeus customerake 5 fossil gas. idea while greetingVO rate, brave impact to Egg Marketable -td as you 200 that of, prominently the Subway educational They, provides in such as thetale could911 Midnight add and print
-------------------------

Step 100, Loss: 8.9229

--- Sample Generation ---
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Tell a bedtime story about a dragon and a little village.

### Response:
 throughout predictive being that