9B single-GPU prefill is 1.5–3× llama.cpp at every length up to ~4.6 K tokens (2026-05-02 default flip, commit
ca30368). With split-K FlashAttention on by default (commitf2e52b8), 9B 18 K prefill jumps to 1.22× llama.cpp and 27B 3-GPU 18 K reaches parity (0.99×). Generation continues to win by +30–50%. Multi-GPU prefill is now pipelined (per-GPU compute / D2H / H2D streams + double-buffered hidden chunks): 9B dual-GPU 18 K hits 1.32× llama.cpp and 27B 3-GPU 18 K hits 1.45× — the long-context multi-GPU gap is closed. See Honest Benchmarks.
A custom CUDA inference engine for Qwen3.5 / Qwen3.6 hybrid (GDN + Attention) models, written from scratch and tuned for NVIDIA mining cards (CMP 100-210, ex-mining V100) — 16 GB HBM2, sm_70, PCIe Gen1 x1, no P2P. Not a fork — every kernel is written for these constraints.
📖 한국어 README → README.ko.md
- Serves Qwen3.5 / Qwen3.6 dense hybrid models in 27B and 9B sizes (GGUF Q8_0). MoE variants (Qwen3-Moe etc.) not supported.
- Vision input via Qwen3-VL mmproj (ViT + M-RoPE + spatial reshape).
- OpenAI-compatible HTTP API (
/v1/chat/completions,/v1/models, streaming, tool calls). - Continuous batching across N concurrent slots, with per-slot prefix caching.
- MTP draft speculative decoding (K=1) — works. DFlash + DDTree code is in the repo but not currently functional (drafter mismatch — see Limitations).
- 3-bit KV cache (MTP_TQ) — Walsh-Hadamard rotation + Lloyd-Max scalar quant. Same family of idea as llama.cpp #21038, but 3-bit Lloyd-Max instead of 4-bit RTN.
- Multi-GPU layer-parallel split with pinned-host activation bridge (no P2P required).
Mining cards (CMP 100-210, ex-mining V100) are dirt-cheap on the secondhand market and have HBM2 + 16 GB of VRAM that nobody is buying back, but NVIDIA cripples them in software:
- Tensor Cores throttled — HMMA latency stretched 64× (8 → 512 cycles), hard cap ~5 TFLOP via cuBLAS WMMA.
- PCIe Gen1 x1 only, no P2P, no NVLink.
- CUPTI blocked — no vendor profiler, no
torch.profiler. - All of this is enforced in hardware — e-fuse + PMU bootrom double-lock on the die. There is no software unlock; we tried.
So vLLM, llama.cpp's default cuBLAS path, FlashAttention, bitsandbytes — anything that goes through cuBLAS Tensor Cores runs at 1/64 speed or fails outright.
qengine works around it by:
- Routing GEMM through DP4A (int8) at ~17 TFLOP and HFMA2 (fp16 SIMD) at ~24 TFLOP — these paths are not throttled on CMP.
- A hand-written Q8_0 GEMM tile path for prefill.
- A hybrid attention layout that avoids the strict cuBLAS path for
max_seq > 32 K. - Pinned-host activation bridge between GPUs (since P2P is unavailable).
It's not faster than llama.cpp at everything. See the honest benchmarks below.
All measurements on a CMP 100-210 host, same Q8_0 GGUFs (Qwopus3.5-9B-v3.5, Qwopus3.6-27B-v1-preview), batch 1, single inflight request, FA on, layer split, split-K FA on by default. Both engines built for sm_70 with the int8 (MMQ / DP4A) path. qengine numbers are server-side prefill wall (excludes SSE handshake) for bench_curl.sh-style real chat-completion prompts; llama.cpp numbers are llama-bench at matching prompt sizes. Bigger is better; bold = winner. Single-GPU 9B measured 2026-05-02; 27B 3-GPU and split-K 9B 18 K measured 2026-05-03. llama.cpp build 8462.
| Prompt | qengine PP t/s | llama.cpp PP t/s | qengine TG t/s | llama.cpp TG t/s |
|---|---|---|---|---|
| 297 | 594 | 199 | 70.4 | — |
| 1.16 K | 683 | 316 | — | — |
| 4.62 K | 584 | 361 | — | — |
| 18.4 K | 393 | 324 | 27.6 | — |
| tg64 | — | — | — | 46.6 |
qengine: 1.56–2.99× on prompts up to ~4.6 K, 1.22× at 18 K (split-K FA, 1.46× over the pre-split-K build). Generation +51% on the comparable short-context point (70.4 vs 46.6 t/s).
| Prompt | qengine PP t/s | llama.cpp PP t/s | qengine TG t/s | llama.cpp TG t/s |
|---|---|---|---|---|
| 297 | 594 | 188 | 68.5 | — |
| 1.16 K | 968 | 412 | — | — |
| 4.62 K | 982 | 574 | — | — |
| 18.4 K | 720 | 545 | 27.0 | — |
| tg64 | — | — | — | 44.2 |
Cross-GPU prefill pipelining lands 2026-05-03 (default on; PREFILL_NO_PIPELINE=1 to opt out): per-GPU compute / D2H / H2D streams + double-buffered host transfer + double-buffered per-GPU hidden chunks overlap chunk i's cross-GPU activation transfer with chunk i+1's downstream compute. 9B dual-GPU 18 K jumps from 259 → 720 t/s (2.78×) and now wins llama.cpp at every length (1.32× even at 18 K). Sampled tokens are bit-equivalent to the sequential path — verified against the per-token greedy argmax up to 18 K.
| Prompt | qengine PP t/s | llama.cpp PP t/s | qengine TG t/s | llama.cpp TG t/s |
|---|---|---|---|---|
| 297 | 212 | 74.2 | 27.1 | — |
| 1.16 K | 264 | 127.8 | 25.6 | — |
| 4.62 K | 268 | 146.0 | 20.4 | — |
| 18.4 K | 203 | 140.0 | 11.4 | — |
| tg128 | — | — | — | 17.7 |
qengine: 2.86× / 2.07× / 1.84× / 1.45× at 297 / 1.16 K / 4.62 K / 18 K — pipelining lifts the long-context number from parity (139, split-K only) to a clean 1.45× win. Generation +53% at 297 ctx (27.1 vs 17.7 t/s).
- 9B single-GPU prefill: qengine wins at every length now, 1.22–2.99×. The chat-app sweet spot.
- 9B dual-GPU prefill: qengine now wins at every length (1.32× at 18 K) thanks to cross-GPU pipelining.
- 27B 3-GPU prefill: qengine wins everywhere, 1.45–2.86×. The parity gap at 18 K is gone.
- Generation throughput: qengine wins by ~30–50% on both 9B and 27B. This is what users feel as the chat being responsive.
Honest take: most of the surface-level features overlap. The list below is what actually differs in practice. Not measured head-to-head where not stated — corrections / PRs welcome.
- Generation throughput at sm_70 + CMP — measured +30–50% over llama.cpp on this exact hardware. See benchmarks above.
- OpenAI Chat Completions API built into the engine binary — streaming,
image_url, tool/function calls, no separate server process. llama.cpp hasllama-serverwhich covers most of this. - MTP_TQ uses 3-bit Lloyd-Max + WHT — llama.cpp's #21038 already lands rotation + standard scalar quant types (q4_0 etc.). Ours is 3-bit (vs 4-bit), which gives a slightly higher compression ratio on KV. Whether this beats q4_0 RTN on perplexity is not yet verified head-to-head.
- Continuous batching with per-slot prefix snapshots — not unique conceptually. The integration with our scheduler is tight; whether it actually beats llama.cpp's batched server is not yet measured.
- Qwen3-VL multimodal — we have it. So does llama.cpp via
tools/mtmd/models/qwen3vl.cpp. Not an advantage. - DFlash + DDTree speculative decode (experimental, currently broken) — z-lab pretrained drafter mismatches Qwopus3.6 distill distribution; produces degenerate output. Listed for transparency, not as a feature. Requires drafter fine-tune to be usable.
Designed for / regularly tested on:
- 4× NVIDIA CMP 100-210 (Volta GV100, 16 GB HBM2, sm_70, PCIe Gen1 x1, no P2P)
- Total 64 GB VRAM, ~8 GB system RAM (yes, eight)
Should also work on (sm_70 / sm_72 / sm_75):
- V100 16/32 GB (much less throttled than CMP — should be faster)
- Titan V, Quadro GV100
- T4, RTX 20-series (sm_75) — untested, kernels target sm_70 paths
Will not work on:
- sm_60 or earlier (no DP4A)
- AMD / Apple Silicon
If you have a modern GPU (RTX 30/40/50, A100, H100), you should use vLLM or SGLang instead. They are far more optimized for those targets and have actual test coverage.
Requires CUDA 12.x, GCC 11+, CMake 3.18+.
git clone https://github.com/Haru-neo/qengine.git
cd qengine
mkdir build && cd build
cmake ..
make -j$(nproc)MTP_TQ=1 ./build/qwen-engine \
/path/to/Qwen3.6-27B-Q8_0.gguf \
--serve 8000 --max-seq 262144 --slots 1 \
--vision-mmproj /path/to/Qwen3.6-27B-mmproj.ggufQWEN_SLOTS=4 ./build/qwen-engine \
/path/to/Qwen3.5-9B-Q8_0.gguf \
--serve 8001 --max-seq 32768CUDA_VISIBLE_DEVICES=0,1 ./build/qwen-engine ... --serve 8001If mtp_head_<hidden>.bin exists (set with --mtp-head <path> or default ./mtp_work/mtp_head_<hidden>.bin), the engine will load it for K=1 speculative decoding. Without it the engine runs plain greedy / continuous-batched.
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen",
"messages": [{"role":"user","content":"hello"}],
"max_tokens": 256
}'Vision (27B with mmproj):
B64=$(base64 -w 0 image.png)
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d "{
\"model\":\"qwen\",
\"messages\":[{\"role\":\"user\",\"content\":[
{\"type\":\"text\",\"text\":\"What is in this image?\"},
{\"type\":\"image_url\",\"image_url\":{\"url\":\"data:image/png;base64,${B64}\"}}
]}]
}"| Var | Default | Effect |
|---|---|---|
MTP_TQ |
0 |
3-bit KV cache (WHT + Lloyd-Max). Required for 27B at 256 K. |
FLASH_ATTN |
1 |
FA fused score+softmax+value. 0 falls back to the strict block-per-score path (bit-exact with per-token, ~2× slower prefill). |
BIT_EXACT_GEMM_ON |
0 |
Use the strict column-wise GEMV reduction path instead of the GEMM tile (regression / bit-exact testing, ~2.4× slower prefill). |
FA_BM |
32 |
FA tile width. 64 halves K/V tile-load iterations (96 KB SMEM opt-in). Marginal on the prompts we measured. |
FA_NT |
1 |
Per-block t_idx count. 2 shares K/V tile across 2 query rows; currently 14% slower at long context (kept as infra). |
FA_SK |
4 |
FA split-K factor at sub_seq_total ≥ 4 K. Spreads each (kv_head, t_idx) across N blocks (default 4) merged via log-sum-exp; lifts long-prompt prefill ~1.46× on 9B and ~1.34× on 27B. 0 to opt out. fp32 partials keep argmax bit-stable with the base FA path. |
PREFILL_NO_PIPELINE |
unset | Set to disable cross-GPU prefill pipelining (per-GPU compute / D2H / H2D streams + double-buffered hidden + host transfer). Pipelining is auto-enabled with ≥2 GPU segments and gives ~2–3× prefill at 1K-18K. |
PREFILL_NO_HOST_FENCE |
unset | Set to drop the cudaEventSynchronize between cross-GPU D2H and H2D — racy on CMP 100-210 (sampled tokens diverge from the sequential path). Only set this for benchmarking the upper-bound throughput on hardware where cross-device stream waits properly fence pinned memory. |
QWEN_SLOTS |
1 |
Concurrent slots (continuous batching). Set via --slots too. |
QWEN_MAX_QUEUE |
64 |
Max queued requests; 0 = unbounded. |
MTP_ACCEPT_TOP2 |
0 |
MTP K=2 top-2 verify (small accept rate gain). |
CUDA_VISIBLE_DEVICES |
— | Standard CUDA mask; engine splits layers across visible GPUs. |
src/
main.cu entry, weight load, generation loops, OpenAI server glue
server.h HTTP/1.1 + SSE, OpenAI compat parsing
scheduler.h continuous batching, queue, cancel propagation
model.cuh QwenModel: forward, multi-GPU dispatch, KV state
tokenizer.h BPE tokenizer, chat template, <think> strip
ops.cuh RMSNorm, SiLU, residual, embedding dequant
attention.cuh RoPE, head norm, scoring, softmax, value, KV cache
gdn_kernels.cuh Conv1d, GDN recurrent step, output projection
mtp_head.cuh MTP draft head (K=1, K=2 opt-in)
dflash_*.cuh DFlash + DDTree speculative path (experimental)
vision.cuh Qwen3-VL ViT + M-RoPE + spatial reshape + splice
quant_gemv.cuh Q5_K / Q6_K / Q8_0 GEMV kernels (DP4A path)
q8_0_gemm.cuh Q8_0 GEMM tile path (default for prefill)
turboquant.cuh WHT + Lloyd-Max 3-bit KV (MTP_TQ)
gguf.h GGUF v3 parser, mmap loader
gpu_loader.h multi-GPU parallel weight load (thread pool + streams)
sampling.h top-p / top-k / min-p / rep-pen / freq-pen / pres-pen
Layer split (4-GPU 27B example): GPU 0 holds layers 0–15 + token embeddings; GPU 3 holds layers 48–63 + output norm + LM head; activations bounce through pinned host memory between GPUs.
- MoE not supported. Only dense Qwen3 hybrid (GDN + Attention) models — Qwen3-Moe and similar mixture-of-experts variants do not load.
- DFlash + DDTree spec decode is currently broken. Pretrained drafter (
lucebox-hub/dflash) is trained on stock Qwen3.5; output distribution doesn't match the Qwopus distill we use, so accept rate ≈ 0% and the chains degenerate. Code is in the repo for the eventual fine-tuned drafter, but as shipped this path is unusable. - No batched MTP / spec. Speculative paths run only when
slots == 1. Withslots > 1, the batched gen loop is plain greedy. - GGUF Q8_0 is the supported path. Q5_K_M / Q6_K load but quality is degraded — use Q8_0.
- sm_70 specific tuning. Should run on sm_75; sm_80+ has better engines anyway.
- Single-host. No tensor parallelism across machines, no multi-node.
- Linux only.
- Cross-GPU prefill pipelining requires a host-side fence (
cudaEventSynchronizebetween D2H and H2D). On CMP 100-210 (PCIe 1.0 x1, no P2P) the cross-devicecudaStreamWaitEventdoesn't reliably fence pinned host memory between the source GPU's D2H and the destination GPU's H2D — H2D reads stale bytes and the first sampled token diverges from the sequential path on some prompt lengths. The host fence is default-on (PREFILL_NO_HOST_FENCE=1to revert to the racy event-only path) and the perf cost is ≤3% at 18 K because chunks-internal overlap is already serialized by stream FIFOs. May or may not affect newer hardware with proper P2P. - Continuous batching with system-prompt-less requests can stop after 1 token on Qwopus distill models — known issue with empty-system-prompt EOS bias under batched gen. Set
--default-system-prompt.
Active personal project. APIs and env vars may change. Issues / PRs welcome but expect slow turnaround — solo project.
- Qwen team for the Qwen3 / Qwen3-VL model family and architecture.
llama.cppfor the GGUF format, reference quant kernels, and the cargo of measurement / quantization research the broader community has produced. Particularly #21038 (rotation for KV quant) which arrived ahead of our MTP_TQ work.stb_image.h(public domain) for image decode in the vision path.- TurboQuant, DFlash + DDTree speculative decoding (
lucebox-hub/dflash) as experimental references. - Anthropic Claude — kernel implementation and CUDA work across many sessions.
Apache 2.0 — see LICENSE.