forked from Consensys/gnark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
field.go
318 lines (284 loc) · 9.75 KB
/
field.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
package emulated
import (
"fmt"
"math/big"
"sync"
"github.com/aakash4dev/gnark2/frontend"
"github.com/aakash4dev/gnark2/internal/kvstore"
"github.com/aakash4dev/gnark2/internal/utils"
"github.com/aakash4dev/gnark2/logger"
"github.com/aakash4dev/gnark2/std/rangecheck"
"github.com/rs/zerolog"
"golang.org/x/exp/constraints"
)
// Field holds the configuration for non-native field operations. The field
// parameters (modulus, number of limbs) is given by [FieldParams] type
// parameter. If [FieldParams.IsPrime] is true, then allows inverse and division
// operations.
type Field[T FieldParams] struct {
// api is the native API
api frontend.API
// f carries the ring parameters
fParams T
// maxOf is the maximum overflow before the element must be reduced.
maxOf uint
maxOfOnce sync.Once
// constants for often used elements n, 0 and 1. Allocated only once
nConstOnce sync.Once
nConst *Element[T]
nprevConstOnce sync.Once
nprevConst *Element[T]
zeroConstOnce sync.Once
zeroConst *Element[T]
oneConstOnce sync.Once
oneConst *Element[T]
log zerolog.Logger
constrainedLimbs map[uint64]struct{}
checker frontend.Rangechecker
mulChecks []mulCheck[T]
}
type ctxKey[T FieldParams] struct{}
// NewField returns an object to be used in-circuit to perform emulated
// arithmetic over the field defined by type parameter [FieldParams]. The
// operations on this type are defined on [Element]. There is also another type
// [FieldAPI] implementing [frontend.API] which can be used in place of native
// API for existing circuits.
//
// This is an experimental feature and performing emulated arithmetic in-circuit
// is extremly costly. See package doc for more info.
func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
if storer, ok := native.(kvstore.Store); ok {
ff := storer.GetKeyValue(ctxKey[T]{})
if ff, ok := ff.(*Field[T]); ok {
return ff, nil
}
}
f := &Field[T]{
api: native,
log: logger.Logger(),
constrainedLimbs: make(map[uint64]struct{}),
checker: rangecheck.New(native),
}
// ensure prime is correctly set
if f.fParams.IsPrime() {
if !f.fParams.Modulus().ProbablyPrime(20) {
return nil, fmt.Errorf("invalid parametrization: modulus is not prime")
}
}
if f.fParams.BitsPerLimb() < 3 {
// even three is way too small, but it should probably work.
return nil, fmt.Errorf("nbBits must be at least 3")
}
if f.fParams.Modulus().Cmp(big.NewInt(1)) < 1 {
return nil, fmt.Errorf("n must be at least 2")
}
nbLimbs := (uint(f.fParams.Modulus().BitLen()) + f.fParams.BitsPerLimb() - 1) / f.fParams.BitsPerLimb()
if nbLimbs != f.fParams.NbLimbs() {
return nil, fmt.Errorf("nbLimbs mismatch got %d expected %d", f.fParams.NbLimbs(), nbLimbs)
}
if f.api == nil {
return f, fmt.Errorf("missing api")
}
if uint(f.api.Compiler().FieldBitLen()) < 2*f.fParams.BitsPerLimb()+1 {
return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb())
}
native.Compiler().Defer(f.performMulChecks)
if storer, ok := native.(kvstore.Store); ok {
storer.SetKeyValue(ctxKey[T]{}, f)
}
return f, nil
}
// NewElement builds a new Element[T] from input v.
// - if v is a Element[T] or *Element[T] it clones it
// - if v is a constant this is equivalent to calling emulated.ValueOf[T]
// - if this methods interprets v as being the limbs (frontend.Variable or []frontend.Variable),
// it constructs a new Element[T] with v as limbs and constraints the limbs to the parameters
// of the Field[T].
func (f *Field[T]) NewElement(v interface{}) *Element[T] {
if e, ok := v.(Element[T]); ok {
return e.copy()
}
if e, ok := v.(*Element[T]); ok {
return e.copy()
}
if frontend.IsCanonical(v) {
return f.packLimbs([]frontend.Variable{v}, true)
}
if e, ok := v.([]frontend.Variable); ok {
return f.packLimbs(e, true)
}
c := ValueOf[T](v)
return &c
}
// Zero returns zero as a constant.
func (f *Field[T]) Zero() *Element[T] {
f.zeroConstOnce.Do(func() {
f.zeroConst = newConstElement[T](0)
})
return f.zeroConst
}
// One returns one as a constant.
func (f *Field[T]) One() *Element[T] {
f.oneConstOnce.Do(func() {
f.oneConst = newConstElement[T](1)
})
return f.oneConst
}
// Modulus returns the modulus of the emulated ring as a constant.
func (f *Field[T]) Modulus() *Element[T] {
f.nConstOnce.Do(func() {
f.nConst = newConstElement[T](f.fParams.Modulus())
})
return f.nConst
}
// modulusPrev returns modulus-1 as a constant.
func (f *Field[T]) modulusPrev() *Element[T] {
f.nprevConstOnce.Do(func() {
f.nprevConst = newConstElement[T](new(big.Int).Sub(f.fParams.Modulus(), big.NewInt(1)))
})
return f.nprevConst
}
// packLimbs returns an element from the given limbs.
// If strict is true, the most significant limb will be constrained to have width of the most
// significant limb of the modulus, which may have less bits than the other limbs. In which case,
// less constraints will be generated.
// If strict is false, each limbs is constrained to have width as defined by field parameter.
func (f *Field[T]) packLimbs(limbs []frontend.Variable, strict bool) *Element[T] {
e := f.newInternalElement(limbs, 0)
f.enforceWidth(e, strict)
return e
}
func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) {
if a == nil {
// for some reason called on nil
return false
}
if a.internal {
// internal elements are already constrained in the method which returned it
return false
}
if _, isConst := f.constantValue(a); isConst {
// enforce constant element limbs not to be large.
for i := range a.Limbs {
val := utils.FromInterface(a.Limbs[i])
if val.BitLen() > int(f.fParams.BitsPerLimb()) {
panic("constant element limb wider than emulated parameter")
}
}
// constant values are constant
return false
}
for i := range a.Limbs {
if !frontend.IsCanonical(a.Limbs[i]) {
// this is not a canonical variable, nor a constant. This may happen
// when some limbs are constant and some variables. Or if we are
// running in a test engine. In either case, we must check that if
// this limb is a [*big.Int] that its bitwidth is less than the
// NbBits.
val := utils.FromInterface(a.Limbs[i])
if val.BitLen() > int(f.fParams.BitsPerLimb()) {
panic("non-canonical integer limb wider than emulated parameter")
}
continue
}
if vv, ok := a.Limbs[i].(interface{ HashCode() uint64 }); ok {
// okay, this is a canonical variable and it has a hashcode. We use
// it to see if the limb is already constrained.
h := vv.HashCode()
if _, ok := f.constrainedLimbs[h]; !ok {
// we found a limb which hasn't yet been constrained. This means
// that we should enforce width for the whole element. But we
// still iterate over all limbs just to mark them in the table.
didConstrain = true
f.constrainedLimbs[h] = struct{}{}
}
} else {
// we have no way of knowing if the limb has been constrained. To be
// on the safe side constrain the whole element again.
didConstrain = true
}
}
if didConstrain {
f.enforceWidth(a, true)
}
return
}
func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) {
var ok bool
constLimbs := make([]*big.Int, len(v.Limbs))
for i, l := range v.Limbs {
// for each limb we get it's constant value if we can, or fail.
if constLimbs[i], ok = f.api.ConstantValue(l); !ok {
return nil, false
}
}
res := new(big.Int)
if err := recompose(constLimbs, f.fParams.BitsPerLimb(), res); err != nil {
f.log.Error().Err(err).Msg("recomposing constant")
return nil, false
}
return res, true
}
// compact returns parameters which allow for most optimal regrouping of
// limbs. In regrouping the limbs, we encode multiple existing limbs as a linear
// combination in a single new limb.
// compact returns a and b minimal (in number of limbs) representation that fits in the snark field
func (f *Field[T]) compact(a, b *Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) {
// omit width reduction as is done in the calling method already
maxOverflow := max(a.overflow, b.overflow)
// subtract one bit as can not potentially use all bits of Fr and one bit as
// grouping may overflow
maxNbBits := uint(f.api.Compiler().FieldBitLen()) - 2 - maxOverflow
groupSize := maxNbBits / f.fParams.BitsPerLimb()
if groupSize == 0 {
// no space for compact
return a.Limbs, b.Limbs, f.fParams.BitsPerLimb()
}
bitsPerLimb = f.fParams.BitsPerLimb() * groupSize
ac = f.compactLimbs(a, groupSize, bitsPerLimb)
bc = f.compactLimbs(b, groupSize, bitsPerLimb)
return
}
// compactLimbs perform the regrouping of limbs between old and new parameters.
func (f *Field[T]) compactLimbs(e *Element[T], groupSize, bitsPerLimb uint) []frontend.Variable {
if f.fParams.BitsPerLimb() == bitsPerLimb {
return e.Limbs
}
nbLimbs := (uint(len(e.Limbs)) + groupSize - 1) / groupSize
r := make([]frontend.Variable, nbLimbs)
coeffs := make([]*big.Int, groupSize)
one := big.NewInt(1)
for i := range coeffs {
coeffs[i] = new(big.Int)
coeffs[i].Lsh(one, f.fParams.BitsPerLimb()*uint(i))
}
for i := uint(0); i < nbLimbs; i++ {
r[i] = uint(0)
for j := uint(0); j < groupSize && i*groupSize+j < uint(len(e.Limbs)); j++ {
r[i] = f.api.Add(r[i], f.api.Mul(coeffs[j], e.Limbs[i*groupSize+j]))
}
}
return r
}
// maxOverflow returns the maximal possible overflow for the element. If the
// overflow of the next operation exceeds the value returned by this method,
// then the limbs may overflow the native field.
func (f *Field[T]) maxOverflow() uint {
f.maxOfOnce.Do(func() {
f.maxOf = uint(f.api.Compiler().FieldBitLen()-2) - f.fParams.BitsPerLimb()
})
return f.maxOf
}
func max[T constraints.Ordered](a ...T) T {
if len(a) == 0 {
var f T
return f
}
m := a[0]
for _, v := range a {
if v > m {
m = v
}
}
return m
}