# Proteus Attention Playground (ASPA Edition)

This notebook walks through three short demos showcasing Adaptive Sparse Proto Attention (ASPA):

1. **Baseline vs ASPA fine-tune** – train both models on Tiny Shakespeare and compare metrics.
2. **Scaling ladder** – benchmark latency + memory from 4K up to 512K tokens with `scripts/tinytoy.py`.
3. **Long-context storytelling** – stream a million-token prompt through the chunked shortlist pipeline and sample continuations.

Each section is self-contained so you can skip ahead or rerun cells independently.

## 0. Environment Setup

Run the next cell on Colab to clone the repo (if needed) and install dependencies. On local machines you can skip the clone and just run the editable install.

In [None]:
%%bash
if [ ! -d Proteus-Attention ]; then
  git clone https://github.com/Zen-Sherbert/Proteus-Attention.git
fi
cd Proteus-Attention
pip install -q -e .

## 1. Baseline vs ASPA Fine-Tune

We reuse `examples/aspa_train.py`, which trains both the ASPA model and the dense baseline on Tiny Shakespeare in a single run. The flags below keep things lightweight for Colab while still showing divergence in loss/latency.

* `--epochs 2` keeps the run short.
* `--block-size 2048` exercises medium contexts.
* `--batch-size 4` fits comfortably in 16 GB GPUs.
* `--run-label` tags the summary row for later inspection.

In [None]:
%%bash
cd Proteus-Attention
python examples/aspa_train.py \n  --epochs 2 \n  --block-size 2048 \n  --batch-size 4 \n  --d-model 512 \n  --n-layer 4 \n  --run-label playground-demo \n  --no_compile \n  --gen-tokens 80

### Inspect training summaries

The previous script writes a JSON summary under `examples/checkpoints/`. Use the helper below to load and display the latest entry.

In [None]:
%%bash
cd Proteus-Attention
python - <<'EOF'
import json
from pathlib import Path
root = Path('examples/checkpoints')
logs = sorted(root.rglob('history.json'))
if not logs:
    print('No history files yet.')
else:
    path = logs[-1]
    data = json.loads(path.read_text())
    print(f'Loaded {path}')
    last = data[-1]
    for k, v in last.items():
        print(f"{k:>16}: {v}")
EOF

### Plot training metrics

If `examples/aspa_train.py` wrote a `history.csv`, this cell plots train loss and validation perplexity for the most recent run.

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
root = Path('Proteus-Attention/examples/checkpoints')
files = sorted(root.rglob('history.csv'))
if not files:
    print('No history.csv files found yet.')
else:
    df = pd.read_csv(files[-1])
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(df['step'], df['loss'], label='train loss')
    if 'val_loss' in df.columns:
        axes[0].plot(df['step'], df['val_loss'], label='val loss')
    axes[0].set_xlabel('step')
    axes[0].set_ylabel('loss')
    axes[0].legend()
    axes[0].set_title('Train / Val Loss')
    if 'val_ppl' in df.columns:
        axes[1].plot(df['step'], df['val_ppl'], label='match ppl')
    if 'val_ppl_fixed' in df.columns:
        axes[1].plot(df['step'], df['val_ppl_fixed'], label='fixed ppl')
    axes[1].set_xlabel('step')
    axes[1].set_ylabel('perplexity')
    axes[1].legend()
    axes[1].set_title('Validation Perplexity')
    fig.tight_layout()
    plt.show()

## 2. Scaling Ladder Benchmark

`tinytoy.py` sweeps context lengths and reports latency/memory. The cell below targets a single GPU; adjust `--device` if you're on CPU. The `--report-dir` flag keeps JSON outputs under `reports/tinytoy` for later analysis.

In [None]:
%%bash
cd Proteus-Attention
python scripts/tinytoy.py \n  --device auto \n  --max-seq-count 8 \n  --plot-path reports/tinytoy/playground.png \n  --report-dir reports/tinytoy \n  --no-save

### Display the generated plot (if matplotlib was available)

If matplotlib ran successfully, the plot saved under `reports/tinytoy/playground.png` will appear below.

In [None]:
from pathlib import Path
from PIL import Image
path = Path('Proteus-Attention/reports/tinytoy/playground.png')
if path.is_file():
    display(Image.open(path))
else:
    print('Plot not found (likely matplotlib unavailable).')

### Plot latency/memory summary

Loads the newest `tinytoy_summary_*.json` from the reports directory and charts latency/memory vs sequence length for Standard vs ASPA runs.

In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
reports = Path('Proteus-Attention/reports/tinytoy')
summary_files = sorted(reports.glob('tinytoy_summary_*.json'))
if not summary_files:
    print('No tinytoy summary files found. Rerun Section 2 first.')
else:
    data = json.loads(summary_files[-1].read_text())
    runs = data.get('runs_summary', [])
    if not runs:
        print('Summary file missing runs_summary entries.')
    else:
        df = pd.DataFrame(runs)
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        for label, group in df.groupby('model'):
            axes[0].plot(group['seq_len'], group['latency_ms'], marker='o', label=label)
            axes[1].plot(group['seq_len'], group['mem_mb'], marker='o', label=label)
        axes[0].set_xscale('log')
        axes[1].set_xscale('log')
        axes[0].set_xlabel('sequence length')
        axes[1].set_xlabel('sequence length')
        axes[0].set_ylabel('latency (ms)')
        axes[1].set_ylabel('memory (MB)')
        axes[0].set_title('Latency vs Seq Len')
        axes[1].set_title('Peak Memory vs Seq Len')
        for ax in axes:
            ax.legend()
            ax.grid(True, which='both', ls='--', alpha=0.3)
        fig.tight_layout()
        plt.show()

## 3. Long-Context Storytelling with Chunked Shortlist

This step streams a ~1 M token prompt through the chunked shortlist runner (`scripts/chunked_shortlist.py`).

* `--seq-len` defines the total tokens to simulate.
* `--chunk-len` controls streaming windows (kept ≤65 K for ROCm).
* `--alpha 1.0` forces the linear shortlist mode.

The script prints shortlist diagnostics and samples a short continuation with Top‑A stats.

In [None]:
%%bash
cd Proteus-Attention
python scripts/chunked_shortlist.py \n  --seq-len 1048576 \n  --chunk-len 65536 \n  --chunk-ratio 0.05 \n  --chunk-budget 4096 \n  --alpha 1.0 \n  --sample-tokens 120

## Next Steps

* Tweak `examples/aspa_train.py` flags to match your dataset or hardware.
* Run `scripts/chunked_shortlist_tests.py` for synthetic stress tests.
* Export the `reports/tinytoy/*.json` files to your own dashboards or add new stages to the notebook.

Happy experimenting!