forked from Consensys/gnark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rangecheck_commit.go
197 lines (181 loc) · 6.03 KB
/
rangecheck_commit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
package rangecheck
import (
"fmt"
"math"
"math/big"
"github.com/aakash4dev/gnark2/constraint/solver"
"github.com/aakash4dev/gnark2/frontend"
"github.com/aakash4dev/gnark2/internal/frontendtype"
"github.com/aakash4dev/gnark2/internal/kvstore"
"github.com/aakash4dev/gnark2/std/internal/logderivarg"
)
type ctxCheckerKey struct{}
func init() {
solver.RegisterHint(DecomposeHint)
}
type checkedVariable struct {
v frontend.Variable
bits int
}
type commitChecker struct {
collected []checkedVariable
closed bool
}
func newCommitRangechecker(api frontend.API) *commitChecker {
kv, ok := api.Compiler().(kvstore.Store)
if !ok {
panic("builder should implement key-value store")
}
ch := kv.GetKeyValue(ctxCheckerKey{})
if ch != nil {
if cht, ok := ch.(*commitChecker); ok {
return cht
} else {
panic("stored rangechecker is not valid")
}
}
cht := &commitChecker{}
kv.SetKeyValue(ctxCheckerKey{}, cht)
api.Compiler().Defer(cht.commit)
return cht
}
func (c *commitChecker) Check(in frontend.Variable, bits int) {
if c.closed {
panic("checker already closed")
}
c.collected = append(c.collected, checkedVariable{v: in, bits: bits})
}
func (c *commitChecker) buildTable(nbTable int) []frontend.Variable {
tbl := make([]frontend.Variable, nbTable)
for i := 0; i < nbTable; i++ {
tbl[i] = i
}
return tbl
}
func (c *commitChecker) commit(api frontend.API) error {
if c.closed {
return nil
}
defer func() { c.closed = true }()
if len(c.collected) == 0 {
return nil
}
baseLength := c.getOptimalBasewidth(api)
// decompose into smaller limbs
decomposed := make([]frontend.Variable, 0, len(c.collected))
collected := make([]frontend.Variable, len(c.collected))
coef := new(big.Int)
one := big.NewInt(1)
for i := range c.collected {
// collect all vars for commitment input
collected[i] = c.collected[i].v
// decompose value into limbs
nbLimbs := decompSize(c.collected[i].bits, baseLength)
limbs, err := api.Compiler().NewHint(DecomposeHint, int(nbLimbs), c.collected[i].bits, baseLength, c.collected[i].v)
if err != nil {
panic(fmt.Sprintf("decompose %v", err))
}
// store all limbs for counting
decomposed = append(decomposed, limbs...)
// check that limbs are correct. We check the sizes of the limbs later
var composed frontend.Variable = 0
for j := range limbs {
composed = api.Add(composed, api.Mul(limbs[j], coef.Lsh(one, uint(baseLength*j))))
}
api.AssertIsEqual(composed, c.collected[i].v)
// we have split the input into nbLimbs partitions of length baseLength.
// This ensures that the checked variable is not more than
// nbLimbs*baseLength bits, but was requested to be c.collected[i].bits,
// which may be less. Conditionally add one more check to the most
// significant partition. If shift is the difference between
// nbLimbs*baseLength and c.collected[i].bits, then check that MS*2^diff
// is also baseLength. Because of both checks for MS and MS*2^diff give
// ensure that the value are small we cannot have overflow.
shift := nbLimbs*baseLength - c.collected[i].bits
if shift > 0 {
msLimbShifted := api.Mul(limbs[nbLimbs-1], coef.Lsh(one, uint(shift)))
decomposed = append(decomposed, msLimbShifted)
}
}
nbTable := 1 << baseLength
return logderivarg.Build(api, logderivarg.AsTable(c.buildTable(nbTable)), logderivarg.AsTable(decomposed))
}
func decompSize(varSize int, limbSize int) int {
return (varSize + limbSize - 1) / limbSize
}
// DecomposeHint is a hint used for range checking with commitment. It
// decomposes large variables into chunks which can be individually range-check
// in the native range.
func DecomposeHint(m *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 3 {
return fmt.Errorf("input must be 3 elements")
}
if !inputs[0].IsUint64() || !inputs[1].IsUint64() {
return fmt.Errorf("first two inputs have to be uint64")
}
varSize := int(inputs[0].Int64())
limbSize := int(inputs[1].Int64())
val := inputs[2]
nbLimbs := decompSize(varSize, limbSize)
if len(outputs) != nbLimbs {
return fmt.Errorf("need %d outputs instead to decompose", nbLimbs)
}
base := new(big.Int).Lsh(big.NewInt(1), uint(limbSize))
tmp := new(big.Int).Set(val)
for i := 0; i < len(outputs); i++ {
outputs[i].Mod(tmp, base)
tmp.Rsh(tmp, uint(limbSize))
}
return nil
}
func (c *commitChecker) getOptimalBasewidth(api frontend.API) int {
if ft, ok := api.(frontendtype.FrontendTyper); ok {
switch ft.FrontendType() {
case frontendtype.R1CS:
return optimalWidth(nbR1CSConstraints, c.collected)
case frontendtype.SCS:
return optimalWidth(nbPLONKConstraints, c.collected)
}
}
return optimalWidth(nbR1CSConstraints, c.collected)
}
func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int {
min := math.MaxInt64
minVal := 0
for j := 2; j < 18; j++ {
current := countFn(j, collected)
if current < min {
min = current
minVal = j
}
}
return minVal
}
func nbR1CSConstraints(baseLength int, collected []checkedVariable) int {
nbDecomposed := 0
for i := range collected {
nbVarLimbs := int(decompSize(collected[i].bits, baseLength))
if nbVarLimbs*baseLength > collected[i].bits {
nbVarLimbs += 1
}
nbDecomposed += int(nbVarLimbs)
}
eqs := len(collected) // correctness of decomposition
nbRight := nbDecomposed // inverse per decomposed
nbleft := (1 << baseLength) // div per table
return nbleft + nbRight + eqs + 1
}
func nbPLONKConstraints(baseLength int, collected []checkedVariable) int {
nbDecomposed := 0
for i := range collected {
nbVarLimbs := int(decompSize(collected[i].bits, baseLength))
if nbVarLimbs*baseLength > collected[i].bits {
nbVarLimbs += 1
}
nbDecomposed += int(nbVarLimbs)
}
eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected
nbRight := 3 * nbDecomposed // denominator sub, inv and large sum per table entry
nbleft := 3 * (1 << baseLength) // denominator sub, div and large sum per table entry
return nbleft + nbRight + eqs + 1 // and the final assert
}