A theoretical reconstruction of the SubQ architecture, built from first principles using the available research literature.
Source: Subquadratic — Introducing SubQ
Standard transformer self-attention computes a dot-product similarity between every pair of tokens, yielding an N × N attention matrix. Memory and compute therefore scale as O(N²), making long sequences prohibitively expensive. State-of-the-art models are typically capped at 128 K–1 M tokens in practice before quality degrades or costs become unacceptable.
SubQ (by Subquadratic) breaks this bottleneck with what they call Subquadratic Sparse Attention (SSA). Rather than a dense attention matrix, SSA computes attention only over a carefully chosen sparse set of token pairs, achieving O(N) time and memory complexity for fixed hyper-parameters. This enables a 12-million-token context window at roughly 1/5 the cost of comparable dense-attention models, with no chunking or summarisation of the context.
SSA replaces the O(N²) self-attention with a union of three sparse patterns that together preserve long-range expressiveness at linear cost:
SSA_mask[i, j] = local_mask[i, j] # 1. local window
| global_mask[i, j] # 2. global token
| routing_mask[i, j] # 3. content routing
A single softmax over the unioned sparse scores produces a valid probability distribution over attended positions.
Each token attends to its nearest ±window_size neighbours (default: 512).
Captures short-range syntactic and semantic patterns.
local_mask[i, j] = 1 iff |i − j| ≤ window_size
The leading num_global_tokens positions (default: 64) act as global sinks:
- They attend to every token in the sequence.
- Every token attends to them.
This gives O(1)-hop connectivity between any two positions regardless of distance — all information can flow through the globals in two steps.
global_mask[i, j] = 1 iff j < G (all tokens → global keys)
| 1 iff i < G (global queries → all tokens)
A lightweight low-rank scorer (routing_rank=16 by default) computes a
similarity between every (query, key) pair using cheap low-dimensional
projections, then selects the top-top_k_sparse keys per query:
routing_scores[i, j] = route_q(h[i]) · route_k(h[j])^T
routing_mask[i, :] = top-K positions of routing_scores[i, :]
This introduces content-dependent long-range connections without enumerating all N² pairs. In a production deployment the top-K selection would be computed with an approximate nearest-neighbour index (FAISS, ScaNN) for true O(N · K) cost; in this reference implementation the routing scores are materialised densely for algorithmic clarity.
| Component | Time | Memory |
|---|---|---|
| Local window | O(N · W) | O(N · W) |
| Global tokens | O(N · G) | O(N · G) |
| Content routing | O(N · K) | O(N · K) |
| SSA total | O(N) | O(N) |
W, G, K are fixed hyper-parameters independent of N.
Rotary Position Embeddings are applied to Q and K before the attention computation. RoPE encodes absolute positions as rotations that cancel out to relative-position information in the dot-product, and handles arbitrarily long sequences without interpolation.
Each transformer block uses a SwiGLU MLP:
FFN(x) = down_proj( Swish(gate_proj(x)) ⊙ up_proj(x) )
SwiGLU empirically outperforms ReLU and GeLU at matched parameter budgets and is the activation used in LLaMA, PaLM, and related frontier models.
Root Mean Square Layer Normalisation is applied before each sub-layer (pre-norm), following the LLaMA / Mistral convention for training stability.
input_ids (B, N)
│
▼
embed_tokens ← nn.Embedding (V, D)
│
▼ × num_hidden_layers
┌─────────────────────────────────────────────────────┐
│ SubQTransformerLayer │
│ ┌────────────────────────────────────────────────┐ │
│ │ SubQRMSNorm │ │
│ │ ↓ │ │
│ │ SubquadraticSparseAttention (SSA) │ │
│ │ ├─ Q/K/V projections │ │
│ │ ├─ RoPE │ │
│ │ ├─ SSA mask (local ∪ global ∪ routing) │ │
│ │ └─ masked softmax → value weighted sum │ │
│ │ ↓ │ │
│ │ residual + │ │
│ └────────────────────────────────────────────────┘ │
│ ┌────────────────────────────────────────────────┐ │
│ │ SubQRMSNorm │ │
│ │ ↓ │ │
│ │ SubQMLP (SwiGLU) │ │
│ │ ↓ │ │
│ │ residual + │ │
│ └────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────┘
│
▼
SubQRMSNorm
│
▼
lm_head ← nn.Linear (D, V)
│
▼
logits (B, N, V)
| Parameter | Default | Description |
|---|---|---|
vocab_size |
50 257 | GPT-2 vocabulary |
hidden_size |
768 | Token embedding / hidden dimension |
num_hidden_layers |
12 | Number of transformer blocks |
num_attention_heads |
12 | Attention heads (query) |
num_key_value_heads |
same as Q | KV heads — set < Q heads for GQA |
intermediate_size |
3 072 | FFN inner dim (per expert for MoE) |
num_experts |
None |
MoE expert count; None = dense |
num_experts_per_tok |
1 | Active experts per token (top-K) |
window_size |
512 | Local attention half-width |
num_global_tokens |
64 | Number of global-sink tokens |
top_k_sparse |
128 | Top-K content routing connections |
routing_rank |
16 | Rank of routing projections |
max_position_embeddings |
12 000 000 | RoPE cache size (12 M tokens) |
rope_theta |
10 000.0 | RoPE base frequency |
rms_norm_eps |
1e-6 | RMSNorm numerical stability term |
pip install -e ".[dev]" # editable install + test deps (no extra deps)
pip install -e ".[dev,train]" # also installs tiktoken for GPT-2/4 BPE datasetsRequires Python ≥ 3.10 and PyTorch ≥ 2.2.
The repository ships two named SubQConfig presets that target distinct
deployment tiers.
Matches the hyper-parameters of Mistral 7B: 32-layer decoder, hidden size 4 096, SwiGLU FFN with inner dim 14 336. Uses Grouped Query Attention (32 Q heads / 8 KV heads) to halve the KV-cache footprint compared to standard MHA. SubQ's SSA replaces Mistral's fixed sliding-window attention, adding global-sink tokens and content routing on top of the local window.
from opensubq import SubQConfig, SubQModel
config = SubQConfig.mistral_7b()
# hidden_size=4096, 32 layers, 32 Q / 8 KV heads, dense SwiGLU FFN
# vocab_size=32_000 (Mistral tokeniser)
print(config)Matches the backbone dimensions of MiMo-V2-Flash (Xiaomi, 2025): 48-layer decoder, hidden size 7 168, 64 Q heads / 8 KV heads. The dense FFN is replaced by a 256-expert Sparse MoE (8 experts active per token via top-K routing), giving ~15 B active parameters per forward pass out of ~309 B total. SubQ's SSA is applied uniformly to every layer, providing the same O(1)-hop global connectivity as MiMo's interleaved full-attention layers but at linear cost.
from opensubq import SubQConfig, SubQModel
config = SubQConfig.mimo_v2_flash()
# hidden_size=7168, 48 layers, 64 Q / 8 KV heads
# 256 experts / 8 active (SparseMoEMLP per layer)
# vocab_size=152_064 (Qwen3 tokeniser)
print(config)| Preset | Scale | Layers | Hidden | Q / KV heads | FFN | Vocab |
|---|---|---|---|---|---|---|
SubQConfig.mistral_7b() |
~7 B | 32 | 4 096 | 32 / 8 | Dense SwiGLU | 32 000 |
SubQConfig.mimo_v2_flash() |
~15 B active / 309 B total | 48 | 7 168 | 64 / 8 | Sparse MoE 256 ×, top-8 | 152 064 |
import torch
from opensubq import SubQConfig, SubQModel
# Small model for experimentation
config = SubQConfig(
vocab_size=50_257,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3_072,
window_size=512,
num_global_tokens=64,
top_k_sparse=128,
)
model = SubQModel(config).eval()
print(f"Parameters: {model.num_parameters():,}")
# Forward pass
input_ids = torch.randint(0, config.vocab_size, (1, 1024))
with torch.no_grad():
logits = model(input_ids) # (1, 1024, 50257)
print(logits.shape)| Class / helper | Description |
|---|---|
CharDataset |
Byte-level (0–255) tokenisation; no extra deps; matches vocab_size=256 tiny config |
TiktokenDataset |
GPT-2 / GPT-4 BPE via tiktoken; requires pip install tiktoken |
make_synthetic_datasets() |
Reproducible random-token corpus for tests and quick demos |
make_split_loaders() |
Returns a (train_loader, val_loader) pair |
Both dataset classes produce (input_ids, labels) tensors with the autoregressive shift baked in and compatible with SubQModel.forward(input_ids, labels=labels).
from opensubq.data import CharDataset, make_split_loaders
# From a plain-text file (byte-level tokenisation):
train_ds, val_ds = CharDataset.from_file("corpus.txt", seq_len=1024)
train_loader, val_loader = make_split_loaders(train_ds, val_ds, batch_size=8)
# Or use GPT-2 BPE (requires tiktoken):
from opensubq.data import TiktokenDataset
train_ds, val_ds = TiktokenDataset.from_file("corpus.txt", seq_len=1024, encoding="gpt2")A ready-to-run training script at the repo root. Features:
torch.autocastmixed-precision — bfloat16 on CUDA, float32 on CPU- AdamW with cosine LR schedule and linear warmup (
--warmup-frac) - Gradient clipping (
--grad-clip, default 1.0) - Checkpoint save / resume (
--checkpoint-dir,--resume) - Eval loss on held-out val split, optional CSV loss log (
--log-file)
# Sanity-check: tiny model, synthetic data, CPU, ~5 s:
python train.py --preset tiny --data synthetic --max-steps 100
# Tier-1 training on a real corpus:
python train.py \
--preset mistral_7b \
--data file --data-file corpus.txt \
--seq-len 4096 --batch-size 4 \
--max-steps 100000 \
--checkpoint-dir ./ckpts \
--log-file loss.csv
python train.py --help # full option listPresets: tiny (64-dim, 2L, vocab 256), mistral_7b, mimo_v2_flash.
pytest tests/ -v109 tests across attention, model, data pipeline, and training loop.
TiktokenDataset tests are automatically skipped when tiktoken is not installed.
Progress against the whitepaper §8 Recommended Roadmap:
- Add causal mask to SSA forward
- Add autoregressive loss computation
- Wire in a tokeniser + small training corpus (
opensubq/data.py) - Write a minimal training loop with bfloat16, AdamW, checkpointing (
train.py) - Train a tiny sanity-check model to verify loss decreases
- Add
torch.compileand bfloat16 inference - Integrate FlashAttention-2 for the local-window component
- Add KV-cache for autoregressive decoding
- Train
SubQConfig.mistral_7b()on a mid-scale dataset
- Replace sequential expert dispatch with batched GEMM / megablocks
- Add auxiliary load-balancing loss
- Integrate an expert-parallel training framework
- Add tensor + pipeline parallelism for 309 B total weight distribution
- Train
SubQConfig.mimo_v2_flash()on a large-scale dataset
This repository is an independent theoretical reconstruction of the SubQ architecture built from publicly available information and the research literature cited below. It is not affiliated with, endorsed by, or based on proprietary code from Subquadratic. The implementation captures the design principles of SSA (local window + global tokens + content routing) as described in the company's public blog post.
| Paper / resource | Relevance |
|---|---|
| Subquadratic — Introducing SubQ | Primary source for SSA design goals and benchmarks |
| BigBird (Zaheer et al., 2020) | Local + global + random sparse attention; theoretical foundations |
| Longformer (Beltagy et al., 2020) | Sliding window + global attention for long documents |
| RoFormer / RoPE (Su et al., 2022) | Rotary Position Embeddings |
| GLU Variants / SwiGLU (Shazeer, 2020) | Gated linear units; SwiGLU activation |
| RMSNorm (Zhang & Sennrich, 2019) | Root Mean Square normalisation |
| FlashAttention (Dao et al., 2022) | Memory-efficient exact attention (production baseline) |
| Efficient Transformers Survey (Tay et al., 2020) | Survey of sub-quadratic attention approaches |
| GQA (Ainslie et al., 2023) | Grouped Query Attention — fewer KV heads than Q heads |
| MiMo-V2-Flash (Xiaomi, 2025) | 309B MoE model; inspiration for the mimo_v2_flash scale preset |
| Mistral 7B (Jiang et al., 2023) | Dense 7B baseline; inspiration for the mistral_7b scale preset |