diff --git a/backend/groth16/groth16_test.go b/backend/groth16/groth16_test.go index 24fca03aed..027dc388d2 100644 --- a/backend/groth16/groth16_test.go +++ b/backend/groth16/groth16_test.go @@ -1,17 +1,56 @@ package groth16_test import ( + "fmt" "math/big" "testing" "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/test" ) +func TestCustomHashToField(t *testing.T) { + assert := test.NewAssert(t) + assignment := &commitmentCircuit{X: 1} + for _, curve := range getCurves() { + assert.Run(func(assert *test.Assert) { + ccs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &commitmentCircuit{}) + assert.NoError(err) + pk, vk, err := groth16.Setup(ccs) + assert.NoError(err) + witness, err := frontend.NewWitness(assignment, curve.ScalarField()) + assert.NoError(err) + assert.Run(func(assert *test.Assert) { + proof, err := groth16.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{})) + assert.NoError(err) + pubWitness, err := witness.Public() + assert.NoError(err) + err = groth16.Verify(proof, vk, pubWitness, backend.WithVerifierHashToFieldFunction(constantHash{})) + assert.NoError(err) + }, "custom success") + assert.Run(func(assert *test.Assert) { + proof, err := groth16.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{})) + assert.NoError(err) + pubWitness, err := witness.Public() + assert.NoError(err) + err = groth16.Verify(proof, vk, pubWitness) + assert.Error(err) + }, "prover_only") + assert.Run(func(assert *test.Assert) { + proof, err := groth16.Prove(ccs, pk, witness) + assert.Error(err) + _ = proof + }, "verifier_only") + }, curve.String()) + } +} + //--------------------// // benches // //--------------------// @@ -116,6 +155,27 @@ func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circu return r1cs, &good } +type commitmentCircuit struct { + X frontend.Variable +} + +func (c *commitmentCircuit) Define(api frontend.API) error { + cmt, err := api.(frontend.Committer).Commit(c.X) + if err != nil { + return fmt.Errorf("commit: %w", err) + } + api.AssertIsEqual(cmt, "0xaabbcc") + return nil +} + +type constantHash struct{} + +func (h constantHash) Write(p []byte) (n int, err error) { return len(p), nil } +func (h constantHash) Sum(b []byte) []byte { return []byte{0xaa, 0xbb, 0xcc} } +func (h constantHash) Reset() {} +func (h constantHash) Size() int { return 3 } +func (h constantHash) BlockSize() int { return 32 } + func getCurves() []ecc.ID { if testing.Short() { return []ecc.ID{ecc.BN254}