In [2]:
import sys

sys.path.append("..")

In [3]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from model.long_new import LongConfig, LongForCausalLM
from model.long_new import chunked_parallel_scan, recurrent_scan
# from transformers import MambaConfig, MambaForCausalLM


Scan Equivalence: Parallel_Output == Recurrent_Output (down to the decimal).

In [5]:
def test_kernel_equivalence():
    """
    CRITICAL TEST: 
    Does the Parallel Scan (Training) match the Recurrent Scan (Inference)?
    """
    torch.manual_seed(42)
    B, T, H, D = 2, 64, 4, 32

    k = torch.randn(B, T, H, D)
    v = torch.randn(B, T, H, D)
    gate = torch.sigmoid(torch.randn(B, T, H, D))
    gamma = torch.sigmoid(torch.randn(B, T, H, 1)) # Decay

    # 1. Run Parallel (Training Mode)
    out_parallel = chunked_parallel_scan(k, v, gate, gamma, chunk_size=16)

    # 2. Run Recurrent (Inference Mode)
    # We initialize state as zeros
    state = torch.zeros(B, H, D)
    out_recurrent, _ = recurrent_scan(k, v, gate, gamma, state)

    # 3. Compare
    # We use a relaxed tolerance because chunked_parallel uses float64 internally
    # while recurrent might use float32 depending on input.
    diff = (out_parallel - out_recurrent).abs().max()
    print(f"Max Difference: {diff.item():.8f}")

    if diff < 1e-4:
        print("✅ SUCCESS: Parallel and Recurrent Scans match!")
    else:
        print("❌ FAILURE: Scans diverge. Check padding or float precision.")

test_kernel_equivalence()

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/workspace/long-attention/benchmarks/../model/long_new/ops/functional.py", line 55, in recurrent_scan
    
    u = k * v * gate
    gamma_sq = gamma.view(1, H, 1)
               ~~~~~~~~~~ <--- HERE
    
    # Simple Recurrence: h_t = gamma * h_{t-1} + u_t
RuntimeError: shape '[1, 4, 1]' is invalid for input of size 512


In [4]:
def test_mixed_generation_consistency():
    """
    Tests if the model produces the same results when:
    A) Processing the whole sequence at once (Prompt Processing)
    B) Processing token-by-token (Generation)
    
    This verifies that your mixed State/KV-Cache logic is correct.
    """
    print("\n--- 2. Testing Generation Consistency (Hybrid Cache) ---")
    
    # Config with Hybrid Layers (Anchor every 2 layers)
    config = LongConfig(
        vocab_size=100, 
        hidden_size=64, 
        num_hidden_layers=4, 
        num_heads=4, 
        hybrid_ratio=2 # Layers 1, 3=Long; 2, 4=Anchor
    )
    model = LongForCausalLM(config)
    model.eval()
    
    # Dummy Input: [Batch, SeqLen]
    input_ids = torch.randint(0, 100, (1, 10))
    
    # A. Full Forward Pass (Prompt)
    with torch.no_grad():
        out_full = model(input_ids)
    logits_full = out_full.logits
    
    # B. Step-by-Step Generation (Simulating model.generate)
    past_key_values = None
    generated_logits = []
    
    with torch.no_grad():
        for t in range(input_ids.shape[1]):
            # Feed one token at a time
            token = input_ids[:, t:t+1]
            
            outputs = model(token, past_key_values=past_key_values)
            
            # Update history
            past_key_values = outputs.past_key_values
            
            # Store logit for this step
            generated_logits.append(outputs.logits)

    # Concatenate step outputs
    logits_step = torch.cat(generated_logits, dim=1)
    
    # Compare only the last few tokens (early tokens might vary slightly due to warm-up)
    diff = (logits_full - logits_step).abs().max()
    print(f"Generation Max Logic Diff: {diff.item():.6f}")
    
    if diff < 1e-4:
        print("✅ SUCCESS: Step-by-step generation matches full forward pass!")
    else:
        print("❌ FAILURE: Generation drift detected. Check State/KV alignment.")

In [5]:
def test_backward_pass():
    """
    Ensures gradients flow through the custom kernels without error.
    """
    print("\n--- 3. Testing Backward Pass (Training) ---")
    config = LongConfig(
        vocab_size=100, 
        hidden_size=64, 
        num_hidden_layers=2, 
        num_heads=4  # <--- ADDED THIS (64 / 4 = 16)
    )
    model = LongForCausalLM(config)
    
    input_ids = torch.randint(0, 100, (2, 32))
    labels = input_ids.clone()
    
    # Forward
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss
    print(f"Initial Loss: {loss.item():.4f}")
    
    # Backward
    try:
        loss.backward()
        print("✅ SUCCESS: Gradients computed successfully.")
        
        # Check for NaNs in gradients
        has_nan = False
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"❌ NaN detected in {name}")
                has_nan = True
        
        if not has_nan:
            print("✅ SUCCESS: No NaNs in gradients.")
            
    except Exception as e:
        print(f"❌ FAILURE: Backward pass crashed: {e}")

In [6]:
if __name__ == "__main__":
    test_kernel_equivalence()
    test_mixed_generation_consistency()
    test_backward_pass()

Max Difference: 0.00000024
✅ SUCCESS: Parallel and Recurrent Scans match!

--- 2. Testing Generation Consistency (Hybrid Cache) ---
Generation Max Logic Diff: 0.280876
❌ FAILURE: Generation drift detected. Check State/KV alignment.

--- 3. Testing Backward Pass (Training) ---
Initial Loss: 4.6158
✅ SUCCESS: Gradients computed successfully.
✅ SUCCESS: No NaNs in gradients.
