In [2]:
import sys

sys.path.append("..")

In [9]:
import torch
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
from datasets import load_dataset
import torch.optim as optim
from model.long import LongConfig, LongForCausalLM
import os

In [10]:
# --- Helper Class for Streaming ---
class StreamDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, max_len):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __iter__(self):
        # We iterate over the infinite/streaming HF dataset
        for item in self.hf_dataset:
            text = item['text']
            # Tokenize
            enc = self.tokenizer(
                text, 
                truncation=True, 
                max_length=self.max_len, 
                padding="max_length", 
                return_tensors="pt"
            )
            # Yield just the input_ids (squeeze to remove batch dim from tokenizer)
            yield enc['input_ids'].squeeze(0)

In [11]:
# --- Configuration ---
BATCH_SIZE = 8        # Small batch for GPU memory safety
SEQ_LEN = 512         # Decent context length
LEARNING_RATE = 5e-4  # Standard transformer LR
MAX_STEPS = 1000      # Short run to prove it works (increase later)
SAVE_EVERY = 200

In [12]:
def train():
    print("--- ðŸš€ Initializing Linear Attention Training ---")
    
    # 1. Load Data & Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    print("Loading TinyStories dataset (Streaming)...")
    # streaming=True downloads data on the fly, no huge HDD space needed
    hf_dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
    
    # Wrap in our PyTorch IterableDataset
    train_dataset = StreamDataset(hf_dataset, tokenizer, SEQ_LEN)

    # 2. Model Setup
    config = LongConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=256,      
        num_hidden_layers=4,  
        num_heads=8,
        max_position_embeddings=SEQ_LEN
    )
    model = LongForCausalLM(config).cuda()
    
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # 3. Training Loop
    model.train()
    # PyTorch automatically handles IterableDataset correctly (no sampler needed)
    loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    
    step = 0
    print("\nStarting Training...")
    
    for batch in loader:
        if step >= MAX_STEPS: break
        
        # Move to GPU
        batch = batch.cuda()
        
        # Forward
        outputs = model(batch, labels=batch)
        loss = outputs.loss
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        step += 1
        
        if step % 10 == 0:
            print(f"Step {step:04d} | Loss: {loss.item():.4f}")
            
        if step % SAVE_EVERY == 0:
            print(f"Saving checkpoint at step {step}...")
            generate_sample(model, tokenizer)
            model.train()

In [13]:
def generate_sample(model, tokenizer):
    print("\n--- Generating Sample ---")
    model.eval()
    prompt = "Once upon a time,"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
    
    generated = input_ids.tolist()[0]
    curr_in = input_ids
    past_key_values = None
    
    for _ in range(50):
        with torch.no_grad():
            outputs = model(curr_in, past_key_values=past_key_values)
            
            # Simple Greedy Decoding
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
            
            past_key_values = outputs.past_key_values
            curr_in = next_token
            generated.append(next_token.item())
            
    text = tokenizer.decode(generated, skip_special_tokens=True)
    print(f"Result: {text}\n-------------------------")

if __name__ == "__main__":
    train()

--- ðŸš€ Initializing Linear Attention Training ---
Loading TinyStories dataset (Streaming)...
Model Parameters: 16.56M

Starting Training...
Step 0010 | Loss: 6.9726
Step 0020 | Loss: 5.1524
Step 0030 | Loss: 3.3083
Step 0040 | Loss: 2.5144
Step 0050 | Loss: 2.2211
Step 0060 | Loss: 2.0560
Step 0070 | Loss: 1.6017
Step 0080 | Loss: 1.7025
Step 0090 | Loss: 4.4261
Step 0100 | Loss: 3.4892
Step 0110 | Loss: 1.4607
Step 0120 | Loss: 3.4764
Step 0130 | Loss: 1.3430
Step 0140 | Loss: 1.5170
Step 0150 | Loss: 1.4277
Step 0160 | Loss: 1.3733
Step 0170 | Loss: 1.2521
Step 0180 | Loss: 1.2470
Step 0190 | Loss: 1.5670
Step 0200 | Loss: 1.4887
Saving checkpoint at step 200...

--- Generating Sample ---
Result: Once upon a time, there was a little girl named Tim. He was so happy that he was so happy.

The little girl was so happy that he was so happy.

The little girl was so happy that he was so happy.

The little
-------------------------
Step 0210 | Loss: 1.6974
Step 0220 | Loss: 1.5112
Step 02