Skip to content

Commit

Permalink
refactor: clean up witness package, introduces clean `witness.Witness…
Browse files Browse the repository at this point in the history
…` interface (#450)

* refactor: replace internal/witness by gnark crypto fr.Vector and decouple schema.Schema

* refactor: move benchmarks in groth16

* fix: fix previous commit

* style: correct couple of typos

* docs: add ExampleWitness and update package doc

* docs: update doc

* style: fix typo

* style: store fr.Vector instead of *fr.Vector in witness

* perf: do less work in witness.FromJSON

* test: added roundTripJSON witness test

* style: use fr.Vector instead of []fr.Element in backend prove / verify apis
  • Loading branch information
gbotrel authored Feb 1, 2023
1 parent 3f4cefa commit 0914308
Show file tree
Hide file tree
Showing 124 changed files with 1,441 additions and 4,706 deletions.
10 changes: 7 additions & 3 deletions backend/groth16/bellman_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,17 @@ func TestVerifyBellmanProof(t *testing.T) {

// verify groth16 proof
// we need to prepend the number of elements in the witness.
// witness package expects [nbPublic nbSecret] followed by [n | elements];
// note that n is redundant with nbPublic + nbSecret
var buf bytes.Buffer
_ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8)))
_ = binary.Write(&buf, binary.BigEndian, uint32(0))
_ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8)))
buf.Write(inputsBytes)

witness := &witness.Witness{
CurveID: ecc.BLS12_381,
}
witness, err := witness.New(ecc.BLS12_381.ScalarField())
require.NoError(t, err)

err = witness.UnmarshalBinary(buf.Bytes())
require.NoError(t, err)

Expand Down
74 changes: 37 additions & 37 deletions backend/groth16/groth16.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ import (
cs_bw6633 "github.com/consensys/gnark/constraint/bw6-633"
cs_bw6761 "github.com/consensys/gnark/constraint/bw6-761"

witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness"
witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness"
witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness"
witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness"
witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness"
witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness"
witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness"
fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr"
fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr"
fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr"
fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr"

gnarkio "github.com/consensys/gnark/io"

Expand Down Expand Up @@ -109,51 +109,51 @@ type VerifyingKey interface {
}

// Verify runs the groth16.Verify algorithm on provided proof with given witness
func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error {
func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error {

switch _proof := proof.(type) {
case *groth16_bls12377.Proof:
w, ok := publicWitness.Vector.(*witness_bls12377.Witness)
w, ok := publicWitness.Vector().(fr_bls12377.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), *w)
return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), w)
case *groth16_bls12381.Proof:
w, ok := publicWitness.Vector.(*witness_bls12381.Witness)
w, ok := publicWitness.Vector().(fr_bls12381.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), *w)
return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), w)
case *groth16_bn254.Proof:
w, ok := publicWitness.Vector.(*witness_bn254.Witness)
w, ok := publicWitness.Vector().(fr_bn254.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), *w)
return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), w)
case *groth16_bw6761.Proof:
w, ok := publicWitness.Vector.(*witness_bw6761.Witness)
w, ok := publicWitness.Vector().(fr_bw6761.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), *w)
return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), w)
case *groth16_bls24317.Proof:
w, ok := publicWitness.Vector.(*witness_bls24317.Witness)
w, ok := publicWitness.Vector().(fr_bls24317.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), *w)
return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), w)
case *groth16_bls24315.Proof:
w, ok := publicWitness.Vector.(*witness_bls24315.Witness)
w, ok := publicWitness.Vector().(fr_bls24315.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), *w)
return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), w)
case *groth16_bw6633.Proof:
w, ok := publicWitness.Vector.(*witness_bw6633.Witness)
w, ok := publicWitness.Vector().(fr_bw6633.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), *w)
return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), w)
default:
panic("unrecognized R1CS curve type")
}
Expand All @@ -166,7 +166,7 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error
// will execute all the prover computations, even if the witness is invalid
// will produce an invalid proof
// internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking
func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness.Witness, opts ...backend.ProverOption) (Proof, error) {
func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) {

// apply options
opt, err := backend.NewProverConfig(opts...)
Expand All @@ -176,47 +176,47 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness

switch _r1cs := r1cs.(type) {
case *cs_bls12377.R1CS:
w, ok := fullWitness.Vector.(*witness_bls12377.Witness)
w, ok := fullWitness.Vector().(fr_bls12377.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), *w, opt)
return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt)
case *cs_bls12381.R1CS:
w, ok := fullWitness.Vector.(*witness_bls12381.Witness)
w, ok := fullWitness.Vector().(fr_bls12381.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), *w, opt)
return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt)
case *cs_bn254.R1CS:
w, ok := fullWitness.Vector.(*witness_bn254.Witness)
w, ok := fullWitness.Vector().(fr_bn254.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), *w, opt)
return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt)
case *cs_bw6761.R1CS:
w, ok := fullWitness.Vector.(*witness_bw6761.Witness)
w, ok := fullWitness.Vector().(fr_bw6761.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), *w, opt)
return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt)
case *cs_bls24317.R1CS:
w, ok := fullWitness.Vector.(*witness_bls24317.Witness)
w, ok := fullWitness.Vector().(fr_bls24317.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), *w, opt)
return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), w, opt)
case *cs_bls24315.R1CS:
w, ok := fullWitness.Vector.(*witness_bls24315.Witness)
w, ok := fullWitness.Vector().(fr_bls24315.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), *w, opt)
return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt)
case *cs_bw6633.R1CS:
w, ok := fullWitness.Vector.(*witness_bw6633.Witness)
w, ok := fullWitness.Vector().(fr_bw6633.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), *w, opt)
return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), w, opt)
default:
panic("unrecognized R1CS curve type")
}
Expand Down
124 changes: 124 additions & 0 deletions backend/groth16/groth16_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package groth16_test

import (
"math/big"
"testing"

"github.com/consensys/gnark"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/constraint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
)

//--------------------//
// benches //
//--------------------//

func BenchmarkSetup(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _ := referenceCircuit(curve)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = groth16.Setup(r1cs)
}
})
}
}

func BenchmarkProver(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _solution := referenceCircuit(curve)
fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
if err != nil {
b.Fatal(err)
}
pk, err := groth16.DummySetup(r1cs)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = groth16.Prove(r1cs, pk, fullWitness)
}
})
}
}

func BenchmarkVerifier(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _solution := referenceCircuit(curve)
fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
if err != nil {
b.Fatal(err)
}
publicWitness, err := fullWitness.Public()
if err != nil {
b.Fatal(err)
}

pk, vk, err := groth16.Setup(r1cs)
if err != nil {
b.Fatal(err)
}
proof, err := groth16.Prove(r1cs, pk, fullWitness)
if err != nil {
panic(err)
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = groth16.Verify(proof, vk, publicWitness)
}
})
}
}

type refCircuit struct {
nbConstraints int
X frontend.Variable
Y frontend.Variable `gnark:",public"`
}

func (circuit *refCircuit) Define(api frontend.API) error {
for i := 0; i < circuit.nbConstraints; i++ {
circuit.X = api.Mul(circuit.X, circuit.X)
}
api.AssertIsEqual(circuit.X, circuit.Y)
return nil
}

func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit) {
const nbConstraints = 40000
circuit := refCircuit{
nbConstraints: nbConstraints,
}
r1cs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &circuit)
if err != nil {
panic(err)
}

var good refCircuit
good.X = 2

// compute expected Y
expectedY := new(big.Int).SetUint64(2)
exp := big.NewInt(1)
exp.Lsh(exp, nbConstraints)
expectedY.Exp(expectedY, exp, curve.ScalarField())

good.Y = expectedY

return r1cs, &good
}

func getCurves() []ecc.ID {
if testing.Short() {
return []ecc.ID{ecc.BN254}
}
return gnark.Curves()
}
Loading

0 comments on commit 0914308

Please sign in to comment.