# GroundThink v7
## GDN + SWA Hybrid with Chunk-Recurrent Delta Rule

**v7 Changes from v6:**
- Chunk-recurrent backward pass (numerically stable)
- Modular package structure (`groundthink_v7/`)
- No inline Triton kernels in notebook

**Package Structure:**
```
groundthink_v7/
├── __init__.py    # Clean exports
├── config.py      # HybridConfig
├── core.py        # Triton kernels + chunk_delta_rule
├── model.py       # GDN, SWA, TransparentHybrid
└── analysis.py    # NIAH tests, training utils
```

In [None]:
# =============================================================================
# SETUP
# =============================================================================

import sys
sys.path.insert(0, '.')  # Ensure groundthink_v7 is importable

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# =============================================================================
# IMPORTS
# =============================================================================

from groundthink_v7 import (
    HybridConfig,
    TransparentHybrid,
    proper_niah_test,
    test_niah_by_distance,
    run_full_diagnostic,
    validate_delta_rule,
    train_curriculum,
    analyze_gradients,
    load_wikitext,
)

print("✓ groundthink_v7 imported")

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

cfg = HybridConfig(
    d_model=256,
    n_heads=8,
    head_dim=32,
    value_dim=64,
    layer_pattern="GS",
    window_size=64,
    chunk_size=64,
    beta_bias=-2.0,
    g_bias=2.0,
)

print(cfg)

In [None]:
# =============================================================================
# CREATE MODEL
# =============================================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = TransparentHybrid(cfg).to(DEVICE).bfloat16()

print(f"Parameters: {model.count_params():,}")
print(f"Device: {DEVICE}")

In [None]:
# =============================================================================
# VALIDATION
# =============================================================================

print("="*60)
print("FORWARD/BACKWARD TEST")
print("="*60)

# Forward
x = torch.randint(0, 1000, (2, 64), device=DEVICE)
with torch.no_grad():
    logits, _, diags, state = model(x)
print(f"Forward: output={logits.shape}, state_norm={state.norm().item():.4f} ✓")

# Backward
model.train()
y = torch.randint(0, 1000, (2, 64), device=DEVICE)
_, loss, _, _ = model(x, y)
loss.backward()
print(f"Backward: loss={loss.item():.4f} ✓")

# Delta Rule validation
validate_delta_rule(DEVICE)

In [None]:
# =============================================================================
# NIAH (Untrained Baseline)
# =============================================================================

print("\n" + "="*60)
print("NIAH TEST (Untrained)")
print("="*60)

model.eval()
proper_niah_test(model, seq_len=64, n_trials=20)

In [None]:
# =============================================================================
# LOAD DATA
# =============================================================================

data_loader = load_wikitext(n_tokens=500_000, seq_len=128, batch_size=16)

In [None]:
# =============================================================================
# TRAIN
# =============================================================================

# Fresh model for training
model = TransparentHybrid(cfg).to(DEVICE).bfloat16()

history = train_curriculum(
    model,
    data_loader,
    steps=1000,
    warmup_steps=200,
    lr=3e-4,
    retrieval_weight=2.0,
    log_interval=100,
)

In [None]:
# =============================================================================
# POST-TRAINING EVALUATION
# =============================================================================

print("="*60)
print("POST-TRAINING EVALUATION")
print("="*60)

print("\n1. NIAH Accuracy:")
proper_niah_test(model, seq_len=128, n_trials=30)

print("\n2. NIAH by Distance:")
test_niah_by_distance(model, seq_len=128)

print("\n3. State Health:")
run_full_diagnostic(model, seq_len=128)

print("\n4. Gradient Analysis:")
analyze_gradients(model)

In [None]:
# =============================================================================
# SAVE MODEL
# =============================================================================

# torch.save({
#     'model_state_dict': model.state_dict(),
#     'config': cfg,
#     'history': history,
# }, 'groundthink_v7_checkpoint.pt')

print("\n✓ Training complete!")