From 3dab72d43284ffbec539df5fc9153a9f592d7140 Mon Sep 17 00:00:00 2001 From: jiangxiaobai <15800375054@163.com> Date: Tue, 26 Sep 2023 11:49:15 +0800 Subject: [PATCH 1/2] add blake3 & fix lookup compress --- plonky2/plonky2/benches/blake3_prove.rs | 325 +++++++++++++++ plonky2/plonky2/src/gates/blake3.rs | 520 ++++++++++++++++++++++++ 2 files changed, 845 insertions(+) create mode 100644 plonky2/plonky2/benches/blake3_prove.rs create mode 100644 plonky2/plonky2/src/gates/blake3.rs diff --git a/plonky2/plonky2/benches/blake3_prove.rs b/plonky2/plonky2/benches/blake3_prove.rs new file mode 100644 index 00000000..baa67ab2 --- /dev/null +++ b/plonky2/plonky2/benches/blake3_prove.rs @@ -0,0 +1,325 @@ +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] + +use plonky2_field::extension::Extendable; +use plonky2_field::types::Field; + +use plonky2::gates::gate::Gate; +use plonky2::hash::hash_types::HashOut; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::witness::{PartialWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars}; +use plonky2::plonk::config::Blake3GoldilocksConfig; +use plonky2::gates::blake3::Blake3Gate; +use plonky2::hash::blake3::STATE_SIZE; +use plonky2::iop::generator::generate_partial_witness; +use plonky2::iop::wire::Wire; +use rand::Rng; +//use plonky2::plonk::verifier::verify; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +pub fn bench_blake3_prove< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + c: &mut Criterion, +) where + [(); C::Hasher::HASH_SIZE]:, +{ + let mut group = c.benchmark_group("poseidon2 prove"); + group.sample_size(10); + + for i in 0..1 { + + group.bench_with_input(BenchmarkId::from_parameter(i), &i, |b, _| { + b.iter(|| { + + let gate = Blake3Gate::::new(); + + let mut rng = rand::thread_rng(); + let config = CircuitConfig::wide_blake3_config(); + let mut builder = CircuitBuilder::new(config); + let row = builder.add_gate(gate, vec![]); + let circuit = builder.build::(); + // generate inputs + let mut permutation_inputs = [F::ZERO; STATE_SIZE]; + + for i in 0..16{ + + permutation_inputs[i] = F::from_canonical_u32(rng.gen()); + + } + + let mut pw = PartialWitness::::new(); + + for i in 0..16 { + pw.set_wire( + Wire { + row, + column: Blake3Gate::::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness::(pw, &circuit.prover_only, &circuit.common); + + // Test that `eval_unfiltered` and `eval_unfiltered_recursively` are coherent. + let mut wires = [F::Extension::ZERO; 696]; + // set input + for i in 0..16 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_input(i), + }); + + wires[i] = out.into(); + } + // set output + for i in 0..8 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_output(i), + }); + + wires[16 + i] = out.into(); + } + + // set xor witness + for i in 0..7 { + + for j in 0..8 { + + for k in 0..4 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_xor_external(i, j, k), + }); + + wires[16 + 8 + i * 32 + j * 4 + k] = out.into(); + + let out1 = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_shift_remain_external(i, j, k), + }); + + wires[16 + 8 + 224 + i * 32 + j * 4 + k] = out1.into(); + + + let out2 = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_shift_q_external(i, j, k), + }); + + wires[16 + 8 + 448 + i * 32 + j * 4 + k] = out2.into(); + } + } + } + + let gate = Blake3Gate::::new(); + let constants = F::Extension::rand_vec(gate.num_constants()); + let public_inputs_hash = HashOut::rand(); + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let wires_t = builder.add_virtual_extension_targets(wires.len()); + let constants_t = builder.add_virtual_extension_targets(constants.len()); + pw.set_extension_targets(&wires_t, &wires); + pw.set_extension_targets(&constants_t, &constants); + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); + + let vars = EvaluationVars { + local_constants: &constants, + local_wires: &wires, + public_inputs_hash: &public_inputs_hash, + }; + let evals = gate.eval_unfiltered(vars); + + let vars_t = EvaluationTargets { + local_constants: &constants_t, + local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, + }; + let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); + pw.set_extension_targets(&evals_t, &evals); + + let data = builder.build::(); + //let start = Instant::now(); + let _proof = data.prove(pw); + + //println!("poseidon prover time = {:?}", start.elapsed().as_micros()); + //verify(proof, &data.verifier_only, &data.common) + } + ); + }); + } +} + +pub fn bench_blake3_remove_prove< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + c: &mut Criterion, +) where + [(); C::Hasher::HASH_SIZE]:, +{ + + let mut group = c.benchmark_group("poseidon2 prove"); + group.sample_size(10); + + for i in 0..1 { + + group.bench_with_input(BenchmarkId::from_parameter(i), &i, |b, _| { + b.iter(|| { + + let gate = Blake3Gate::::new(); + + let mut rng = rand::thread_rng(); + let config = CircuitConfig::wide_blake3_config(); + let mut builder = CircuitBuilder::new(config); + let row = builder.add_gate(gate, vec![]); + let circuit = builder.build::(); + // generate inputs + let mut permutation_inputs = [F::ZERO; STATE_SIZE]; + + for i in 0..16{ + + permutation_inputs[i] = F::from_canonical_u32(rng.gen()); + + } + + let mut pw = PartialWitness::::new(); + + for i in 0..16 { + pw.set_wire( + Wire { + row, + column: Blake3Gate::::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness::(pw, &circuit.prover_only, &circuit.common); + + // Test that `eval_unfiltered` and `eval_unfiltered_recursively` are coherent. + let mut wires = [F::Extension::ZERO; 696]; + // set input + for i in 0..16 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_input(i), + }); + + wires[i] = out.into(); + } + // set output + for i in 0..8 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_output(i), + }); + + wires[16 + i] = out.into(); + } + + // set xor witness + for i in 0..7 { + + for j in 0..8 { + + for k in 0..4 { + + let out = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_xor_external(i, j, k), + }); + + wires[16 + 8 + i * 32 + j * 4 + k] = out.into(); + + let out1 = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_shift_remain_external(i, j, k), + }); + + wires[16 + 8 + 224 + i * 32 + j * 4 + k] = out1.into(); + + + let out2 = witness.get_wire(Wire { + row: 0, + column: Blake3Gate::::wire_shift_q_external(i, j, k), + }); + + wires[16 + 8 + 448 + i * 32 + j * 4 + k] = out2.into(); + } + } + } + + let gate = Blake3Gate::::new(); + let constants = F::Extension::rand_vec(gate.num_constants()); + let public_inputs_hash = HashOut::rand(); + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let wires_t = builder.add_virtual_extension_targets(wires.len()); + let constants_t = builder.add_virtual_extension_targets(constants.len()); + pw.set_extension_targets(&wires_t, &wires); + pw.set_extension_targets(&constants_t, &constants); + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); + + let vars = EvaluationVars { + local_constants: &constants, + local_wires: &wires, + public_inputs_hash: &public_inputs_hash, + }; + let evals = gate.eval_unfiltered(vars); + + let vars_t = EvaluationTargets { + local_constants: &constants_t, + local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, + }; + let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); + pw.set_extension_targets(&evals_t, &evals); + + let data = builder.build::(); + //let start = Instant::now(); + //let _proof = data.prove(pw); + + //println!("poseidon prover time = {:?}", start.elapsed().as_micros()); + //verify(proof, &data.verifier_only, &data.common) + }); + } + ); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + + const D: usize = 2; + type C = Blake3GoldilocksConfig; + type F = >::F; +; + bench_blake3_prove::(c); + bench_blake3_remove_prove::(c); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/plonky2/plonky2/src/gates/blake3.rs b/plonky2/plonky2/src/gates/blake3.rs new file mode 100644 index 00000000..d0806ea1 --- /dev/null +++ b/plonky2/plonky2/src/gates/blake3.rs @@ -0,0 +1,520 @@ +use std::marker::PhantomData; +use alloc::sync::Arc; + +use plonky2_field::extension::Extendable; +use plonky2_field::types::Field; + +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::hash::blake3::{*}; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Evaluates a full Blake33 permutation with 12 state elements. +#[derive(Debug)] +pub struct Blake3Gate, const D: usize> { + _phantom: PhantomData, +} + +impl, const D: usize> Blake3Gate { + pub fn new() -> Self { + Blake3Gate { + _phantom: PhantomData, + } + } + + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + 16 + i + } + + pub fn wire_xor_external(round: usize, g_round: usize, i: usize) -> usize { + 16 + 8 + round * 4 * 8 + g_round * 4 + i + } + + pub fn wire_shift_remain_external(round: usize, g_round: usize, i: usize) -> usize { + 16 + 8 + 4 * 7 * 8 + round * 4 * 8 + g_round * 4 + i + } + + pub fn wire_shift_q_external(round: usize, g_round: usize, i: usize) -> usize { + 16 + 8 + 4 * 7 * 8 * 2 + round * 4 * 8 + g_round * 4 + i + } + + /// End of wire indices, exclusive. + /// 696 column + fn end() -> usize { + 16 + 8 + 8 * 7 * 4 * 3 + } + +} + +impl, const D: usize> Gate for Blake3Gate { + + fn id(&self) -> String { + + format!("{:?}", self, Self::end()) + + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + + let mut constraints = Vec::with_capacity(self.num_constraints()); + + + // Get input of blake3 + let mut block = [F::Extension::ZERO; STATE_SIZE]; + + for i in 0..16{ + + block[i] = vars.local_wires[Self::wire_input(i as usize)]; + + } + + let mut cv = [F::Extension::ZERO; 8]; + + for i in 0..8 { + + cv[i] = F::Extension::from_canonical_u32(::IV[i]); + + } + + ::compress_in_place_field(&mut cv, block, 16, 0, 8); + + for i in 0..8 { + + let output = vars.local_wires[Self::wire_output(i as usize)]; + + constraints.push(output - output); + + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + + // Get input of blake3 + let mut block = [F::ZERO; STATE_SIZE]; + + for i in 0..16{ + + block[i] = vars.local_wires[Self::wire_input(i as usize)]; + + } + + let mut cv = [F::ZERO; 8]; + + for i in 0..8 { + + cv[i] = F::from_canonical_u32(::IV[i]); + + } + + ::compress_in_place(&mut cv, block, 16, 0, 8); + + for i in 0..8 { + + let output = vars.local_wires[Self::wire_output(i as usize)]; + + yield_constr.one(output - cv[i]); + + } + + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + + // The naive method is more efficient if we have enough routed wires for + // Need fixed for 65536 + let mut table =Vec::<(u8,u8,u8)>::with_capacity(LOOKUP_LIMB_RANGE * LOOKUP_LIMB_RANGE); + + for i in 0..LOOKUP_LIMB_RANGE { + for j in 0..LOOKUP_LIMB_RANGE { + + let xor = i as u8 ^ j as u8; + + table.push((i as u8, j as u8, xor)); + } + } + + let table_index = builder.add_lookup_table_from_pairs(Arc::new(table)); + + + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Get input of blake3 + let mut block = [builder.zero_extension(); 16]; + + for i in 0..16{ + + block[i] = vars.local_wires[Self::wire_input(i as usize)]; + + } + + let mut cv = [builder.zero_extension(); 8]; + + for i in 0..8 { + + cv[i] = builder.constant_extension(F::Extension::from_canonical_u32(::IV[i])); + + } + + let mut xor = [[[builder.zero_extension(); 4]; 8]; 7]; + let mut remain = [[[builder.zero_extension(); 4]; 8]; 7]; + let mut q = [[[builder.zero_extension(); 4]; 8]; 7]; + + for i in 0..7 { + for j in 0..8 { + for k in 0..4 { + xor[i][j][k] = vars.local_wires[Self::wire_xor_external(i, j, k)]; + remain[i][j][k] = vars.local_wires[Self::wire_shift_remain_external(i, j, k)]; + q[i][j][k] = vars.local_wires[Self::wire_shift_q_external(i, j, k)]; + } + } + } + + let mut shift_const = [builder.zero_extension(); 4]; + shift_const[0] = builder.constant_extension(F::Extension::from_canonical_u32(1 << 16)); + shift_const[1] = builder.constant_extension(F::Extension::from_canonical_u32(1 << 12)); + shift_const[2] = builder.constant_extension(F::Extension::from_canonical_u32(1 << 8)); + shift_const[3] = builder.constant_extension(F::Extension::from_canonical_u32(1 << 7)); + + let state = ::compress_pre_circuit(builder, &mut cv, xor, remain, q, shift_const, block, 16, 0, 8); + + let mut output = [builder.zero_extension(); 8]; + + for i in 0..8 { + + output[i] = vars.local_wires[Self::wire_output(i as usize)]; + + } + + for i in 0..8 { + + let limbs_input_a = builder.split_le_base::(state[i].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_b = builder.split_le_base::(state[8 + i].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_c = builder.split_le_base::(output[i].to_target_array()[0], LOOKUP_LIMB_NUMBER); + + builder.add_lookup_from_index_bitwise(limbs_input_a[0], limbs_input_b[0], limbs_input_c[0], table_index); + builder.add_lookup_from_index_bitwise(limbs_input_a[1], limbs_input_b[1], limbs_input_c[1], table_index); + builder.add_lookup_from_index_bitwise(limbs_input_a[2], limbs_input_b[2], limbs_input_c[2], table_index); + builder.add_lookup_from_index_bitwise(limbs_input_a[3], limbs_input_b[3], limbs_input_c[3], table_index); + + } + + for i in 0..8 { + + constraints.push(builder.sub_extension(output[i], output[i])); + + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = Blake3Generator:: { + row, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 + } + + fn num_constraints(&self) -> usize { + 8 + } +} + +#[derive(Debug)] +struct Blake3Generator + Blake3, const D: usize> { + row: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for Blake3Generator +{ + fn dependencies(&self) -> Vec { + (0..16) + .map(|i| Blake3Gate::::wire_input(i)) + .map(|column| Target::wire(self.row, column)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let mut block = [F::ZERO; STATE_SIZE]; + for i in 0 .. STATE_SIZE { + + block[i] = witness.get_wire(local_wire(Blake3Gate::::wire_input(i))); + + } + + let mut cv = [F::ZERO; 8]; + for i in 0..8 { + + cv[i] = F::from_canonical_u32(::IV[i]); + + } + + ::compress_in_place_field_run_once(out_buffer, self.row, &mut cv, block, 16, 0, 8); + + for i in 0..8 { + out_buffer.set_wire(local_wire(Blake3Gate::::wire_output(i)), cv[i]); + } + + } +} + +#[cfg(test)] +mod tests{ + #![allow(incomplete_features)] + + use plonky2_field::types::Field; + use rand::Rng; + + use crate::gates::gate::Gate; + use crate::hash::hash_types::HashOut; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::iop::generator::generate_partial_witness; + use crate::iop::wire::Wire; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::GenericConfig; + use crate::plonk::vars::{EvaluationTargets, EvaluationVars}; + use crate::plonk::config::Blake3GoldilocksConfig; + use crate::plonk::verifier::verify; + use crate::gates::blake3::Blake3Gate; + use crate::hash::blake3::{*}; + + #[test] + pub fn test_blake3_prove() + { + const D: usize = 2; + type C = Blake3GoldilocksConfig; + type F = >::F; + + type Gate = Blake3Gate; + let gate = Gate::new(); + let mut rng = rand::thread_rng(); + let config = CircuitConfig::wide_blake3_config(); + let mut builder = CircuitBuilder::new(config); + + let row = builder.add_gate(gate, vec![]); + let circuit = builder.build::(); + + // generate inputs + let mut permutation_inputs = [F::ZERO; STATE_SIZE]; + + for i in 0..16{ + + permutation_inputs[i] = F::from_canonical_u32(rng.gen()); + + } + + let mut pw = PartialWitness::::new(); + + for i in 0..16 { + pw.set_wire( + Wire { + row, + column: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness::(pw, &circuit.prover_only, &circuit.common); + + // Test that `eval_unfiltered` and `eval_unfiltered_recursively` are coherent. + let mut wires = [<>::F as plonky2_field::extension::Extendable>::Extension::ZERO; 696]; + // set input + for i in 0..16 { + + let out = witness.get_wire(Wire { + row: 0, + column: Gate::wire_input(i), + }); + + wires[i] = out.into(); + } + // set output + for i in 0..8 { + + let out = witness.get_wire(Wire { + row: 0, + column: Gate::wire_output(i), + }); + + wires[16 + i] = out.into(); + } + + // set xor witness + for i in 0..7 { + + for j in 0..8 { + + for k in 0..4 { + + let out = witness.get_wire(Wire { + row: 0, + column: Gate::wire_xor_external(i, j, k), + }); + + wires[16 + 8 + i * 32 + j * 4 + k] = out.into(); + + let out1 = witness.get_wire(Wire { + row: 0, + column: Gate::wire_shift_remain_external(i, j, k), + }); + + wires[16 + 8 + 224 + i * 32 + j * 4 + k] = out1.into(); + + + let out2 = witness.get_wire(Wire { + row: 0, + column: Gate::wire_shift_q_external(i, j, k), + }); + + wires[16 + 8 + 448 + i * 32 + j * 4 + k] = out2.into(); + } + } + } + + let gate = Gate::new(); + let constants = <>::F as plonky2_field::extension::Extendable>::Extension::rand_vec(gate.num_constants()); + let public_inputs_hash = HashOut::rand(); + + let config = CircuitConfig::wide_blake3_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let wires_t = builder.add_virtual_extension_targets(wires.len()); + let constants_t = builder.add_virtual_extension_targets(constants.len()); + pw.set_extension_targets(&wires_t, &wires); + pw.set_extension_targets(&constants_t, &constants); + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); + + let vars = EvaluationVars { + local_constants: &constants, + local_wires: &wires, + public_inputs_hash: &public_inputs_hash, + }; + let evals = gate.eval_unfiltered(vars); + + let vars_t = EvaluationTargets { + local_constants: &constants_t, + local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, + }; + let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); + pw.set_extension_targets(&evals_t, &evals); + + let data = builder.build::(); + + let _proof = data.prove(pw); + + let result = verify(_proof.unwrap(), &data.verifier_only, &data.common); + + result.is_ok(); + + } + + #[test] + fn generated_output() { + const D: usize = 2; + type C = Blake3GoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::wide_blake3_config(); + let mut builder = CircuitBuilder::new(config); + let mut rng = rand::thread_rng(); + + type Gate = Blake3Gate; + let gate = Gate::new(); + + let row = builder.add_gate(gate, vec![]); + let circuit = builder.build::(); + + // generate inputs + let mut permutation_inputs = [F::ZERO; STATE_SIZE]; + + for i in 0..16{ + + permutation_inputs[i] = F::from_canonical_u32(rng.gen()); + + } + + let mut pw = PartialWitness::::new(); + + for i in 0..16 { + pw.set_wire( + Wire { + row, + column: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness::(pw, &circuit.prover_only, &circuit.common); + + let mut cv = [F::ZERO; IV_SIZE]; + + for i in 0..8 { + + cv[i] = F::from_canonical_u32(::IV[i]); + + } + + //get expect output + ::compress_in_place(& mut cv, permutation_inputs, 16, 0, 8); + + for i in 0..8 { + let out = witness.get_wire(Wire { + row: 0, + column: Gate::wire_output(i), + }); + assert_eq!(out, cv[i]); + } + + } + + +} From b74dd026ed47ccccc22a3658a2603a36171584a1 Mon Sep 17 00:00:00 2001 From: jiangxiaobai <15800375054@163.com> Date: Tue, 26 Sep 2023 11:50:40 +0800 Subject: [PATCH 2/2] add blake3 & fix lookup compress --- plonky2/plonky2/Cargo.toml | 4 + plonky2/plonky2/benches/poseidon_prove.rs | 14 +- plonky2/plonky2/src/gadgets/lookup.rs | 13 + plonky2/plonky2/src/gates/base_sum.rs | 21 +- plonky2/plonky2/src/gates/lookup.rs | 115 ++++- plonky2/plonky2/src/gates/mod.rs | 1 + plonky2/plonky2/src/hash/blake3.rs | 515 ++++++++++++++++++-- plonky2/plonky2/src/hash/hash_types.rs | 3 +- plonky2/plonky2/src/iop/generator.rs | 2 +- plonky2/plonky2/src/plonk/circuit_data.rs | 7 + plonky2/plonky2/src/plonk/prover.rs | 23 +- plonky2/plonky2/src/plonk/vanishing_poly.rs | 60 ++- 12 files changed, 685 insertions(+), 93 deletions(-) diff --git a/plonky2/plonky2/Cargo.toml b/plonky2/plonky2/Cargo.toml index 41d8a7af..e2ec96ff 100644 --- a/plonky2/plonky2/Cargo.toml +++ b/plonky2/plonky2/Cargo.toml @@ -66,6 +66,10 @@ harness = false name = "poseidon2_prove" harness = false +[[bench]] +name = "blake3_prove" +harness = false + [[bench]] name = "ffts" harness = false diff --git a/plonky2/plonky2/benches/poseidon_prove.rs b/plonky2/plonky2/benches/poseidon_prove.rs index 34d299b1..344bb762 100644 --- a/plonky2/plonky2/benches/poseidon_prove.rs +++ b/plonky2/plonky2/benches/poseidon_prove.rs @@ -68,11 +68,7 @@ pub fn bench_poseidon< let data = builder.build::(); - //let start = Instant::now(); - - let proof = data.prove(pw); - - //println!("poseidon prover time = {:?}", start.elapsed().as_micros()); + let _ = data.prove(pw); }); } @@ -129,13 +125,7 @@ pub fn bench_poseidon_remove_prove< let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); pw.set_extension_targets(&evals_t, &evals); - let data = builder.build::(); - - //let start = Instant::now(); - - //let proof = data.prove(pw); - - //println!("poseidon prover time = {:?}", start.elapsed().as_micros()); + let _ = builder.build::(); }); } diff --git a/plonky2/plonky2/src/gadgets/lookup.rs b/plonky2/plonky2/src/gadgets/lookup.rs index 0b27acca..caf06f9d 100644 --- a/plonky2/plonky2/src/gadgets/lookup.rs +++ b/plonky2/plonky2/src/gadgets/lookup.rs @@ -72,6 +72,19 @@ impl, const D: usize> CircuitBuilder { looking_out } + /// Adds a lookup (input, output) pair to the stored lookups. Takes a `Target` input and returns a `Target` output. + pub fn add_lookup_from_index_bitwise(&mut self, looking_in_0: Target, looking_in_1: Target, looking_out: Target, lut_index: usize) { + assert!( + lut_index < self.get_luts_length(), + "lut number {} not in luts (length = {})", + lut_index, + self.get_luts_length() + ); + //let looking_out = self.add_virtual_target(); + self.update_lookups(looking_in_0, looking_in_1, looking_out, lut_index); + //looking_out + } + /// We call this function at the end of circuit building right before the PI gate to add all `LookupTableGate` and `LookupGate`. /// It also updates `self.lookup_rows` accordingly. pub fn add_all_lookups(&mut self) { diff --git a/plonky2/plonky2/src/gates/base_sum.rs b/plonky2/plonky2/src/gates/base_sum.rs index 5be54eeb..9b4313fe 100644 --- a/plonky2/plonky2/src/gates/base_sum.rs +++ b/plonky2/plonky2/src/gates/base_sum.rs @@ -56,14 +56,16 @@ impl, const D: usize, const B: usize> Gate fo let sum = vars.local_wires[Self::WIRE_SUM]; let limbs = vars.local_wires[self.limbs()].to_vec(); let computed_sum = reduce_with_powers(&limbs, F::Extension::from_canonical_usize(B)); - let mut constraints = vec![computed_sum - sum]; + let constraints = vec![computed_sum - sum]; + // limbs will be used in lookup, so dont need to rangecheck constraints + /*let mut constraints = vec![computed_sum - sum]; for limb in limbs { constraints.push( (0..B) .map(|i| limb - F::Extension::from_canonical_usize(i)) .product(), ); - } + }*/ constraints } @@ -88,7 +90,8 @@ impl, const D: usize, const B: usize> Gate fo let sum = vars.local_wires[Self::WIRE_SUM]; let limbs = vars.local_wires[self.limbs()].to_vec(); let computed_sum = reduce_with_powers_ext_circuit(builder, &limbs, base); - let mut constraints = vec![builder.sub_extension(computed_sum, sum)]; + let constraints = vec![builder.sub_extension(computed_sum, sum)]; + /*let mut constraints = vec![builder.sub_extension(computed_sum, sum)]; for limb in limbs { constraints.push({ let mut acc = builder.one_extension(); @@ -102,7 +105,7 @@ impl, const D: usize, const B: usize> Gate fo }); acc }); - } + }*/ constraints } @@ -125,12 +128,14 @@ impl, const D: usize, const B: usize> Gate fo // Bounded by the range-check (x-0)*(x-1)*...*(x-B+1). fn degree(&self) -> usize { - B + //B + 1 } // 1 for checking the sum then `num_limbs` for range-checking the limbs. fn num_constraints(&self) -> usize { - 1 + self.num_limbs + //1 + self.num_limbs + 1 } } @@ -148,12 +153,12 @@ impl, const D: usize, const B: usize> PackedEvaluab yield_constr.one(computed_sum - sum); - let constraints_iter = limbs.iter().map(|&limb| { + /*let constraints_iter = limbs.iter().map(|&limb| { (0..B) .map(|i| limb - F::from_canonical_usize(i)) .product::

() }); - yield_constr.many(constraints_iter); + yield_constr.many(constraints_iter);*/ } } diff --git a/plonky2/plonky2/src/gates/lookup.rs b/plonky2/plonky2/src/gates/lookup.rs index 54bf1d70..642a1ff3 100644 --- a/plonky2/plonky2/src/gates/lookup.rs +++ b/plonky2/plonky2/src/gates/lookup.rs @@ -13,6 +13,7 @@ use crate::gates::gate::Gate; use crate::gates::packed_util::PackedEvaluableBase; use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; +use crate::hash::blake3::LOOKUP_LIMB_RANGE; use crate::iop::ext_target::ExtensionTarget; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::target::Target; @@ -327,29 +328,35 @@ impl, const D: usize> SimpleGenerator for Bitwis } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let get_wire = |wire: usize| -> F { witness.get_target(Target::wire(self.row, wire)) }; let input0_val = get_wire(BitwiseLookupGate::wire_ith_looking_inp0(self.slot_nb)); let input1_val = get_wire(BitwiseLookupGate::wire_ith_looking_inp1(self.slot_nb)); - let (input0, input1, output) = self.lut[(input0_val.to_canonical_u64() * 16 + input1_val.to_canonical_u64()) as usize]; + let (input0, input1, output) = self.lut[(input0_val.to_canonical_u64() * LOOKUP_LIMB_RANGE as u64 + input1_val.to_canonical_u64()) as usize]; // find directly if input0_val == F::from_canonical_u8(input0) && input1_val == F::from_canonical_u8(input1){ - let output_val = F::from_canonical_u8(output); + + let output_val = get_wire(BitwiseLookupGate::wire_ith_looking_out(self.slot_nb)); + + assert_eq!(F::from_canonical_u8(output), output_val, "unvalid lookup input, + input is {:?}, {:?}, {:?}, \n expect_output is {:?}, {:?}, {:?} \n, trace info is {:?}, {:?} \n", + input0_val, input1_val, output_val, + input0_val, input1_val, output, + self.row, self.slot_nb); + + return; - let out_wire = Target::wire(self.row, BitwiseLookupGate::wire_ith_looking_out(self.slot_nb)); - out_buffer.set_target(out_wire, output_val); } else { // loop all case for (input0, input1, output) in self.lut.iter() { if input0_val == F::from_canonical_u8(*input0) && input1_val == F::from_canonical_u8(*input1){ - let output_val = F::from_canonical_u8(*output); + let output_val = get_wire(BitwiseLookupGate::wire_ith_looking_out(self.slot_nb)); - let out_wire = - Target::wire(self.row, BitwiseLookupGate::wire_ith_looking_out(self.slot_nb)); - out_buffer.set_target(out_wire, output_val); - return; + assert_eq!(F::from_canonical_u8(*output), output_val, "unvalid lookup input") + } } panic!("Incorrect input value provided"); @@ -362,14 +369,9 @@ mod tests { static LOGGER_INITIALIZED: Once = Once::new(); use alloc::sync::Arc; - use std::ops::Add; use std::sync::Once; - - use itertools::Itertools; use log::{Level, LevelFilter}; - use crate::gadgets::lookup::{OTHER_TABLE, SMALLER_TABLE, TIP5_TABLE}; - use crate::gates::lookup_table::LookupTable; use crate::gates::noop::NoopGate; use crate::plonk::prover::prove; use crate::util::timing::TimingTree; @@ -524,6 +526,91 @@ mod tests { Ok(()) } + #[test] + fn test_one_lookup_bitwise_multi() -> anyhow::Result<()> { + use crate::field::types::Field; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + LOGGER_INITIALIZED.call_once(|| init_logger().unwrap()); + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let mut table =Vec::<(u8,u8,u8)>::with_capacity(256); + + for i in 0..16 { + for j in 0..16 { + + let xor = i as u8 ^ j as u8; + + table.push((i as u8, j as u8, xor)); + } + } + + let table_index = builder.add_lookup_table_from_pairs(Arc::new(table)); + + let looking_a = builder.add_virtual_target(); + let looking_b = builder.add_virtual_target(); + let looking_a_xor_b = builder.add_virtual_target(); + + let looking_val_a = 1; + let looking_val_b = 2; + let looking_val_a_xor_b = looking_val_a ^ looking_val_b;// 0x11 = 3 + + let looking_c = builder.add_virtual_target(); + let looking_d = builder.add_virtual_target(); + let looking_c_xor_d = builder.add_virtual_target(); + + let looking_val_c = 4; + let looking_val_d = 5; + let looking_val_c_xor_d = looking_val_c ^ looking_val_d;// 0x001 = 1 + + + //let output_a_xor_b = builder.add_lookup_from_index(looking_a, looking_b, table_index); + //let output_c_xor_d = builder.add_lookup_from_index(looking_c, looking_d, table_index); + + builder.add_lookup_from_index_bitwise(looking_a, looking_b, looking_a_xor_b, table_index); + builder.add_lookup_from_index_bitwise(looking_c, looking_d, looking_c_xor_d, table_index); + + + builder.register_public_input(looking_a); + builder.register_public_input(looking_b); + builder.register_public_input(looking_a_xor_b); + + //builder.register_public_input(output_a_xor_b); + + builder.register_public_input(looking_c); + builder.register_public_input(looking_d); + builder.register_public_input(looking_c_xor_d); + + //builder.register_public_input(output_c_xor_d); + + let mut pw = PartialWitness::new(); + + pw.set_target(looking_a, F::from_canonical_usize(looking_val_a)); + pw.set_target(looking_b, F::from_canonical_usize(looking_val_b)); + pw.set_target(looking_a_xor_b, F::from_canonical_usize(looking_val_a_xor_b)); + + pw.set_target(looking_c, F::from_canonical_usize(looking_val_c)); + pw.set_target(looking_d, F::from_canonical_usize(looking_val_d)); + pw.set_target(looking_c_xor_d, F::from_canonical_usize(looking_val_c_xor_d)); + + let data = builder.build::(); + let mut timing = TimingTree::new("prove one lookup", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + data.verify(proof.clone())?; + + Ok(()) + } + /* // Tests two lookups in one lookup table. #[test] diff --git a/plonky2/plonky2/src/gates/mod.rs b/plonky2/plonky2/src/gates/mod.rs index 3dc68853..8efab134 100644 --- a/plonky2/plonky2/src/gates/mod.rs +++ b/plonky2/plonky2/src/gates/mod.rs @@ -13,6 +13,7 @@ pub mod low_degree_interpolation; pub mod multiplication_extension; pub mod noop; pub mod packed_util; +pub mod blake3; pub mod poseidon; pub mod poseidon2; pub mod poseidon_mds; diff --git a/plonky2/plonky2/src/hash/blake3.rs b/plonky2/plonky2/src/hash/blake3.rs index 0067aeba..ca6c797a 100644 --- a/plonky2/plonky2/src/hash/blake3.rs +++ b/plonky2/plonky2/src/hash/blake3.rs @@ -5,6 +5,10 @@ use std::mem::size_of; use crate::hash::hash_types::RichField; use crate::hash::hashing::{PlonkyPermutation, SPONGE_WIDTH}; use crate::plonk::config::Hasher; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::GeneratedValues; +use crate::iop::wire::Wire; use core::slice; use blake3; @@ -12,22 +16,137 @@ use blake3; use super::hash_types::BytesHash; use plonky2_field::types::{Field, PrimeField64}; use plonky2_field::extension::{Extendable, FieldExtension}; -use arrayref::array_ref; pub const ROUND: usize = 7; pub const STATE_SIZE: usize = 16; pub const IV_SIZE: usize = 8; pub const BLOCK_LEN: usize = 64; +pub const LOOKUP_LIMB_RANGE: usize = 16; +pub const LOOKUP_LIMB_NUMBER: usize = 16; pub trait Blake3: PrimeField64 { const MSG_SCHEDULE: [[usize; STATE_SIZE]; ROUND]; const IV: [u32; IV_SIZE]; + #[inline] + fn g( + state: &mut [Self; STATE_SIZE], a: usize, b: usize, c: usize, d: usize, x_field: Self, y_field: Self) { + + + let mut state_tmp = [0u32; STATE_SIZE]; + + for i in 0..STATE_SIZE { + state_tmp[i] = Self::to_noncanonical_u64(&state[i].to_basefield_array()[0]) as u32; + } + + let x = Self::to_noncanonical_u64(&x_field.to_basefield_array()[0]) as u32; + let y = Self::to_noncanonical_u64(&y_field.to_basefield_array()[0]) as u32; + + state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(x); + state_tmp[d] = (state_tmp[d] ^ state_tmp[a]).rotate_right(16); + state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + state_tmp[b] = (state_tmp[b] ^ state_tmp[c]).rotate_right(12); + state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(y); + state_tmp[d] = (state_tmp[d] ^ state_tmp[a]).rotate_right(8); + state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + state_tmp[b] = (state_tmp[b] ^ state_tmp[c]).rotate_right(7); + + + for i in 0..STATE_SIZE { + state[i] = Self::from_canonical_u32(state_tmp[i]); + } + + } + + #[inline(always)] + fn round( + state: &mut [Self; STATE_SIZE], msg: [Self; STATE_SIZE], round: usize) { + // Select the message schedule based on the round. + let schedule = Self::MSG_SCHEDULE[round]; + + // Mix the columns. + Self::g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + Self::g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + Self::g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + Self::g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + // Mix the diagonals. + Self::g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + Self::g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + Self::g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + Self::g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); + } + + #[inline(always)] + fn compress_pre( + cv: &mut [Self; IV_SIZE], + block_words: [Self; 16], + block_len: u8, + counter: u64, + flags: u8, + ) -> [Self; 16] { + + let mut state = [ + cv[0], + cv[1], + cv[2], + cv[3], + cv[4], + cv[5], + cv[6], + cv[7], + Self::from_canonical_u32(Self::IV[0]), + Self::from_canonical_u32(Self::IV[1]), + Self::from_canonical_u32(Self::IV[2]), + Self::from_canonical_u32(Self::IV[3]), + Self::from_canonical_u32(counter as u32), + Self::from_canonical_u32((counter >> 32) as u32), + Self::from_canonical_u32(block_len as u32), + Self::from_canonical_u32(flags as u32), + ]; + + Self::round(&mut state, block_words, 0); + Self::round(&mut state, block_words, 1); + Self::round(&mut state, block_words, 2); + Self::round(&mut state, block_words, 3); + Self::round(&mut state, block_words, 4); + Self::round(&mut state, block_words, 5); + Self::round(&mut state, block_words, 6); + + state + } + + + fn compress_in_place( + cv: &mut [Self; IV_SIZE], + block: [Self; STATE_SIZE], + block_len: u8, + counter: u64, + flags: u8, + ) { + let state = Self::compress_pre(cv, block, block_len, counter, flags); + let mut state_tmp = [0u32; STATE_SIZE]; + + for i in 0..STATE_SIZE { + state_tmp[i] = Self::to_noncanonical_u64(&state[i].to_basefield_array()[0]) as u32; + } + + cv[0] = Self::from_canonical_u32(state_tmp[0] ^ state_tmp[8]); + cv[1] = Self::from_canonical_u32(state_tmp[1] ^ state_tmp[9]); + cv[2] = Self::from_canonical_u32(state_tmp[2] ^ state_tmp[10]); + cv[3] = Self::from_canonical_u32(state_tmp[3] ^ state_tmp[11]); + cv[4] = Self::from_canonical_u32(state_tmp[4] ^ state_tmp[12]); + cv[5] = Self::from_canonical_u32(state_tmp[5] ^ state_tmp[13]); + cv[6] = Self::from_canonical_u32(state_tmp[6] ^ state_tmp[14]); + cv[7] = Self::from_canonical_u32(state_tmp[7] ^ state_tmp[15]); + } + + // -------------------------------------- field ------------------------------------ #[inline] fn g_field, const D: usize>( - state: &mut [F; STATE_SIZE], a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) { + state: &mut [F; STATE_SIZE], a: usize, b: usize, c: usize, d: usize, x_field: F, y_field: F) { let mut state_tmp = [0u32; STATE_SIZE]; @@ -36,6 +155,9 @@ pub trait Blake3: PrimeField64 { state_tmp[i] = F::BaseField::to_noncanonical_u64(&state[i].to_basefield_array()[0]) as u32; } + let x = F::BaseField::to_noncanonical_u64(&x_field.to_basefield_array()[0]) as u32; + let y = F::BaseField::to_noncanonical_u64(&y_field.to_basefield_array()[0]) as u32; + state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(x); state_tmp[d] = (state_tmp[d] ^ state_tmp[a]).rotate_right(16); state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); @@ -54,7 +176,7 @@ pub trait Blake3: PrimeField64 { #[inline(always)] fn round_field, const D: usize>( - state: &mut [F; 16], msg: &[u32; 16], round: usize) { + state: &mut [F; STATE_SIZE], msg: [F; STATE_SIZE], round: usize) { // Select the message schedule based on the round. let schedule = Self::MSG_SCHEDULE[round]; @@ -74,31 +196,12 @@ pub trait Blake3: PrimeField64 { #[inline(always)] fn compress_pre_field, const D: usize>( cv: &mut [F; IV_SIZE], - block: &[u8; BLOCK_LEN], + block_words: [F; 16], block_len: u8, counter: u64, flags: u8, ) -> [F; 16] { - let mut block_words = [0u32; STATE_SIZE]; - - block_words[0] = u32::from_le_bytes(*array_ref!(block, 0 * 4, 4)); - block_words[1] = u32::from_le_bytes(*array_ref!(block, 1 * 4, 4)); - block_words[2] = u32::from_le_bytes(*array_ref!(block, 2 * 4, 4)); - block_words[3] = u32::from_le_bytes(*array_ref!(block, 3 * 4, 4)); - block_words[4] = u32::from_le_bytes(*array_ref!(block, 4 * 4, 4)); - block_words[5] = u32::from_le_bytes(*array_ref!(block, 5 * 4, 4)); - block_words[6] = u32::from_le_bytes(*array_ref!(block, 6 * 4, 4)); - block_words[7] = u32::from_le_bytes(*array_ref!(block, 7 * 4, 4)); - block_words[8] = u32::from_le_bytes(*array_ref!(block, 8 * 4, 4)); - block_words[9] = u32::from_le_bytes(*array_ref!(block, 9 * 4, 4)); - block_words[10] = u32::from_le_bytes(*array_ref!(block, 10 * 4, 4)); - block_words[11] = u32::from_le_bytes(*array_ref!(block, 11 * 4, 4)); - block_words[12] = u32::from_le_bytes(*array_ref!(block, 12 * 4, 4)); - block_words[13] = u32::from_le_bytes(*array_ref!(block, 13 * 4, 4)); - block_words[14] = u32::from_le_bytes(*array_ref!(block, 14 * 4, 4)); - block_words[15] = u32::from_le_bytes(*array_ref!(block, 15 * 4, 4)); - let mut state = [ cv[0], cv[1], @@ -118,13 +221,13 @@ pub trait Blake3: PrimeField64 { F::from_canonical_u32(flags as u32), ]; - Self::round_field(&mut state, &block_words, 0); - Self::round_field(&mut state, &block_words, 1); - Self::round_field(&mut state, &block_words, 2); - Self::round_field(&mut state, &block_words, 3); - Self::round_field(&mut state, &block_words, 4); - Self::round_field(&mut state, &block_words, 5); - Self::round_field(&mut state, &block_words, 6); + Self::round_field(&mut state, block_words, 0); + Self::round_field(&mut state, block_words, 1); + Self::round_field(&mut state, block_words, 2); + Self::round_field(&mut state, block_words, 3); + Self::round_field(&mut state, block_words, 4); + Self::round_field(&mut state, block_words, 5); + Self::round_field(&mut state, block_words, 6); state } @@ -132,7 +235,7 @@ pub trait Blake3: PrimeField64 { fn compress_in_place_field, const D: usize>( cv: &mut [F; IV_SIZE], - block: &[u8; BLOCK_LEN], + block: [F; STATE_SIZE], block_len: u8, counter: u64, flags: u8, @@ -155,11 +258,357 @@ pub trait Blake3: PrimeField64 { cv[7] = F::from_canonical_u32(state_tmp[7] ^ state_tmp[15]); } - // ---------------------------------- circuit -------------------------------------- -} + // g_circuit + //#[inline(always)] + fn g_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; STATE_SIZE], + input_xor: [ExtensionTarget; 4], + shift_constant: [ExtensionTarget; 4], + remain: [ExtensionTarget; 4], + q: [ExtensionTarget; 4], + a: usize, + b: usize, + c: usize, + d: usize, + x_et: ExtensionTarget, + y_et: ExtensionTarget, + ) where + Self: RichField + Extendable, + { + + // state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(x); + state[a] = builder.add_many_extension([state[a], state[b], x_et]); + + let limbs_input_a = builder.split_le_base::(state[a].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_d = builder.split_le_base::(state[d].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_xor_d_a = builder.split_le_base::(input_xor[0].to_target_array()[0], LOOKUP_LIMB_NUMBER); + + builder.add_lookup_from_index_bitwise(limbs_input_a[0], limbs_input_d[0], limbs_xor_d_a[0], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[1], limbs_input_d[1], limbs_xor_d_a[1], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[2], limbs_input_d[2], limbs_xor_d_a[2], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[3], limbs_input_d[3], limbs_xor_d_a[3], 0); + + // state_tmp[d] = (state_tmp[d] ^ state_tmp[a]).rotate_right(16); + //state[d] = builder.div_extension(input_xor[0], shift_constant[0]); + + let input_xor_real = builder.mul_add_extension(remain[0], shift_constant[0], q[0]); + builder.connect_extension(input_xor_real, input_xor[0]); + state[d] = remain[0]; + // state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + state[c] = builder.add_extension(state[c], state[d]); + + let limbs_input_b = builder.split_le_base::(state[b].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_c = builder.split_le_base::(state[c].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_xor_b_c = builder.split_le_base::(input_xor[1].to_target_array()[0], LOOKUP_LIMB_NUMBER); + + builder.add_lookup_from_index_bitwise(limbs_input_b[0], limbs_input_c[0], limbs_xor_b_c[0], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[1], limbs_input_c[1], limbs_xor_b_c[1], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[2], limbs_input_c[2], limbs_xor_b_c[2], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[3], limbs_input_c[3], limbs_xor_b_c[3], 0); + + + // state_tmp[b] = (state_tmp[b] ^ state_tmp[c]).rotate_right(12); + //state[b] = builder.div_extension(input_xor[1], shift_constant[1]); + + let input_xor_real = builder.mul_add_extension(remain[1], shift_constant[1], q[1]); + builder.connect_extension(input_xor_real, input_xor[1]); + state[b] = remain[1]; + // state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(y); + state[a] = builder.add_many_extension([state[a], state[b], y_et]); + + let limbs_input_a = builder.split_le_base::(state[a].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_d = builder.split_le_base::(state[d].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_xor_d_a = builder.split_le_base::(input_xor[2].to_target_array()[0], LOOKUP_LIMB_NUMBER); + + builder.add_lookup_from_index_bitwise(limbs_input_a[0], limbs_input_d[0], limbs_xor_d_a[0], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[1], limbs_input_d[1], limbs_xor_d_a[1], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[2], limbs_input_d[2], limbs_xor_d_a[2], 0); + builder.add_lookup_from_index_bitwise(limbs_input_a[3], limbs_input_d[3], limbs_xor_d_a[3], 0); + + // state_tmp[d] = (state_tmp[d] ^ state_tmp[a]).rotate_right(8); + //state[d] = builder.div_extension(input_xor[2], shift_constant[2]); + + let input_xor_real = builder.mul_add_extension(remain[2], shift_constant[2], q[2]); + builder.connect_extension(input_xor_real, input_xor[2]); + state[d] = remain[2]; + // state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + state[c] = builder.add_extension(state[c], state[d]); + + let limbs_input_b = builder.split_le_base::(state[b].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_input_c = builder.split_le_base::(state[c].to_target_array()[0], LOOKUP_LIMB_NUMBER); + let limbs_xor_b_c = builder.split_le_base::(input_xor[3].to_target_array()[0], LOOKUP_LIMB_NUMBER); + + builder.add_lookup_from_index_bitwise(limbs_input_b[0], limbs_input_c[0], limbs_xor_b_c[0], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[1], limbs_input_c[1], limbs_xor_b_c[1], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[2], limbs_input_c[2], limbs_xor_b_c[2], 0); + builder.add_lookup_from_index_bitwise(limbs_input_b[3], limbs_input_c[3], limbs_xor_b_c[3], 0); + + + // state_tmp[b] = (state_tmp[b] ^ state_tmp[c]).rotate_right(7); + //state[b] = builder.div_extension(input_xor[3], shift_constant[3]); + + let input_xor_real = builder.mul_add_extension(remain[3], shift_constant[3], q[3]); + builder.connect_extension(input_xor_real, input_xor[3]); + state[b] = remain[3]; + } + + // g_circuit + //#[inline(always)] + fn round_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; STATE_SIZE], + input_xor: [[ExtensionTarget; 4]; 8], + remain: [[ExtensionTarget; 4]; 8], + q: [[ExtensionTarget; 4]; 8], + shift_constant: [ExtensionTarget; 4], + msg: [ExtensionTarget; STATE_SIZE], + round: usize + ) where + Self: RichField + Extendable, + { + // Select the message schedule based on the round. + let schedule = Self::MSG_SCHEDULE[round]; + + // Mix the columns. + Self::g_circuit(builder, state, input_xor[0], shift_constant, remain[0], q[0], + 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + Self::g_circuit(builder, state, input_xor[1], shift_constant, remain[1], q[1], + 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + Self::g_circuit(builder, state, input_xor[2], shift_constant, remain[2], q[2], + 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + Self::g_circuit(builder, state, input_xor[3], shift_constant, remain[3], q[3], + 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + // Mix the diagonals. + Self::g_circuit(builder, state, input_xor[4], shift_constant, remain[4], q[4], + 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + Self::g_circuit(builder, state, input_xor[5], shift_constant, remain[5], q[5], + 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + Self::g_circuit(builder, state, input_xor[6], shift_constant, remain[6], q[6], + 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + Self::g_circuit(builder, state, input_xor[7], shift_constant, remain[7], q[7], + 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); + + } + + fn compress_pre_circuit( + builder: &mut CircuitBuilder, + cv: &mut [ExtensionTarget; IV_SIZE], + input_xor: [[[ExtensionTarget; 4]; 8]; 7], + remain: [[[ExtensionTarget; 4]; 8]; 7], + q: [[[ExtensionTarget; 4]; 8]; 7], + shift_constant: [ExtensionTarget; 4], + block_words: [ExtensionTarget; STATE_SIZE], + block_len: u8, + counter: u64, + flags: u8, + ) -> [ExtensionTarget; STATE_SIZE] + where + Self: RichField + Extendable, + { + + let mut state = [builder.zero_extension(); STATE_SIZE]; + + state[0] = cv[0]; + state[1] = cv[1]; + state[2] = cv[2]; + state[3] = cv[3]; + state[4] = cv[4]; + state[5] = cv[5]; + state[6] = cv[6]; + state[7] = cv[7]; + state[8] = builder.constant_extension(Self::Extension::from_canonical_u32(Self::IV[0])); + state[9] = builder.constant_extension(Self::Extension::from_canonical_u32(Self::IV[1])); + state[10] = builder.constant_extension(Self::Extension::from_canonical_u32(Self::IV[2])); + state[11] = builder.constant_extension(Self::Extension::from_canonical_u32(Self::IV[3])); + state[12] = builder.constant_extension(Self::Extension::from_canonical_u32(counter as u32)); + state[13] = builder.constant_extension(Self::Extension::from_canonical_u32((counter >> 32) as u32)); + state[14] = builder.constant_extension(Self::Extension::from_canonical_u8(block_len)); + state[15] = builder.constant_extension(Self::Extension::from_canonical_u8(flags)); + + + Self::round_circuit(builder, &mut state, input_xor[0], remain[0], + q[0], shift_constant, block_words, 0); + Self::round_circuit(builder, &mut state, input_xor[1], remain[1], + q[1], shift_constant, block_words, 1); + Self::round_circuit(builder, &mut state, input_xor[2], remain[2], + q[2], shift_constant, block_words, 2); + Self::round_circuit(builder, &mut state, input_xor[3], remain[3], + q[3], shift_constant, block_words, 3); + Self::round_circuit(builder, &mut state, input_xor[4], remain[4], + q[4], shift_constant, block_words, 4); + Self::round_circuit(builder, &mut state, input_xor[5], remain[5], + q[5], shift_constant, block_words, 5); + Self::round_circuit(builder, &mut state, input_xor[6], remain[6], + q[6], shift_constant, block_words, 6); + + state + + } + + // ---------------------------- run once ---------------------------------- + #[inline] + fn g_field_run_once, const D: usize>( + out_buffer: &mut GeneratedValues, + row_num: usize, + xor_index: usize, + state: &mut [F; STATE_SIZE], + a: usize, + b: usize, + c: usize, + d: usize, + x_field: F, + y_field: F) { + + let mut state_tmp = [0u32; STATE_SIZE]; + + for i in 0..STATE_SIZE { + state_tmp[i] = F::BaseField::to_noncanonical_u64(&state[i].to_basefield_array()[0]) as u32; + } + + let x = F::BaseField::to_noncanonical_u64(&x_field.to_basefield_array()[0]) as u32; + let y = F::BaseField::to_noncanonical_u64(&y_field.to_basefield_array()[0]) as u32; + + state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(x); + let tmp = state_tmp[d] ^ state_tmp[a]; + out_buffer.set_wire(Wire{row: row_num, column: xor_index}, F::from_canonical_u32(tmp)); + + state_tmp[d] = tmp.rotate_right(16); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 448}, F::from_canonical_u32(state_tmp[d] >> 16)); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 224}, + F::from_canonical_u32(state_tmp[d] - ((state_tmp[d] >> 16) << 16))); + + state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + let tmp = state_tmp[b] ^ state_tmp[c]; + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 1}, F::from_canonical_u32(tmp)); + + state_tmp[b] = tmp.rotate_right(12); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 448 + 1}, F::from_canonical_u32(state_tmp[b] >> 20)); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 224 + 1}, + F::from_canonical_u32(state_tmp[b] - ((state_tmp[b] >> 20) << 20))); + + state_tmp[a] = state_tmp[a].wrapping_add(state_tmp[b]).wrapping_add(y); + let tmp = state_tmp[d] ^ state_tmp[a]; + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 2}, F::from_canonical_u32(tmp)); + + state_tmp[d] = tmp.rotate_right(8); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 448 + 2}, F::from_canonical_u32(state_tmp[d] >> 24)); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 224 + 2}, + F::from_canonical_u32(state_tmp[d] - ((state_tmp[d] >> 24) << 24))); + + state_tmp[c] = state_tmp[c].wrapping_add(state_tmp[d]); + let tmp = state_tmp[b] ^ state_tmp[c]; + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 3}, F::from_canonical_u32(tmp)); + + state_tmp[b] = tmp.rotate_right(7); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 448 + 3}, F::from_canonical_u32(state_tmp[b] >> 25)); + out_buffer.set_wire(Wire{row: row_num, column: xor_index + 224 + 3}, + F::from_canonical_u32(state_tmp[b] - ((state_tmp[b] >> 25) << 25))); + + for i in 0..STATE_SIZE { + state[i] = F::from_canonical_u32(state_tmp[i]); + } + + } + + #[inline(always)] + fn round_field_run_once, const D: usize>( + out_buffer: &mut GeneratedValues, + row: usize, + state: &mut [F; STATE_SIZE], + msg: [F; STATE_SIZE], + round: usize) { + // Select the message schedule based on the round. + let schedule = Self::MSG_SCHEDULE[round]; + + // Mix the columns. + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 0 * 4, state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 1 * 4, state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 2 * 4, state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 3 * 4, state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + // Mix the diagonals. + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 4 * 4, state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 5 * 4, state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 6 * 4, state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + Self::g_field_run_once(out_buffer, row, 24 + round * 8 * 4 + 7 * 4, state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); + } + + #[inline(always)] + fn compress_pre_field_run_once, const D: usize>( + out_buffer: &mut GeneratedValues, + row: usize, + cv: &mut [F; IV_SIZE], + block_words: [F; 16], + block_len: u8, + counter: u64, + flags: u8, + ) -> [F; 16] { + + let mut state = [ + cv[0], + cv[1], + cv[2], + cv[3], + cv[4], + cv[5], + cv[6], + cv[7], + F::from_canonical_u32(Self::IV[0]), + F::from_canonical_u32(Self::IV[1]), + F::from_canonical_u32(Self::IV[2]), + F::from_canonical_u32(Self::IV[3]), + F::from_canonical_u32(counter as u32), + F::from_canonical_u32((counter >> 32) as u32), + F::from_canonical_u32(block_len as u32), + F::from_canonical_u32(flags as u32), + ]; + + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 0); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 1); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 2); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 3); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 4); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 5); + Self::round_field_run_once(out_buffer, row, &mut state, block_words, 6); + + state + } + + + fn compress_in_place_field_run_once, const D: usize>( + out_buffer: &mut GeneratedValues, + row: usize, + cv: &mut [F; IV_SIZE], + block: [F; STATE_SIZE], + block_len: u8, + counter: u64, + flags: u8, + ) { + let state = Self::compress_pre_field_run_once(out_buffer, row, cv, block, block_len, counter, flags); + + let mut state_tmp = [0u32; STATE_SIZE]; + + for i in 0..STATE_SIZE { + state_tmp[i] = F::BaseField::to_noncanonical_u64(&state[i].to_basefield_array()[0]) as u32; + } + + cv[0] = F::from_canonical_u32(state_tmp[0] ^ state_tmp[8]); + cv[1] = F::from_canonical_u32(state_tmp[1] ^ state_tmp[9]); + cv[2] = F::from_canonical_u32(state_tmp[2] ^ state_tmp[10]); + cv[3] = F::from_canonical_u32(state_tmp[3] ^ state_tmp[11]); + cv[4] = F::from_canonical_u32(state_tmp[4] ^ state_tmp[12]); + cv[5] = F::from_canonical_u32(state_tmp[5] ^ state_tmp[13]); + cv[6] = F::from_canonical_u32(state_tmp[6] ^ state_tmp[14]); + cv[7] = F::from_canonical_u32(state_tmp[7] ^ state_tmp[15]); + } + + +} pub struct Blake3Permutation; impl PlonkyPermutation for Blake3Permutation { diff --git a/plonky2/plonky2/src/hash/hash_types.rs b/plonky2/plonky2/src/hash/hash_types.rs index 6171fe0b..03fc9059 100644 --- a/plonky2/plonky2/src/hash/hash_types.rs +++ b/plonky2/plonky2/src/hash/hash_types.rs @@ -4,12 +4,13 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::hash::poseidon::Poseidon; use crate::hash::poseidon2::Poseidon2; +use crate::hash::blake3::Blake3; use crate::iop::target::Target; use crate::plonk::config::GenericHashOut; /// A prime order field with the features we need to use it as a base field in /// our argument system. -pub trait RichField: PrimeField64 + Poseidon + Poseidon2 {} +pub trait RichField: PrimeField64 + Poseidon + Poseidon2 + Blake3 {} impl RichField for GoldilocksField {} diff --git a/plonky2/plonky2/src/iop/generator.rs b/plonky2/plonky2/src/iop/generator.rs index cc8cfd2f..4ca6e990 100644 --- a/plonky2/plonky2/src/iop/generator.rs +++ b/plonky2/plonky2/src/iop/generator.rs @@ -14,7 +14,7 @@ use crate::plonk::config::GenericConfig; /// Given a `PartitionWitness` that has only inputs set, populates the rest of /// the witness using the given set of generators. -pub(crate) fn generate_partial_witness< +pub fn generate_partial_witness< 'a, F: RichField + Extendable, C: GenericConfig, diff --git a/plonky2/plonky2/src/plonk/circuit_data.rs b/plonky2/plonky2/src/plonk/circuit_data.rs index 948d8b7c..40ab48ae 100644 --- a/plonky2/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/plonky2/src/plonk/circuit_data.rs @@ -99,6 +99,13 @@ impl CircuitConfig { } } + pub fn wide_blake3_config() -> Self { + Self { + num_wires: 696, + ..Self::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, diff --git a/plonky2/plonky2/src/plonk/prover.rs b/plonky2/plonky2/src/plonk/prover.rs index 61ae2d8c..d924a8ba 100644 --- a/plonky2/plonky2/src/plonk/prover.rs +++ b/plonky2/plonky2/src/plonk/prover.rs @@ -15,9 +15,10 @@ use crate::field::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::field::types::Field; use crate::field::zero_poly_coset::ZeroPolyOnCoset; use crate::fri::oracle::PolynomialBatch; -use crate::gates::lookup::{LookupGate, BitwiseLookupGate}; -use crate::gates::lookup_table::{LookupTableGate, BitwiseLookupTableGate}; +use crate::gates::lookup::BitwiseLookupGate; +use crate::gates::lookup_table::BitwiseLookupTableGate; use crate::gates::selectors::LookupSelectors; +use crate::hash::blake3::LOOKUP_LIMB_RANGE; use crate::hash::hash_types::RichField; use crate::iop::challenger::Challenger; use crate::iop::generator::generate_partial_witness; @@ -66,15 +67,15 @@ pub fn set_lookup_wires< let table_value_to_idx: HashMap = common_data.luts[lut_index] .iter() .enumerate() - .map(|(_, (inp_target0, inp_target1, out_target))| (*inp_target0 * 16 + *inp_target1, *out_target as usize)) + .map(|(_, (inp_target0, inp_target1, out_target))| (*inp_target0 * LOOKUP_LIMB_RANGE as u8 + *inp_target1, *out_target as usize)) .collect(); for (inp_target0, inp_target1, _) in prover_data.lut_to_lookups[lut_index].iter() { let inp_value0 = pw.get_target(*inp_target0); let inp_value1 = pw.get_target(*inp_target1); - let index = inp_value0.to_canonical_u64() * 16 + inp_value1.to_canonical_u64(); + let index = inp_value0.to_canonical_u64() * LOOKUP_LIMB_RANGE as u64 + inp_value1.to_canonical_u64(); let idx = table_value_to_idx - .get(&u8::try_from(inp_value0.to_canonical_u64() * 16 + inp_value1.to_canonical_u64()).unwrap()) + .get(&u8::try_from(inp_value0.to_canonical_u64() * LOOKUP_LIMB_RANGE as u64 + inp_value1.to_canonical_u64()).unwrap()) .is_some(); if idx { @@ -480,7 +481,9 @@ fn compute_lookup_polys< let looked_out = witness.get_wire(row, BitwiseLookupTableGate::wire_ith_looked_out(s)); - looked_inp_0 + looked_inp_1 + deltas[LookupChallenges::ChallengeA as usize] * looked_out + looked_inp_0 + + looked_inp_1 * deltas[LookupChallenges::ChallengeA as usize] * deltas[LookupChallenges::ChallengeA as usize] + + deltas[LookupChallenges::ChallengeA as usize] * looked_out }) .collect(); // Get (alpha - combo). @@ -497,7 +500,9 @@ fn compute_lookup_polys< let looked_inp_1 = witness.get_wire(row, BitwiseLookupTableGate::wire_ith_looked_inp1(s)); let looked_out = witness.get_wire(row, BitwiseLookupTableGate::wire_ith_looked_out(s)); - looked_inp_0 + looked_inp_1 + deltas[LookupChallenges::ChallengeB as usize] * looked_out + looked_inp_0 + + looked_inp_1 * deltas[LookupChallenges::ChallengeB as usize] * deltas[LookupChallenges::ChallengeB as usize] + + deltas[LookupChallenges::ChallengeB as usize] * looked_out }) .collect(); @@ -535,7 +540,9 @@ fn compute_lookup_polys< let looking_in_1 = witness.get_wire(row, BitwiseLookupGate::wire_ith_looking_inp1(s)); let looking_out = witness.get_wire(row, BitwiseLookupGate::wire_ith_looking_out(s)); - looking_in_0 + looking_in_1 + deltas[LookupChallenges::ChallengeA as usize] * looking_out + looking_in_0 + + looking_in_1 * deltas[LookupChallenges::ChallengeA as usize] * deltas[LookupChallenges::ChallengeA as usize] + + deltas[LookupChallenges::ChallengeA as usize] * looking_out }) .collect(); // Get (alpha - combo). diff --git a/plonky2/plonky2/src/plonk/vanishing_poly.rs b/plonky2/plonky2/src/plonk/vanishing_poly.rs index 56546835..05a57bf6 100644 --- a/plonky2/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/plonky2/src/plonk/vanishing_poly.rs @@ -41,7 +41,7 @@ pub(crate) fn get_lut_poly, C: GenericConfig, C: GenericConfig, C: GenericConfig, C: GenericConfig, C: GenericCo let input_wire_0 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp0(s)]; let input_wire_1 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_out(s)]; - input_wire_0 + input_wire_1 + deltas[LookupChallenges::ChallengeA as usize] * output_wire + input_wire_0 + + input_wire_1 * deltas[LookupChallenges::ChallengeA as usize] * deltas[LookupChallenges::ChallengeA as usize] + + deltas[LookupChallenges::ChallengeA as usize] * output_wire }) .collect(); @@ -557,7 +565,9 @@ pub fn check_lookup_constraints_batch, C: GenericCo let input_wire_0 = vars.local_wires[BitwiseLookupGate::wire_ith_looking_inp0(s)]; let input_wire_1 = vars.local_wires[BitwiseLookupGate::wire_ith_looking_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupGate::wire_ith_looking_out(s)]; - input_wire_0 + input_wire_1 + deltas[LookupChallenges::ChallengeA as usize] * output_wire + input_wire_0 + + input_wire_1 * deltas[LookupChallenges::ChallengeA as usize] * deltas[LookupChallenges::ChallengeA as usize] + + deltas[LookupChallenges::ChallengeA as usize] * output_wire }) .collect(); @@ -568,7 +578,9 @@ pub fn check_lookup_constraints_batch, C: GenericCo let input_wire_1 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_out(s)]; - input_wire_0 + input_wire_1 + deltas[LookupChallenges::ChallengeB as usize] * output_wire + input_wire_0 + + input_wire_1 * deltas[LookupChallenges::ChallengeB as usize] * deltas[LookupChallenges::ChallengeB as usize] + + deltas[LookupChallenges::ChallengeB as usize] * output_wire }) .collect(); @@ -779,9 +791,13 @@ pub(crate) fn get_lut_poly_circuit, C: GenericConfi let mut coeffs: Vec = common_data.luts[lut_index] .iter() .map(|(input0, input1, output)| { - let temp = builder.mul_const(F::from_canonical_u8(*output), b); - builder.add_const(temp, F::from_canonical_u8(*input0)); - builder.add_const(temp, F::from_canonical_u8(*input1)) + // input0 + challenge_b * (out + challenge_b * input1) + let output_target = builder.constant(F::from_canonical_u8(*output)); + let input_target = builder.constant(F::from_canonical_u8(*input0)); + // (out + challenge_b * input1) + let temp = builder.mul_const_add(F::from_canonical_u8(*input1), b, output_target); + // input0 + challenge_b * (out + challenge_b * input1) + builder.mul_add(temp, b, input_target) }) .collect(); for _ in n..degree { @@ -976,12 +992,16 @@ pub fn check_lookup_constraints_circuit, C: Generic let input_wire_0 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp0(s)]; let input_wire_1 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_out(s)]; + // input0 + challenge_b * (out + challenge_a * input1) let temp = builder.mul_add_extension( ext_deltas[LookupChallenges::ChallengeA as usize], + input_wire_1, output_wire, - input_wire_0, ); - builder.add_extension(temp, input_wire_1) + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeA as usize], + temp, + input_wire_0) }) .collect::>(); let current_looking_combos = (0..num_lu_slots) @@ -989,12 +1009,16 @@ pub fn check_lookup_constraints_circuit, C: Generic let input_wire_0 = vars.local_wires[BitwiseLookupGate::wire_ith_looking_inp0(s)]; let input_wire_1 = vars.local_wires[BitwiseLookupGate::wire_ith_looking_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupGate::wire_ith_looking_out(s)]; + // input0 + challenge_b * (out + challenge_a * input1) let temp = builder.mul_add_extension( ext_deltas[LookupChallenges::ChallengeA as usize], + input_wire_1, output_wire, - input_wire_0, ); - builder.add_extension(temp, input_wire_1) + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeA as usize], + temp, + input_wire_0) }) .collect::>(); @@ -1022,12 +1046,16 @@ pub fn check_lookup_constraints_circuit, C: Generic let input_wire_0 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp0(s)]; let input_wire_1 = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_inp1(s)]; let output_wire = vars.local_wires[BitwiseLookupTableGate::wire_ith_looked_out(s)]; + // input0 + challenge_b * (out + challenge_b * input1) let temp = builder.mul_add_extension( ext_deltas[LookupChallenges::ChallengeB as usize], + input_wire_1, output_wire, - input_wire_0, ); - builder.add_extension(temp, input_wire_1) + builder.mul_add_extension( + ext_deltas[LookupChallenges::ChallengeB as usize], + temp, + input_wire_0) }) .collect::>();