In [1]:
import hashlib
from sage.rings.finite_rings.integer_mod import IntegerMod_gmp
from sage.rings.polynomial.polynomial_zmod_flint import Polynomial_zmod_flint
#from math import log
# DIFINES
PRIME = 3*2**30+1
FIELD = GF(PRIME)

In [2]:
F.<x> = PolynomialRing(GF(PRIME),'x')
field_gen = FIELD.multiplicative_generator()
gen8192 = field_gen ** ((PRIME-1)/8192)
gen1024 = field_gen ** ((PRIME-1)/1024)

In [3]:
Y_INDEX = 999
a = 2 

In [9]:
# def merkle(points: list):
#     if len(points) <= 1:
#         return points[0]
#     squeezed_points = []
#     for i in range(len(points)//2):
#         temp_string = ''.join(str(points[i*2]))
#         temp_string = temp_string.join(str(points[i*2+1]))
#         squeezed_points.append(sha3(temp_string))
#     if len(points)%2 == 1:
#         squeezed_points.append(sha3(''.join(str(points[-1]))))
#     return merkle(squeezed_points)

In [10]:
def sha3(string: str):
    byte=string.encode('ascii')
    m = hashlib.sha3_256()
    m.update(byte)
    
    return m.hexdigest()

In [11]:
def fiat_shamir_random(data: str, nonce = 0)-> int:
    if nonce:
        data = sha3(data + str(nonce))
    rand = int.from_bytes(data.encode('ascii'), "big")
    return rand%PRIME

In [13]:
def reverse_bit(n, width = 10):
    n_binary = '{:0{width}b}'.format(n, width=width)
    return int(n_binary[::-1], 2)

In [14]:
# def low_degree_extension(poly: Polynomial_zmod_flint, trance_len: int, group_gen: IntegerMod_gmp, field_gen: IntegerMod_gmp):
#     coset_set = [field_gen*(group_gen**i) for i in range (trance_len)]
#     new_coset_set = [coset_set[reverse_bit(i, log(trance_len,2))] for i in range (trance_len)]
#     return [(c, poly(c)) for c in new_coset_set]


In [37]:
def hash_tow_elements(element1, element2):
        temp_string = ''.join(str(element1))
        temp_string = temp_string.join(str(element2))
        return(sha3(temp_string))

def hash_one_elements(element):
    return(sha3(str(element)))


class MerkeTree():
    
    tree: dict = {}
    domain_size: int
    
    def __init__(self, domain:list):
        self.tree={}
        self.domain_size = len(domain)
        # Calculate the hashes of each point in the domain.
        # Inset the leavs and their hashes to the tree. 
        domain_hashed = []
        for element in domain:
            hashed_element = hash_one_elements(element)
            self.tree[hashed_element] = element
            domain_hashed.append(hashed_element)

        # Now all the leavs are in the tree.
        # Construct the hash piramid.
        self.recursive_merkle(nodes_layer=domain_hashed)

    def recursive_merkle(self, nodes_layer: list):
        if len(nodes_layer) <= 1:
            #This is the root of the merkle tree.
            self.tree['root']=nodes_layer[0]
            return

        assert len(nodes_layer)%2 ==0
        
        # Create a new layer of nodes in the tree
        new_nodes_layre = []

        # Create a new node based on the two node beneath it.
        for i in range(len(nodes_layer)/2):
            hash_element = hash_tow_elements(nodes_layer[i*2], nodes_layer[i*2+1])
            self.tree[hash_element] = (nodes_layer[i*2], nodes_layer[i*2+1])
            new_nodes_layre.append(hash_element)
            
        return self.recursive_merkle(nodes_layer = new_nodes_layre)
    
    @property
    def root(self):
        return self.tree['root']
    
    def get_value_and_path_by_index(self, index: int):
        
        index_size = int(log(self.domain_size, 2))
        key = self.tree['root']
        
        # Shift the index from an int to a binary list.
        index_as_str = format(index, f'#0{index_size+2}b')

        # Shift from '0b1110' to '1110'
        index_as_str = index_as_str[2:] 

        path = {}
        while(index_as_str):
            value = self.tree[key]
            path[key] = value
            direction_bit = int(index_as_str[0])
            key = value[direction_bit]
            index_as_str = index_as_str[1:]
        
        #Now the key is the hash of the required index. Reauired value = tree[key] = (coset, CP[coset])
        path[key] = self.tree[key]
        return(path[key], path)


In [2]:
# def verify_path_by_index_can(root: str, expected_value , path: dict):

#     #Verify - expected value is in the path
#     end_of_path = hash_one_elements(expected_value)
#     assert path[end_of_path] == expected_value 
#     key = root
#     while(key):
#         value = path[key]
#         #verify - the path is correct
#         if key != end_of_path:
#             assert key == hash_tow_elements(value[0], value[1])

#         #check wich child is the continuation of the path
#         if value[0] in path:
#             key = value[0]
#         elif value[1] in path:
#             key = value[1]

#         else: #Stop at the end of the path
#             key=0

In [4]:
def verify_path_by_index(root: str, expected_value: tuple, index: int, domain_size: int, path: dict):

    index_size = int(log(domain_size, 2))
    index_as_str = format(index, f'#0{index_size+2}b')
    # shift from '0b1110' to '1110'
    index_as_str = index_as_str[2:] 
    key=root
    while(index_as_str):
        value = path[key]

        #Verify the hash:
        assert hash_tow_elements(value[0], value[1]) == key

        direction_bit = int(index_as_str[0])
        key = value[direction_bit]
        index_as_str = index_as_str[1:]

    #Now the key is the hash of the required index. Reauired value = tree[key] = (coset, CP[coset])
    value = path[key]
    assert hash_one_elements(value) == key
    assert value == expected_value

SyntaxError: invalid syntax (<ipython-input-4-edb34032a947>, line 9)

In [None]:
def domain_extension( trace_len: int, group_gen: IntegerMod_gmp, field_gen: IntegerMod_gmp):
    coset_set = [field_gen*(group_gen**i) for i in range (trace_len)]
    new_coset_set = [coset_set[reverse_bit(i, log(trace_len,2))] for i in range (trace_len)]
    return new_coset_set

In [45]:
def get_extented_domain(excecution_trace_length: int, num_of_queries: int = 8):
    domain_size = excecution_trace_length*num_of_queries
    domain_gen = field_gen ** ((PRIME-1)/domain_size)
    domain = domain_extension(domain_size , domain_gen, field_gen)
    return domain

In [None]:
def get_random_index_to_sample(queries: int, merkel_root: list, domain_len: int):
    index_to_sample = [0]*queries
    combined_root = ''.join(root for root in merkel_root)
    num_of_couples = queries/2
    domain_to_query = domain_len/num_of_couples #make sure that every couple of points are from different branch
    for i in range(num_of_couples):
        index_to_sample[i*2] = int(fiat_shamir_random(combined_root,i)%domain_to_query+i*domain_to_query)
        index_to_sample[i*2+1] = index_to_sample[i*2]+(-1)**(index_to_sample[i*2]%2)
    return index_to_sample

In [None]:
def get_next_stage_indexes(current_stage_indexes: list):
    new_index_to_sample = []
    for index in current_stage_indexes:
        index//=2
        temp_index = index + (-1)**(index%2)
        if temp_index not in new_index_to_sample:
            new_index_to_sample.append(temp_index)
    return new_index_to_sample

In [None]:
def create_poly_coeff(ex_poly_root):
    c1 = fiat_shamir_random(ex_poly_root)
    c2 = fiat_shamir_random(str(c1))
    c3 = fiat_shamir_random(str(c2))
    print(f"{(c3, c2, c3)=}")
    return(c3, c2, c3)