-
Notifications
You must be signed in to change notification settings - Fork 1
/
fibo_stark.py
257 lines (211 loc) · 11.8 KB
/
fibo_stark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from permuted_tree import merkelize, mk_branch, verify_branch, blake, mk_multi_branch, verify_multi_branch
from poly_utils import PrimeField
import time
from fft import fft
from fri import prove_low_degree, verify_low_degree_proof
from utils import get_power_cycle, get_pseudorandom_indices, is_a_power_of_2
"""
All calculations are done modulo ; Vitalik used this prime field modulus
because it is the largest prime below 2^256 whose multiplicative group contains
an order 2^32 subgroup (that is, there's a number g such that successive powers
of g modulo this prime loop around back to 1 after exactly 32^32 cycles), and
which is of the form 6k+5. The first property is necessary to make sure that
the efficient versions of the FFT and FRI algorithms can work, and the second
ensures that the MIMC actually can be computed "backwards". We dont need the
backwards computation but keep that modulus anyway.
"""
modulus = 2**256 - 2**32 * 351 + 1
f = PrimeField(modulus)
# V nonresidue = 7 I think we dont need it
# This determines the number of checks to be performed on the row
spot_check_security_factor = 80
# The extension factor is the extent to which we will be "stretching" the
# computational trace. We do this to easier calculate the polynomial Z later
extension_factor = 8
# Compute a the Fibonacci sequence for some number of steps
def fib_to(n):
fibs = [0, 1]
for i in range(2, n+1):
fibs.append(fibs[-1] + fibs[-2])
return fibs
def generate_proof(fibo_to_n):
start_time = time.time()
assert fibo_to_n + 1 <= 2**32
assert is_a_power_of_2(fibo_to_n + 1)
print("\nStarting proof generation:")
# This is the higher power root of unity to calculate the intermediate
# evaluations of the polynomials i.e. not exactly on the computation trace
# Where evaluations are zero for some polys and lead to problems with
# evaluation
# Root of unity such that x^(precision)=1
precision = (fibo_to_n + 1) * extension_factor
G2 = f.exp(7, (modulus-1)//precision)
# This is the lower root of unity where we lay the computational trace on
# Root of unity such that x^fibo_to_n + 1=1
skips = precision // (fibo_to_n + 1)
G1 = f.exp(G2, skips)
# Powers of the higher order root of unity
xs = get_power_cycle(G2, modulus)
print("Number of elements in the extended value set xs: %d" % len(xs))
last_step_position = xs[fibo_to_n * extension_factor]
print("Generate the computational trace...")
computational_trace = fib_to(fibo_to_n)
output = computational_trace[-1]
print("Computational_trace lenght: ", len(computational_trace))
print("Start: %d, %d , %d ..." % (computational_trace[0],
computational_trace[1],
computational_trace[2]))
print("End: ...%d, %d , %d" % (computational_trace[-3],
computational_trace[-2],
computational_trace[-1]))
# Interpolate, i.e. inverse FFT the computational trace points into a
# polynomial P, with each step along a successive power of G1 ...
print("Interpolating the trace in lower order inverse FFT to P")
computational_trace_polynomial = fft(computational_trace, modulus, G1,
inv=True)
print("Number of coefficients of P: ", len(computational_trace_polynomial))
# Evaluate the new polynomial. Not on the orifinal trace which lies on the
# Lower order roots G1, but on denser domain generated by G2
print("Evaluating P on higher order...")
p_evaluations = fft(computational_trace_polynomial, modulus, G2)
# Create the composed polynomial such that"
print("Compose the evaluations of P such that")
print("C(P(x), P(g1*x), P((g1^2)*x)) = P(x) - P(g1*x) - P(g1^2*x)")
c_of_p_evaluations = [(
p_evaluations[i % precision] -
p_evaluations[i - 2 * extension_factor % precision] -
p_evaluations[i - extension_factor % precision])
% modulus for i in range(precision)]
print("\nCompute D(x) = C(P(x), P(g1*x), P(g1^2*x)) / Z(x)")
print("First compute Z(x) as Numerator / Denominator")
print("Numerator: (x^fibo_to_n + 1) -1")
print("Denominator: (x-x[0])*(x-x[extension_factor])*(x-x_atlast_step)")
z_num_evaluations = [xs[(i * (fibo_to_n + 1)) % precision] - 1
for i in range(precision)]
print("Efficiently computing the inverse of Numerator it in batch...")
z_num_inv = f.multi_inv(z_num_evaluations)
polymult = f.mul_polys([-xs[0], 1], [-xs[extension_factor], 1])
polymult = f.mul_polys(polymult, [-last_step_position, 1])
z_den_evaluations = [f.eval_poly_at(polymult, x) for x in xs]
print("Computing evaluation of D as C * Denominator * 1/Numerator")
d_evaluations = [cp * zd * zni % modulus for cp, zd, zni in
zip(c_of_p_evaluations, z_den_evaluations, z_num_inv)]
print("\nCompute Polynomial B to represent input and output")
print("Compute the interpolant that passes through (0,1,%d)" % output)
interpolant = f.lagrange_interp([xs[0], xs[extension_factor], last_step_position], [0, 1, output])
i_evaluations = [f.eval_poly_at(interpolant, x) for x in xs]
print("Compute the quotient")
zeropoly2_1 = f.mul_polys([-xs[0], 1], [-last_step_position, 1])
zeropoly2 = f.mul_polys([-xs[extension_factor], 1], zeropoly2_1)
print("Efficiently compute the inverse of the quotient")
z_2 = [f.eval_poly_at(zeropoly2, x) for x in xs]
inv_z2_evaluations = f.multi_inv(z_2)
print("Calculate B = (P - Interpolant) * quotient in evaluation form")
b_evaluations = [((p - i) * invq) % modulus for p, i, invq in
zip(p_evaluations, i_evaluations, inv_z2_evaluations)]
print('DONE Computed B polynomial\n')
print("\n Compute the Merkle root of p_- d_- and b_evaluations")
mtree = merkelize([pval.to_bytes(32, 'big') +
dval.to_bytes(32, 'big') +
bval.to_bytes(32, 'big') for
pval, dval, bval in zip(p_evaluations, d_evaluations,
b_evaluations)])
print('DONE Computed hash root len', len(mtree), "\n")
print("\nCalculate a random linear combination of P * x^fibo_to_n, P, B * x^fibo_to_n, B and D, to prove the low-degreeness of that, instead of proving the low-degreeness of P B and D separately")
k1 = int.from_bytes(blake(mtree[1] + b'\x01'), 'big')
k2 = int.from_bytes(blake(mtree[1] + b'\x04'), 'big')
k3 = int.from_bytes(blake(mtree[1] + b'\x03'), 'big')
l_evaluations = [(d_evaluations[i] +
p_evaluations[i] * k1 +
b_evaluations[i] * k3 +
b_evaluations[i] * k2) % modulus
for i in range(precision)]
l_mtree = merkelize(l_evaluations)
print("Put evaluation of the polynomial in merkletree")
print('DONE Computed random linear combination')
print("\nPrepare spot checks of the random linear combination Merkle tree at pseudo-random coordinates, excluding multiples of `extension_factor because we have divided by zero there")
samples = spot_check_security_factor
positions = get_pseudorandom_indices(l_mtree[1], precision, samples,
exclude_multiples_of=extension_factor)
print("For each random position x we also get x - skips and x - 2 * skips positions for the Fibonacci checks")
augmented_positions = sum([[x, (x - skips) % precision, (x - 2 * skips) % precision]
for x in positions], [])
print('DONE Computed %d spot checks' % samples)
o = [mtree[1],
l_mtree[1],
mk_multi_branch(mtree, augmented_positions),
mk_multi_branch(l_mtree, positions),
prove_low_degree(l_evaluations, G2, fibo_to_n,
modulus, exclude_multiples_of=extension_factor)]
print("\nProof DONE!!! It consists of 5 parts:")
print("1. Merkle root of p_- d_- and b_evaluations at x, x-1 and x-2")
print("2. Merkle root of linear combinations of p d and b")
print("3. Merke proof for 1.")
print("4. Merkle proof for 2.")
print("5. proving low degreeness of 2.")
print("STARK computed in %.4f sec" % (time.time() - start_time))
return o
def verify_proof(fibo_to_n, output, proof):
print("\n\n\n\nStarting Verification...")
m_root, l_root, main_branches, linear_comb_branches, fri_proof = proof
start_time = time.time()
assert fibo_to_n + 1 <= 2**32 // extension_factor
assert is_a_power_of_2(fibo_to_n + 1)
precision = (fibo_to_n + 1) * extension_factor
skips = precision // (fibo_to_n + 1)
G2 = f.exp(7, (modulus-1)//precision)
last_step_position = f.exp(G2, (fibo_to_n) * skips)
x0 = f.exp(G2, 0)
x1 = f.exp(G2, extension_factor)
print("Use FRI to verify low degreeness of linear combination...")
assert verify_low_degree_proof(l_root, G2, fri_proof, fibo_to_n + 1,
modulus,
exclude_multiples_of=extension_factor)
print("Re-create the random spot checks of the prover...")
print("Retrieve random scalars from merkle root")
k1 = int.from_bytes(blake(m_root + b'\x01'), 'big')
k3 = int.from_bytes(blake(m_root + b'\x03'), 'big')
k4 = int.from_bytes(blake(m_root + b'\x04'), 'big')
samples = spot_check_security_factor
print("Retrieve pseudo random spots to check from merkle root")
positions = get_pseudorandom_indices(l_root, precision, samples,
exclude_multiples_of=extension_factor)
augmented_positions = sum([[x, (x - skips) % precision, (x - 2 * skips) % precision] for x in positions], [])
print("Retrieve the values at these positions and verify the m-paths")
main_branch_leaves = verify_multi_branch(m_root, augmented_positions, main_branches)
linear_comb_branch_leaves = verify_multi_branch(l_root, positions, linear_comb_branches)
print("Loop over all the random positions to check consistency")
print(" Check that D = C / Z i.e. C - Z * D = 0")
print(" Check that P = I + B*Q")
print(" Check that the linear combibation adds up to zero")
for i, pos in enumerate(positions):
x = f.exp(G2, pos)
mbranch1 = main_branch_leaves[i*3]
mbranch2 = main_branch_leaves[i*3+1]
mbranch3 = main_branch_leaves[i*3+2]
l_of_x = int.from_bytes(linear_comb_branch_leaves[i], 'big')
p_of_x = int.from_bytes(mbranch1[:32], 'big')
p_of_x1 = int.from_bytes(mbranch2[:32], 'big')
p_of_x2 = int.from_bytes(mbranch3[:32], 'big')
d_of_x = int.from_bytes(mbranch1[32:64], 'big')
b_of_x = int.from_bytes(mbranch1[64:], 'big')
zvalue = f.div(f.exp(x, fibo_to_n + 1) - 1,
((x - last_step_position) * (x - x0) * (x - x1)) % modulus)
# Check that D = C / Z i.e. C - Z * D = 0
assert (p_of_x - p_of_x1 - p_of_x2 - zvalue * d_of_x) % modulus == 0
# Check boundary constraints B(x) * Q(x) + I(x) = P(x)
interpolant = f.lagrange_interp([x0, x1, last_step_position], [0, 1, output])
zeropoly2_1 = f.mul_polys([-x0, 1], [-last_step_position, 1])
zeropoly2 = f.mul_polys([-x1, 1], zeropoly2_1)
# Check that P = I + B*Q
assert (p_of_x - b_of_x * f.eval_poly_at(zeropoly2, x) -
f.eval_poly_at(interpolant, x)) % modulus == 0
# Check that the linear combibation adds up to zero
assert (l_of_x - d_of_x -
k1 * p_of_x -
k3 * b_of_x -
k4 * b_of_x) % modulus == 0
print(" Checks passed for %d-th position x = G2^%d" % (i, pos))
print('Verified %d consistency checks' % spot_check_security_factor)
print('Verified STARK in %.4f sec' % (time.time() - start_time))
return True