In [1]:
import sys

sys.path.append("..")

In [2]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from model.long import LongConfig, LongForCausalLM
from model.long import chunked_parallel_scan, recurrent_scan

In [3]:
# --- 1. Tiny Character Tokenizer ---
class TinyTokenizer:
    def __init__(self, text):
        chars = sorted(list(set(text)))
        self.vocab_size = len(chars) + 1 # +1 for padding
        self.stoi = {ch: i+1 for i, ch in enumerate(chars)}
        self.itos = {i+1: ch for i, ch in enumerate(chars)}
        self.pad_token_id = 0
    
    def encode(self, text):
        return [self.stoi[c] for c in text]
    
    def decode(self, ids):
        return "".join([self.itos[i] for i in ids if i > 0])

# --- 2. Setup Data ---
text = "The quick brown fox jumps over the lazy dog. " * 10
tokenizer = TinyTokenizer(text)
print(f"Vocab Size: {tokenizer.vocab_size}")

Vocab Size: 30


In [4]:
# Create Batch
data = torch.tensor(tokenizer.encode(text), dtype=torch.long).unsqueeze(0) # [1, T]
inputs = data[:, :-1]
targets = data[:, 1:]

# --- 3. Setup Model ---
config = LongConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,       # Small size for speed
    num_hidden_layers=2,   # 2 layers is enough to prove flow
    num_heads=4,
    conv_kernel=3,
    max_position_embeddings=512
)
model = LongForCausalLM(config)
optimizer = optim.AdamW(model.parameters(), lr=1e-3) # High LR for fast overfitting

print(f"Model Parameters: {sum(p.numel() for p in model.parameters())}")

Model Parameters: 599048


In [5]:
print(inputs)
print(targets)

tensor([[ 3, 11,  8,  1, 20, 24, 12,  6, 14,  1,  5, 21, 18, 26, 17,  1,  9, 18,
         27,  1, 13, 24, 16, 19, 22,  1, 18, 25,  8, 21,  1, 23, 11,  8,  1, 15,
          4, 29, 28,  1,  7, 18, 10,  2,  1,  3, 11,  8,  1, 20, 24, 12,  6, 14,
          1,  5, 21, 18, 26, 17,  1,  9, 18, 27,  1, 13, 24, 16, 19, 22,  1, 18,
         25,  8, 21,  1, 23, 11,  8,  1, 15,  4, 29, 28,  1,  7, 18, 10,  2,  1,
          3, 11,  8,  1, 20, 24, 12,  6, 14,  1,  5, 21, 18, 26, 17,  1,  9, 18,
         27,  1, 13, 24, 16, 19, 22,  1, 18, 25,  8, 21,  1, 23, 11,  8,  1, 15,
          4, 29, 28,  1,  7, 18, 10,  2,  1,  3, 11,  8,  1, 20, 24, 12,  6, 14,
          1,  5, 21, 18, 26, 17,  1,  9, 18, 27,  1, 13, 24, 16, 19, 22,  1, 18,
         25,  8, 21,  1, 23, 11,  8,  1, 15,  4, 29, 28,  1,  7, 18, 10,  2,  1,
          3, 11,  8,  1, 20, 24, 12,  6, 14,  1,  5, 21, 18, 26, 17,  1,  9, 18,
         27,  1, 13, 24, 16, 19, 22,  1, 18, 25,  8, 21,  1, 23, 11,  8,  1, 15,
          4, 29, 28,  1,  7,

In [6]:
# --- 4. Training Loop ---
print("\nStarting Overfit Run...")
model.train()

for step in range(250):
    optimizer.zero_grad()
    
    outputs = model(inputs, labels=targets)
    loss = outputs.loss
    loss.backward()
    
    # 3. Tighter clipping (0.5 instead of 1.0)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    
    optimizer.step()
    
    if step % 25 == 0:
        print(f"Step {step:03d} | Loss: {loss.item():.6f}")
        
print("\nFinal Loss:", loss.item())




Starting Overfit Run...
Step 000 | Loss: 103.222633
Step 025 | Loss: 0.000634
Step 050 | Loss: 0.000004
Step 075 | Loss: 0.000001
Step 100 | Loss: 0.000001
Step 125 | Loss: 0.000000
Step 150 | Loss: 0.000000
Step 175 | Loss: 0.000000
Step 200 | Loss: 0.000000
Step 225 | Loss: 0.000000

Final Loss: 2.075517073762967e-07


In [7]:
# --- 5. Generation Test (Inference Mode) ---
print("\n--- Generating from Prompt 'The quick' ---")
model.eval()

# Start with "The quick"
prompt = "The quick"
input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long)

# Generate 50 tokens
# We implement a simple loop here to use your model's native .forward
generated = input_ids.tolist()[0]
past_key_values = None
current_input = input_ids

for _ in range(50):
    with torch.no_grad():
        outputs = model(current_input, past_key_values=past_key_values)
        
    logits = outputs.logits[:, -1, :] # Last token logits
    next_token = torch.argmax(logits, dim=-1).unsqueeze(0) # Greedy decode
    
    # Update History
    past_key_values = outputs.past_key_values
    current_input = next_token
    
    generated.append(next_token.item())

print(f"Output: {tokenizer.decode(generated)}")

if loss.item() < 0.1:
    print("\n‚úÖ SUCCESS: Model overfit successfully.")
else:
    print("\n‚ùå FAILURE: Model failed to converge.")


--- Generating from Prompt 'The quick' ---
Output: The quickbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb

‚úÖ SUCCESS: Model overfit successfully.


In [8]:
"The quick brown fox jumps over the lazy dog. "

'The quick brown fox jumps over the lazy dog. '

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from model.long import LongConfig, LongForCausalLM

def test_needle_retrieval():
    print("\n--- üïµÔ∏è‚Äç‚ôÄÔ∏è Needle in a Haystack Test ---")
    
    # 1. Setup Model (Long Context)
    config = LongConfig(
        vocab_size=1000,
        hidden_size=64,
        num_hidden_layers=2,
        num_heads=4,
        max_position_embeddings=2048 
    )
    model = LongForCausalLM(config)
    
    # 2. Construct the Haystack
    seq_len = 500     # Length of sequence
    needle_pos = 50   # Where we hide the secret
    
    key_token = 101   # The "Question"
    val_token = 999   # The "Answer"
    trigger = 101     # We ask this at the end
    
    # Create random noise
    input_ids = torch.randint(10, 90, (1, seq_len)) 
    
    # Hide the needle: "When you see 101, the next token is 999"
    input_ids[0, needle_pos] = key_token
    input_ids[0, needle_pos+1] = val_token
    
    # Place trigger at the very end
    input_ids[0, -1] = trigger
    
    print(f"Sequence Length: {seq_len}")
    print(f"Needle '{key_token} -> {val_token}' hidden at index {needle_pos}")
    print(f"Trigger '{trigger}' placed at end (index {seq_len-1})")
    
    # 3. Train on this ONE example
    # We want the model to learn: "If I saw 101 earlier, output 999 now."
    optimizer = optim.AdamW(model.parameters(), lr=0.01)
    model.train()
    
    # Target: Predict the token AFTER the trigger
    target = torch.tensor([val_token], device=input_ids.device)
    
    print("\nTraining...")
    for i in range(150):
        # Forward pass
        outputs = model(input_ids)
        
        # We only care about the prediction at the very last step
        last_token_logits = outputs.logits[0, -1, :] # [Vocab]
        
        # Calculate loss manually for just the last token
        loss = nn.CrossEntropyLoss()(last_token_logits.unsqueeze(0), target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 30 == 0:
            print(f"Step {i:03d}: Loss {loss.item():.4f}")

    # 4. Inference Test (Recurrent Mode)
    print("\n--- Testing Inference (Generation) ---")
    model.eval()
    
    # We feed the sequence up to the trigger, then generate ONE token
    # This forces the model to use its Recurrent State to remember the needle
    with torch.no_grad():
        # Pass the whole sequence
        outputs = model(input_ids)
        
        # Look at the prediction for the step *after* the trigger
        logits = outputs.logits[0, -1, :]
        predicted_id = torch.argmax(logits).item()
    
    print(f"Expected: {val_token}")
    print(f"Predicted: {predicted_id}")
    
    if predicted_id == val_token:
        print("\n‚úÖ SUCCESS: Found the needle! Your Linear Attention is working.")
    else:
        print(f"\n‚ùå FAILURE: Predicted {predicted_id}. Memory lost.")

if __name__ == "__main__":
    test_needle_retrieval()


--- üïµÔ∏è‚Äç‚ôÄÔ∏è Needle in a Haystack Test ---
Sequence Length: 500
Needle '101 -> 999' hidden at index 50
Trigger '101' placed at end (index 499)

Training...
Step 000: Loss 72.3350
Step 030: Loss 0.0000
Step 060: Loss 0.0000
Step 090: Loss 0.0000
Step 120: Loss 0.0000

--- Testing Inference (Generation) ---
Expected: 999
Predicted: 999

‚úÖ SUCCESS: Found the needle! Your Linear Attention is working.
