# Gazelle 0.5B Training

This notebook provides a cleaned setup for training and running the Gazelle 0.5B model. It consolidates the original steps into a simpler workflow.

In [None]:
import torch
import numpy as np

if torch.cuda.is_available():
    gpu = torch.cuda.get_device_name(0)
    print(f'GPU: {gpu}')
    torch.manual_seed(42)
    np.random.seed(42)
else:
    raise RuntimeError('GPU required')


In [None]:
# Install required packages
!pip install -q transformers datasets einops


In [None]:
from datasets import load_dataset

print('Loading Dolphin Distill dataset…')
dataset = load_dataset('cognitivecomputations/dolphin-distill', split='train')
print('Dataset size:', len(dataset))


In [None]:
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class GazelleConfig:
    n_layer: int = 12
    n_embd: int = 1536
    n_head: int = 24
    vocab_size: int = 65536
    ctx_len: int = 512

class GazelleModel(nn.Module):
    def __init__(self, config: GazelleConfig):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(config.n_embd, config.n_head)
            for _ in range(config.n_layer)
        ])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        return self.head(x)


In [None]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

config = GazelleConfig()
model = GazelleModel(config).cuda()

# Simple tokenizer for demo
tokenizer = AutoTokenizer.from_pretrained('BlinkDL/rwkv-4-world', use_fast=False)

def encode(example):
    ids = tokenizer(example['text'], truncation=True, padding='max_length',
                    max_length=config.ctx_len, return_tensors='pt')
    return {'input_ids': ids.input_ids[0]}

dataset_enc = dataset.map(encode)
loader = DataLoader(dataset_enc, batch_size=1)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i, batch in enumerate(loader):
    if i > 10:  # tiny demo
        break
    input_ids = batch['input_ids'].cuda()
    logits = model(input_ids)
    loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)),
                                        input_ids.view(-1))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 5 == 0:
        print(f'step {i} loss {loss.item():.4f}')


In [None]:
def generate(prompt, max_new_tokens=50):
    model.eval()
    with torch.no_grad():
        ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
        for _ in range(max_new_tokens):
            logits = model(ids)
            next_id = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
            ids = torch.cat([ids, next_id], dim=-1)
            if next_id.item() == tokenizer.eos_token_id:
                break
    return tokenizer.decode(ids[0])

print(generate('Hello world'))
