Skip to content

Commit

Permalink
fix: make program hash computations consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth committed Jun 21, 2022
1 parent b1eeb34 commit 21ce731
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 76 deletions.
2 changes: 1 addition & 1 deletion core/src/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub const ADDR_COL_IDX: usize = 0;
/// Index at which operation bit columns start in the decoder trace.
pub const OP_BITS_OFFSET: usize = ADDR_COL_IDX + 1;

/// Number of columns needed to hold a binary representation of of opcodes.
/// Number of columns needed to hold a binary representation of opcodes.
pub const NUM_OP_BITS: usize = Operation::OP_BITS;

/// Location of operation bits columns in the decoder trace.
Expand Down
5 changes: 4 additions & 1 deletion core/src/program/blocks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub use call_block::Call;
pub use join_block::Join;
pub use loop_block::Loop;
pub use proxy_block::Proxy;
pub use span_block::{OpBatch, Span, BATCH_SIZE as OP_BATCH_SIZE, GROUP_SIZE as OP_GROUP_SIZE};
pub use span_block::{
get_span_op_group_count, OpBatch, Span, BATCH_SIZE as OP_BATCH_SIZE,
GROUP_SIZE as OP_GROUP_SIZE,
};
pub use split_block::Split;

// PROGRAM BLOCK
Expand Down
27 changes: 16 additions & 11 deletions core/src/program/blocks/span_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,24 +338,29 @@ fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, Digest) {
batches.push(batch);
}

// compute total number of operation groups in all batches. This is done as follows:
// - For all batches but the last one we set the number of groups to 8, regardless of the
// actual number of groups in the batch. The reason for this is that when operation
// batches are concatenated together each batch contributes 8 elements to the hash.
// - For the last batch, we take the number of actual batches and round it up to the next
// power of two. The reason for rounding is that the VM always executes a number of
// operation groups which is a power of two.
let num_batches = batches.len();
let last_batch_num_groups = batches[num_batches - 1].num_groups().next_power_of_two();
let num_op_groups = (num_batches - 1) * BATCH_SIZE + last_batch_num_groups;

// compute the hash of all operation groups
let num_op_groups = get_span_op_group_count(&batches);
let op_groups = &flatten_slice_elements(&batch_groups)[..num_op_groups];
let hash = hasher::hash_elements(op_groups);

(batches, hash)
}

/// Returns the total number of operation groups in a span defined by the provides list of
/// operation batches.
///
/// Then number of operation groups is computed as follows:
/// - For all batches but the last one we set the number of groups to 8, regardless of the
/// actual number of groups in the batch. The reason for this is that when operation
/// batches are concatenated together each batch contributes 8 elements to the hash.
/// - For the last batch, we take the number of actual batches and round it up to the next
/// power of two. The reason for rounding is that the VM always executes a number of
/// operation groups which is a power of two.
pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize {
let last_batch_num_groups = op_batches.last().expect("no last group").num_groups();
(op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}

// TESTS
// ================================================================================================

Expand Down
8 changes: 3 additions & 5 deletions core/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ pub use library::Library;
#[derive(Clone, Debug)]
pub struct Script {
root: CodeBlock,
hash: Digest,
}

impl Script {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Constructs a new program from the specified code block.
pub fn new(root: CodeBlock) -> Self {
let hash = hasher::merge(&[root.hash(), Digest::default()]);
Self { root, hash }
Self { root }
}

// PUBLIC ACCESSORS
Expand All @@ -46,8 +44,8 @@ impl Script {
}

/// Returns a hash of this script.
pub fn hash(&self) -> &Digest {
&self.hash
pub fn hash(&self) -> Digest {
self.root.hash()
}
}

Expand Down
4 changes: 2 additions & 2 deletions examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ pub fn test_example(example: Example, fail: bool) {

if fail {
outputs[0] += 1;
assert!(miden::verify(*program.hash(), &pub_inputs, &outputs, proof).is_err())
assert!(miden::verify(program.hash(), &pub_inputs, &outputs, proof).is_err())
} else {
assert!(miden::verify(*program.hash(), &pub_inputs, &outputs, proof).is_ok());
assert!(miden::verify(program.hash(), &pub_inputs, &outputs, proof).is_ok());
}
}
2 changes: 1 addition & 1 deletion examples/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn main() {
// results in the expected output
let proof = StarkProof::from_bytes(&proof_bytes).unwrap();
let now = Instant::now();
match miden::verify(*program.hash(), &pub_inputs, &outputs, proof) {
match miden::verify(program.hash(), &pub_inputs, &outputs, proof) {
Ok(_) => debug!("Execution verified in {} ms", now.elapsed().as_millis()),
Err(err) => debug!("Failed to verify execution: {}", err),
}
Expand Down
4 changes: 2 additions & 2 deletions miden/tests/integration/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ impl Test {

if test_fail {
outputs[0] += 1;
assert!(miden::verify(*script.hash(), &pub_inputs, &outputs, proof).is_err());
assert!(miden::verify(script.hash(), &pub_inputs, &outputs, proof).is_err());
} else {
assert!(miden::verify(*script.hash(), &pub_inputs, &outputs, proof).is_ok());
assert!(miden::verify(script.hash(), &pub_inputs, &outputs, proof).is_ok());
}
}

Expand Down
35 changes: 19 additions & 16 deletions processor/src/decoder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use super::{
ExecutionError, Felt, Join, Loop, OpBatch, Operation, Process, Span, Split, StarkField, Vec,
Word, MIN_TRACE_LEN, OP_BATCH_SIZE,
Word, MIN_TRACE_LEN, ONE, OP_BATCH_SIZE, ZERO,
};
use vm_core::{
decoder::{
NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS,
OP_BATCH_4_GROUPS, OP_BATCH_8_GROUPS,
},
ONE, ZERO,
hasher::DIGEST_LEN,
program::blocks::get_span_op_group_count,
};

mod trace;
Expand Down Expand Up @@ -239,6 +240,22 @@ impl Decoder {
}
}

// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

/// Returns execution trace length for this decoder.
pub fn trace_len(&self) -> usize {
self.trace.trace_len()
}

/// Hash of the program decoded by this decoder.
///
/// Hash of the program is taken from the last row of first 4 registers of the hasher section
/// of the decoder trace (i.e., columns 8 - 12).
pub fn program_hash(&self) -> [Felt; DIGEST_LEN] {
self.trace.program_hash()
}

// CONTROL BLOCKS
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -541,20 +558,6 @@ impl Default for SpanContext {
// HELPER FUNCTIONS
// ================================================================================================

/// Returns the total number of operation groups in sequence of operation batches.
///
/// The number of groups is computed as follows:
/// - For all batches except for the last one we set the number of groups to 8.
/// - For the last batch, we take the number of groups and round it up to the next power of two.
fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize {
let last_batch_num_groups = op_batches
.last()
.expect("no last group")
.num_groups()
.next_power_of_two();
(op_batches.len() - 1) * OP_BATCH_SIZE + last_batch_num_groups
}

/// Removes the specified operation from the op group and returns the resulting op group.
fn remove_opcode_from_group(op_group: Felt, op: Operation) -> Felt {
let opcode = op.op_code().expect("no opcode") as u64;
Expand Down
22 changes: 19 additions & 3 deletions processor/src/decoder/trace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use super::{
Felt, Operation, Vec, Word, MIN_TRACE_LEN, NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS,
ONE, ZERO, OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS, OP_BATCH_4_GROUPS, OP_BATCH_8_GROUPS
Felt, Operation, StarkField, Vec, Word, DIGEST_LEN, MIN_TRACE_LEN, NUM_HASHER_COLUMNS,
NUM_OP_BATCH_FLAGS, NUM_OP_BITS, ONE, OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS, OP_BATCH_4_GROUPS,
OP_BATCH_8_GROUPS, OP_BATCH_SIZE, ZERO,
};
use core::ops::Range;
use vm_core::{program::blocks::OP_BATCH_SIZE, utils::new_array_vec, StarkField};
use vm_core::utils::new_array_vec;

// CONSTANTS
// ================================================================================================
Expand Down Expand Up @@ -65,6 +66,15 @@ impl DecoderTrace {
self.addr_trace.len()
}

/// Returns the contents of the first 4 registers of the hasher state at the last row.
pub fn program_hash(&self) -> [Felt; DIGEST_LEN] {
let mut result = [ZERO; DIGEST_LEN];
for (i, element) in result.iter_mut().enumerate() {
*element = self.last_hasher_value(i);
}
result
}

// TRACE MUTATORS
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -418,6 +428,12 @@ impl DecoderTrace {
*self.group_count_trace.last().expect("no group count")
}

/// Returns the last value in the specified hasher column.
fn last_hasher_value(&self, idx: usize) -> Felt {
debug_assert!(idx < NUM_HASHER_COLUMNS, "invalid hasher register index");
*self.hasher_trace[idx].last().expect("no last hasher value")
}

/// Returns a reference to the last value in the helper register at the specified index.
fn last_helper_mut(&mut self, idx: usize) -> &mut Felt {
debug_assert!(idx < USER_OP_HELPERS.len(), "invalid helper register index");
Expand Down
50 changes: 31 additions & 19 deletions processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use vm_core::{
utils::collections::{BTreeMap, Vec},
AdviceInjector, DebugOptions, Felt, FieldElement, Operation, ProgramInputs, StackTopState,
StarkField, Word, AUX_TRACE_WIDTH, DECODER_TRACE_WIDTH, MIN_STACK_DEPTH, MIN_TRACE_LEN,
NUM_STACK_HELPER_COLS, RANGE_CHECK_TRACE_WIDTH, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH,
NUM_STACK_HELPER_COLS, ONE, RANGE_CHECK_TRACE_WIDTH, STACK_TRACE_WIDTH, SYS_TRACE_WIDTH, ZERO,
};

mod operations;
Expand Down Expand Up @@ -78,15 +78,25 @@ pub struct RangeCheckTrace {
pub fn execute(script: &Script, inputs: &ProgramInputs) -> Result<ExecutionTrace, ExecutionError> {
let mut process = Process::new(inputs.clone());
process.execute_code_block(script.root())?;
// TODO: make sure program hash from script and trace are the same
Ok(ExecutionTrace::new(process, *script.hash()))
let trace = ExecutionTrace::new(process);
assert_eq!(
script.hash(),
trace.program_hash(),
"inconsistent program hash"
);
Ok(trace)
}

/// Returns an iterator that allows callers to step through each execution and inspect
/// vm state information along side.
pub fn execute_iter(script: &Script, inputs: &ProgramInputs) -> VmStateIterator {
let mut process = Process::new_debug(inputs.clone());
let result = process.execute_code_block(script.root());
assert_eq!(
script.hash(),
process.decoder.program_hash().into(),
"inconsistent program hash"
);
VmStateIterator::new(process, result)
}

Expand All @@ -105,6 +115,18 @@ pub struct Process {
}

impl Process {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates a new process with the provided inputs.
pub fn new(inputs: ProgramInputs) -> Self {
Self::initialize(inputs, false)
}

/// Creates a new process with provided inputs and debug options enabled.
pub fn new_debug(inputs: ProgramInputs) -> Self {
Self::initialize(inputs, true)
}

fn initialize(inputs: ProgramInputs, in_debug_mode: bool) -> Self {
Self {
system: System::new(MIN_TRACE_LEN),
Expand All @@ -118,16 +140,6 @@ impl Process {
}
}

/// Creates a new process with the provided inputs.
pub fn new(inputs: ProgramInputs) -> Self {
Self::initialize(inputs, false)
}

/// Creates a new process with provided inputs and debug options enabled.
pub fn new_debug(inputs: ProgramInputs) -> Self {
Self::initialize(inputs, true)
}

// CODE BLOCK EXECUTORS
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -165,9 +177,9 @@ impl Process {
let condition = self.start_split_block(block)?;

// execute either the true or the false branch of the split block based on the condition
if condition == Felt::ONE {
if condition == ONE {
self.execute_code_block(block.on_true())?;
} else if condition == Felt::ZERO {
} else if condition == ZERO {
self.execute_code_block(block.on_false())?;
} else {
return Err(ExecutionError::NotBinaryValue(condition));
Expand All @@ -183,22 +195,22 @@ impl Process {
let condition = self.start_loop_block(block)?;

// if the top of the stack is ONE, execute the loop body; otherwise skip the loop body
if condition == Felt::ONE {
if condition == ONE {
// execute the loop body at least once
self.execute_code_block(block.body())?;

// keep executing the loop body until the condition on the top of the stack is no
// longer ONE; each iteration of the loop is preceded by executing REPEAT operation
// which drops the condition from the stack
while self.stack.peek() == Felt::ONE {
while self.stack.peek() == ONE {
self.decoder.repeat();
self.execute_op(Operation::Drop)?;
self.execute_code_block(block.body())?;
}

// end the LOOP block and drop the condition from the stack
self.end_loop_block(block, true)
} else if condition == Felt::ZERO {
} else if condition == ZERO {
// end the LOOP block, but don't drop the condition from the stack because it was
// already dropped when we started the LOOP block
self.end_loop_block(block, false)
Expand Down Expand Up @@ -305,7 +317,7 @@ impl Process {
// operation groups. the groups were are processing are just NOOPs - so, the op group
// value is ZERO
if group_idx < num_batch_groups - 1 {
self.decoder.start_op_group(Felt::ZERO);
self.decoder.start_op_group(ZERO);
}
}

Expand Down
Loading

0 comments on commit 21ce731

Please sign in to comment.