Skip to content

fix: memory auto-cap strategy for SSD MoE streaming + speculative decoding (Issue #72)#77

Merged
solderzzc merged 10 commits intomainfrom
fix/issue-72-draft-model-ssd-ram
Apr 23, 2026
Merged

fix: memory auto-cap strategy for SSD MoE streaming + speculative decoding (Issue #72)#77
solderzzc merged 10 commits intomainfrom
fix/issue-72-draft-model-ssd-ram

Conversation

@solderzzc
Copy link
Copy Markdown
Member

Resolves #72 by implementing an auto-cap strategy for speculative decoding when combined with SSD streaming.

Changes:

  • Auto-caps --num-draft-tokens to 1 when --stream-experts and --draft-model are combined. This reduces I/O fan-out from 5x down to 2x, preventing OS memory thrashing.
  • Reverted the temporary ssd-opt-v2 buffer deactivation (preserving the 4% speedup for non-speculative users).
  • CI Enforcement: Added a mandatory ssd-draft-memory-guard CI job that runs a 2B main model and 0.8B draft model and enforces physical RAM limits via vm_stat.
  • Docs: Updated README to properly document the performance mechanics of this combination.

…odel (#72 follow-up)

Reporter confirmed the original fix addressed load-time RAM, but swap still
explodes during inference: OS_RAM=20.7GB / MEM_DEMAND=40.2GB on a 16GB machine.

Root cause (inference-time):
The 200GB memoryLimit sentinel is necessary for SSD streaming alone — it bypasses
MLX eval_impl's spin-wait loop when expert pages are evicted mid-graph.  However,
with speculative decoding the draft model (4B / 3GB) and main model (35B / 20GB)
alternate forward passes in tight succession.  Both models' expert pages are
demanded within the same inference cycle, combined demand ~23GB >> 16GB physical.
The 200GB sentinel provides zero back-pressure, so macOS swaps aggressively
(10+ GB observed in Activity Monitor).

Fix:
When --stream-experts + --draft-model are both set AND combinedFootprint > 70%
of physical RAM, lower memoryLimit from 200GB to physicalRAM × 1.1.  This forces
MLX to hit its hard limit sooner and evict stale expert pages more aggressively
rather than extending into swap.  A clear startup warning is also printed:

  ⚠️  SSD + draft-model RAM pressure warning:
     Main model: 20.4GB  Draft: 3.0GB  Combined: 23.4GB  Physical RAM: 16.0GB
     Speculative decoding alternates both models' forward passes.
     On this machine the combined weight exceeds physical RAM,
     causing page-cache thrashing and swap during inference.
     → Recommendation: remove --draft-model on this machine,
       or use a smaller draft model whose weights fit in
       remaining RAM after the main model's page budget (6GB).
     Memory limit set to 17GB (tight cap for MLX eviction pressure)

When combined footprint fits in RAM (e.g. smaller draft on a 32GB machine),
the 200GB sentinel is still used as before — no regression for capable hardware.
…-draft-model (#72)

Git history audit (mlx-swift-lm):
  e6ba580 - 8.5x speedup (0.58→4.95 tok/s) from cross-projection batching (Eric Lake, M1 Ultra)
  2c71c6c - ssd-opt-v2: +4% more via persistent expert buffers (asyncEval warm path)
  2b1c653 - PAPPS N+1 prefetch permanently disabled (hurt Apple-native TPS)

README (line 245) explicitly states:
  'Speculative decoding is counterproductive for SSD-streaming MoE specifically.
   The verify pass sends N+1 tokens, each routing to *different* experts — SSD I/O
   scales with the *union* of all positions' expert selections.'

Strategy (not a hard error):
When --stream-experts + --draft-model are combined:
  - Auto-cap --num-draft-tokens to 1 (verify pass = 2 positions, not N+1)
  - At 1 draft token: fan-out is 2× SSD I/O (vs 5× at default 4 tokens)
  - If acceptance rate ≥ 50% (typical for same-family models), net TPS is positive
  - Print a clear advisory so users understand the tradeoff
  - Persistent expert buffers (~5 GB warm path, ssd-opt-v2) are PRESERVED —
    no regression to Eric Lake's M1 Ultra benchmark

What is NOT changed:
  - SwitchLayers.swift warm path: untouched (idx.size <= 32 guard intact)
  - ExpertStreamingConfig: no new flags added (reverted failed hasDraftModel attempt)
  - computeSSDMemoryBudget() + cacheLimit logic from load-time fix: intact
  - Tight memoryLimit sentinel (physicalRAM × 1.1) when combined > 70% RAM: intact

Test coverage (18 tests, 0 failures):
  SSDDraftStrategyTests (10 new):
    - Fan-out arithmetic: 4 draft tokens → 5× I/O, 1 token → 2× I/O
    - Auto-cap fires only when streamExperts + draftModel + numDraftTokens > 1
    - Auto-cap does NOT fire for solo SSD streaming or pure RAM speculative decoding
    - Net throughput model: 70% acceptance at 2× fan-out is net positive
    - memoryLimit sentinel selection: tight cap on 16 GB, sentinel on 64 GB
  SSDMemoryBudgetTests (8 existing): all pass, no regressions
…sion

Three-check E2E test for the --stream-experts + --draft-model fix:

  [1/3] Auto-cap guard: verifies server log contains the 'auto-capping'
        warning, proving numDraftTokens was reduced from 4 to 1 at startup

  [2/3] RAM guard: measures vm_stat peak RAM during inference and fails
        if it exceeds 80% of physical RAM (the indicator that exposed the
        original swap explosion on reporter's 16GB M4 Mini)

  [3/3] Inference: verifies the combination still produces valid content
        (not crashed/empty), proving functional correctness

Uses small models (Qwen3.5-4B main + Qwen3.5-0.8B draft) — same
parameter-class proportions as the reporter's 35B+4B scenario but
runnable on any machine without 35B weights.

Run: ./run_benchmark.sh → option 10
New mandatory CI job: ssd-draft-memory-guard
  - Runs on every PR, needs: build_and_unit_test
  - Models: Qwen3.5-2B (main, SSD-streamed) + Qwen3.5-0.8B (draft)
    sized for the 7 GB macos-15 runner
  - Passes --num-draft-tokens 4 intentionally so the auto-cap fires

Three enforced checks:
  [1] grep 'auto-capping' in server log — proves guard fires, fails PR if absent
  [2] vm_stat peak RAM ≤ 85% of runner RAM during inference — fails PR if exceeded
  [3] /v1/chat/completions returns content — ensures combination stays functional

Every step writes vm_stat before/loaded/peak to GITHUB_STEP_SUMMARY as a
markdown table so memory readings are visible on every PR without digging logs.

Also upgrades speculative-decoding-eval (continue-on-error: true) to emit
vm_stat before/after readings to its step summary as telemetry.
…sue #72)

Three targeted README updates:

1. SSD Expert Streaming 'Important finding' callout (line 245):
   - Changed from blanket 'counterproductive / excluded' statement to
     explain the fan-out problem (5x I/O at default 4 draft tokens) and
     document the auto-cap-to-1 mitigation (2x I/O, net positive at >=50% acceptance)

2. Usage code block (line 274):
   - Added a '--stream-experts + --draft-model' example showing that
     num-draft-tokens is auto-capped to 1 at startup

3. CLI options table (line 407):
   - Updated --draft-model and --num-draft-tokens rows to mention the
     auto-cap behavior when combined with --stream-experts
… comment)

The multi_replace_file_content tool previously emitted a stray line
'eculative-eval.log' which was deleted with sed, but left 'retention'
(without '-days: 7') merged with an inline comment on line 219.
This caused GitHub Actions to reject the workflow file entirely with:
  'yaml: while scanning a simple key at line 219'

Fix: restore 'retention-days: 7' as a proper YAML key-value pair.
Copilot AI review requested due to automatic review settings April 23, 2026 17:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements Issue #72 mitigations for the problematic combination of SSD expert streaming (MoE) with speculative decoding, aiming to prevent RAM thrash/swap and preserve SSD-streaming performance.

Changes:

  • Auto-caps --num-draft-tokens to 1 when --stream-experts and --draft-model are used together (plus startup advisory).
  • Introduces a context-aware Memory.memoryLimit selection under SSD streaming when a draft model is present and combined weights are high vs physical RAM.
  • Adds regression coverage via a new unit test file, a new local benchmark test, updated README guidance, and a new mandatory CI guard job.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
Sources/SwiftLM/Server.swift Adds auto-cap strategy + tight memoryLimit selection logic for SSD streaming + draft model.
tests/SwiftLMTests/SSDPersistentBufferGuardTests.swift Adds regression-style tests for fan-out arithmetic and memory-limit sentinel selection behavior.
run_benchmark.sh Adds an interactive “Test 10” to exercise SSD + draft scenario and RAM guard checks locally.
.github/workflows/ci.yml Extends speculative eval with RAM snapshots and adds a mandatory ssd-draft-memory-guard job.
README.md Documents the SSD-streaming + draft-model interaction and the auto-cap behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +44 to +61
/// Net throughput is positive when: acceptance_rate × draft_tps > fan_out_penalty × base_tps
/// At 50% acceptance and 2× fan-out this is just barely net-neutral.
/// At 70% acceptance (typical for family models) it's clearly positive.
func testNetThroughput_CappedDraft_PositiveAt70PctAcceptance() {
let baseTPS = 5.0 // tok/s for SSD streaming alone
let draftTPS = 73.0 // tok/s for a 4B draft model in RAM
let fanOutPenalty = 2.0 // 2× I/O at 1 draft token
let acceptRate = 0.70 // typical for same-family models

// Net effective TPS with draft (simplified model):
// Each round: draft generates 1 token fast, main verifies 2 positions.
// If accepted: 1 extra token at draft speed per round.
// Cost: main model verify at base_tps / fan_out_penalty.
let effectiveVerifyTPS = baseTPS / fanOutPenalty
let netTPS = effectiveVerifyTPS + acceptRate * (draftTPS / draftTPS)

XCTAssertGreaterThan(netTPS, effectiveVerifyTPS,
"At 70% acceptance + 1 draft token, net TPS must exceed un-assisted verify TPS")
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

netTPS is computed as effectiveVerifyTPS + acceptRate * (draftTPS / draftTPS), but draftTPS / draftTPS is always 1.0, so this test effectively asserts only that acceptRate > 0 and doesn’t validate the intended throughput tradeoff. Adjust the simplified model so draftTPS actually influences the outcome, or reframe the test to assert something directly derivable from the auto-cap logic (e.g., verify-position count / fan-out).

Suggested change
/// Net throughput is positive when: acceptance_rate × draft_tps > fan_out_penalty × base_tps
/// At 50% acceptance and 2× fan-out this is just barely net-neutral.
/// At 70% acceptance (typical for family models) it's clearly positive.
func testNetThroughput_CappedDraft_PositiveAt70PctAcceptance() {
let baseTPS = 5.0 // tok/s for SSD streaming alone
let draftTPS = 73.0 // tok/s for a 4B draft model in RAM
let fanOutPenalty = 2.0 // 2× I/O at 1 draft token
let acceptRate = 0.70 // typical for same-family models
// Net effective TPS with draft (simplified model):
// Each round: draft generates 1 token fast, main verifies 2 positions.
// If accepted: 1 extra token at draft speed per round.
// Cost: main model verify at base_tps / fan_out_penalty.
let effectiveVerifyTPS = baseTPS / fanOutPenalty
let netTPS = effectiveVerifyTPS + acceptRate * (draftTPS / draftTPS)
XCTAssertGreaterThan(netTPS, effectiveVerifyTPS,
"At 70% acceptance + 1 draft token, net TPS must exceed un-assisted verify TPS")
/// With 1 draft token, the verify pass covers 2 positions, so SSD I/O fan-out is 2×.
/// In this simplified model, break-even acceptance is therefore 1 / fan_out = 50%.
/// At 70% acceptance (typical for same-family models), the capped strategy is on the
/// positive side of that threshold.
func testNetThroughput_CappedDraft_PositiveAt70PctAcceptance() {
let fanOutPenalty = 2.0 // 2× I/O at 1 draft token
let acceptRate = 0.70 // typical for same-family models
// Reframe the assertion around the auto-cap arithmetic directly:
// break-even acceptance_rate = 1 / verify_positions = 1 / fanOutPenalty.
let breakEvenAcceptanceRate = 1.0 / fanOutPenalty
XCTAssertEqual(breakEvenAcceptanceRate, 0.50, accuracy: 0.000_001,
"At 1 draft token, 2 verify positions imply a 50% break-even acceptance threshold")
XCTAssertGreaterThan(acceptRate, breakEvenAcceptanceRate,
"At 70% acceptance + 1 draft token, acceptance is above the capped 2-position break-even threshold")

Copilot uses AI. Check for mistakes.
Comment on lines +127 to +137
let physicalRAM = Int(16.0 * Double(gb))
let mainBytes = Int(20.4 * 1e9)
let draftBytes = Int(3.0 * 1e9)
let combined = mainBytes + draftBytes
let threshold = Int(Double(physicalRAM) * 0.70) // 11.2 GB

XCTAssertGreaterThan(combined, threshold,
"Reporter scenario: 23.4 GB combined must exceed 70% of 16 GB physical RAM")

let tightCap = Int(Double(physicalRAM) * 1.1) // ~17.6 GB
let sentinel = 200 * gb
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mixes GiB (gb = 1_073_741_824) for physicalRAM with decimal GB (1e9) for model footprints, but the inline comments assume all values are in “GB” (e.g., // 11.2 GB). This makes the comments misleading and can hide unit mistakes. Use a single unit system consistently in the test (either all GiB or all decimal GB) and update the comments accordingly.

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml Outdated
-H "Content-Type: application/json" \
-d '{"model":"test","messages":[{"role":"user","content":"What is 2+2? One word."}],"max_tokens":32,"stream":false}' \
2>/dev/null || echo "{}")
echo "inf_result=$RESULT" >> $GITHUB_OUTPUT
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

echo "inf_result=$RESULT" >> $GITHUB_OUTPUT is unsafe for GitHub Actions outputs because the JSON response can contain newlines and other characters that will corrupt the output file format, causing later steps to misread outputs. Store the response in a file, or use the multiline output syntax (inf_result<<EOF ... EOF) / base64-encode it before writing to $GITHUB_OUTPUT.

Suggested change
echo "inf_result=$RESULT" >> $GITHUB_OUTPUT
{
echo "inf_result<<EOF"
echo "$RESULT"
echo "EOF"
} >> "$GITHUB_OUTPUT"

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml Outdated
Comment on lines +510 to +516
RESULT='${{ steps.ram_peak.outputs.inf_result }}'
if echo "$RESULT" | grep -q '"content"'; then
TEXT=$(echo "$RESULT" | python3 -c \
"import sys,json;d=json.load(sys.stdin);print(d['choices'][0]['message']['content'])" \
2>/dev/null || echo "(parse error)")
echo "✅ Response: $TEXT"
else
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RESULT='${{ steps.ram_peak.outputs.inf_result }}' will break if the model response contains single quotes and can also be truncated/corrupted if the output wasn't encoded safely. Prefer reading the inference JSON from a file produced by the previous step, or base64-decode an encoded output, rather than injecting raw JSON into a single-quoted shell string.

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml Outdated
Comment on lines +414 to +418
RAM=$(vm_stat | awk '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*16384/1073741824 }
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vm_stat parsing hard-codes a 16,384-byte page size. That’s true for many Apple Silicon systems today, but it’s not guaranteed across macOS hardware/runner types (and vm_stat itself prints the actual page size). Consider reading the page size from sysctl -n hw.pagesize or parsing it from the first vm_stat line to keep the RAM calculations accurate.

Suggested change
RAM=$(vm_stat | awk '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*16384/1073741824 }
PAGE_SIZE=$(sysctl -n hw.pagesize)
RAM=$(vm_stat | awk -v page_size="$PAGE_SIZE" '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*page_size/1073741824 }

Copilot uses AI. Check for mistakes.
Comment on lines +389 to +410
if combinedExceedsRAM && draftFootprintBytes > 0 {
// Combined model weights exceed 70% of physical RAM.
// Speculative decoding causes both models' pages to be demanded
// simultaneously during draft+verify cycles, which will thrash
// the SSD page cache and trigger heavy swap.
// Use a tight memoryLimit so MLX evicts pages rather than swapping.
let tightLimit = Int(Double(physicalRAM) * 1.1)
Memory.memoryLimit = tightLimit
print("[SwiftLM] ⚠️ SSD + draft-model RAM pressure warning:")
print("[SwiftLM] Main model: \(String(format: "%.1f", Double(mainFootprintBytes) / 1e9))GB Draft: \(String(format: "%.1f", Double(draftFootprintBytes) / 1e9))GB Combined: \(String(format: "%.1f", Double(combinedFootprint) / 1e9))GB Physical RAM: \(String(format: "%.1f", Double(physicalRAM) / 1e9))GB")
print("[SwiftLM] Speculative decoding alternates both models' forward passes.")
print("[SwiftLM] On this machine the combined weight exceeds physical RAM,")
print("[SwiftLM] causing page-cache thrashing and swap during inference.")
print("[SwiftLM] → Recommendation: remove --draft-model on this machine,")
print("[SwiftLM] or use a smaller draft model whose weights fit in")
print("[SwiftLM] remaining RAM after the main model's page budget (\(Memory.cacheLimit / (1024*1024*1024))GB).")
print("[SwiftLM] Memory limit set to \(tightLimit / (1024*1024*1024))GB (tight cap for MLX eviction pressure)")
} else {
// No draft model, or combined fits in RAM — use the standard sentinel
// to bypass MLX eval_impl's spin-wait loop safely.
Memory.memoryLimit = 200 * 1024 * 1024 * 1024 // 200 GB sentinel
}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory.memoryLimit is set to either tightLimit or the 200GB sentinel here, but later in run() the partition strategy switch unconditionally sets Memory.memoryLimit = 200GB for streamExperts in the .swapAssisted / .layerPartitioned branches. That will override this tight-cap logic and make the Issue #72 protection ineffective. Consider applying the conditional memoryLimit after the strategy switch, or gating the later assignment so it doesn’t overwrite a tighter limit that was already selected.

Copilot uses AI. Check for mistakes.
Comment thread run_benchmark.sh Outdated
Comment on lines +1175 to +1181
# Measure RAM via vm_stat (Apple Silicon page size = 16384 bytes)
get_ram_gb_t10() {
vm_stat | awk '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*16384/1073741824 }
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_ram_gb_t10() hard-codes a 16,384-byte vm_stat page size. That will misreport RAM on systems with a different page size (e.g., some Intel macs). Prefer sysctl -n hw.pagesize (or parse the vm_stat header) and multiply by that value so the regression guard is accurate across machines.

Suggested change
# Measure RAM via vm_stat (Apple Silicon page size = 16384 bytes)
get_ram_gb_t10() {
vm_stat | awk '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*16384/1073741824 }
# Measure RAM via vm_stat using the system page size
get_ram_gb_t10() {
local vm_output
local page_size
vm_output=$(vm_stat) || return 1
page_size=$(sysctl -n hw.pagesize 2>/dev/null | tr -d '[:space:]')
if [ -z "$page_size" ]; then
page_size=$(
printf '%s\n' "$vm_output" \
| sed -n '1s/.*page size of \([0-9][0-9]*\) bytes.*/\1/p'
)
fi
[ -n "$page_size" ] || return 1
printf '%s\n' "$vm_output" | awk -v page_size="$page_size" '
/Pages active:/ { v=$3; gsub(/\./, "", v); act=v+0 }
/Pages wired down:/ { v=$4; gsub(/\./, "", v); wire=v+0 }
/Pages occupied by compressor:/ { v=$5; gsub(/\./, "", v); comp=v+0 }
END { printf "%.2f", (act+wire+comp)*page_size/1073741824 }

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/ci.yml
Comment on lines +432 to +448
.build/release/SwiftLM \
--model mlx-community/Qwen3.5-2B-4bit \
--draft-model mlx-community/Qwen3.5-0.8B-MLX-4bit \
--stream-experts \
--num-draft-tokens 4 \
--port 15473 \
--max-tokens 64 \
> /tmp/ssd_draft_guard.log 2>&1 &
echo "server_pid=$!" >> $GITHUB_OUTPUT

echo "Waiting for server (up to 300s)..."
for i in $(seq 1 300); do
if ! kill -0 ${{ steps.server.outputs.server_pid }} 2>/dev/null; then
echo "Server died early:"
cat /tmp/ssd_draft_guard.log
exit 1
fi
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this step, the server PID is written to $GITHUB_OUTPUT, but then the loop immediately references ${{ steps.server.outputs.server_pid }} within the same step. Step outputs are not available until the step completes, so this will expand to empty and the kill -0 check will be unreliable. Capture $! into a shell variable (e.g., PID) and use that inside the loop, while still emitting it as an output for later steps.

Copilot uses AI. Check for mistakes.
Comment thread Sources/SwiftLM/Server.swift Outdated
Memory.memoryLimit = 200 * 1024 * 1024 * 1024 // 200 GB sentinel

// Determine safe memoryLimit sentinel
let mainFootprintBytes = ModelProfiler.profile(modelDirectory: modelDir, modelId: modelId)?.weightFileSizeBytes ?? 0
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block re-profiles the main model (ModelProfiler.profile(...)) to get weightFileSizeBytes, but the main model is profiled again a few lines later to build the PartitionPlan. ModelProfiler.profile appears to walk the model directory/weights, so doing it twice adds unnecessary startup overhead. Consider computing the main profile once and reusing it for both the footprint calculation and the plan generation.

Suggested change
let mainFootprintBytes = ModelProfiler.profile(modelDirectory: modelDir, modelId: modelId)?.weightFileSizeBytes ?? 0
let mainModelProfile = ModelProfiler.profile(modelDirectory: modelDir, modelId: modelId)
let mainFootprintBytes = mainModelProfile?.weightFileSizeBytes ?? 0

Copilot uses AI. Check for mistakes.
GitHub Actions output contexts (${{ steps.X.outputs.Y }}) are not populated
until the step finishes. Trying to use it inside the same step resulted in an
empty string being passed to 'kill -0', causing the health check to instantly
abort the test runner. Switched to standard bash '0' capturing.
- Fix Server.swift memory limit being unconditionally overridden later in execution
- Consolidate ModelProfiler.profile calls to reduce startup latency
- Replace hardcoded 16384 page sizes with dynamic sysctl hw.pagesize in CI and benchmark scripts
- Ensure CI multiline JSON inference output is correctly piped to files instead of GITHUB_OUTPUT
- Refine unit tests to assert fan-out break even limits properly and standardize to GiB
@solderzzc
Copy link
Copy Markdown
Member Author

Addressed all 🔴/🟡 Copilot review comments in commit 7b0bfd4:

  • memoryLimit override: removed redundant sentinels that overwrote the tight-cap logic.
  • ModelProfiler calls: elevated to prevent scanning weights twice on startup.
  • JSON output passing: refactored CI to save curl outputs to a temp file (/tmp/inf_result.json) to eliminate bash multiline output corruption.
  • vm_stat dynamic sizing: replaced hardcoded 16384 pages with sysctl -n hw.pagesize across CI and benchmark scripts.
  • Unit testing constraints: standardized memory bounds to exact GiB sizes and fixed throughput assertions to mirror the N+1 SSD fanout math.

All CI checks, including ssd-draft-memory-guard, are now green.

@solderzzc
Copy link
Copy Markdown
Member Author

Addressed all 🔴/🟡 Copilot review comments in commit 7b0bfd4:

  • memoryLimit override: removed redundant sentinels that overwrote the tight-cap logic.
  • ModelProfiler calls: elevated mainModelProfile to prevent scanning weights twice on startup.
  • JSON output passing: refactored CI to save curl outputs to a temp file (/tmp/inf_result.json) to eliminate bash multiline output corruption.
  • vm_stat dynamic sizing: replaced hardcoded 16384 pages with sysctl -n hw.pagesize across CI and benchmark scripts.
  • Unit testing constraints: standardized memory bounds to exact GiB sizes and fixed throughput assertions to mirror the N+1 SSD fanout math.

All CI checks, including ssd-draft-memory-guard, are now green.

@solderzzc solderzzc merged commit b33801a into main Apr 23, 2026
10 checks passed
@solderzzc solderzzc deleted the fix/issue-72-draft-model-ssd-ram branch April 23, 2026 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Streams experts not working with draft model?

2 participants