Skip to content

Commit

Permalink
perf: emulated equality assertion (#1064)
Browse files Browse the repository at this point in the history
* perf: sum over limbs for IsZero

* perf: use mulmod for equality assertion

* fix: handle edge case in mulcheck with zero limbs

* refactor: do not use temp var

* feat: remove AssertLimbsEquality

* feat: implement shortOne() method

* chore: remove unused private methods

* docs: equality assertion

* fix: deduce maximum degree from all mulcheck inputs

* test: enable all mul tests

* chore: stats

* refactor: generic impl for assert/mul

* fix: mul pre cond overflow computation

* docs: comments

* chore: stats
  • Loading branch information
ivokub committed Mar 7, 2024
1 parent 22d2c33 commit 3dedc99
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 470 deletions.
Binary file modified internal/stats/latest.stats
Binary file not shown.
68 changes: 8 additions & 60 deletions std/math/emulated/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ The complexity of native limb-wise multiplication is k^2. This translates
directly to the complexity in the number of constraints in the constraint
system.
For multiplication, we would instead use polynomial representation of the elements:
For multiplication, we would instead use polynomial representation of the
elements:
x = ∑_{i=0}^k x_i 2^{w i}
y = ∑_{i=0}^k y_i 2^{w i}.
Expand Down Expand Up @@ -140,68 +141,15 @@ larger than every limb of b. The subtraction is performed as
# Equality checking
The package provides two ways to check equality -- limb-wise equality check and
checking equality by value.
Equality checking is performed using modular multiplication. To check that a, b
are equal modulo r, we compute
In the limb-wise equality check we check that the integer values of the elements
x and y are equal. We have to carry the excess using bit decomposition (which
makes the computation fairly inefficient). To reduce the number of bit
decompositions, we instead carry over the excess of the difference of the limbs
instead. As we take the difference, then similarly as computing the padding in
subtraction algorithm, we need to add padding to the limbs before subtracting
limb-wise to avoid underflows. However, the padding in this case is slightly
different -- we do not need the padding to be divisible by the modulus, but
instead need that the limb padding is larger than the limb which is being
subtracted.
diff = b-a,
Lets look at the algorithm itself. We assume that the overflow f of x is larger
than y. If overflow of y is larger, then we can just swap the arguments and
apply the same argumentation. Let
and enforce modular multiplication check using the techniques for modular
multiplication:
maxValue = 1 << (k+f), // padding for limbs
maxValueShift = 1 << f. // carry part of the padding
For every limb we compute the difference as
diff_0 = maxValue+x_0-y_0,
diff_i = maxValue+carry_i+x_i-y_i-maxValueShift.
We check that the normal part of the difference is zero and carry the rest over
to next limb:
diff_i[0:k] == 0,
carry_{i+1} = diff_i[k:k+f+1] // we also carry over the padding bit.
Finally, after we have compared all the limbs, we still need to check that the
final carry corresponds to the padding. We add final check:
carry_k == maxValueShift.
We can further optimise the limb-wise equality check by first regrouping the
limbs. The idea is to group several limbs so that the result would still fit
into the scalar field. If
x = ∑_{i=0}^k x_i 2^{w i},
then we can instead take w' divisible by w such that
x = ∑_{i=0}^(k/(w'/w)) x'_i 2^{w' i},
where
x'_j = ∑_{i=0}^(w'/w) x_{j*w'/w+i} 2^{w i}.
For element value equality check, we check that two elements x and y are equal
modulo r and for that we need to show that r divides x-y. As mentioned in the
subtraction section, we add sufficient padding such that x-y does not underflow
and its integer value is always larger than 0. We use hint function to compute z
such that
x-y = z*r,
compute z*r and use limbwise equality checking to show that
x-y == z*r.
diff * 1 = 0 + k * r.
# Bitwidth enforcement
Expand Down
35 changes: 2 additions & 33 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,11 @@ import (

const testCurve = ecc.BN254

type AssertLimbEqualityCircuit[T FieldParams] struct {
A, B Element[T]
}

func (c *AssertLimbEqualityCircuit[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return err
}
f.AssertLimbsEquality(&c.A, &c.B)
return nil
}

func testName[T FieldParams]() string {
var fp T
return fmt.Sprintf("%s/limb=%d", reflect.TypeOf(fp).Name(), fp.BitsPerLimb())
}

func TestAssertLimbEqualityNoOverflow(t *testing.T) {
testAssertLimbEqualityNoOverflow[Goldilocks](t)
testAssertLimbEqualityNoOverflow[Secp256k1Fp](t)
testAssertLimbEqualityNoOverflow[BN254Fp](t)
}

func testAssertLimbEqualityNoOverflow[T FieldParams](t *testing.T) {
var fp T
assert := test.NewAssert(t)
assert.Run(func(assert *test.Assert) {
var circuit, witness AssertLimbEqualityCircuit[T]
val, _ := rand.Int(rand.Reader, fp.Modulus())
witness.A = ValueOf[T](val)
witness.B = ValueOf[T](val)
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness))
}, testName[T]())
}

// TODO: add also cases which should fail

type AssertIsLessEqualThanCircuit[T FieldParams] struct {
Expand Down Expand Up @@ -184,9 +153,9 @@ func (c *MulNoOverflowCircuit[T]) Define(api frontend.API) error {
}

func TestMulCircuitNoOverflow(t *testing.T) {
// testMulCircuitNoOverflow[Goldilocks](t)
testMulCircuitNoOverflow[Goldilocks](t)
testMulCircuitNoOverflow[Secp256k1Fp](t)
// testMulCircuitNoOverflow[BN254Fp](t)
testMulCircuitNoOverflow[BN254Fp](t)
}

func testMulCircuitNoOverflow[T FieldParams](t *testing.T) {
Expand Down
71 changes: 18 additions & 53 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ type Field[T FieldParams] struct {
maxOfOnce sync.Once

// constants for often used elements n, 0 and 1. Allocated only once
nConstOnce sync.Once
nConst *Element[T]
nprevConstOnce sync.Once
nprevConst *Element[T]
zeroConstOnce sync.Once
zeroConst *Element[T]
oneConstOnce sync.Once
oneConst *Element[T]
nConstOnce sync.Once
nConst *Element[T]
nprevConstOnce sync.Once
nprevConst *Element[T]
zeroConstOnce sync.Once
zeroConst *Element[T]
oneConstOnce sync.Once
oneConst *Element[T]
shortOneConstOnce sync.Once
shortOneConst *Element[T]

log zerolog.Logger

Expand Down Expand Up @@ -146,6 +148,14 @@ func (f *Field[T]) One() *Element[T] {
return f.oneConst
}

// shortOne returns one as a constant stored in a single limb.
func (f *Field[T]) shortOne() *Element[T] {
f.shortOneConstOnce.Do(func() {
f.shortOneConst = f.newInternalElement([]frontend.Variable{1}, 0)
})
return f.shortOneConst
}

// Modulus returns the modulus of the emulated ring as a constant.
func (f *Field[T]) Modulus() *Element[T] {
f.nConstOnce.Do(func() {
Expand Down Expand Up @@ -248,51 +258,6 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) {
return res, true
}

// compact returns parameters which allow for most optimal regrouping of
// limbs. In regrouping the limbs, we encode multiple existing limbs as a linear
// combination in a single new limb.
// compact returns a and b minimal (in number of limbs) representation that fits in the snark field
func (f *Field[T]) compact(a, b *Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) {
// omit width reduction as is done in the calling method already
maxOverflow := max(a.overflow, b.overflow)
// subtract one bit as can not potentially use all bits of Fr and one bit as
// grouping may overflow
maxNbBits := uint(f.api.Compiler().FieldBitLen()) - 2 - maxOverflow
groupSize := maxNbBits / f.fParams.BitsPerLimb()
if groupSize == 0 {
// no space for compact
return a.Limbs, b.Limbs, f.fParams.BitsPerLimb()
}

bitsPerLimb = f.fParams.BitsPerLimb() * groupSize

ac = f.compactLimbs(a, groupSize, bitsPerLimb)
bc = f.compactLimbs(b, groupSize, bitsPerLimb)
return
}

// compactLimbs perform the regrouping of limbs between old and new parameters.
func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []frontend.Variable {
if f.fParams.BitsPerLimb() == bitsPerLimb {
return e.Limbs
}
nbLimbs := (uint(len(e.Limbs)) + groupSize - 1) / groupSize
r := make([]frontend.Variable, nbLimbs)
coeffs := make([]*big.Int, groupSize)
one := big.NewInt(1)
for i := range coeffs {
coeffs[i] = new(big.Int)
coeffs[i].Lsh(one, f.fParams.BitsPerLimb()*uint(i))
}
for i := uint(0); i < nbLimbs; i++ {
r[i] = uint(0)
for j := uint(0); j < groupSize && i*groupSize+j < uint(len(e.Limbs)); j++ {
r[i] = f.api.Add(r[i], f.api.Mul(coeffs[j], e.Limbs[i*groupSize+j]))
}
}
return r
}

// maxOverflow returns the maximal possible overflow for the element. If the
// overflow of the next operation exceeds the value returned by this method,
// then the limbs may overflow the native field.
Expand Down
123 changes: 24 additions & 99 deletions std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,93 +2,10 @@ package emulated

import (
"fmt"
"math/big"

"github.com/consensys/gnark/frontend"
)

// assertLimbsEqualitySlow is the main routine in the package. It asserts that the
// two slices of limbs represent the same integer value. This is also the most
// costly operation in the package as it does bit decomposition of the limbs.
func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) {

nbLimbs := max(len(l), len(r))
maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits)
maxValueShift := new(big.Int).Lsh(big.NewInt(1), nbCarryBits)

var carry frontend.Variable = 0
for i := 0; i < nbLimbs; i++ {
diff := api.Add(maxValue, carry)
if i < len(l) {
diff = api.Add(diff, l[i])
}
if i < len(r) {
diff = api.Sub(diff, r[i])
}
if i > 0 {
diff = api.Sub(diff, maxValueShift)
}

// carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1]
// we know that diff[:nbBits] are 0 bits, but still need to constrain them.
// to do both; we do a "clean" right shift and only need to boolean constrain the carry part
carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1))
}
api.AssertIsEqual(carry, maxValueShift)
}

func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable {
// if v is a constant, work with the big int value.
if c, ok := f.api.Compiler().ConstantValue(v); ok {
bits := make([]frontend.Variable, endDigit-startDigit)
for i := 0; i < len(bits); i++ {
bits[i] = c.Bit(i + startDigit)
}
return bits
}
shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v)
if err != nil {
panic(fmt.Sprintf("right shift: %v", err))
}
f.checker.Check(shifted[0], endDigit-startDigit)
shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit))
composed := f.api.Mul(shifted[0], shift)
f.api.AssertIsEqual(composed, v)
return shifted[0]
}

// AssertLimbsEquality asserts that the limbs represent a same integer value.
// This method does not ensure that the values are equal modulo the field order.
// For strict equality, use AssertIsEqual.
func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
f.enforceWidthConditional(a)
f.enforceWidthConditional(b)
ba, aConst := f.constantValue(a)
bb, bConst := f.constantValue(b)
if aConst && bConst {
ba.Mod(ba, f.fParams.Modulus())
bb.Mod(bb, f.fParams.Modulus())
if ba.Cmp(bb) != 0 {
panic(fmt.Errorf("constant values are different: %s != %s", ba.String(), bb.String()))
}
return
}

// first, we check if we can compact a and b; they could be using 8 limbs of 32bits
// but with our snark field, we could express them in 2 limbs of 128bits, which would make bit decomposition
// and limbs equality in-circuit (way) cheaper
ca, cb, bitsPerLimb := f.compact(a, b)

// slow path -- the overflows are different. Need to compare with carries.
// TODO: we previously assumed that one side was "larger" than the other
// side, but I think this assumption is not valid anymore
if a.overflow > b.overflow {
f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow)
} else {
f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow)
}
}

// enforceWidth enforces the width of the limbs. When modWidth is true, then the
// limbs are asserted to be the width of the modulus (highest limb may be less
// than full limb width). Otherwise, every limb is assumed to have same width
Expand Down Expand Up @@ -129,19 +46,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) {
}

diff := f.Sub(b, a)

// we compute k such that diff / p == k
// so essentially, we say "I know an element k such that k*p == diff"
// hence, diff == 0 mod p
p := f.Modulus()
k, err := f.computeQuoHint(diff)
if err != nil {
panic(fmt.Sprintf("hint error: %v", err))
}

kp := f.reduceAndOp(f.mul, f.mulPreCond, k, p)

f.AssertLimbsEquality(diff, kp)
f.checkZero(diff)
}

// AssertIsLessOrEqual ensures that e is less or equal than a. For proper
Expand Down Expand Up @@ -196,11 +101,31 @@ func (f *Field[T]) AssertIsInRange(a *Element[T]) {
func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable {
ca := f.Reduce(a)
f.AssertIsInRange(ca)
res := f.api.IsZero(ca.Limbs[0])
// we use two approaches for checking if the element is exactly zero. The
// first approach is to check that every limb individually is zero. The
// second approach is to check if the sum of all limbs is zero. Usually, we
// cannot use this approach as we could have false positive due to overflow
// in the native field. However, as the widths of the limbs are restricted,
// then we can ensure in most cases that no overflows happen.

// as ca is already reduced, then every limb overflow is already 0. Only
// every addition adds a bit to the overflow
totalOverflow := len(ca.Limbs) - 1
if totalOverflow < int(f.maxOverflow()) {
// the sums of limbs would overflow the native field. Use the first
// approach instead.
res := f.api.IsZero(ca.Limbs[0])
for i := 1; i < len(ca.Limbs); i++ {
res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i]))
}
return res
}
// default case, limbs sum does not overflow the native field
limbSum := ca.Limbs[0]
for i := 1; i < len(ca.Limbs); i++ {
res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i]))
limbSum = f.api.Add(limbSum, ca.Limbs[i])
}
return res
return f.api.IsZero(limbSum)
}

// // Cmp returns:
Expand Down

0 comments on commit 3dedc99

Please sign in to comment.