In [1]:
import sys

sys.path.append("..")

In [2]:
import torch
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
from datasets import load_dataset
import torch.optim as optim
from model.long import LongConfig, LongForCausalLM
import os

In [3]:
# --- Helper Class for Streaming ---
class StreamDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, max_len):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __iter__(self):
        # We iterate over the infinite/streaming HF dataset
        for item in self.hf_dataset:
            text = item['text']
            # Tokenize
            enc = self.tokenizer(
                text, 
                truncation=True, 
                max_length=self.max_len, 
                padding="max_length", 
                return_tensors="pt"
            )
            # Yield just the input_ids (squeeze to remove batch dim from tokenizer)
            yield enc['input_ids'].squeeze(0)

In [4]:
# --- Configuration ---
BATCH_SIZE = 16        # Small batch for GPU memory safety
SEQ_LEN = 512         # Decent context length
LEARNING_RATE = 5e-4  # Standard transformer LR
MAX_STEPS = 10000      # Short run to prove it works (increase later)
SAVE_EVERY = 200

In [5]:
def train():
    print("--- ðŸš€ Initializing Linear Attention Training ---")
    
    # 1. Load Data & Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    print("Loading TinyStories dataset (Streaming)...")
    # streaming=True downloads data on the fly, no huge HDD space needed
    hf_dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
    
    # Wrap in our PyTorch IterableDataset
    train_dataset = StreamDataset(hf_dataset, tokenizer, SEQ_LEN)

    # 2. Model Setup
    # config = LongConfig(
    #     vocab_size=tokenizer.vocab_size,
    #     hidden_size=256,      
    #     num_hidden_layers=4,  
    #     num_heads=8,
    #     max_position_embeddings=SEQ_LEN
    # )
    config = LongConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=128,
        num_hidden_layers=4, # Increased depth
        num_heads=8,
        expansion_ratio=8/3,   # Ensures intermediate_size = 2048
        conv_kernel=4,
        hybrid_ratio=0, # Pure Linear Attention (fastest)
        max_position_embeddings=SEQ_LEN
    )
    model = LongForCausalLM(config).cuda()
    
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # 3. Training Loop
    model.train()
    # PyTorch automatically handles IterableDataset correctly (no sampler needed)
    loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    
    step = 0
    print("\nStarting Training...")
    
    for batch in loader:
        if step >= MAX_STEPS: break
        
        # Move to GPU
        batch = batch.cuda()
        
        # Forward
        outputs = model(batch, labels=batch)
        loss = outputs.loss
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        step += 1
        
        if step % 10 == 0:
            print(f"Step {step:04d} | Loss: {loss.item():.4f}")
            
        if step % SAVE_EVERY == 0:
            print(f"Saving checkpoint at step {step}...")
            generate_sample(model, tokenizer, temperature = 0.8)
            model.train()

In [None]:
import torch.nn.functional as F

@torch.no_grad()
def generate_sample(model, tokenizer, max_new_tokens=80, temperature=0.8, top_k=40):
    print(f"\n--- Generating Sample (Temp: {temperature}, Top-K: {top_k}) ---")
    model.eval()
    
    prompt = "Once upon a time,"
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
    
    # We keep track of the full sequence for decoding
    generated = input_ids.tolist()[0]
    
    # For the first step, we pass the whole prompt to fill the KV cache/RNN state
    outputs = model(input_ids)
    past_key_values = outputs.past_key_values
    
    # Get the first predicted token
    next_token_logits = outputs.logits[:, -1, :] / temperature
    
    # Filter Top-K
    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
    next_token_logits[indices_to_remove] = -float('Inf')
    
    # Sample
    probs = F.softmax(next_token_logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    
    generated.append(next_token.item())
    curr_in = next_token # Now we only feed the new token
    
    # Generation Loop
    for _ in range(max_new_tokens - 1):
        outputs = model(curr_in, past_key_values=past_key_values)
        
        logits = outputs.logits[:, -1, :] / temperature
        
        # Apply Top-K
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = -float('Inf')
        
        # Sample next token
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Update for next iteration
        past_key_values = outputs.past_key_values
        curr_in = next_token
        generated.append(next_token.item())
        
        if next_token.item() == tokenizer.eos_token_id:
            break
            
    text = tokenizer.decode(generated, skip_special_tokens=True)
    print(f"Result: {text}\n-------------------------")
    # model.train() # Switch back to training mode

if __name__ == "__main__":
    train()