Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions Ix/Aiur/Semantics/BytecodeFfi.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,38 @@ instance : BEq IOBuffer where
-- via `Std.HashMap.beq_iff_equiv` + `Std.HashMap.Equiv.{refl,symm,trans}`,
-- bypassing the need for `LawfulBEq` on the outer `IOBuffer`.

/-- Per-circuit query counts for one circuit (one per function circuit, then
one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum
of query multiplicities. The difference `totalHits - uniqueRows` is the number
of cache hits. -/
structure QueryCount where
uniqueRows : Nat
totalHits : Nat
deriving Inhabited

namespace Bytecode.Toplevel

@[extern "rs_aiur_toplevel_execute"]
private opaque execute' : @& Bytecode.Toplevel →
@& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) →
(ioMap : @& Array (Array G × IOKeyInfo)) →
Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array Nat)
Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat))

/-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`,
returning the raw output of the function, the updated `IOBuffer`, and an array
of query counts (one per function circuit, then one per memory size). Returns
`Except.error msg` when execution fails (e.g. `assert_eq!` mismatch from a
typechecker rejecting a constant), so callers can recover instead of crashing. -/
of per-circuit `QueryCount`s. Returns `Except.error msg` when execution
fails (e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so
callers can recover instead of crashing. -/
def execute (toplevel : @& Bytecode.Toplevel)
(funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) :
Except String (Array G × IOBuffer × Array Nat) :=
Except String (Array G × IOBuffer × Array QueryCount) :=
let ioData := ioBuffer.data
let ioMap := ioBuffer.map
match execute' toplevel funIdx args ioData ioMap.toArray with
| .error e => .error e
| .ok (output, (ioData, ioMap), queryCounts) =>
let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅
let queryCounts := queryCounts.map fun (uniqueRows, totalHits) => { uniqueRows, totalHits }
.ok (output, ⟨ioData, ioMap⟩, queryCounts)

end Bytecode.Toplevel
Expand Down
39 changes: 28 additions & 11 deletions Ix/Aiur/Statistics.lean
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
module
public import Ix.Aiur.Compiler
public import Ix.Aiur.Semantics.BytecodeFfi

/-!
Circuit statistics for Aiur executions.

Given a `CompiledToplevel` and the query counts returned by `execute`, computes
per-circuit width, height (the query count), and the FFT cost
Given a `CompiledToplevel` and the per-circuit `QueryCount`s
returned by `execute`, computes per-circuit width, height (the unique-row
count), cache hits (sum of multiplicities), and the FFT cost
(width × height × log2(height)) for every constrained function and memory
circuit. The FFT cost is a Float to capture small changes continuously.
Results are sorted by FFT cost in decreasing order and printed with cumulative
Expand All @@ -20,11 +22,14 @@ structure CircuitStats where
name : String
width : Nat
height : Nat
cacheHits : Nat
fftCost : Float

structure ExecutionStats where
circuits : Array CircuitStats
totalFftCost : Float
totalUncachedFftCost : Float
totalCacheHits : Nat

-- Clamp to at least 2 so that log2 is at least 1, avoiding zero cost for h = 1
def fftCost (w h : Nat) : Float :=
Expand All @@ -34,7 +39,7 @@ def fftCost (w h : Nat) : Float :=
let hf := h.toFloat
wf * hf * (max hf 2.0).log2

def computeStats (compiled : CompiledToplevel) (queryCounts : Array Nat) :
def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) :
ExecutionStats :=
let t := compiled.bytecode
-- Invert nameMap to get FunIdx → String
Expand All @@ -46,18 +51,24 @@ def computeStats (compiled : CompiledToplevel) (queryCounts : Array Nat) :
for i in [:nAllFuns] do
if t.functions[i]!.constrained then
let w := t.functions[i]!.layout.totalWidth
let h := queryCounts[i]!
let qc := queryCounts[i]!
let h := qc.uniqueRows
let hits := qc.totalHits - qc.uniqueRows
let name := reverseMap[i]?.getD s!"<fn {i}>"
acc := acc.push { name, width := w, height := h, fftCost := fftCost w h : CircuitStats }
acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats }
acc
let memoryCircuits := t.memorySizes.mapIdx fun i size =>
let w := size + 11
let h := queryCounts[nAllFuns + i]!
let qc := queryCounts[nAllFuns + i]!
let h := qc.uniqueRows
let hits := qc.totalHits - qc.uniqueRows
{ name := s!"memory[{size}]",
width := w, height := h, fftCost := fftCost w h : CircuitStats }
width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats }
let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost)
let totalFftCost := circuits.foldl (· + ·.fftCost) 0.0
{ circuits, totalFftCost }
let totalUncachedFftCost := circuits.foldl (fun acc cs => acc + fftCost cs.width (cs.height + cs.cacheHits)) 0.0
let totalCacheHits := circuits.foldl (· + ·.cacheHits) 0
{ circuits, totalFftCost, totalUncachedFftCost, totalCacheHits }

private def padLeft (s : String) (n : Nat) : String :=
let pad := n - s.length
Expand All @@ -81,28 +92,34 @@ def printStats (stats : ExecutionStats) : IO Unit := do
let wName := stats.circuits.foldl (fun m cs => Nat.max m cs.name.length) 4
let wWidth := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.width).length) 5
let wHeight := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.height).length) 6
let wHits := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.cacheHits).length) 5
let formatCost (f : Float) : String :=
let n := f.round.toUInt64.toNat
toString n
let wFftCost := stats.circuits.foldl (fun m cs => Nat.max m (formatCost cs.fftCost).length) 7
let wPct := 7
let wCum := 7
let totalW := wName + 1 + wWidth + 1 + wHeight + 1 + wFftCost + 1 + wPct + 1 + wCum
let totalW := wName + 1 + wWidth + 1 + wHeight + 1 + wHits + 1 + wFftCost + 1 + wPct + 1 + wCum
let totalWidth := stats.circuits.foldl (· + ·.width) 0
let savedPct :=
if stats.totalUncachedFftCost == 0.0 then "0.00%"
else formatPercent (stats.totalUncachedFftCost - stats.totalFftCost) stats.totalUncachedFftCost
let sep := String.ofList (List.replicate totalW '-')
IO.println "=== Circuit Statistics ==="
IO.println s!"Circuits: {stats.circuits.size}"
IO.println s!"Total width: {totalWidth}"
IO.println s!"Total FFT cost: {formatCost stats.totalFftCost}"
IO.println s!"Total cache hits: {stats.totalCacheHits}"
IO.println s!"Total saved cost: {savedPct}"
IO.println sep
IO.println s!"{padRight "Name" wName} {padLeft "Width" wWidth} {padLeft "Height" wHeight} {padLeft "FFT cost" wFftCost} {padLeft "%" wPct} {padLeft "%++" wCum}"
IO.println s!"{padRight "Name" wName} {padLeft "Width" wWidth} {padLeft "Height" wHeight} {padLeft "Hits" wHits} {padLeft "FFT cost" wFftCost} {padLeft "%" wPct} {padLeft "%++" wCum}"
IO.println sep
let mut cumFftCost : Float := 0.0
for cs in stats.circuits do
cumFftCost := cumFftCost + cs.fftCost
let pct := formatPercent cs.fftCost stats.totalFftCost
let cum := formatPercent cumFftCost stats.totalFftCost
IO.println s!"{padRight cs.name wName} {padLeft (toString cs.width) wWidth} {padLeft (toString cs.height) wHeight} {padLeft (formatCost cs.fftCost) wFftCost} {padLeft pct wPct} {padLeft cum wCum}"
IO.println s!"{padRight cs.name wName} {padLeft (toString cs.width) wWidth} {padLeft (toString cs.height) wHeight} {padLeft (toString cs.cacheHits) wHits} {padLeft (formatCost cs.fftCost) wFftCost} {padLeft pct wPct} {padLeft cum wCum}"

end Aiur

Expand Down
44 changes: 30 additions & 14 deletions src/ffi/aiur/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use multi_stark::{
p3_field::{Field, PrimeField64},
p3_field::PrimeField64,
prover::Proof,
types::{CommitmentParameters, FriParameters},
};
Expand Down Expand Up @@ -87,7 +87,9 @@ extern "C" fn rs_aiur_system_verify(
}

/// `Bytecode.Toplevel.execute`: runs execution only (no proof) and returns
/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array Nat)`.
/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat))`.
/// The trailing `Array (Nat × Nat)` is one `(uniqueRows, totalHits)` pair per
/// function circuit followed by one per memory size.
/// On execution failure (e.g. assertion mismatch from a typechecker
/// rejecting a constant), returns `Except.error msg` instead of panicking
/// — letting Lean test runners (`KernelArena.lean`) classify failures.
Expand All @@ -112,31 +114,45 @@ extern "C" fn rs_aiur_toplevel_execute(
Err(err) => return LeanExcept::error_string(&err.to_string()),
};

// Build query counts: one per function, then one per memory size
let mut query_counts: Vec<usize> = Vec::with_capacity(
// Build per-circuit (unique_rows, total_hits) pairs:
// one per function, then one per memory size. `unique_rows` is the trace
// height (number of distinct queries); `total_hits` is the sum of
// multiplicities (how often those rows were hit).
let mut query_counts: Vec<(usize, usize)> = Vec::with_capacity(
query_record.function_queries.len() + toplevel.memory_sizes.len(),
);
let summarize = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) {
let mut rows = 0usize;
let mut hits = 0usize;
for (_, res) in q.iter() {
let m = usize::try_from(res.multiplicity.as_canonical_u64())
.expect("multiplicity exceeds usize");
if m != 0 {
rows += 1;
hits += m;
}
}
(rows, hits)
};
for queries in &query_record.function_queries {
let count =
queries.iter().filter(|(_, res)| !res.multiplicity.is_zero()).count();
query_counts.push(count);
query_counts.push(summarize(queries));
}
for size in &toplevel.memory_sizes {
let count = query_record.memory_queries.get(size).map_or(0, |q| {
q.iter().filter(|(_, res)| !res.multiplicity.is_zero()).count()
});
query_counts.push(count);
let pair = query_record.memory_queries.get(size).map_or((0, 0), summarize);
query_counts.push(pair);
}
let lean_query_counts = {
let arr = LeanArray::alloc(query_counts.len());
for (i, &count) in query_counts.iter().enumerate() {
arr.set(i, LeanOwned::box_usize(count));
for (i, &(rows, hits)) in query_counts.iter().enumerate() {
let pair =
LeanProd::new(LeanOwned::box_usize(rows), LeanOwned::box_usize(hits));
arr.set(i, pair);
}
arr
};

let lean_io = build_lean_io_buffer(&io_buffer);
// (Array G, (Array G × Array (Array G × IOKeyInfo), Array Nat))
// (Array G, (Array G × Array (Array G × IOKeyInfo), Array (Nat × Nat)))
let io_counts = LeanProd::new(lean_io, lean_query_counts);
let result = LeanProd::new(build_g_array(&output), io_counts);
LeanExcept::ok(result)
Expand Down