In [54]:
import hashlib
from sage.rings.finite_rings.integer_mod import IntegerMod_gmp
from sage.rings.polynomial.polynomial_zmod_flint import Polynomial_zmod_flint

# DIFINES
PRIME = 3*2**30+1
FIELD = GF(PRIME)

In [55]:
F.<x> = PolynomialRing(GF(PRIME),'x')
field_gen = FIELD.multiplicative_generator()
gen8192 = field_gen ** ((PRIME-1)/8192)
gen1024 = field_gen ** ((PRIME-1)/1024)

In [56]:
Y_INDEX = 999
a = 2 

In [57]:
def trace_calculator(a: int, trance_len: int) -> list:
    trace = [1, a]
    for i in range(2,trance_len):
        trace.append((trace[i-1] + trace[i-2])%PRIME)
    return trace

In [58]:
def polynomial_evaluation (trace: list, generator: IntegerMod_gmp) -> Polynomial_zmod_flint:
    points =[]
    for i, y in enumerate (trace): 
        points.append((generator**i, y))
    
    R = FIELD['x']
    polynomial = R.lagrange_polynomial(points)
    print(type(polynomial))
    return polynomial

In [59]:
trance_len = 1024
trace = trace_calculator(a=a,trance_len=trance_len)
#print(trace)
poly = polynomial_evaluation(trace=trace, generator=gen1024)
# Sainaty check:
assert poly(gen1024**Y_INDEX) == trace[Y_INDEX]
Y=trace[Y_INDEX]

<class 'sage.rings.polynomial.polynomial_zmod_flint.Polynomial_zmod_flint'>


In [60]:
def compositon_polynomial(poly: Polynomial_zmod_flint, trance_len: int, Y: IntegerMod_gmp, index_y: int , gen: IntegerMod_gmp):
    n=trance_len
    p1 = (poly-1)/(x-gen**0)
    co1 = fiat_shamir_random(data=str(p1))
    p2 = (poly - Y)/(x-gen**index_y)
    co2 = fiat_shamir_random(data=str(p2))
    
    # (x-g**0)(x-g**1)...(x-g**(n-1)) = x**n-1
    # 
    constrain_3_numer = poly(gen ** 2 * x) - poly(gen * x) - poly(x)
    # constrain_3_numer should divide by all powers of gen: x=g**0, .... x= g**(n-3)
    constrain_3_denom = (x**n-1) / ( (x-gen**(n-1)) * (x-gen**(n-2)) )
    # p3 = (poly(gen ** 2 * x) - poly(gen * x) - poly(x))*(x-gen**(n-1)) \
    #     * (x-gen**(n-2))*(x-gen**(n-3))/(x**n-1)
    p3 = constrain_3_numer / constrain_3_denom
    co3 = fiat_shamir_random(data=str(p3))
    return p1*co1 + p2+co2 + p3*co3

In [61]:
def merkle(points: list):
    if len(points) <= 1:
        return points[0]
    squeezed_points = []
    for i in range(len(points)//2):
        temp_string = ''.join(str(points[i*2]))
        temp_string = temp_string.join(str(points[i*2+1]))
        squeezed_points.append(sha3(temp_string))
    if len(points)%2 == 1:
        squeezed_points.append(sha3(''.join(str(points[-1]))))
    return merkle(squeezed_points)

In [62]:
def sha3(string: str):
    byte=string.encode('ascii')
    m = hashlib.sha3_256()
    m.update(byte)
    
    return m.hexdigest()

In [63]:
def fiat_shamir_random(data: str)-> int:
    rand = int.from_bytes(data.encode('ascii'), "big") 
    return rand%PRIME

In [64]:
CP = compositon_polynomial(poly=poly, trance_len=trance_len, Y=Y, index_y=Y_INDEX , gen=gen1024)

In [65]:
points = trace_calculator(2, 15)
merkle(points)

'112bda191ff7087fc2d7fda84d713e02eb8cae32817df6a2e2a41224bf81616c'

In [66]:
#CP

In [67]:
def reverse_bit(n, width = 10):
    n_binary = '{:0{width}b}'.format(n, width=width)
    return int(n_binary[::-1], 2)

In [68]:
def low_degree_extension(poly: Polynomial_zmod_flint, trance_len: int, group_gen: IntegerMod_gmp, field_gen: IntegerMod_gmp):
    coset_set = [field_gen*(group_gen**i) for i in range (trance_len)]
    new_coset_set = [coset_set[reverse_bit(i, log(trance_len,2))] for i in range (trance_len)]
    return [(c, poly(c)) for c in new_coset_set]

In [69]:
low_degree_extension(poly=CP, trance_len=16, group_gen=gen8192, field_gen=field_gen)

[(5, 240136124),
 (2833855974, 3137794864),
 (2623023147, 832770097),
 (2540449978, 178471777),
 (1056415280, 2443620285),
 (1800233432, 3212932430),
 (380960369, 1435039985),
 (671626609, 1479639957),
 (2229935889, 2287296401),
 (2457247830, 2896588392),
 (19433037, 3031081945),
 (3171820980, 546310648),
 (2439017906, 1350764941),
 (13901640, 2967064874),
 (820071082, 450352649),
 (565703797, 1281215413)]

In [70]:
def fri(poly: Polynomial_zmod_flint, domain: list , degree = 1024) -> dict:
    
    proof = {} #{stage: [merkle root, [pathes], random number for naxt-stage]}
    for stage in range (log(degree, 2)):
        
        pathes = []
        points, merkel_root = commit(poly=poly, domain=domain)
        #first stage: evaluation above the whole domain and creation of merkle tree

        evaluate_points_and_path()
        #second stage: takes n number randoms, and claculate n/2 time P(x_i) 1<i<n/2 + merkle path for them

        rand = fiat_shamir_random(merkel_root)
        proof[stage]= [merkel_root ,pathes, rand]
        poly, domain = fri_next_layer(poly=poly, domain=domain, rand = rand)
        #third stage: calculates FRI next Layer
    
    return proof

In [71]:
def commit(poly: Polynomial_zmod_flint, domain: list):
    #first stage: evaluation above the whole domain and creation of merkle tree
    points=[(d, poly(d)) for d in domain]
    return (points, merkle(points=points))

In [72]:
def fri_next_layer(poly: Polynomial_zmod_flint, domain: list , rand: int):
    #calculate the polynomial and the domain of the next stage
    even = 0*x
    odd = 0*x
    for degree,coef in poly.dict().items():
        if degree%2==0:
            even = even + coef*x**(degree//2)
        else:
            odd = odd + coef*x**(degree//2)
    next_layer = even + rand*odd
    new_domain = []
    for i in range(0,len(domain),2):
        assert domain[i]**2 == domain[i+1]**2
        new_domain.append(domain[i]**2)
    return next_layer, new_domain

In [73]:
def evaluate_points_and_path():
    pass

In [74]:
trace = trace_calculator(2,15)
poly = polynomial_evaluation(trace, field_gen ** ((PRIME-1)/16))
#poly = compositon_polynomial(poly, 15, field_gen ** ((PRIME-1)/16), 10, field_gen)
#f = 3*x**1 #2895570615*x**12 + x**7
#poly = poly.sub(f)
#f.coefficients(0)
#poly.dict()
print(f"{poly} \n \n" )
poly, domain = fri_next_layer(poly,[0,0,0,0], 1)
print (f"{poly} \n \n {254490165 + 2542374603} \n \n {domain}" )

trace = [0]*16
proof = fri(poly,trace,16)
proof

<class 'sage.rings.polynomial.polynomial_zmod_flint.Polynomial_zmod_flint'>
1949650749*x^14 + 254490165*x^13 + 2542374603*x^12 + 1167923054*x^11 + 2345897642*x^10 + 2315637714*x^9 + 2941037616*x^8 + 2682956350*x^7 + 56638226*x^6 + 854385164*x^5 + 1484157149*x^4 + 54628453*x^3 + 2895570615*x^2 + 2333655492*x + 1890800793 
 

1949650749*x^7 + 2796864768*x^6 + 292595223*x^5 + 2035449857*x^4 + 2739594576*x^3 + 2338542313*x^2 + 2950199068*x + 1003230812 
 
 2796864768 
 
 [0, 0]


{0: ['c87e08942e4e9cbd3fab82cfa4226926a81329539ce8625da2de602bffecef6b',
  [],
  1369968218],
 1: ['f358532b1d5611ca6e474d6569c808bf01b4f8d3210de943197e3b33acb0cda3',
  [],
  2733178454],
 2: ['81cb2040b6b34230ef32cc06c6bc7f73d629c932b241187e787ed558f4dda829',
  [],
  2355736069],
 3: ['ad8a8fd47b6f1a3da797f1fd55d9f85973bd47719234e83b9571722e0b50f4dc',
  [],
  3033393731]}

In [75]:
trace = []
domain = []
for i in range (1024):
    trace.append(gen1024**i)
for i in range (1024):
    domain.append(trace[reverse_bit(i, 10)])
domain[2]**2

3221225472

In [76]:
def hash_tow_elements(element1, element2):
        temp_string = ''.join(str(element1))
        temp_string = temp_string.join(str(element2))
        return(sha3(temp_string))

def hash_one_elements(element):
    return(sha3(str(element)))


class MerkeTree():
    
    tree: dict = {}
    
    def __init__(self, domain:list):
        self.tree={}
        
        # Calculate the hashes of each point in the domain.
        # Inset the leavs and their hashes to the tree. 
        domain_hashed = []
        for element in domain:
            hashed_element = hash_one_elements(element)
            self.tree[hashed_element] = element
            domain_hashed.append(hashed_element)

        # Now all the leavs are in the tree.
        # Construct the hash piramid.
        self.recursive_merkle(nodes_layer=domain_hashed)

    def recursive_merkle(self, nodes_layer: list):
        if len(nodes_layer) <= 1:
            #This is the root of the merkle tree.
            self.tree['root']=nodes_layer[0]
            return

        assert len(nodes_layer)%2 ==0
        
        # Create a new layer of nodes in the tree
        new_nodes_layre = []

        # Create a new node based on the two node beneath it.
        for i in range(len(nodes_layer)/2):
            hash_element = hash_tow_elements(nodes_layer[i*2], nodes_layer[i*2+1])
            self.tree[hash_element] = (nodes_layer[i*2], nodes_layer[i*2+1])
            new_nodes_layre.append(hash_element)
            
        return self.recursive_merkle(nodes_layer = new_nodes_layre)
    
    @property
    def root(self):
        return self.tree['root']
    
    def get_value_and_path_by_index(self, index: int, index_size: int):
        key = self.tree['root']
        
        # Shift the index from an int to a binary list.
        index_as_str = format(index, f'#0{index_size}b')

        # Shift from '0b1110' to '1110'
        index_as_str = index_as_str[2:] 

        path = {}
        while(index_as_str):
            value = self.tree[key]
            path[key] = value
            direction_bit = int(index_as_str[0])
            key = value[direction_bit]
            index_as_str = index_as_str[1:]
        
        #Now the key is the hash of the required index. Reauired value = tree[key] = (coset, CP[coset])
        path[key] = self.tree[key]
        return(path[key], path)


In [84]:
def verify_path_by_index(root: str, element: tuple, index: int, index_size: int, path: dict):
    index_as_str = format(index, f'#0{index_size}b')

    # shift from '0b1110' to '1110'
    index_as_str = index_as_str[2:] 
    
    key=root
    while(index_as_str):
        value = path[key]
        
        #Verify the hash:
        assert hash_tow_elements(value[0], value[1]) == key
        
        direction_bit = int(index_as_str[0])
        key = value[direction_bit]
        index_as_str = index_as_str[1:]

    #Now the key is the hash of the required index. Reauired value = tree[key] = (coset, CP[coset])
    value = path[key]
    assert hash_one_elements(value) == key
    assert value == element

In [85]:
domain = low_degree_extension(poly=CP, trance_len=16, group_gen=gen8192, field_gen=field_gen)
merke_tree = MerkeTree(domain=domain)
#print(merke_tree)

In [86]:
point, path = merke_tree.get_value_and_path_by_index(index = 3, index_size = 6)
print(path)
print(point)

{'3a9fdfeb854501daae03210db5902c8df8cb7e5a25b212cab8dae848e66c6aab': ('077c9e562151407d3616085c71ff5f18dc32f96ae590b274f00ee5136947d6d7', '472486be1232cdad7893200e8b7040ca2abf41bc9ac614e9a4f2013a3141efe8'), '077c9e562151407d3616085c71ff5f18dc32f96ae590b274f00ee5136947d6d7': ('87de3f1ca068651b4b9976e2a1e1921f8548d92f2434acd270b0db1b6dae9bca', '9bfd09ada6ced6ef13c4e55ebaf7cb41367cedb88f23241dbce245e71f08ab41'), '87de3f1ca068651b4b9976e2a1e1921f8548d92f2434acd270b0db1b6dae9bca': ('5505618816ee93e5851f05c15f2eafff2d7dcb3d358d22b8dd5e404473cdcac7', 'ab00ad1ce64b83c0e85476793709c2a5dcb8bc2cfbac56ea27f6ba45fde404db'), 'ab00ad1ce64b83c0e85476793709c2a5dcb8bc2cfbac56ea27f6ba45fde404db': ('bb0425bb0a904ccd116af4eb65b82b3bd6b15c4d892733752a0445aa9d538e1c', 'dfa3f5ab415f794db94609e279bead9c3ee194b55e1088c76081ca814dc8187f'), 'dfa3f5ab415f794db94609e279bead9c3ee194b55e1088c76081ca814dc8187f': (2540449978, 178471777)}
(2540449978, 178471777)


In [87]:
verify_path_by_index(root=merke_tree.root, element=point, index=3, index_size=6, path=path)