diff --git a/backend/groth16/bellman_test.go b/backend/groth16/bellman_test.go index f53fd49505..af74b4e47d 100644 --- a/backend/groth16/bellman_test.go +++ b/backend/groth16/bellman_test.go @@ -106,13 +106,17 @@ func TestVerifyBellmanProof(t *testing.T) { // verify groth16 proof // we need to prepend the number of elements in the witness. + // witness package expects [nbPublic nbSecret] followed by [n | elements]; + // note that n is redundant with nbPublic + nbSecret var buf bytes.Buffer _ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8))) + _ = binary.Write(&buf, binary.BigEndian, uint32(0)) + _ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8))) buf.Write(inputsBytes) - witness := &witness.Witness{ - CurveID: ecc.BLS12_381, - } + witness, err := witness.New(ecc.BLS12_381.ScalarField()) + require.NoError(t, err) + err = witness.UnmarshalBinary(buf.Bytes()) require.NoError(t, err) diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index 623654a8c6..1aefdfa072 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -34,13 +34,13 @@ import ( cs_bw6633 "github.com/consensys/gnark/constraint/bw6-633" cs_bw6761 "github.com/consensys/gnark/constraint/bw6-761" - witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness" - witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" - witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" - witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gnarkio "github.com/consensys/gnark/io" @@ -109,51 +109,51 @@ type VerifyingKey interface { } // Verify runs the groth16.Verify algorithm on provided proof with given witness -func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error { +func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error { switch _proof := proof.(type) { case *groth16_bls12377.Proof: - w, ok := publicWitness.Vector.(*witness_bls12377.Witness) + w, ok := publicWitness.Vector().(fr_bls12377.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), *w) + return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), w) case *groth16_bls12381.Proof: - w, ok := publicWitness.Vector.(*witness_bls12381.Witness) + w, ok := publicWitness.Vector().(fr_bls12381.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), *w) + return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), w) case *groth16_bn254.Proof: - w, ok := publicWitness.Vector.(*witness_bn254.Witness) + w, ok := publicWitness.Vector().(fr_bn254.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), *w) + return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), w) case *groth16_bw6761.Proof: - w, ok := publicWitness.Vector.(*witness_bw6761.Witness) + w, ok := publicWitness.Vector().(fr_bw6761.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), *w) + return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), w) case *groth16_bls24317.Proof: - w, ok := publicWitness.Vector.(*witness_bls24317.Witness) + w, ok := publicWitness.Vector().(fr_bls24317.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), *w) + return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), w) case *groth16_bls24315.Proof: - w, ok := publicWitness.Vector.(*witness_bls24315.Witness) + w, ok := publicWitness.Vector().(fr_bls24315.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), *w) + return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), w) case *groth16_bw6633.Proof: - w, ok := publicWitness.Vector.(*witness_bw6633.Witness) + w, ok := publicWitness.Vector().(fr_bw6633.Vector) if !ok { return witness.ErrInvalidWitness } - return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), *w) + return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), w) default: panic("unrecognized R1CS curve type") } @@ -166,7 +166,7 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error // will execute all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking -func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness.Witness, opts ...backend.ProverOption) (Proof, error) { +func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { // apply options opt, err := backend.NewProverConfig(opts...) @@ -176,47 +176,47 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness switch _r1cs := r1cs.(type) { case *cs_bls12377.R1CS: - w, ok := fullWitness.Vector.(*witness_bls12377.Witness) + w, ok := fullWitness.Vector().(fr_bls12377.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), *w, opt) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) case *cs_bls12381.R1CS: - w, ok := fullWitness.Vector.(*witness_bls12381.Witness) + w, ok := fullWitness.Vector().(fr_bls12381.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), *w, opt) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) case *cs_bn254.R1CS: - w, ok := fullWitness.Vector.(*witness_bn254.Witness) + w, ok := fullWitness.Vector().(fr_bn254.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), *w, opt) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) case *cs_bw6761.R1CS: - w, ok := fullWitness.Vector.(*witness_bw6761.Witness) + w, ok := fullWitness.Vector().(fr_bw6761.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), *w, opt) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) case *cs_bls24317.R1CS: - w, ok := fullWitness.Vector.(*witness_bls24317.Witness) + w, ok := fullWitness.Vector().(fr_bls24317.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), *w, opt) + return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), w, opt) case *cs_bls24315.R1CS: - w, ok := fullWitness.Vector.(*witness_bls24315.Witness) + w, ok := fullWitness.Vector().(fr_bls24315.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), *w, opt) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) case *cs_bw6633.R1CS: - w, ok := fullWitness.Vector.(*witness_bw6633.Witness) + w, ok := fullWitness.Vector().(fr_bw6633.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), *w, opt) + return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), w, opt) default: panic("unrecognized R1CS curve type") } diff --git a/backend/groth16/groth16_test.go b/backend/groth16/groth16_test.go new file mode 100644 index 0000000000..24fca03aed --- /dev/null +++ b/backend/groth16/groth16_test.go @@ -0,0 +1,124 @@ +package groth16_test + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +//--------------------// +// benches // +//--------------------// + +func BenchmarkSetup(b *testing.B) { + for _, curve := range getCurves() { + b.Run(curve.String(), func(b *testing.B) { + r1cs, _ := referenceCircuit(curve) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = groth16.Setup(r1cs) + } + }) + } +} + +func BenchmarkProver(b *testing.B) { + for _, curve := range getCurves() { + b.Run(curve.String(), func(b *testing.B) { + r1cs, _solution := referenceCircuit(curve) + fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) + if err != nil { + b.Fatal(err) + } + pk, err := groth16.DummySetup(r1cs) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = groth16.Prove(r1cs, pk, fullWitness) + } + }) + } +} + +func BenchmarkVerifier(b *testing.B) { + for _, curve := range getCurves() { + b.Run(curve.String(), func(b *testing.B) { + r1cs, _solution := referenceCircuit(curve) + fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) + if err != nil { + b.Fatal(err) + } + publicWitness, err := fullWitness.Public() + if err != nil { + b.Fatal(err) + } + + pk, vk, err := groth16.Setup(r1cs) + if err != nil { + b.Fatal(err) + } + proof, err := groth16.Prove(r1cs, pk, fullWitness) + if err != nil { + panic(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = groth16.Verify(proof, vk, publicWitness) + } + }) + } +} + +type refCircuit struct { + nbConstraints int + X frontend.Variable + Y frontend.Variable `gnark:",public"` +} + +func (circuit *refCircuit) Define(api frontend.API) error { + for i := 0; i < circuit.nbConstraints; i++ { + circuit.X = api.Mul(circuit.X, circuit.X) + } + api.AssertIsEqual(circuit.X, circuit.Y) + return nil +} + +func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit) { + const nbConstraints = 40000 + circuit := refCircuit{ + nbConstraints: nbConstraints, + } + r1cs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } + + var good refCircuit + good.X = 2 + + // compute expected Y + expectedY := new(big.Int).SetUint64(2) + exp := big.NewInt(1) + exp.Lsh(exp, nbConstraints) + expectedY.Exp(expectedY, exp, curve.ScalarField()) + + good.Y = expectedY + + return r1cs, &good +} + +func getCurves() []ecc.ID { + if testing.Short() { + return []ecc.ID{ecc.BN254} + } + return gnark.Curves() +} diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index a3247875ba..0a6fbc4330 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -44,13 +44,13 @@ import ( plonk_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/plonk" plonk_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/plonk" - witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness" - witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" - witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" - witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" kzg_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" kzg_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" @@ -123,7 +123,7 @@ func Setup(ccs constraint.ConstraintSystem, kzgSRS kzg.SRS) (ProvingKey, Verifyi // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking -func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness.Witness, opts ...backend.ProverOption) (Proof, error) { +func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { // apply options opt, err := backend.NewProverConfig(opts...) @@ -133,53 +133,53 @@ func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness. switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bn254.Witness) + w, ok := fullWitness.Vector().(fr_bn254.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), *w, opt) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) case *cs_bls12381.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls12381.Witness) + w, ok := fullWitness.Vector().(fr_bls12381.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), *w, opt) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) case *cs_bls12377.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls12377.Witness) + w, ok := fullWitness.Vector().(fr_bls12377.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), *w, opt) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) case *cs_bw6761.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bw6761.Witness) + w, ok := fullWitness.Vector().(fr_bw6761.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), *w, opt) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) case *cs_bw6633.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bw6633.Witness) + w, ok := fullWitness.Vector().(fr_bw6633.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), *w, opt) + return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), w, opt) case *cs_bls24317.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls24317.Witness) + w, ok := fullWitness.Vector().(fr_bls24317.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), *w, opt) + return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), w, opt) case *cs_bls24315.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls24315.Witness) + w, ok := fullWitness.Vector().(fr_bls24315.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), *w, opt) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) default: panic("unrecognized SparseR1CS curve type") @@ -187,58 +187,58 @@ func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness. } // Verify verifies a PLONK proof, from the proof, preprocessed public data, and public witness. -func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error { +func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error { switch _proof := proof.(type) { case *plonk_bn254.Proof: - w, ok := publicWitness.Vector.(*witness_bn254.Witness) + w, ok := publicWitness.Vector().(fr_bn254.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bn254.Verify(_proof, vk.(*plonk_bn254.VerifyingKey), *w) + return plonk_bn254.Verify(_proof, vk.(*plonk_bn254.VerifyingKey), w) case *plonk_bls12381.Proof: - w, ok := publicWitness.Vector.(*witness_bls12381.Witness) + w, ok := publicWitness.Vector().(fr_bls12381.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls12381.Verify(_proof, vk.(*plonk_bls12381.VerifyingKey), *w) + return plonk_bls12381.Verify(_proof, vk.(*plonk_bls12381.VerifyingKey), w) case *plonk_bls12377.Proof: - w, ok := publicWitness.Vector.(*witness_bls12377.Witness) + w, ok := publicWitness.Vector().(fr_bls12377.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls12377.Verify(_proof, vk.(*plonk_bls12377.VerifyingKey), *w) + return plonk_bls12377.Verify(_proof, vk.(*plonk_bls12377.VerifyingKey), w) case *plonk_bw6761.Proof: - w, ok := publicWitness.Vector.(*witness_bw6761.Witness) + w, ok := publicWitness.Vector().(fr_bw6761.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bw6761.Verify(_proof, vk.(*plonk_bw6761.VerifyingKey), *w) + return plonk_bw6761.Verify(_proof, vk.(*plonk_bw6761.VerifyingKey), w) case *plonk_bw6633.Proof: - w, ok := publicWitness.Vector.(*witness_bw6633.Witness) + w, ok := publicWitness.Vector().(fr_bw6633.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bw6633.Verify(_proof, vk.(*plonk_bw6633.VerifyingKey), *w) + return plonk_bw6633.Verify(_proof, vk.(*plonk_bw6633.VerifyingKey), w) case *plonk_bls24317.Proof: - w, ok := publicWitness.Vector.(*witness_bls24317.Witness) + w, ok := publicWitness.Vector().(fr_bls24317.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls24317.Verify(_proof, vk.(*plonk_bls24317.VerifyingKey), *w) + return plonk_bls24317.Verify(_proof, vk.(*plonk_bls24317.VerifyingKey), w) case *plonk_bls24315.Proof: - w, ok := publicWitness.Vector.(*witness_bls24315.Witness) + w, ok := publicWitness.Vector().(fr_bls24315.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls24315.Verify(_proof, vk.(*plonk_bls24315.VerifyingKey), *w) + return plonk_bls24315.Verify(_proof, vk.(*plonk_bls24315.VerifyingKey), w) default: panic("unrecognized proof type") diff --git a/backend/plonkfri/plonkfri.go b/backend/plonkfri/plonkfri.go index 89b975d0fb..b2cab82fd6 100644 --- a/backend/plonkfri/plonkfri.go +++ b/backend/plonkfri/plonkfri.go @@ -36,15 +36,15 @@ import ( plonk_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/plonkfri" plonk_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/plonkfri" - witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" - witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" - witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" plonk_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/plonkfri" - witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" ) // Proof represents a Plonk proof generated by plonk.Prove @@ -104,7 +104,7 @@ func Setup(ccs constraint.ConstraintSystem) (ProvingKey, VerifyingKey, error) { // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking -func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness.Witness, opts ...backend.ProverOption) (Proof, error) { +func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) { // apply options opt, err := backend.NewProverConfig(opts...) @@ -114,109 +114,109 @@ func Prove(ccs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness. switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bn254.Witness) + w, ok := fullWitness.Vector().(fr_bn254.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), *w, opt) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) case *cs_bls12381.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls12381.Witness) + w, ok := fullWitness.Vector().(fr_bls12381.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), *w, opt) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) case *cs_bls12377.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls12377.Witness) + w, ok := fullWitness.Vector().(fr_bls12377.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), *w, opt) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) case *cs_bw6761.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bw6761.Witness) + w, ok := fullWitness.Vector().(fr_bw6761.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), *w, opt) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) case *cs_bw6633.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bw6633.Witness) + w, ok := fullWitness.Vector().(fr_bw6633.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), *w, opt) + return plonk_bw6633.Prove(tccs, pk.(*plonk_bw6633.ProvingKey), w, opt) case *cs_bls24315.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls24315.Witness) + w, ok := fullWitness.Vector().(fr_bls24315.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), *w, opt) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) case *cs_bls24317.SparseR1CS: - w, ok := fullWitness.Vector.(*witness_bls24317.Witness) + w, ok := fullWitness.Vector().(fr_bls24317.Vector) if !ok { return nil, witness.ErrInvalidWitness } - return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), *w, opt) + return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), w, opt) default: panic("unrecognized SparseR1CS curve type") } } // Verify verifies a PLONK proof, from the proof, preprocessed public data, and public witness. -func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error { +func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error { switch _proof := proof.(type) { case *plonk_bn254.Proof: - w, ok := publicWitness.Vector.(*witness_bn254.Witness) + w, ok := publicWitness.Vector().(fr_bn254.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bn254.Verify(_proof, vk.(*plonk_bn254.VerifyingKey), *w) + return plonk_bn254.Verify(_proof, vk.(*plonk_bn254.VerifyingKey), w) case *plonk_bls12381.Proof: - w, ok := publicWitness.Vector.(*witness_bls12381.Witness) + w, ok := publicWitness.Vector().(fr_bls12381.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls12381.Verify(_proof, vk.(*plonk_bls12381.VerifyingKey), *w) + return plonk_bls12381.Verify(_proof, vk.(*plonk_bls12381.VerifyingKey), w) case *plonk_bls12377.Proof: - w, ok := publicWitness.Vector.(*witness_bls12377.Witness) + w, ok := publicWitness.Vector().(fr_bls12377.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls12377.Verify(_proof, vk.(*plonk_bls12377.VerifyingKey), *w) + return plonk_bls12377.Verify(_proof, vk.(*plonk_bls12377.VerifyingKey), w) case *plonk_bw6761.Proof: - w, ok := publicWitness.Vector.(*witness_bw6761.Witness) + w, ok := publicWitness.Vector().(fr_bw6761.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bw6761.Verify(_proof, vk.(*plonk_bw6761.VerifyingKey), *w) + return plonk_bw6761.Verify(_proof, vk.(*plonk_bw6761.VerifyingKey), w) case *plonk_bw6633.Proof: - w, ok := publicWitness.Vector.(*witness_bw6633.Witness) + w, ok := publicWitness.Vector().(fr_bw6633.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bw6633.Verify(_proof, vk.(*plonk_bw6633.VerifyingKey), *w) + return plonk_bw6633.Verify(_proof, vk.(*plonk_bw6633.VerifyingKey), w) case *plonk_bls24315.Proof: - w, ok := publicWitness.Vector.(*witness_bls24315.Witness) + w, ok := publicWitness.Vector().(fr_bls24315.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls24315.Verify(_proof, vk.(*plonk_bls24315.VerifyingKey), *w) + return plonk_bls24315.Verify(_proof, vk.(*plonk_bls24315.VerifyingKey), w) case *plonk_bls24317.Proof: - w, ok := publicWitness.Vector.(*witness_bls24317.Witness) + w, ok := publicWitness.Vector().(fr_bls24317.Vector) if !ok { return witness.ErrInvalidWitness } - return plonk_bls24317.Verify(_proof, vk.(*plonk_bls24317.VerifyingKey), *w) + return plonk_bls24317.Verify(_proof, vk.(*plonk_bls24317.VerifyingKey), w) default: panic("unrecognized proof type") diff --git a/backend/witness/vector.go b/backend/witness/vector.go index adb2565189..248e293ab5 100644 --- a/backend/witness/vector.go +++ b/backend/witness/vector.go @@ -1,92 +1,249 @@ package witness import ( - "io" + "errors" "math/big" "reflect" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend/schema" - witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness" - witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" - witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" - witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark/internal/tinyfield" - witness_tinyfield "github.com/consensys/gnark/internal/tinyfield/witness" "github.com/consensys/gnark/internal/utils" ) -type Vector interface { - io.WriterTo - io.ReaderFrom - FromAssignment(assignment interface{}, leafType reflect.Type, publicOnly bool) (*schema.Schema, error) - ToAssignment(assigment interface{}, leafType reflect.Type, publicOnly bool) - Len() int - Type() reflect.Type -} - -func newVector(field *big.Int) (Vector, error) { - var w Vector +func newVector(field *big.Int, size int) (any, error) { curveID := utils.FieldToCurve(field) switch curveID { case ecc.BN254: - w = &witness_bn254.Witness{} + return make(fr_bn254.Vector, size), nil case ecc.BLS12_377: - w = &witness_bls12377.Witness{} + return make(fr_bls12377.Vector, size), nil case ecc.BLS12_381: - w = &witness_bls12381.Witness{} + return make(fr_bls12381.Vector, size), nil case ecc.BW6_761: - w = &witness_bw6761.Witness{} + return make(fr_bw6761.Vector, size), nil case ecc.BLS24_317: - w = &witness_bls24317.Witness{} + return make(fr_bls24317.Vector, size), nil case ecc.BLS24_315: - w = &witness_bls24315.Witness{} + return make(fr_bls24315.Vector, size), nil case ecc.BW6_633: - w = &witness_bw6633.Witness{} + return make(fr_bw6633.Vector, size), nil default: if field.Cmp(tinyfield.Modulus()) == 0 { - w = &witness_tinyfield.Witness{} + return make(tinyfield.Vector, size), nil } else { - return nil, errMissingCurveID + return nil, errors.New("unsupported modulus") } } - return w, nil } -func newFrom(from Vector, n int) (Vector, error) { +func newFrom(from any, n int) (any, error) { switch wt := from.(type) { - case *witness_bn254.Witness: - a := make(witness_bn254.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bls12377.Witness: - a := make(witness_bls12377.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bls12381.Witness: - a := make(witness_bls12381.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bw6761.Witness: - a := make(witness_bw6761.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bls24317.Witness: - a := make(witness_bls24317.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bls24315.Witness: - a := make(witness_bls24315.Witness, n) - copy(a, *wt) - return &a, nil - case *witness_bw6633.Witness: - a := make(witness_bw6633.Witness, n) - copy(a, *wt) - return &a, nil + case fr_bn254.Vector: + a := make(fr_bn254.Vector, n) + copy(a, wt) + return a, nil + case fr_bls12377.Vector: + a := make(fr_bls12377.Vector, n) + copy(a, wt) + return a, nil + case fr_bls12381.Vector: + a := make(fr_bls12381.Vector, n) + copy(a, wt) + return a, nil + case fr_bw6761.Vector: + a := make(fr_bw6761.Vector, n) + copy(a, wt) + return a, nil + case fr_bls24317.Vector: + a := make(fr_bls24317.Vector, n) + copy(a, wt) + return a, nil + case fr_bls24315.Vector: + a := make(fr_bls24315.Vector, n) + copy(a, wt) + return a, nil + case fr_bw6633.Vector: + a := make(fr_bw6633.Vector, n) + copy(a, wt) + return a, nil + case tinyfield.Vector: + a := make(tinyfield.Vector, n) + copy(a, wt) + return a, nil + default: + return nil, errors.New("unsupported modulus") + } +} + +func leafType(v any) reflect.Type { + switch v.(type) { + case fr_bn254.Vector: + return reflect.TypeOf(fr_bn254.Element{}) + case fr_bls12377.Vector: + return reflect.TypeOf(fr_bls12377.Element{}) + case fr_bls12381.Vector: + return reflect.TypeOf(fr_bls12381.Element{}) + case fr_bw6761.Vector: + return reflect.TypeOf(fr_bw6761.Element{}) + case fr_bls24317.Vector: + return reflect.TypeOf(fr_bls24317.Element{}) + case fr_bls24315.Vector: + return reflect.TypeOf(fr_bls24315.Element{}) + case fr_bw6633.Vector: + return reflect.TypeOf(fr_bw6633.Element{}) + case tinyfield.Vector: + return reflect.TypeOf(tinyfield.Element{}) + default: + panic("invalid input") + } +} + +func set(v any, index int, value any) error { + switch pv := v.(type) { + case fr_bn254.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bls12377.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bls12381.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bw6761.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bls24317.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bls24315.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case fr_bw6633.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + case tinyfield.Vector: + if index >= len(pv) { + return errors.New("out of bounds") + } + _, err := pv[index].SetInterface(value) + return err + default: + panic("invalid input") + } +} + +func iterate(v any) chan any { + chValues := make(chan any) + switch pv := v.(type) { + case fr_bn254.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bls12377.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bls12381.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bw6761.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bls24317.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bls24315.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case fr_bw6633.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + case tinyfield.Vector: + go func() { + for i := 0; i < len(pv); i++ { + chValues <- &(pv)[i] + } + close(chValues) + }() + default: + panic("invalid input") + } + return chValues +} + +func resize(v any, n int) any { + switch v.(type) { + case fr_bn254.Vector: + return make(fr_bn254.Vector, n) + case fr_bls12377.Vector: + return make(fr_bls12377.Vector, n) + case fr_bls12381.Vector: + return make(fr_bls12381.Vector, n) + case fr_bw6761.Vector: + return make(fr_bw6761.Vector, n) + case fr_bls24317.Vector: + return make(fr_bls24317.Vector, n) + case fr_bls24315.Vector: + return make(fr_bls24315.Vector, n) + case fr_bw6633.Vector: + return make(fr_bw6633.Vector, n) + case tinyfield.Vector: + return make(tinyfield.Vector, n) default: - return nil, errMissingCurveID + panic("invalid input") } } diff --git a/backend/witness/witness.go b/backend/witness/witness.go index 4b7f52f4bd..8b240654e3 100644 --- a/backend/witness/witness.go +++ b/backend/witness/witness.go @@ -16,12 +16,8 @@ // // Binary protocol // -// Full witness -> [uint32(nbElements) | publicVariables | secretVariables] -// Public witness -> [uint32(nbElements) | publicVariables ] -// -// where -// - `nbElements == len(publicVariables) [+ len(secretVariables)]`. -// - each variable (a *field element*) is encoded as a big-endian byte array, where `len(bytes(variable)) == len(bytes(modulus))` +// Witness -> [uint32(nbPublic) | uint32(nbSecret) | fr.Vector(variables)] +// fr.Vector is a *field element* vector encoded a big-endian byte array like so: [uint32(len(vector)) | elements] // // # Ordering // @@ -35,13 +31,15 @@ // } // // A valid witness would be: -// - `[uint32(3)|bytes(Y)|bytes(X)|bytes(Z)]` +// - `[uint32(1)|uint32(2)|uint32(3)|bytes(Y)|bytes(X)|bytes(Z)]` // - Hex representation with values `Y = 35`, `X = 3`, `Z = 2` -// `00000003000000000000000000000000000000000000000000000000000000000000002300000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000002` +// `000000010000000200000003000000000000000000000000000000000000000000000000000000000000002300000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000002` package witness import ( "bytes" + "encoding" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -49,124 +47,234 @@ import ( "math/big" "reflect" - "github.com/consensys/gnark-crypto/ecc" + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/internal/tinyfield" ) -var ( - ErrInvalidWitness = errors.New("invalid witness") - errMissingSchema = errors.New("missing Schema") - errMissingCurveID = errors.New("missing CurveID") -) +var ErrInvalidWitness = errors.New("invalid witness") // Witness represents a zkSNARK witness. // -// A witness can be in 3 states: -// 1. Assignment (ie assigning values to a frontend.Circuit object) -// 2. Witness (this object: an ordered vector of field elements + metadata) -// 3. Serialized (Binary or JSON) using MarshalBinary or MarshalJSON +// The underlying data structure is a vector of field elements, but a Witness +// also may have some additional meta information about the number of public elements and +// secret elements. // -// ! MarshalJSON and UnmarshalJSON are slow, and do not handle all complex circuit structures -type Witness struct { - Vector Vector // TODO @gbotrel the result is an interface for now may change to generic Witness[fr.Element] in an upcoming PR - Schema *schema.Schema // optional, Binary encoding needs no schema - CurveID ecc.ID // should be redundant with generic impl +// In most cases a Witness should be [de]serialized using a binary protocol. +// JSON conversions for pretty printing are slow and don't handle all complex circuit structures well. +type Witness interface { + io.WriterTo + io.ReaderFrom + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + + // Public returns the Public an object containing the public part of the Witness only. + Public() (Witness, error) + + // Vector returns the underlying fr.Vector slice + Vector() any + + // ToJSON returns the JSON encoding of the witness following the provided Schema. This is a + // convenience method and should be avoided in most cases. + ToJSON(s *schema.Schema) ([]byte, error) + + // FromJSON parses a JSON data input and attempt to reconstruct a witness following the provided Schema. + // This is a convenience method and should be avoided in most cases. + FromJSON(s *schema.Schema, data []byte) error + + // Fill range over the provided chan to fill the underlying vector. + // Will allocate the underlying vector with nbPublic + nbSecret elements. + // This is typically call by internal APIs to fill the vector by walking a structure. + Fill(nbPublic, nbSecret int, values <-chan any) error } -func New(field *big.Int, schema *schema.Schema) (*Witness, error) { - v, err := newVector(field) +type witness struct { + vector any + nbPublic, nbSecret uint32 +} + +// New initialize a new empty Witness. +func New(field *big.Int) (Witness, error) { + v, err := newVector(field, 0) if err != nil { return nil, err } - return &Witness{ - CurveID: utils.FieldToCurve(field), - Vector: v, - Schema: schema, + return &witness{ + vector: v, }, nil } -// Public extracts the public part of the witness and returns a new witness object -func (w *Witness) Public() (*Witness, error) { - if w.Vector == nil { - return nil, fmt.Errorf("%w: empty witness", ErrInvalidWitness) +func (w *witness) Fill(nbPublic, nbSecret int, values <-chan any) error { + n := int(nbPublic + nbSecret) + w.vector = resize(w.vector, n) + w.nbPublic = uint32(nbPublic) + w.nbSecret = uint32(nbSecret) + + i := 0 + + // note; this shouldn't be perf critical but if it is we could have 2 input chan and + // fill public and secret values concurrently. + for v := range values { + if i >= n { + // we panic here; shouldn't happen and if it does we may leek a chan + producer go routine + panic("chan of values returns more elements than expected") + } + // if v == nil { + // this is caught in the set method. however, error message will be unclear; reason + // is there is a nil field in assignment, we could print which one. + // } + if err := set(w.vector, i, v); err != nil { + return err + } + i++ } - if w.Schema == nil { - return nil, errMissingSchema + + if i != n { + return fmt.Errorf("expected %d values, filled only %d", n, i) } - v, err := newFrom(w.Vector, w.Schema.NbPublic) + + return nil +} + +func (w *witness) iterate() chan any { + return iterate(w.vector) +} + +func (w *witness) Public() (Witness, error) { + v, err := newFrom(w.vector, int(w.nbPublic)) if err != nil { return nil, err } - return &Witness{ - CurveID: w.CurveID, - Vector: v, - Schema: w.Schema, + return &witness{ + vector: v, + nbPublic: w.nbPublic, }, nil } -// MarshalBinary implements encoding.BinaryMarshaler -// Only the vector of field elements is marshalled: the curveID and the Schema are omitted. -func (w *Witness) MarshalBinary() (data []byte, err error) { - var buf bytes.Buffer +func (w *witness) WriteTo(wr io.Writer) (n int64, err error) { + // write number of public, number of secret + if err := binary.Write(wr, binary.BigEndian, w.nbPublic); err != nil { + return 0, err + } + n = int64(4) + if err := binary.Write(wr, binary.BigEndian, w.nbSecret); err != nil { + return n, err + } + n += 4 + + // write the vector + m, err := w.vector.(io.WriterTo).WriteTo(wr) + n += m + return n, err +} - if w.Vector == nil { - return nil, fmt.Errorf("%w: empty witness", ErrInvalidWitness) +func (w *witness) ReadFrom(r io.Reader) (n int64, err error) { + var buf [4]byte + if read, err := io.ReadFull(r, buf[:]); err != nil { + return int64(read), err + } + w.nbPublic = binary.BigEndian.Uint32(buf[:4]) + if read, err := io.ReadFull(r, buf[:]); err != nil { + return int64(read) + 4, err } + w.nbSecret = binary.BigEndian.Uint32(buf[:4]) - if _, err = w.Vector.WriteTo(&buf); err != nil { - return + var m int64 + switch t := w.vector.(type) { + case fr_bn254.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bls12377.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bls12381.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bw6761.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bls24317.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bls24315.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case fr_bw6633.Vector: + m, err = t.ReadFrom(r) + w.vector = t + case tinyfield.Vector: + m, err = t.ReadFrom(r) + w.vector = t + default: + panic("invalid input") } - return buf.Bytes(), nil + + n += m + return n, err } -// UnmarshalBinary implements encoding.BinaryUnmarshaler -func (w *Witness) UnmarshalBinary(data []byte) error { +// MarshalBinary encodes the number of public, number of secret and the fr.Vector. +func (w *witness) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer - snarkFieldSize := utils.ByteLen(w.CurveID.ScalarField()) - var r io.Reader - r = bytes.NewReader(data) - if w.Schema != nil { - // if schema is set we can do a limit reader - maxSize := 4 + (w.Schema.NbPublic+w.Schema.NbSecret)*snarkFieldSize - r = io.LimitReader(r, int64(maxSize)) + if _, err = w.WriteTo(&buf); err != nil { + return } + return buf.Bytes(), nil +} - v, err := newVector(w.CurveID.ScalarField()) - if err != nil { - return err - } - _, err = v.ReadFrom(r) - if err != nil { - return err - } - w.Vector = v +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (w *witness) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := w.ReadFrom(r) + return err +} - return nil +func (w *witness) Vector() any { + return w.vector } -// MarshalJSON implements json.Marshaler -// -// Only the vector of field elements is marshalled: the curveID and the Schema are omitted. -// -// ! MarshalJSON and UnmarshalJSON are slow, and do not handle all complex circuit structures -func (w *Witness) MarshalJSON() (r []byte, err error) { - if w.Schema == nil { - return nil, errMissingSchema - } - if w.Vector == nil { - return nil, fmt.Errorf("%w: empty witness", ErrInvalidWitness) +// ToJSON returns the JSON encoding of the witness following the provided Schema. This is a +// convenience method and should be avoided in most cases. +func (w *witness) ToJSON(s *schema.Schema) ([]byte, error) { + if s.NbPublic != int(w.nbPublic) || (w.nbSecret != 0 && w.nbSecret != uint32(s.NbSecret)) { + return nil, errors.New("schema is inconsistent with Witness") } + typ := reflect.PtrTo(leafType(w.vector)) + instance := s.Instantiate(typ) - typ := w.Vector.Type() - - instance := w.Schema.Instantiate(reflect.PtrTo(typ)) - if err := w.toAssignment(instance, reflect.PtrTo(typ)); err != nil { + chValues := w.iterate() + if _, err := schema.Walk(instance, typ, func(field schema.LeafInfo, tValue reflect.Value) error { + if field.Visibility == schema.Public { + v := <-chValues + tValue.Set(reflect.ValueOf(v)) + } + return nil + }); err != nil { return nil, err } + if w.nbSecret != 0 { + // secret part. + if _, err := schema.Walk(instance, typ, func(field schema.LeafInfo, tValue reflect.Value) error { + if field.Visibility == schema.Secret { + v := <-chValues + tValue.Set(reflect.ValueOf(v)) + } + return nil + }); err != nil { + return nil, err + } + } + if debug.Debug { return json.MarshalIndent(instance, " ", " ") } else { @@ -174,23 +282,15 @@ func (w *Witness) MarshalJSON() (r []byte, err error) { } } -// UnmarshalJSON implements json.Unmarshaler -// -// ! MarshalJSON and UnmarshalJSON are slow, and do not handle all complex circuit structures -func (w *Witness) UnmarshalJSON(data []byte) error { - if w.Schema == nil { - return errMissingSchema - } - v, err := newVector(w.CurveID.ScalarField()) - if err != nil { - return err - } - - typ := v.Type() +// FromJSON parses a JSON data input and attempt to reconstruct a witness following the provided Schema. +// This is a convenience method and should be avoided in most cases. +func (w *witness) FromJSON(s *schema.Schema, data []byte) error { + typ := leafType(w.vector) + ptrTyp := reflect.PtrTo(typ) // we instantiate an object matching the schema, with leaf type == field element // note that we pass a pointer here to have nil for zero values - instance := w.Schema.Instantiate(reflect.PtrTo(typ)) + instance := s.Instantiate(ptrTyp) dec := json.NewDecoder(bytes.NewReader(data)) dec.DisallowUnknownFields() @@ -199,50 +299,68 @@ func (w *Witness) UnmarshalJSON(data []byte) error { if err := dec.Decode(instance); err != nil { return err } + // walk through the public AND secret values + missingAssignment := func(name string) error { + return fmt.Errorf("missing assignment for %s", name) + } - // optimistic approach: first try to unmarshall everything. then only the public part if it fails - // note that our instance has leaf type == *fr.Element, so the zero value is nil - // and is going to make the newWitness method error since it doesn't accept missing assignments - _, err = v.FromAssignment(instance, reflect.PtrTo(typ), false) - if err != nil { - // try with public only - _, err := v.FromAssignment(instance, reflect.PtrTo(typ), true) - if err != nil { - return err + // collect all public values; if any are missing, no point going further. + publicValues := make([]any, 0, s.NbPublic) + if _, err := schema.Walk(instance, ptrTyp, func(leaf schema.LeafInfo, tValue reflect.Value) error { + if leaf.Visibility == schema.Public { + if tValue.IsNil() { + return missingAssignment(leaf.FullName()) + } + publicValues = append(publicValues, reflect.Indirect(tValue).Interface()) } - w.Vector = v return nil + }); err != nil { + // missing public values + return err } - w.Vector = v - return nil -} -func (w *Witness) toAssignment(to interface{}, toLeafType reflect.Type) error { - if w.Schema == nil { - return errMissingSchema + // collect all secret values; if any are missing, we just deal with the public part. + secretValues := make([]any, 0, s.NbSecret) + publicOnly := false + if _, err := schema.Walk(instance, ptrTyp, func(leaf schema.LeafInfo, tValue reflect.Value) error { + if leaf.Visibility == schema.Secret { + if tValue.IsNil() { + return missingAssignment(leaf.FullName()) + } + secretValues = append(secretValues, reflect.Indirect(tValue).Interface()) + } + return nil + }); err != nil { + // missing secret values, we just do the public part. + publicOnly = true } - if w.Vector == nil { - return fmt.Errorf("%w: empty witness", ErrInvalidWitness) + + // reconstruct the witness + // we use a buffered channel to ensure this go routine terminates, even if setting a witness + // value failed. All this is not really performant for large witnesses, but again, JSON + // shouldn't be used in perf-critical scenario. + var chValues chan any + if publicOnly { + chValues = make(chan any, len(publicValues)) + s.NbSecret = 0 + } else { + chValues = make(chan any, len(publicValues)+len(secretValues)) } + go func() { + defer close(chValues) - // we check the size of the underlying vector to determine if we have the full witness - // or only the public part - n := w.Vector.Len() + for _, v := range publicValues { + chValues <- v + } - nbSecret, nbPublic := w.Schema.NbSecret, w.Schema.NbPublic + if publicOnly { + return + } - var publicOnly bool - if n == nbPublic { - // public witness only - publicOnly = true - } else if n == (nbPublic + nbSecret) { - // full witness - publicOnly = false - } else { - // invalid witness size - return fmt.Errorf("%w: got %d elements, expected either %d (public) or %d (full)", ErrInvalidWitness, n, nbPublic, nbPublic+nbSecret) - } - w.Vector.ToAssignment(to, toLeafType, publicOnly) + for _, v := range secretValues { + chValues <- v + } + }() - return nil + return w.Fill(s.NbPublic, s.NbSecret, chValues) } diff --git a/backend/witness/witness_test.go b/backend/witness/witness_test.go index 88a5bbe253..ba127b3240 100644 --- a/backend/witness/witness_test.go +++ b/backend/witness/witness_test.go @@ -1,86 +1,58 @@ -package witness +package witness_test import ( + "fmt" "reflect" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness" - witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" - witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" - witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend" "github.com/stretchr/testify/require" ) type circuit struct { // tagging a variable is optional // default uses variable name and secret visibility. - X *fr.Element `gnark:",public"` - Y *fr.Element `gnark:",public"` + X frontend.Variable `gnark:",public"` + Y frontend.Variable `gnark:",public"` - E *fr.Element + E frontend.Variable } -type marshaller uint8 +func (c *circuit) Define(frontend.API) error { + return nil +} -const ( - JSON marshaller = iota - Binary -) +func ExampleWitness() { + // Witnesses can be created directly by "walking" through an assignment (circuit structure) + // simple assignment + assignment := &circuit{ + X: 42, + Y: 8000, + E: 1, + } -func roundTripMarshal(assert *require.Assertions, assignment circuit, m marshaller, publicOnly bool) { - // build the vector - w, err := New(ecc.BN254.ScalarField(), nil) - assert.NoError(err) + w, _ := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) - w.Schema, err = w.Vector.FromAssignment(&assignment, tVariable, publicOnly) - assert.NoError(err) + // Binary [de]serialization + data, _ := w.MarshalBinary() - marshal := w.MarshalBinary - if m == JSON { - marshal = w.MarshalJSON - } + reconstructed, _ := witness.New(ecc.BN254.ScalarField()) + reconstructed.UnmarshalBinary(data) - // serialize the vector to binary - data, err := marshal() - assert.NoError(err) + // For pretty printing, we can do JSON conversions; they are not efficient and don't handle + // complex circuit structures well. - // re-read - witness := Witness{CurveID: ecc.BN254, Schema: w.Schema} - unmarshal := witness.UnmarshalBinary - if m == JSON { - unmarshal = witness.UnmarshalJSON - } - err = unmarshal(data) - assert.NoError(err) + // first get the circuit expected schema + schema, _ := frontend.NewSchema(assignment) + json, _ := reconstructed.ToJSON(schema) - // reconstruct a circuit object - var reconstructed circuit - - switch wt := witness.Vector.(type) { - case *witness_bls12377.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bls12381.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bls24317.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bls24315.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bn254.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bw6633.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - case *witness_bw6761.Witness: - wt.ToAssignment(&reconstructed, tVariable, publicOnly) - default: - panic("not implemented") - } + fmt.Println(string(json)) + // Output: + // {"X":42,"Y":8000,"E":1} - assert.True(reflect.DeepEqual(assignment, reconstructed), "public witness reconstructed doesn't match original value") } func TestMarshalPublic(t *testing.T) { @@ -90,8 +62,8 @@ func TestMarshalPublic(t *testing.T) { assignment.X = new(fr.Element).SetInt64(42) assignment.Y = new(fr.Element).SetInt64(8000) - roundTripMarshal(assert, assignment, JSON, true) - roundTripMarshal(assert, assignment, Binary, true) + roundTripMarshal(assert, assignment, true) + roundTripMarshalJSON(assert, assignment, true) } func TestMarshal(t *testing.T) { @@ -102,8 +74,8 @@ func TestMarshal(t *testing.T) { assignment.Y = new(fr.Element).SetInt64(8000) assignment.E = new(fr.Element).SetInt64(1) - roundTripMarshal(assert, assignment, JSON, false) - roundTripMarshal(assert, assignment, Binary, false) + roundTripMarshal(assert, assignment, false) + roundTripMarshalJSON(assert, assignment, false) } func TestPublic(t *testing.T) { @@ -114,26 +86,65 @@ func TestPublic(t *testing.T) { assignment.Y = new(fr.Element).SetInt64(8000) assignment.E = new(fr.Element).SetInt64(1) - w, err := New(ecc.BN254.ScalarField(), nil) + w, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) assert.NoError(err) - w.Schema, err = w.Vector.FromAssignment(&assignment, tVariable, false) + publicW, err := w.Public() assert.NoError(err) - publicW, err := w.Public() + wt := publicW.Vector().(fr.Vector) + + assert.Equal(3, len(w.Vector().(fr.Vector))) + assert.Equal(2, len(wt)) + + assert.Equal("42", wt[0].String()) + assert.Equal("8000", wt[1].String()) +} + +func roundTripMarshal(assert *require.Assertions, assignment circuit, publicOnly bool) { + // build the vector + var opts []frontend.WitnessOption + if publicOnly { + opts = append(opts, frontend.PublicOnly()) + } + w, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField(), opts...) + assert.NoError(err) + + // serialize the vector to binary + data, err := w.MarshalBinary() assert.NoError(err) - assert.Equal(3, w.Vector.Len()) - assert.Equal(2, publicW.Vector.Len()) + // re-read + rw, err := witness.New(ecc.BN254.ScalarField()) + assert.NoError(err) + err = rw.UnmarshalBinary(data) + assert.NoError(err) - wt := publicW.Vector.(*witness_bn254.Witness) + assert.True(reflect.DeepEqual(rw, w), "witness binary round trip serialization") - assert.Equal("42", (*wt)[0].String()) - assert.Equal("8000", (*wt)[1].String()) } +func roundTripMarshalJSON(assert *require.Assertions, assignment circuit, publicOnly bool) { + // build the vector + var opts []frontend.WitnessOption + if publicOnly { + opts = append(opts, frontend.PublicOnly()) + } + w, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField(), opts...) + assert.NoError(err) + + s, err := frontend.NewSchema(&assignment) + assert.NoError(err) + + // serialize the vector to JSON + data, err := w.ToJSON(s) + assert.NoError(err) + + // re-read + rw, err := witness.New(ecc.BN254.ScalarField()) + assert.NoError(err) + err = rw.FromJSON(s, data) + assert.NoError(err) -var tVariable reflect.Type + assert.True(reflect.DeepEqual(rw, w), "witness json round trip serialization") -func init() { - tVariable = reflect.TypeOf(circuit{}.E) } diff --git a/constraint/bls12-377/r1cs.go b/constraint/bls12-377/r1cs.go index 4c28331631..2f5941ab4a 100644 --- a/constraint/bls12-377/r1cs.go +++ b/constraint/bls12-377/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bls12_377witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bls12-377/r1cs_sparse.go b/constraint/bls12-377/r1cs_sparse.go index 92e8020153..34a8895794 100644 --- a/constraint/bls12-377/r1cs_sparse.go +++ b/constraint/bls12-377/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bls12_377witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bls12-381/r1cs.go b/constraint/bls12-381/r1cs.go index 9a81e50cfc..6ec76d1d7b 100644 --- a/constraint/bls12-381/r1cs.go +++ b/constraint/bls12-381/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - bls12_381witness "github.com/consensys/gnark/internal/backend/bls12-381/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bls12_381witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bls12-381/r1cs_sparse.go b/constraint/bls12-381/r1cs_sparse.go index b6bc8643eb..00bff07eb2 100644 --- a/constraint/bls12-381/r1cs_sparse.go +++ b/constraint/bls12-381/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - bls12_381witness "github.com/consensys/gnark/internal/backend/bls12-381/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bls12_381witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bls24-315/r1cs.go b/constraint/bls24-315/r1cs.go index 062b33d53e..e5eca18b70 100644 --- a/constraint/bls24-315/r1cs.go +++ b/constraint/bls24-315/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - bls24_315witness "github.com/consensys/gnark/internal/backend/bls24-315/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bls24_315witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bls24-315/r1cs_sparse.go b/constraint/bls24-315/r1cs_sparse.go index bc3728c507..28e4a19b4c 100644 --- a/constraint/bls24-315/r1cs_sparse.go +++ b/constraint/bls24-315/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - bls24_315witness "github.com/consensys/gnark/internal/backend/bls24-315/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bls24_315witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bls24-317/r1cs.go b/constraint/bls24-317/r1cs.go index 31fa5071e2..ca861bb4bf 100644 --- a/constraint/bls24-317/r1cs.go +++ b/constraint/bls24-317/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - bls24_317witness "github.com/consensys/gnark/internal/backend/bls24-317/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bls24_317witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bls24-317/r1cs_sparse.go b/constraint/bls24-317/r1cs_sparse.go index 0ea71e36c0..535ed02351 100644 --- a/constraint/bls24-317/r1cs_sparse.go +++ b/constraint/bls24-317/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - bls24_317witness "github.com/consensys/gnark/internal/backend/bls24-317/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bls24_317witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bn254/r1cs.go b/constraint/bn254/r1cs.go index 8865156d4d..0e027ad14e 100644 --- a/constraint/bn254/r1cs.go +++ b/constraint/bn254/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - bn254witness "github.com/consensys/gnark/internal/backend/bn254/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bn254witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bn254/r1cs_sparse.go b/constraint/bn254/r1cs_sparse.go index 6b250206df..08aad4b607 100644 --- a/constraint/bn254/r1cs_sparse.go +++ b/constraint/bn254/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - bn254witness "github.com/consensys/gnark/internal/backend/bn254/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bn254witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bw6-633/r1cs.go b/constraint/bw6-633/r1cs.go index 7ca174f02a..034fcbfb1a 100644 --- a/constraint/bw6-633/r1cs.go +++ b/constraint/bw6-633/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - bw6_633witness "github.com/consensys/gnark/internal/backend/bw6-633/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bw6_633witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bw6-633/r1cs_sparse.go b/constraint/bw6-633/r1cs_sparse.go index 339089f202..465df10443 100644 --- a/constraint/bw6-633/r1cs_sparse.go +++ b/constraint/bw6-633/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - bw6_633witness "github.com/consensys/gnark/internal/backend/bw6-633/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bw6_633witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/bw6-761/r1cs.go b/constraint/bw6-761/r1cs.go index 65d20cb903..037ef187e1 100644 --- a/constraint/bw6-761/r1cs.go +++ b/constraint/bw6-761/r1cs.go @@ -37,8 +37,6 @@ import ( "math" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - bw6_761witness "github.com/consensys/gnark/internal/backend/bw6-761/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*bw6_761witness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/bw6-761/r1cs_sparse.go b/constraint/bw6-761/r1cs_sparse.go index 13666417a5..c17079f867 100644 --- a/constraint/bw6-761/r1cs_sparse.go +++ b/constraint/bw6-761/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - bw6_761witness "github.com/consensys/gnark/internal/backend/bw6-761/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*bw6_761witness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/constraint/system.go b/constraint/system.go index 3850e510ae..c400d00972 100644 --- a/constraint/system.go +++ b/constraint/system.go @@ -24,7 +24,7 @@ type ConstraintSystem interface { CoeffEngine // IsSolved returns nil if given witness solves the constraint system and error otherwise - IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error + IsSolved(witness witness.Witness, opts ...backend.ProverOption) error // GetNbVariables return number of internal, secret and public Variables // Deprecated: use GetNbSecretVariables() instead diff --git a/constraint/tinyfield/r1cs.go b/constraint/tinyfield/r1cs.go index 6217ad69cf..3a45c39772 100644 --- a/constraint/tinyfield/r1cs.go +++ b/constraint/tinyfield/r1cs.go @@ -37,8 +37,6 @@ import ( "math" fr "github.com/consensys/gnark/internal/tinyfield" - - unknownwitness "github.com/consensys/gnark/internal/tinyfield/witness" ) // R1CS describes a set of R1CS constraint @@ -81,13 +79,13 @@ func (cs *R1CS) AddConstraint(r1c constraint.R1C, debugInfo ...constraint.DebugI // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "groth16").Logger() nbWires := len(cs.Public) + len(cs.Secret) + cs.NbInternalVariables solution, err := newSolution(nbWires, opt.HintFunctions, cs.MHintsDependencies, cs.MHints, cs.Coefficients, &cs.System.SymbolTable) if err != nil { - return make([]fr.Element, nbWires), err + return make(fr.Vector, nbWires), err } start := time.Now() @@ -139,7 +137,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return solution.values, nil } -func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { +func (cs *R1CS) parallelSolve(a, b, c fr.Vector, solution *solution) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -254,17 +252,17 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *R1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - a := make([]fr.Element, len(cs.Constraints)) - b := make([]fr.Element, len(cs.Constraints)) - c := make([]fr.Element, len(cs.Constraints)) - v := witness.Vector.(*unknownwitness.Witness) - _, err = cs.Solve(*v, a, b, c, opt) + a := make(fr.Vector, len(cs.Constraints)) + b := make(fr.Vector, len(cs.Constraints)) + c := make(fr.Vector, len(cs.Constraints)) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, a, b, c, opt) return err } diff --git a/constraint/tinyfield/r1cs_sparse.go b/constraint/tinyfield/r1cs_sparse.go index a19e6f92dc..dcd4cc331a 100644 --- a/constraint/tinyfield/r1cs_sparse.go +++ b/constraint/tinyfield/r1cs_sparse.go @@ -36,8 +36,6 @@ import ( "github.com/consensys/gnark/profile" fr "github.com/consensys/gnark/internal/tinyfield" - - unknownwitness "github.com/consensys/gnark/internal/tinyfield/witness" ) // SparseR1CS represents a Plonk like circuit @@ -77,7 +75,7 @@ func (cs *SparseR1CS) AddConstraint(c constraint.SparseR1C, debugInfo ...constra // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { +func (cs *SparseR1CS) Solve(witness fr.Vector, opt backend.ProverConfig) (fr.Vector, error) { log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved @@ -87,7 +85,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f expectedWitnessSize := int(len(cs.Public) + len(cs.Secret)) if len(witness) != expectedWitnessSize { - return make([]fr.Element, nbVariables), fmt.Errorf( + return make(fr.Vector, nbVariables), fmt.Errorf( "invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), expectedWitnessSize, @@ -142,7 +140,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } -func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv fr.Vector) error { // minWorkPerCPU is the minimum target number of constraint a task should hold // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed // sequentially without sync. @@ -303,7 +301,7 @@ func (cs *SparseR1CS) computeHints(c constraint.SparseR1C, solution *solution) ( // solveConstraint solve any unsolved wire in given constraint and update the solution // a SparseR1C may have up to one unsolved wire (excluding hints) // if it doesn't, then this function returns and does nothing -func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv []fr.Element) error { +func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution, coefficientsNegInv fr.Vector) error { lro, err := cs.computeHints(c, solution) if err != nil { @@ -374,14 +372,14 @@ func (cs *SparseR1CS) solveConstraint(c constraint.SparseR1C, solution *solution // IsSolved returns nil if given witness solves the SparseR1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness *witness.Witness, opts ...backend.ProverOption) error { +func (cs *SparseR1CS) IsSolved(witness witness.Witness, opts ...backend.ProverOption) error { opt, err := backend.NewProverConfig(opts...) if err != nil { return err } - v := witness.Vector.(*unknownwitness.Witness) - _, err = cs.Solve(*v, opt) + v := witness.Vector().(fr.Vector) + _, err = cs.Solve(v, opt) return err } diff --git a/frontend/cs/r1cs/heap.go b/frontend/cs/r1cs/heap.go index 8dfb9fdbae..a0d6035411 100644 --- a/frontend/cs/r1cs/heap.go +++ b/frontend/cs/r1cs/heap.go @@ -1,6 +1,6 @@ package r1cs -// An minHeap is a min-heap of linear expressions. It facililates merging k-linear expressions. +// An minHeap is a min-heap of linear expressions. It facilitates merging k-linear expressions. // // The code is identical to https://pkg.go.dev/container/heap but replaces interfaces with concrete // type to avoid memory overhead. diff --git a/frontend/schema/schema.go b/frontend/schema/schema.go index 49930979fc..0f41a7ad22 100644 --- a/frontend/schema/schema.go +++ b/frontend/schema/schema.go @@ -63,7 +63,7 @@ func (s Schema) Instantiate(leafType reflect.Type, omitEmptyTag ...bool) interfa // first, let's replace the Field by reflect.StructField is := toStructField(s.Fields, leafType, omitEmpty) - // now create the correspoinding type + // now create the corresponding type typ := reflect.StructOf(is) // instantiate the type @@ -234,7 +234,7 @@ func parse(r []Field, input interface{}, target reflect.Type, parentFullName, pa // default visibility is Unset visibility := Unset - // variable name is field name, unless overriden by gnark tag value + // variable name is field name, unless overridden by gnark tag value name := f.Name var nameTag string @@ -322,7 +322,7 @@ func parse(r []Field, input interface{}, target reflect.Type, parentFullName, pa if tValue.Kind() == reflect.Slice || tValue.Kind() == reflect.Array { if tValue.Len() == 0 { if reflect.SliceOf(target) == tValue.Type() { - fmt.Printf("ignoring uninitizalized slice: %s %s\n", parentGoName, reflect.SliceOf(target).String()) + fmt.Printf("ignoring uninitialized slice: %s %s\n", parentGoName, reflect.SliceOf(target).String()) } return r, nil } diff --git a/frontend/schema/walk.go b/frontend/schema/walk.go index c4d0440f06..07d56b880a 100644 --- a/frontend/schema/walk.go +++ b/frontend/schema/walk.go @@ -82,7 +82,7 @@ func (w *walker) Pointer(value reflect.Value) error { func (w *walker) Slice(value reflect.Value) error { if value.Type() == w.targetSlice { if value.Len() == 0 { - fmt.Printf("ignoring uninitizalized slice: %s %s\n", w.name(), reflect.SliceOf(w.target).String()) + fmt.Printf("ignoring uninitialized slice: %s %s\n", w.name(), reflect.SliceOf(w.target).String()) return nil } return w.handleLeaves(value) diff --git a/frontend/witness.go b/frontend/witness.go index d90becb839..b0eecd4249 100644 --- a/frontend/witness.go +++ b/frontend/witness.go @@ -2,34 +2,71 @@ package frontend import ( "math/big" + "reflect" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/schema" ) -// NewWitness build an orderded vector of field elements from the given assignment (Circuit) +// NewWitness build an ordered vector of field elements from the given assignment (Circuit) // if PublicOnly is specified, returns the public part of the witness only -// else returns [public | secret]. The result can then be serialized to / from json & binary +// else returns [public | secret]. The result can then be serialized to / from json & binary. // -// Returns an error if the assignment has missing entries -func NewWitness(assignment Circuit, field *big.Int, opts ...WitnessOption) (*witness.Witness, error) { +// See ExampleWitness in witness package for usage. +func NewWitness(assignment Circuit, field *big.Int, opts ...WitnessOption) (witness.Witness, error) { opt, err := options(opts...) if err != nil { return nil, err } - w, err := witness.New(field, nil) + // count the leaves + s, err := schema.Walk(assignment, tVariable, nil) if err != nil { return nil, err } + if opt.publicOnly { + s.Secret = 0 + } - w.Schema, err = w.Vector.FromAssignment(assignment, tVariable, opt.publicOnly) + // allocate the witness + w, err := witness.New(field) if err != nil { return nil, err } + // write the public | secret values in a chan + chValues := make(chan any) + go func() { + defer close(chValues) + schema.Walk(assignment, tVariable, func(leaf schema.LeafInfo, tValue reflect.Value) error { + if leaf.Visibility == schema.Public { + chValues <- tValue.Interface() + } + return nil + }) + if !opt.publicOnly { + schema.Walk(assignment, tVariable, func(leaf schema.LeafInfo, tValue reflect.Value) error { + if leaf.Visibility == schema.Secret { + chValues <- tValue.Interface() + } + return nil + }) + } + }() + if err := w.Fill(s.Public, s.Secret, chValues); err != nil { + return nil, err + } + return w, nil } +// NewSchema returns the schema corresponding to the circuit structure. +// +// This is used to JSON (un)marshall witnesses. +func NewSchema(circuit Circuit) (*schema.Schema, error) { + return schema.New(circuit, tVariable) +} + // default options func options(opts ...WitnessOption) (witnessConfig, error) { // apply options @@ -45,7 +82,7 @@ func options(opts ...WitnessOption) (witnessConfig, error) { return opt, nil } -// WitnessOption sets optional parameter to witness instantiation from an assigment +// WitnessOption sets optional parameter to witness instantiation from an assignment type WitnessOption func(*witnessConfig) error type witnessConfig struct { diff --git a/go.mod b/go.mod index bfff27c8c4..9ba2c81a81 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 - github.com/consensys/gnark-crypto v0.8.1-0.20221220191316-4b7364bddab8 + github.com/consensys/gnark-crypto v0.9.1-0.20230126211359-1835092d6670 github.com/ethereum/go-ethereum v1.10.26 github.com/fxamacker/cbor/v2 v2.2.0 github.com/google/go-cmp v0.5.8 diff --git a/go.sum b/go.sum index 5903d61278..1152a9455d 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/Yj github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/gnark-crypto v0.8.1-0.20221220191316-4b7364bddab8 h1:Ij6UQpKx4/Ox6L6qFPk8NhEnTsYCEXlILnh+1Hi1grY= github.com/consensys/gnark-crypto v0.8.1-0.20221220191316-4b7364bddab8/go.mod h1:CkbdF9hbRidRJYMRzmfX8TMOr95I2pYXRHF18MzRrvA= +github.com/consensys/gnark-crypto v0.9.1-0.20230126211359-1835092d6670 h1:AkewHCm7VuiCV3nDxsFVYE8JHPi9RhR6zFq4I6Ha0Fg= +github.com/consensys/gnark-crypto v0.9.1-0.20230126211359-1835092d6670/go.mod h1:CkbdF9hbRidRJYMRzmfX8TMOr95I2pYXRHF18MzRrvA= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/backend/bls12-377/groth16/commitment_test.go b/internal/backend/bls12-377/groth16/commitment_test.go index 234e82707e..18f6f142d1 100644 --- a/internal/backend/bls12-377/groth16/commitment_test.go +++ b/internal/backend/bls12-377/groth16/commitment_test.go @@ -51,7 +51,7 @@ func setup(t *testing.T, circuit frontend.Circuit) (constraint.ConstraintSystem, return _r1cs, pk, vk } -func prove(t *testing.T, assignment frontend.Circuit, cs constraint.ConstraintSystem, pk groth16.ProvingKey) (*witness.Witness, groth16.Proof) { +func prove(t *testing.T, assignment frontend.Circuit, cs constraint.ConstraintSystem, pk groth16.ProvingKey) (witness.Witness, groth16.Proof) { _witness, err := frontend.NewWitness(assignment, ecc.BLS12_377.ScalarField()) assert.NoError(t, err) diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go deleted file mode 100644 index 9871fd1739..0000000000 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package groth16_test - -import ( - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - - "github.com/consensys/gnark/constraint/bls12-377" - - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - - "bytes" - bls12_377groth16 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "reflect" - "testing" - - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" -) - -//--------------------// -// benches // -//--------------------// - -type refCircuit struct { - nbConstraints int - X frontend.Variable - Y frontend.Variable `gnark:",public"` -} - -func (circuit *refCircuit) Define(api frontend.API) error { - for i := 0; i < circuit.nbConstraints; i++ { - circuit.X = api.Mul(circuit.X, circuit.X) - } - api.AssertIsEqual(circuit.X, circuit.Y) - return nil -} - -func referenceCircuit() (constraint.ConstraintSystem, frontend.Circuit) { - const nbConstraints = 40000 - circuit := refCircuit{ - nbConstraints: nbConstraints, - } - r1cs, err := frontend.Compile(curve.ID.ScalarField(), r1cs.NewBuilder, &circuit) - if err != nil { - panic(err) - } - - var good refCircuit - good.X = 2 - - // compute expected Y - var expectedY fr.Element - expectedY.SetUint64(2) - - for i := 0; i < nbConstraints; i++ { - expectedY.Mul(&expectedY, &expectedY) - } - - good.Y = (expectedY) - - return r1cs, &good -} - -func BenchmarkSetup(b *testing.B) { - r1cs, _ := referenceCircuit() - - var pk bls12_377groth16.ProvingKey - var vk bls12_377groth16.VerifyingKey - b.ResetTimer() - - b.Run("setup", func(b *testing.B) { - for i := 0; i < b.N; i++ { - bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - } - }) -} - -func BenchmarkProver(b *testing.B) { - r1cs, _solution := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) - if err != nil { - b.Fatal(err) - } - - var pk bls12_377groth16.ProvingKey - bls12_377groth16.DummySetup(r1cs.(*cs.R1CS), &pk) - - b.ResetTimer() - b.Run("prover", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, _ = bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverConfig{}) - } - }) -} - -func BenchmarkVerifier(b *testing.B) { - r1cs, _solution := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) - if err != nil { - b.Fatal(err) - } - publicWitness := bls12_377witness.Witness{} - _, err = publicWitness.FromAssignment(_solution, tVariable, true) - if err != nil { - b.Fatal(err) - } - - var pk bls12_377groth16.ProvingKey - var vk bls12_377groth16.VerifyingKey - bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverConfig{}) - if err != nil { - panic(err) - } - - b.ResetTimer() - b.Run("verifier", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = bls12_377groth16.Verify(proof, &vk, publicWitness) - } - }) -} - -func BenchmarkProofSerialization(b *testing.B) { - r1cs, _solution := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) - if err != nil { - b.Fatal(err) - } - - var pk bls12_377groth16.ProvingKey - var vk bls12_377groth16.VerifyingKey - bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverConfig{}) - if err != nil { - panic(err) - } - - b.ReportAllocs() - - // --------------------------------------------------------------------------------------------- - // bls12_377groth16.Proof binary serialization - b.Run("proof: binary serialization (bls12_377groth16.Proof)", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - var buf bytes.Buffer - _, _ = proof.WriteTo(&buf) - } - }) - b.Run("proof: binary deserialization (bls12_377groth16.Proof)", func(b *testing.B) { - var buf bytes.Buffer - _, _ = proof.WriteTo(&buf) - var proofReconstructed bls12_377groth16.Proof - b.ResetTimer() - for i := 0; i < b.N; i++ { - buf := bytes.NewBuffer(buf.Bytes()) - _, _ = proofReconstructed.ReadFrom(buf) - } - }) - { - var buf bytes.Buffer - _, _ = proof.WriteTo(&buf) - } - - // --------------------------------------------------------------------------------------------- - // bls12_377groth16.Proof binary serialization (uncompressed) - b.Run("proof: binary raw serialization (bls12_377groth16.Proof)", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - var buf bytes.Buffer - _, _ = proof.WriteRawTo(&buf) - } - }) - b.Run("proof: binary raw deserialization (bls12_377groth16.Proof)", func(b *testing.B) { - var buf bytes.Buffer - _, _ = proof.WriteRawTo(&buf) - var proofReconstructed bls12_377groth16.Proof - b.ResetTimer() - for i := 0; i < b.N; i++ { - buf := bytes.NewBuffer(buf.Bytes()) - _, _ = proofReconstructed.ReadFrom(buf) - } - }) - { - var buf bytes.Buffer - _, _ = proof.WriteRawTo(&buf) - } - -} - -func BenchmarkProvingKeySerialization(b *testing.B) { - r1cs, _ := referenceCircuit() - - var pk bls12_377groth16.ProvingKey - bls12_377groth16.DummySetup(r1cs.(*cs.R1CS), &pk) - - var buf bytes.Buffer - // grow the buffer once - pk.WriteTo(&buf) - - b.ResetTimer() - b.Run("pk_serialize_compressed", func(b *testing.B) { - for i := 0; i < b.N; i++ { - buf.Reset() - pk.WriteTo(&buf) - } - }) - - compressedBytes := buf.Bytes() - b.ResetTimer() - b.Run("pk_deserialize_compressed_safe", func(b *testing.B) { - for i := 0; i < b.N; i++ { - pk.ReadFrom(bytes.NewReader(compressedBytes)) - } - }) - - b.ResetTimer() - b.Run("pk_deserialize_compressed_unsafe", func(b *testing.B) { - for i := 0; i < b.N; i++ { - pk.UnsafeReadFrom(bytes.NewReader(compressedBytes)) - } - }) - - b.ResetTimer() - b.Run("pk_serialize_raw", func(b *testing.B) { - for i := 0; i < b.N; i++ { - buf.Reset() - pk.WriteRawTo(&buf) - } - }) - - rawBytes := buf.Bytes() - b.ResetTimer() - b.Run("pk_deserialize_raw_safe", func(b *testing.B) { - for i := 0; i < b.N; i++ { - pk.ReadFrom(bytes.NewReader(rawBytes)) - } - }) - - b.ResetTimer() - b.Run("pk_deserialize_raw_unsafe", func(b *testing.B) { - for i := 0; i < b.N; i++ { - pk.UnsafeReadFrom(bytes.NewReader(rawBytes)) - } - }) -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index bfd52a57e7..4590530c00 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -24,7 +24,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/constraint/bls12-377" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "math/big" @@ -52,7 +51,7 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, opt backend.ProverConfig) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness fr.Vector, opt backend.ProverConfig) (*Proof, error) { // TODO @gbotrel witness size check is done by R1CS, doesn't mean we shouldn't sanitize here. // if len(witness) != r1cs.NbPublicVariables-1+r1cs.NbSecretVariables { // return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public) + %d (secret)", len(witness), r1cs.NbPublicVariables-1+r1cs.NbSecretVariables, r1cs.NbPublicVariables, r1cs.NbSecretVariables) diff --git a/internal/backend/bls12-377/groth16/verify.go b/internal/backend/bls12-377/groth16/verify.go index 242705a35a..6a9819c093 100644 --- a/internal/backend/bls12-377/groth16/verify.go +++ b/internal/backend/bls12-377/groth16/verify.go @@ -21,7 +21,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/logger" "io" "math/big" @@ -34,7 +34,7 @@ var ( ) // Verify verifies a proof with given VerifyingKey and publicWitness -func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witness) error { +func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { nbPublicVars := len(vk.G1.K) if vk.CommitmentInfo.Is() { diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 9c163a4dc0..245ac5fbc6 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -23,8 +23,6 @@ import ( "github.com/consensys/gnark/constraint/bls12-377" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - bls12_377plonk "github.com/consensys/gnark/internal/backend/bls12-377/plonk" "bytes" @@ -102,8 +100,7 @@ func BenchmarkSetup(b *testing.B) { func BenchmarkProver(b *testing.B) { ccs, _solution, srs := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) + fullWitness, err := frontend.NewWitness(_solution, fr.Modulus()) if err != nil { b.Fatal(err) } @@ -115,7 +112,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverConfig{}) + _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness.Vector().(fr.Vector), backend.ProverConfig{}) if err != nil { b.Fatal(err) } @@ -124,13 +121,12 @@ func BenchmarkProver(b *testing.B) { func BenchmarkVerifier(b *testing.B) { ccs, _solution, srs := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) + fullWitness, err := frontend.NewWitness(_solution, fr.Modulus()) if err != nil { b.Fatal(err) } - publicWitness := bls12_377witness.Witness{} - _, err = publicWitness.FromAssignment(_solution, tVariable, true) + + publicWitness, err := frontend.NewWitness(_solution, fr.Modulus(), frontend.PublicOnly()) if err != nil { b.Fatal(err) } @@ -140,21 +136,20 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverConfig{}) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness.Vector().(fr.Vector), backend.ProverConfig{}) if err != nil { panic(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - _ = bls12_377plonk.Verify(proof, vk, publicWitness) + _ = bls12_377plonk.Verify(proof, vk, publicWitness.Vector().(fr.Vector)) } } func BenchmarkSerialization(b *testing.B) { ccs, _solution, srs := referenceCircuit() - fullWitness := bls12_377witness.Witness{} - _, err := fullWitness.FromAssignment(_solution, tVariable, false) + fullWitness, err := frontend.NewWitness(_solution, fr.Modulus()) if err != nil { b.Fatal(err) } @@ -164,7 +159,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverConfig{}) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness.Vector().(fr.Vector), backend.ProverConfig{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 3796066585..b29a962bd3 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -32,8 +32,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/fiat-shamir" @@ -61,7 +59,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { log := logger.Logger().With().Str("curve", spr.CurveID().String()).Int("nbConstraints", len(spr.Constraints)).Str("backend", "plonk").Logger() start := time.Now() diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index 74c3b02849..2cf320c732 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -29,8 +29,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/logger" @@ -40,7 +38,7 @@ var ( errWrongClaimedQuotient = errors.New("claimed quotient is not as expected") ) -func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witness) error { +func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { log := logger.Logger().With().Str("curve", "bls12_377").Str("backend", "plonk").Logger() start := time.Now() diff --git a/internal/backend/bls12-377/plonkfri/prove.go b/internal/backend/bls12-377/plonkfri/prove.go index 802de65953..f3c15d78f9 100644 --- a/internal/backend/bls12-377/plonkfri/prove.go +++ b/internal/backend/bls12-377/plonkfri/prove.go @@ -26,8 +26,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fri" @@ -69,7 +67,7 @@ type Proof struct { OpeningsId1Id2Id3mp [3]fri.OpeningProof } -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, opt backend.ProverConfig) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness fr.Vector, opt backend.ProverConfig) (*Proof, error) { var proof Proof diff --git a/internal/backend/bls12-377/plonkfri/verify.go b/internal/backend/bls12-377/plonkfri/verify.go index 4c3c04b4c3..b2d6630b07 100644 --- a/internal/backend/bls12-377/plonkfri/verify.go +++ b/internal/backend/bls12-377/plonkfri/verify.go @@ -23,14 +23,12 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fri" "math/big" - bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) var ErrInvalidAlgebraicRelation = errors.New("algebraic relation does not hold") -func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witness) error { +func Verify(proof *Proof, vk *VerifyingKey, publicWitness fr.Vector) error { // 0 - derive the challenges with Fiat Shamir hFunc := sha256.New() @@ -354,7 +352,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne } // completeQk returns ∑_{i". -// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. func (z *Element) Text(base int) string { if base < 2 || base > 36 { panic("invalid base") @@ -646,14 +647,6 @@ func (z *Element) Text(base int) string { } const maxUint16 = 65535 - if base == 10 { - var zzNeg Element - zzNeg.Neg(z) - zzNeg.fromMont() - if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { - return "-" + strconv.FormatUint(zzNeg[0], base) - } - } zz := z.Bits() return strconv.FormatUint(zz[0], base) } @@ -708,14 +701,14 @@ func (z *Element) SetBytes(e []byte) *Element { // slow path. // get a big int from our pool - vv := field.BigIntPool.Get() + vv := pool.BigInt.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - field.BigIntPool.Put(vv) + pool.BigInt.Put(vv) return z } @@ -752,17 +745,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := field.BigIntPool.Get() + vv := pool.BigInt.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - field.BigIntPool.Put(vv) + pool.BigInt.Put(vv) return z } @@ -806,7 +798,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := field.BigIntPool.Get() + vv := pool.BigInt.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -815,7 +807,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - field.BigIntPool.Put(vv) + pool.BigInt.Put(vv) return z, nil } @@ -855,7 +847,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := field.BigIntPool.Get() + vv := pool.BigInt.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -864,7 +856,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - field.BigIntPool.Put(vv) + pool.BigInt.Put(vv) return nil } diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index 892366a261..d6c74b751d 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -2084,20 +2084,14 @@ func TestElementJSON(t *testing.T) { encoded, err := json.Marshal(&s) assert.NoError(err) - // since our modulus is on 1 word, we may need to adjust "42" and "8000" values; + // we may need to adjust "42" and "8000" values for some moduli; see Text() method for more details. formatValue := func(v int64) string { - const maxUint16 = 65535 - var a, aNeg big.Int + var a big.Int a.SetInt64(v) a.Mod(&a, Modulus()) - aNeg.Neg(&a).Mod(&aNeg, Modulus()) - fmt.Println("aNeg", aNeg.Text(10)) - if aNeg.Uint64() != 0 && aNeg.Uint64() <= maxUint16 { - return "-" + aNeg.Text(10) - } return a.Text(10) } - expected := fmt.Sprintf("{\"A\":-1,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(42), formatValue(8000)) + expected := fmt.Sprintf("{\"A\":%s,\"B\":[0,0,%s],\"C\":null,\"D\":%s}", formatValue(-1), formatValue(42), formatValue(8000)) assert.Equal(expected, string(encoded)) // decode valid diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go new file mode 100644 index 0000000000..d71e1ac83c --- /dev/null +++ b/internal/tinyfield/vector.go @@ -0,0 +1,132 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package tinyfield + +import ( + "bytes" + "encoding/binary" + "io" + "strings" +) + +// Vector represents a slice of Element. +// +// It implements the following interfaces: +// - Stringer +// - io.WriterTo +// - io.ReaderFrom +// - encoding.BinaryMarshaler +// - encoding.BinaryUnmarshaler +// - sort.Interface +type Vector []Element + +// MarshalBinary implements encoding.BinaryMarshaler +func (vector Vector) MarshalBinary() (data []byte, err error) { + var buf bytes.Buffer + + if _, err = vector.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (vector *Vector) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + _, err := vector.ReadFrom(r) + return err +} + +// WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. +// Length of the vector is encoded as a uint32 on the first 4 bytes. +func (vector Vector) WriteTo(w io.Writer) (int64, error) { + // encode slice length + if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + return 0, err + } + + n := int64(4) + + var buf [Bytes]byte + for i := 0; i < len(vector); i++ { + BigEndian.PutElement(&buf, vector[i]) + m, err := w.Write(buf[:]) + n += int64(m) + if err != nil { + return n, err + } + } + return n, nil +} + +// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. +// Length of the vector must be encoded as a uint32 on the first 4 bytes. +func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { + + var buf [Bytes]byte + if read, err := io.ReadFull(r, buf[:4]); err != nil { + return int64(read), err + } + sliceLen := binary.BigEndian.Uint32(buf[:4]) + + n := int64(4) + (*vector) = make(Vector, sliceLen) + + for i := 0; i < int(sliceLen); i++ { + read, err := io.ReadFull(r, buf[:]) + n += int64(read) + if err != nil { + return n, err + } + (*vector)[i], err = BigEndian.Element(&buf) + if err != nil { + return n, err + } + } + + return n, nil +} + +// String implements fmt.Stringer interface +func (vector Vector) String() string { + var sbb strings.Builder + sbb.WriteByte('[') + for i := 0; i < len(vector); i++ { + sbb.WriteString(vector[i].String()) + if i != len(vector)-1 { + sbb.WriteByte(',') + } + } + sbb.WriteByte(']') + return sbb.String() +} + +// Len is the number of elements in the collection. +func (vector Vector) Len() int { + return len(vector) +} + +// Less reports whether the element with +// index i should sort before the element with index j. +func (vector Vector) Less(i, j int) bool { + return vector[i].Cmp(&vector[j]) == -1 +} + +// Swap swaps the elements with indexes i and j. +func (vector Vector) Swap(i, j int) { + vector[i], vector[j] = vector[j], vector[i] +} diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go new file mode 100644 index 0000000000..e1db416306 --- /dev/null +++ b/internal/tinyfield/vector_test.go @@ -0,0 +1,72 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package tinyfield + +import ( + "github.com/stretchr/testify/require" + "reflect" + "sort" + "testing" +) + +func TestVectorSort(t *testing.T) { + assert := require.New(t) + + v := make(Vector, 3) + v[0].SetUint64(2) + v[1].SetUint64(3) + v[2].SetUint64(1) + + sort.Sort(v) + + assert.Equal("[1,2,3]", v.String()) +} + +func TestVectorRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 3) + v1[0].SetUint64(2) + v1[1].SetUint64(3) + v1[2].SetUint64(1) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) +} + +func TestVectorEmptyRoundTrip(t *testing.T) { + assert := require.New(t) + + v1 := make(Vector, 0) + + b, err := v1.MarshalBinary() + assert.NoError(err) + + var v2 Vector + + err = v2.UnmarshalBinary(b) + assert.NoError(err) + + assert.True(reflect.DeepEqual(v1, v2)) +} diff --git a/internal/tinyfield/witness/witness.go b/internal/tinyfield/witness/witness.go deleted file mode 100644 index 331fc93858..0000000000 --- a/internal/tinyfield/witness/witness.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by gnark DO NOT EDIT - -package witness - -import ( - "encoding/binary" - "fmt" - "io" - "reflect" - "strings" - - "github.com/consensys/gnark/frontend/schema" - - fr "github.com/consensys/gnark/internal/tinyfield" -) - -type Witness []fr.Element - -// WriteTo encodes witness to writer (implements io.WriterTo) -func (witness *Witness) WriteTo(w io.Writer) (int64, error) { - // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(*witness))); err != nil { - return 0, err - } - - n := int64(4) - - var buf [fr.Bytes]byte - for i := 0; i < len(*witness); i++ { - buf = (*witness)[i].Bytes() - m, err := w.Write(buf[:]) - n += int64(m) - if err != nil { - return n, err - } - } - return n, nil -} - -func (witness *Witness) Len() int { - return len(*witness) -} - -func (witness *Witness) Type() reflect.Type { - return reflect.TypeOf(fr.Element{}) -} - -func (witness *Witness) ReadFrom(r io.Reader) (int64, error) { - - var buf [fr.Bytes]byte - if read, err := io.ReadFull(r, buf[:4]); err != nil { - return int64(read), err - } - sliceLen := binary.BigEndian.Uint32(buf[:4]) - - n := int64(4) - - if len(*witness) != int(sliceLen) { - *witness = make([]fr.Element, sliceLen) - } - - for i := 0; i < int(sliceLen); i++ { - read, err := io.ReadFull(r, buf[:]) - n += int64(read) - if err != nil { - return n, err - } - (*witness)[i].SetBytes(buf[:]) - } - - return n, nil -} - -// FromAssignment extracts the witness and its schema -func (witness *Witness) FromAssignment(assignment interface{}, leafType reflect.Type, publicOnly bool) (*schema.Schema, error) { - s, err := schema.Walk(assignment, leafType, nil) - if err != nil { - return nil, err - } - nbSecret, nbPublic := s.Secret, s.Public - - if publicOnly { - nbSecret = 0 - } - - if len(*witness) < (nbPublic + nbSecret) { - (*witness) = make(Witness, nbPublic+nbSecret) - } else { - (*witness) = (*witness)[:nbPublic+nbSecret] - } - - var i, j int // indexes for secret / public variables - i = nbPublic // offset - - collectHandler := func(f schema.LeafInfo, tInput reflect.Value) error { - if publicOnly && f.Visibility != schema.Public { - return nil - } - if tInput.IsNil() { - return fmt.Errorf("when parsing variable %s: missing assignment", f.FullName()) - } - v := tInput.Interface() - - if v == nil { - return fmt.Errorf("when parsing variable %s: missing assignment", f.FullName()) - } - - if !publicOnly && f.Visibility == schema.Secret { - if _, err := (*witness)[i].SetInterface(v); err != nil { - return fmt.Errorf("when parsing variable %s: %v", f.FullName(), err) - } - i++ - } else if f.Visibility == schema.Public { - if _, err := (*witness)[j].SetInterface(v); err != nil { - return fmt.Errorf("when parsing variable %s: %v", f.FullName(), err) - } - j++ - } - return nil - } - if _, err := schema.Walk(assignment, leafType, collectHandler); err != nil { - return nil, err - } - return schema.New(assignment, leafType) -} - -// ToAssignment sets to leaf values to witness underlying vector element values (in order) -// see witness.MarshalBinary protocol description -func (witness *Witness) ToAssignment(assignment interface{}, leafType reflect.Type, publicOnly bool) { - i := 0 - setAddr := leafType.Kind() == reflect.Ptr - setHandler := func(v schema.Visibility) schema.LeafHandler { - return func(f schema.LeafInfo, tInput reflect.Value) error { - if f.Visibility == v { - if setAddr { - tInput.Set(reflect.ValueOf((&(*witness)[i]))) - } else { - tInput.Set(reflect.ValueOf(((*witness)[i]))) - } - - i++ - } - return nil - } - } - _, _ = schema.Walk(assignment, leafType, setHandler(schema.Public)) - if publicOnly { - return - } - _, _ = schema.Walk(assignment, leafType, setHandler(schema.Secret)) - -} - -func (witness *Witness) String() string { - var sbb strings.Builder - sbb.WriteByte('[') - for i := 0; i < len(*witness); i++ { - sbb.WriteString((*witness)[i].String()) - sbb.WriteByte(',') - } - sbb.WriteByte(']') - return sbb.String() -} diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index 3f2cea50ff..21d86eeb47 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -21,13 +21,13 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/constraint" cs_bls12377 "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" - "github.com/consensys/gnark/internal/backend/bls12-377/witness" "github.com/consensys/gnark/std/algebra/sw_bls12377" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" @@ -69,13 +69,12 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, assignment.PreImage = preImage assignment.Hash = publicHash - var witness, publicWitness witness.Witness - _, err = witness.FromAssignment(&assignment, tVariable, false) + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) if err != nil { t.Fatal(err) } - _, err = publicWitness.FromAssignment(&assignment, tVariable, true) + publicWitness, err := witness.Public() if err != nil { t.Fatal(err) } @@ -87,7 +86,7 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, t.Fatal(err) } - _proof, err := groth16_bls12377.Prove(r1cs.(*cs_bls12377.R1CS), &pk, witness, backend.ProverConfig{}) + _proof, err := groth16_bls12377.Prove(r1cs.(*cs_bls12377.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) if err != nil { t.Fatal(err) } @@ -96,7 +95,7 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, proof.Krs = _proof.Krs // before returning verifies that the proof passes on bls12377 - if err := groth16_bls12377.Verify(proof, vk, publicWitness); err != nil { + if err := groth16_bls12377.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { t.Fatal(err) } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 6a209c30c6..0287d84815 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -21,13 +21,13 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/constraint" cs_bls24315 "github.com/consensys/gnark/constraint/bls24-315" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" - "github.com/consensys/gnark/internal/backend/bls24-315/witness" "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" @@ -70,14 +70,12 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, assignment.PreImage = preImage assignment.Hash = publicHash - var witness, publicWitness witness.Witness - - _, err = witness.FromAssignment(&assignment, tVariable, false) + witness, err := frontend.NewWitness(&assignment, ecc.BLS24_315.ScalarField()) if err != nil { t.Fatal(err) } - _, err = publicWitness.FromAssignment(&assignment, tVariable, true) + publicWitness, err := witness.Public() if err != nil { t.Fatal(err) } @@ -89,7 +87,7 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, t.Fatal(err) } - _proof, err := groth16_bls24315.Prove(r1cs.(*cs_bls24315.R1CS), &pk, witness, backend.ProverConfig{}) + _proof, err := groth16_bls24315.Prove(r1cs.(*cs_bls24315.R1CS), &pk, witness.Vector().(fr.Vector), backend.ProverConfig{}) if err != nil { t.Fatal(err) } @@ -98,7 +96,7 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, proof.Krs = _proof.Krs // before returning verifies that the proof passes on bls24315 - if err := groth16_bls24315.Verify(proof, vk, publicWitness); err != nil { + if err := groth16_bls24315.Verify(proof, vk, publicWitness.Vector().(fr.Vector)); err != nil { t.Fatal(err) } } diff --git a/std/hints_test.go b/std/hints_test.go index ab824942f3..f88779b231 100644 --- a/std/hints_test.go +++ b/std/hints_test.go @@ -1,7 +1,6 @@ package std import ( - "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/constraint" ) @@ -17,5 +16,5 @@ func ExampleRegisterHints() { RegisterHints() // then --> - _ = ccs.IsSolved(&witness.Witness{}) + _ = ccs.IsSolved(nil) } diff --git a/test/assert.go b/test/assert.go index 82648f7c0b..c104f6c646 100644 --- a/test/assert.go +++ b/test/assert.go @@ -34,6 +34,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/frontend/schema" "github.com/stretchr/testify/require" ) @@ -66,7 +67,7 @@ func (assert *Assert) Run(fn func(assert *Assert), descs ...string) { desc := strings.Join(descs, "/") assert.t.Run(desc, func(t *testing.T) { // TODO(ivokub): access to compiled cache is not synchronized -- running - // the tests in parallel will result in undetermined behaviour. A better + // the tests in parallel will result in undetermined behavior. A better // approach would be to synchronize compiled and run the tests in // parallel for a potential speedup. assert := &Assert{t, require.New(t), assert.compiled} @@ -105,22 +106,26 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment // do a round trip marshalling test assert.Run(func(assert *Assert) { assert.t.Parallel() - assert.t.Skip("skipping json") - assert.marshalWitness(validWitness, curve, JSON) - }, curve.String(), "marshal/json") - assert.Run(func(assert *Assert) { - assert.t.Parallel() - assert.marshalWitness(validWitness, curve, Binary) + assert.marshalWitness(validWitness, curve, false) }, curve.String(), "marshal/binary") assert.Run(func(assert *Assert) { assert.t.Parallel() - assert.t.Skip("skipping json") - assert.marshalWitness(validPublicWitness, curve, JSON, frontend.PublicOnly()) - }, curve.String(), "marshal-public/json") - assert.Run(func(assert *Assert) { - assert.t.Parallel() - assert.marshalWitness(validPublicWitness, curve, Binary, frontend.PublicOnly()) + assert.marshalWitness(validPublicWitness, curve, true) }, curve.String(), "marshal-public/binary") + + if !testing.Short() { + assert.Run(func(assert *Assert) { + assert.t.Parallel() + s := lazySchema(circuit)() + assert.marshalWitnessJSON(validWitness, s, curve, false) + }, curve.String(), "marshal/json") + assert.Run(func(assert *Assert) { + assert.t.Parallel() + s := lazySchema(circuit)() + assert.marshalWitnessJSON(validWitness, s, curve, true) + }, curve.String(), "marshal-public/json") + + } } for _, b := range opt.backends { @@ -128,7 +133,7 @@ func (assert *Assert) ProverSucceeded(circuit frontend.Circuit, validAssignment b := b assert.Run(func(assert *Assert) { - checkError := func(err error) { assert.checkError(err, b, curve, validWitness) } + checkError := func(err error) { assert.checkError(err, b, curve, validWitness, lazySchema(circuit)) } // 1- compile the circuit ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) @@ -216,8 +221,8 @@ func (assert *Assert) ProverFailed(circuit frontend.Circuit, invalidAssignment f b := b assert.Run(func(assert *Assert) { - checkError := func(err error) { assert.checkError(err, b, curve, invalidWitness) } - mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness) } + checkError := func(err error) { assert.checkError(err, b, curve, invalidWitness, lazySchema(circuit)) } + mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness, lazySchema(circuit)) } // 1- compile the circuit ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) @@ -289,7 +294,7 @@ func (assert *Assert) solvingSucceeded(circuit frontend.Circuit, validAssignment validWitness, err := frontend.NewWitness(validAssignment, curve.ScalarField()) assert.NoError(err, "can't parse valid assignment") - checkError := func(err error) { assert.checkError(err, b, curve, validWitness) } + checkError := func(err error) { assert.checkError(err, b, curve, validWitness, lazySchema(circuit)) } // 1- compile the circuit ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) @@ -318,13 +323,24 @@ func (assert *Assert) SolvingFailed(circuit frontend.Circuit, invalidWitness fro } } +func lazySchema(circuit frontend.Circuit) func() *schema.Schema { + return func() *schema.Schema { + // we only parse the schema if we need to display the witness in json. + s, err := schema.New(circuit, tVariable) + if err != nil { + panic("couldn't parse schema from circuit: " + err.Error()) + } + return s + } +} + func (assert *Assert) solvingFailed(circuit frontend.Circuit, invalidAssignment frontend.Circuit, b backend.ID, curve ecc.ID, opt *testingConfig) { // parse assignment invalidWitness, err := frontend.NewWitness(invalidAssignment, curve.ScalarField()) assert.NoError(err, "can't parse invalid assignment") - checkError := func(err error) { assert.checkError(err, b, curve, invalidWitness) } - mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness) } + checkError := func(err error) { assert.checkError(err, b, curve, invalidWitness, lazySchema(circuit)) } + mustError := func(err error) { assert.mustError(err, b, curve, invalidWitness, lazySchema(circuit)) } // 1- compile the circuit ccs, err := assert.compile(circuit, curve, b, opt.compileOpts) @@ -482,12 +498,12 @@ func (assert *Assert) options(opts ...TestingOption) testingConfig { } // ensure the error is set, else fails the test -func (assert *Assert) mustError(err error, backendID backend.ID, curve ecc.ID, witness *witness.Witness) { +func (assert *Assert) mustError(err error, backendID backend.ID, curve ecc.ID, w witness.Witness, lazyS func() *schema.Schema) { if err != nil { return } var json string - bjson, err := witness.MarshalJSON() + bjson, err := w.ToJSON(lazyS()) if err != nil { json = err.Error() } else { @@ -499,7 +515,7 @@ func (assert *Assert) mustError(err error, backendID backend.ID, curve ecc.ID, w } // ensure the error is nil, else fails the test -func (assert *Assert) checkError(err error, backendID backend.ID, curve ecc.ID, witness *witness.Witness) { +func (assert *Assert) checkError(err error, backendID backend.ID, curve ecc.ID, w witness.Witness, lazyS func() *schema.Schema) { if err == nil { return } @@ -507,7 +523,7 @@ func (assert *Assert) checkError(err error, backendID backend.ID, curve ecc.ID, var json string e := fmt.Errorf("%s(%s): %w", backendID.String(), curve.String(), err) - bjson, err := witness.MarshalJSON() + bjson, err := w.ToJSON(lazyS()) if err != nil { json = err.Error() } else { @@ -518,46 +534,44 @@ func (assert *Assert) checkError(err error, backendID backend.ID, curve ecc.ID, assert.FailNow(e.Error()) } -type marshaller uint8 +func (assert *Assert) marshalWitness(w witness.Witness, curveID ecc.ID, publicOnly bool) { + // serialize the vector to binary + var err error + if publicOnly { + w, err = w.Public() + assert.NoError(err) + } + data, err := w.MarshalBinary() + assert.NoError(err) -const ( - JSON marshaller = iota - Binary -) + // re-read + witness, err := witness.New(curveID.ScalarField()) + assert.NoError(err) + err = witness.UnmarshalBinary(data) + assert.NoError(err) -func (m marshaller) String() string { - if m == JSON { - return "JSON" - } - return "Binary" + witnessMatch := reflect.DeepEqual(w, witness) + + assert.True(witnessMatch, "round trip marshaling failed") } -func (assert *Assert) marshalWitness(w *witness.Witness, curveID ecc.ID, m marshaller, opts ...frontend.WitnessOption) { - marshal := w.MarshalBinary - if m == JSON { - marshal = w.MarshalJSON +func (assert *Assert) marshalWitnessJSON(w witness.Witness, s *schema.Schema, curveID ecc.ID, publicOnly bool) { + var err error + if publicOnly { + w, err = w.Public() + assert.NoError(err) } // serialize the vector to binary - data, err := marshal() + data, err := w.ToJSON(s) assert.NoError(err) // re-read - witness := witness.Witness{CurveID: curveID, Schema: w.Schema} - unmarshal := witness.UnmarshalBinary - if m == JSON { - unmarshal = witness.UnmarshalJSON - } - err = unmarshal(data) + witness, err := witness.New(curveID.ScalarField()) + assert.NoError(err) + err = witness.FromJSON(s, data) assert.NoError(err) - witnessMatch := reflect.DeepEqual(*w, witness) - - if !witnessMatch { - assert.Log("original json", string(data)) - // assert.Log("original vector", w.Vector) - // assert.Log("reconstructed vector", witness.Vector) - } - - assert.True(witnessMatch, m.String()+" round trip marshaling failed") + witnessMatch := reflect.DeepEqual(w, witness) + assert.True(witnessMatch, "round trip marshaling failed") } diff --git a/test/engine.go b/test/engine.go index 2375922ca5..58311214d2 100644 --- a/test/engine.go +++ b/test/engine.go @@ -30,7 +30,7 @@ import ( "github.com/consensys/gnark/logger" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/field" + "github.com/consensys/gnark-crypto/field/pool" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" @@ -158,13 +158,13 @@ func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend } func (e *engine) MulAcc(a, b, c frontend.Variable) frontend.Variable { - bc := field.BigIntPool.Get() + bc := pool.BigInt.Get() bc.Mul(e.toBigInt(b), e.toBigInt(c)) _a := e.toBigInt(a) _a.Add(_a, bc).Mod(_a, e.modulus()) - field.BigIntPool.Put(bc) + pool.BigInt.Put(bc) return _a }