Skip to content

Commit

Permalink
Verification time (#90)
Browse files Browse the repository at this point in the history
* verifier optimize and refactor

* minor

* correctness done.

* fmt
  • Loading branch information
zhiyong1997 authored Sep 12, 2024
1 parent 9c00ff9 commit 373dbd6
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 191 deletions.
214 changes: 26 additions & 188 deletions src/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,162 +1,16 @@
mod sumcheck_verifier_helper;

pub use sumcheck_verifier_helper::*;

use std::{io::Cursor, vec};

use arith::{ExtensionField, Field};
use arith::Field;
use ark_std::{end_timer, start_timer};

#[cfg(feature = "grinding")]
use crate::grind;

use crate::{
eq_evals_at_primitive, Circuit, CircuitLayer, Config, FieldType, GKRConfig, Gate, Proof,
RawCommitment, Transcript, _eq_vec,
};

#[inline]
fn degree_2_eval<C: GKRConfig>(
p0: C::ChallengeField,
p1: C::ChallengeField,
p2: C::ChallengeField,
x: C::ChallengeField,
) -> C::ChallengeField {
if C::FIELD_TYPE == FieldType::GF2 {
let c0 = &p0;
let c2 = (p2 - p0 - p1.mul_by_x() + p0.mul_by_x())
* (C::ChallengeField::X - C::ChallengeField::one())
.mul_by_x()
.inv()
.unwrap();
let c1 = p1 - p0 - c2;
*c0 + (c2 * x + c1) * x
} else {
let c0 = &p0;
let c2 = C::ChallengeField::INV_2 * (p2 - p1 - p1 + p0);
let c1 = p1 - p0 - c2;
*c0 + (c2 * x + c1) * x
}
}

#[inline(always)]
fn lag_eval<F: Field + ExtensionField>(base: &[F], vals: &[F], x: &F) -> F {
debug_assert_eq!(base.len(), vals.len());
// trivial lag eval:
let mut v = F::zero();
for i in 0..base.len() {
let mut numerator = F::one();
let mut denominator = F::one();
for j in 0..base.len() {
if j == i {
continue;
}
numerator *= *x - base[j];
denominator *= base[i] - base[j];
}
v += numerator * denominator.inv().unwrap() * vals[i];
}
v
}

#[inline]
fn degree_3_eval<C: GKRConfig>(
p0: C::ChallengeField,
p1: C::ChallengeField,
p2: C::ChallengeField,
p3: C::ChallengeField,
x: C::ChallengeField,
) -> C::ChallengeField {
// TODO-OPTIMIZATION: precompute values and inverses
if C::FIELD_TYPE == FieldType::GF2 {
lag_eval(
&[
C::ChallengeField::zero(),
C::ChallengeField::one(),
C::ChallengeField::X,
C::ChallengeField::X.mul_by_x(),
],
&[p0, p1, p2, p3],
&x,
)
} else {
lag_eval(
&[
C::ChallengeField::zero(),
C::ChallengeField::one(),
C::ChallengeField::from(2),
C::ChallengeField::from(3),
],
&[p0, p1, p2, p3],
&x,
)
}
}

// TODO: Remove redundant computation and split it into cst, add/uni and mul
#[allow(clippy::too_many_arguments)]
fn eval_sparse_circuit_connect_poly<C: GKRConfig, const INPUT_NUM: usize>(
gates: &[Gate<C, INPUT_NUM>],
rz0: &[C::ChallengeField],
rz1: &[C::ChallengeField],
r_simd: &[C::ChallengeField],
alpha: C::ChallengeField,
beta: C::ChallengeField,
rx: &[C::ChallengeField],
ry: &[C::ChallengeField],
r_simd_xy: &[C::ChallengeField],
) -> C::ChallengeField {
let mut eq_evals_at_rz0 = vec![C::ChallengeField::zero(); 1 << rz0.len()];
let mut eq_evals_at_rz1 = vec![C::ChallengeField::zero(); 1 << rz1.len()];
let mut eq_evals_at_r_simd = vec![C::ChallengeField::zero(); 1 << r_simd.len()];

let mut eq_evals_at_rx = vec![C::ChallengeField::zero(); 1 << rx.len()];
let mut eq_evals_at_ry = vec![C::ChallengeField::zero(); 1 << ry.len()];
let mut eq_evals_at_r_simd_xy = vec![C::ChallengeField::zero(); 1 << r_simd_xy.len()];

eq_evals_at_primitive(rz0, &alpha, &mut eq_evals_at_rz0);
eq_evals_at_primitive(rz1, &beta, &mut eq_evals_at_rz1);
eq_evals_at_primitive(r_simd, &C::ChallengeField::one(), &mut eq_evals_at_r_simd);

eq_evals_at_primitive(rx, &C::ChallengeField::one(), &mut eq_evals_at_rx);
eq_evals_at_primitive(ry, &C::ChallengeField::one(), &mut eq_evals_at_ry);
eq_evals_at_primitive(
r_simd_xy,
&C::ChallengeField::one(),
&mut eq_evals_at_r_simd_xy,
);

if INPUT_NUM == 0 {
let mut v = C::ChallengeField::zero();

for cst_gate in gates {
v += C::challenge_mul_circuit_field(
&(eq_evals_at_rz0[cst_gate.o_id] + eq_evals_at_rz1[cst_gate.o_id]),
&cst_gate.coef,
);
}

let simd_sum: C::ChallengeField = eq_evals_at_r_simd.iter().sum();
v * simd_sum
} else if INPUT_NUM == 1 {
let mut v = C::ChallengeField::zero();
for add_gate in gates {
let tmp =
C::challenge_mul_circuit_field(&eq_evals_at_rx[add_gate.i_ids[0]], &add_gate.coef);
v += (eq_evals_at_rz0[add_gate.o_id] + eq_evals_at_rz1[add_gate.o_id]) * tmp;
}
v * _eq_vec(r_simd, r_simd_xy)
} else if INPUT_NUM == 2 {
let mut v = C::ChallengeField::zero();
for mul_gate in gates {
let tmp = eq_evals_at_rx[mul_gate.i_ids[0]]
* C::challenge_mul_circuit_field(
&eq_evals_at_ry[mul_gate.i_ids[1]],
&mul_gate.coef,
);
v += (eq_evals_at_rz0[mul_gate.o_id] + eq_evals_at_rz1[mul_gate.o_id]) * tmp;
}
v * _eq_vec(r_simd, r_simd_xy)
} else {
unreachable!()
}
}
use crate::{Circuit, CircuitLayer, Config, GKRConfig, Proof, RawCommitment, Transcript};

#[inline(always)]
fn verify_sumcheck_step<C: GKRConfig>(
Expand All @@ -165,6 +19,7 @@ fn verify_sumcheck_step<C: GKRConfig>(
transcript: &mut Transcript<C::FiatShamirHashType>,
claimed_sum: &mut C::ChallengeField,
randomness_vec: &mut Vec<C::ChallengeField>,
sp: &VerifierScratchPad<C>,
) -> bool {
let mut ps = vec![];
for i in 0..(degree + 1) {
Expand All @@ -178,9 +33,9 @@ fn verify_sumcheck_step<C: GKRConfig>(
let verified = (ps[0] + ps[1]) == *claimed_sum;

if degree == 2 {
*claimed_sum = degree_2_eval::<C>(ps[0], ps[1], ps[2], r);
*claimed_sum = GKRVerifierHelper::degree_2_eval(&ps, r, sp);
} else if degree == 3 {
*claimed_sum = degree_3_eval::<C>(ps[0], ps[1], ps[2], ps[3], r);
*claimed_sum = GKRVerifierHelper::degree_3_eval(&ps, r, sp);
}

verified
Expand All @@ -193,13 +48,14 @@ fn sumcheck_verify_gkr_layer<C: GKRConfig>(
layer: &CircuitLayer<C>,
rz0: &[C::ChallengeField],
rz1: &[C::ChallengeField],
r_simd0: &[C::ChallengeField],
r_simd: &Vec<C::ChallengeField>,
claimed_v0: C::ChallengeField,
claimed_v1: C::ChallengeField,
alpha: C::ChallengeField,
beta: C::ChallengeField,
proof: &mut Proof,
transcript: &mut Transcript<C::FiatShamirHashType>,
sp: &mut VerifierScratchPad<C>,
) -> (
bool,
Vec<C::ChallengeField>,
Expand All @@ -208,63 +64,42 @@ fn sumcheck_verify_gkr_layer<C: GKRConfig>(
C::ChallengeField,
C::ChallengeField,
) {
GKRVerifierHelper::prepare_layer(layer, &alpha, &beta, rz0, rz1, r_simd, sp);

let var_num = layer.input_var_num;
let simd_var_num = C::get_field_pack_size().trailing_zeros() as usize;
let mut sum = claimed_v0 * alpha + claimed_v1 * beta
- eval_sparse_circuit_connect_poly(
&layer.const_,
rz0,
rz1,
r_simd0,
alpha,
beta,
&[],
&[],
&[],
);
let mut sum =
claimed_v0 * alpha + claimed_v1 * beta - GKRVerifierHelper::eval_cst(&layer.const_, sp);

let mut rx = vec![];
let mut ry = vec![];
let mut r_simd_xy = vec![];
let mut verified = true;

for _i_var in 0..var_num {
verified &= verify_sumcheck_step::<C>(proof, 2, transcript, &mut sum, &mut rx);
verified &= verify_sumcheck_step::<C>(proof, 2, transcript, &mut sum, &mut rx, sp);
// println!("x {} var, verified? {}", _i_var, verified);
}
GKRVerifierHelper::set_rx(&rx, sp);

for _i_var in 0..simd_var_num {
verified &= verify_sumcheck_step::<C>(proof, 3, transcript, &mut sum, &mut r_simd_xy);
verified &= verify_sumcheck_step::<C>(proof, 3, transcript, &mut sum, &mut r_simd_xy, sp);
// println!("{} simd var, verified? {}", _i_var, verified);
}
GKRVerifierHelper::set_r_simd_xy(&r_simd_xy, sp);

let vx_claim = proof.get_next_and_step::<C::ChallengeField>();
sum -= vx_claim
* eval_sparse_circuit_connect_poly(
&layer.add,
rz0,
rz1,
r_simd0,
alpha,
beta,
&rx,
&[],
&r_simd_xy,
);
sum -= vx_claim * GKRVerifierHelper::eval_add(&layer.add, sp);
transcript.append_challenge_f::<C>(&vx_claim);

for _i_var in 0..var_num {
verified &= verify_sumcheck_step::<C>(proof, 2, transcript, &mut sum, &mut ry);
verified &= verify_sumcheck_step::<C>(proof, 2, transcript, &mut sum, &mut ry, sp);
// println!("y {} var, verified? {}", _i_var, verified);
}
GKRVerifierHelper::set_ry(&ry, sp);

let vy_claim = proof.get_next_and_step::<C::ChallengeField>();
verified &= sum
== vx_claim
* vy_claim
* eval_sparse_circuit_connect_poly(
&layer.mul, rz0, rz1, r_simd0, alpha, beta, &rx, &ry, &r_simd_xy,
);
verified &= sum == vx_claim * vy_claim * GKRVerifierHelper::eval_mul(&layer.mul, sp);
transcript.append_challenge_f::<C>(&vy_claim);
(verified, rx, ry, r_simd_xy, vx_claim, vy_claim)
}
Expand All @@ -285,6 +120,8 @@ pub fn gkr_verify<C: GKRConfig>(
C::ChallengeField,
) {
let timer = start_timer!(|| "gkr verify");
let mut sp = VerifierScratchPad::<C>::new(circuit);

let layer_num = circuit.layers.len();
let mut rz0 = vec![];
let mut rz1 = vec![];
Expand Down Expand Up @@ -318,6 +155,7 @@ pub fn gkr_verify<C: GKRConfig>(
beta,
proof,
transcript,
&mut sp,
);
verified &= cur_verified;
alpha = transcript.challenge_f::<C>();
Expand Down
Loading

0 comments on commit 373dbd6

Please sign in to comment.