<a href="https://colab.research.google.com/github/Haykhovhannisyan1/Stark101_Verifier/blob/TheVerifier/Stark101_verifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import os

In [2]:
import time
import inspect
from hashlib import sha256
from math import log2, ceil
import operator
from functools import reduce
from itertools import dropwhile, starmap, zip_longest
from random import randint


In [3]:
###Chanel###
def serialize(obj):
    """
    Serializes an object into a string.
    """
    if isinstance(obj, (list, tuple)):
        return ','.join(map(serialize, obj))
    return obj._serialize_()


class Channel(object):
    """
    A Channel instance can be used by a prover or a verifier to preserve the semantics of an
    interactive proof system, while under the hood it is in fact non-interactive, and uses Sha256
    to generate randomness when this is required.
    It allows writing string-form data to it, and reading either random integers of random
    FieldElements from it.
    """

    def __init__(self):
        self.state = '0'
        self.proof = []

    def send(self, s):
        self.state = sha256((self.state + s).encode()).hexdigest()
        self.proof.append(f'{inspect.stack()[0][3]}:{s}')

    def receive_random_int(self, min, max, show_in_proof=True):
        """
        Emulates a random integer sent by the verifier in the range [min, max] (including min and
        max).
        """

        # Note that when the range is close to 2^256 this does not emit a uniform distribution,
        # even if sha256 is uniformly distributed.
        # It is, however, close enough for this tutorial's purposes.
        num = min + (int(self.state, 16) % (max - min + 1))
        self.state = sha256((self.state).encode()).hexdigest()
        if show_in_proof:
            self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return num

    def receive_random_field_element(self):
        """
        Emulates a random field element sent by the verifier.
        """
        num = self.receive_random_int(0, FieldElement.k_modulus - 1, show_in_proof=False)
        self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return FieldElement(num)

In [4]:
###MerkleTree
class MerkleTree(object):
    """
    A simple and naive implementation of an immutable Merkle tree.
    """

    def __init__(self, data):
        assert isinstance(data, list)
        assert len(data) > 0, 'Cannot construct an empty Merkle Tree.'
        num_leaves = 2 ** ceil(log2(len(data)))
        self.data = data + [FieldElement(0)] * (num_leaves - len(data))
        self.height = int(log2(num_leaves))
        self.facts = {}
        self.root = self.build_tree()

    def get_authentication_path(self, leaf_id):
        assert 0 <= leaf_id < len(self.data)
        node_id = leaf_id + len(self.data)
        cur = self.root
        decommitment = []
        # In a Merkle Tree, the path from the root to a leaf, corresponds to the the leaf id's
        # binary representation, starting from the second-MSB, where '0' means 'left', and '1' means
        # 'right'.
        # We therefore iterate over the bits of the binary representation - skipping the '0b'
        # prefix, as well as the MSB.
        for bit in bin(node_id)[3:]:
            cur, auth = self.facts[cur]
            if bit == '1':
                auth, cur = cur, auth
            decommitment.append(auth)
        return decommitment

    def build_tree(self):
        return self.recursive_build_tree(1)

    def recursive_build_tree(self, node_id):
        if node_id >= len(self.data):
            # A leaf.
            id_in_data = node_id - len(self.data)
            leaf_data = str(self.data[id_in_data])
            h = sha256(leaf_data.encode()).hexdigest()
            self.facts[h] = leaf_data
            return h
        else:
            # An internal node.
            left = self.recursive_build_tree(node_id * 2)
            right = self.recursive_build_tree(node_id * 2 + 1)
            h = sha256((left + right).encode()).hexdigest()
            self.facts[h] = (left, right)
            return h


def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    #print(cur)
    #print(decommitment)
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
        #print(cur) 
    return cur == root

In [5]:
def remove_trailing_elements(list_of_elements, element_to_remove):
    return list(dropwhile(lambda x: x == element_to_remove, list_of_elements[::-1]))[::-1]


def two_lists_tuple_operation(f, g, operation, fill_value):
    return list(starmap(operation, zip_longest(f, g, fillvalue=fill_value)))


def scalar_operation(list_of_elements, operation, scalar):
    return [operation(c, scalar) for c in list_of_elements]

In [5]:
###FieldElement
class FieldElement:
    """
    Represents an element of F_(3 * 2**30 + 1).
    """
    k_modulus = 3 * 2**30 + 1
    generator_val = 5

    def __init__(self, val):
        self.val = val % FieldElement.k_modulus

    @staticmethod
    def zero():
        """
        Obtains the zero element of the field.
        """
        return FieldElement(0)

    @staticmethod
    def one():
        """
        Obtains the unit element of the field.
        """
        return FieldElement(1)

    def __repr__(self):
        # Choose the shorter representation between the positive and negative values of the element.
        return repr((self.val + self.k_modulus//2) % self.k_modulus - self.k_modulus//2)

    def __eq__(self, other):
        if isinstance(other, int):
            other = FieldElement(other)
        return isinstance(other, FieldElement) and self.val == other.val

    def __hash__(self):
        return hash(self.val)

    @staticmethod
    def generator():
        return FieldElement(FieldElement.generator_val)

    @staticmethod
    def typecast(other):
        if isinstance(other, int):
            return FieldElement(other)
        assert isinstance(other, FieldElement), f'Type mismatch: FieldElement and {type(other)}.'
        return other

    def __neg__(self):
        return self.zero() - self

    def __add__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val + other.val) % FieldElement.k_modulus)

    __radd__ = __add__

    def __sub__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val - other.val) % FieldElement.k_modulus)

    def __rsub__(self, other):
        return -(self - other)

    def __mul__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val * other.val) % FieldElement.k_modulus)

    __rmul__ = __mul__

    def __truediv__(self, other):
        other = FieldElement.typecast(other)
        return self * other.inverse()

    def __pow__(self, n):
        assert n >= 0
        cur_pow = self
        res = FieldElement(1)
        while n > 0:
            if n % 2 != 0:
                res *= cur_pow
            n = n // 2
            cur_pow *= cur_pow
        return res

    def inverse(self):
        t, new_t = 0, 1
        r, new_r = FieldElement.k_modulus, self.val
        while new_r != 0:
            quotient = r // new_r
            t, new_t = new_t, (t - (quotient * new_t))
            r, new_r = new_r, r - quotient * new_r
        assert r == 1
        return FieldElement(t)

    def is_order(self, n):
        """
        Naively checks that the element is of order n by raising it to all powers up to n, checking
        that the element to the n-th power is the unit, but not so for any k<n.
        """
        assert n >= 1
        h = FieldElement(1)
        for _ in range(1, n):
            h *= self
            if h == FieldElement(1):
                return False
        return h * self == FieldElement(1)

    def _serialize_(self):
        return repr(self.val)

    @staticmethod
    def random_element(exclude_elements=[]):
        fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        while fe in exclude_elements:
            fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        return fe

In [6]:
g = FieldElement.generator() ** (3 * 2 ** 20)
points = [g ** i for i in range(1024)]
h_gen = FieldElement.generator() ** ((2 ** 30 * 3) // 8192)
h = [h_gen ** i for i in range(8192)]
domain = [FieldElement.generator() * x for x in h]

In [7]:
def next_fri_domain(domain):
    return [x ** 2 for x in domain[:len(domain) // 2]]


In [8]:
# Auth_path is given as 1 long string, we need to make it an array of strings 
def get_path(path):
  non_string_path = []
  i = 5 # as path starts with "send:"
  while i < len(path):
    if path[i] == "'": #start of the i_th part
      i += 1
      j = i
      while path[i] != "'": # end of i_th part
        i += 1
      non_string_path.append(path[j:i])
      i += 1 #Now we are at the semicolon or reached the end 

    if path[i] == ']':
      assert i + 1 == len(path),f'smth wrong' #There should not be any other ']'. 
   
    i += 1# this is the comma

  return non_string_path

In [9]:
def empty_paramether_sets(number_of_queries = 3):
  roots = []
  evals = []
  auth_paths = []
  alphas = []
  betas = []
  for i in range(number_of_queries):
    evals.append([])
    auth_paths.append([])

  return roots, evals, auth_paths, alphas, betas

In [10]:
number_of_queries = 3

In [11]:
idx = []#query indexes

In [12]:
channel = Channel()#We will emulate the prover by sending all values that prover send to the channel
# to verifie that all challenges were recieved after sending the cooresponding evaluations.

In [13]:
proof = np.load("/content/The proof.npy")

In [14]:
proof

array(['send:6c266a104eeaceae93c14ad799ce595ec8c2764359d7ad1b4b7c57a4da52be04',
       'receive_random_field_element:2948900820',
       'receive_random_field_element:1859037345',
       'receive_random_field_element:2654806830',
       'send:61f7d8283e244d391a483c420776e351fcfdbce525a698461a8307a1345b5652',
       'receive_random_field_element:394024765',
       'send:9431516ee735a498c4aec3da30112e417b03e55e5be939ff44ca8a0a62475b15',
       'receive_random_field_element:1705983878',
       'send:584b4b88a7f296efa0309d8e6faef13573b1ee5dfcb02ed8be5d853172f3fc69',
       'receive_random_field_element:665918954',
       'send:2debb983bb6473a5d4e9046944fb7ef66ef814c64f58ca5d8ebc2a15ed61ca4a',
       'receive_random_field_element:3182659911',
       'send:5da75aa9d9a9a564d7f19e431cbbb91eff030c353f3825dc5352674d1b7813f9',
       'receive_random_field_element:2692084106',
       'send:8ca6b618f3d758e7a99c4988e3a30e5c443f6f4ed79c64b698b031cca67ee4c2',
       'receive_random_field_element:24536

In [15]:
assert len(proof) == 26 + 48 * number_of_queries, f"Proof doesn't have the rigth length"
#as recieve_query_index, f(x) f(gx) f(g^2x) with their paths = 6 CP_i(x) and CP_i(-x) with their auth_paths are 4*10 and the last value

In [16]:
#Roots in the proof are given as "send:root" so get_roots_alphas_betas slices "send:" parts and then sends them to the channel
# to get the challenges(alphas and betas).
def get_roots_alphas_betas(proof, channel):
  roots, evals, auth_paths, alphas, betas = empty_paramether_sets()
  roots.append(proof[0][5:])#The root of Trace polynomial
  channel.send(roots[0])#sending the Trace root

  alphas.append(channel.receive_random_field_element())#receiving the first alpha
  alphas.append(channel.receive_random_field_element())#receiving the second alpha
  alphas.append(channel.receive_random_field_element())#receiving the third alpha

  for i in range(4, 25, 2):
    roots.append(proof[i][5:])#The roots of CP0,CP1,...,CP10

  for i in roots[1: -1]:
    channel.send(i)#Sending CP roots
    betas.append(channel.receive_random_field_element())#receiving the cooresponding betas 

  last_value = proof[25][5:]#The value of the last layer
  channel.send(roots[-1])#sending the last layer
  channel.send(last_value)#sending the last layer's value

  idx.append(channel.receive_random_int(0, 8191 - 16))#the index where verifier makes the check
  return roots, alphas, betas, int(last_value)

In [17]:
#get_evals_and_auth_paths slices evals and auth_paths from proof's "send:evals" and "send:auth_paths" and sends the to the channel
# to get the next query's index.
def get_evals_and_auth_paths(proof, channel, idx, number_of_queries = 3):
  roots, evals, auth_paths, alphas, betas = empty_paramether_sets(number_of_queries)
  for j in range(number_of_queries):
    evals[j].append(int(proof[27 + j * 48][5:]))#f_eval(x)
    evals[j].append(int(proof[29 + j * 48][5:]))#f_eval(gx)
    evals[j].append(int(proof[31 + j * 48][5:]))#f_eval(g^2x)

    auth_paths[j].append(get_path(proof[28 + j * 48]))#auth path for f_eval(x)
    auth_paths[j].append(get_path(proof[30 + j * 48]))#auth path for f_eval(gx)
    auth_paths[j].append(get_path(proof[32 + j * 48]))#auth path for f_eval(g^2x)

    for i in range(33 + j * 48, 74 + j * 48, 2):
      evals[j].append(int(proof[i][5:]))#The CP_i(x) and CP_i(-x)
      #The proof[73] is the last value

    for i in range(34 + j * 48, 74 + j * 48, 2):
      auth_paths[j].append(get_path(proof[i]))#auth paths for  CP_i(x) and CP_i(-x)

    for i in range(len(evals[j][:-1])):
      channel.send(str(evals[j][i]))#sending evals
      channel.send(str(auth_paths[j][i]))#sending auth_paths 

    channel.send(str(evals[j][-1]))#sending the last eval
    idx.append(channel.receive_random_int(0, 8191 - 16))#the index where verifier makes the check

  idx.pop()  

  return evals, auth_paths


In [18]:
roots, alphas, betas, last_value = get_roots_alphas_betas(proof, channel)

In [19]:
evals, auth_paths = get_evals_and_auth_paths(proof, channel, idx)

In [20]:
#verify_paths verifies that f(x), f(gx), f(g^2x), CP_i(x) and CP_i(-x) belong to their respaective  merkle trees.
def verify_paths(idx, evals, auth_paths, roots, length):
    assert verify_decommitment(idx, evals[0], auth_paths[0], roots[0]), f'f(x) is not from the data'
    assert verify_decommitment(idx + 8, evals[1], auth_paths[1], roots[0]), f'f(gx) is not from the data'
    assert verify_decommitment(idx + 16, evals[2], auth_paths[2], roots[0]), f'f(g^2x) is not from the data'
    j = 3
    for i in range(1, len(roots), 1):
        sib_idx = (idx + length // 2) % length#calculating -x
        if j != len(evals) -1:
            assert verify_decommitment(idx, evals[j], auth_paths[j], roots[i]), f'CP[{i-1}](x) is not from the data'
            assert verify_decommitment(sib_idx, evals[j + 1], auth_paths[j + 1], roots[i]), f'CP[{i-1}](-x) is not from the data'
        length = length//2#every next layer is 2 times shorter then the previous one
        j += 2   

In [21]:
for query in range(number_of_queries):
  verify_paths(idx[query], evals[query], auth_paths[query], roots, len(domain))

In [27]:
#verify_Trace verifies that CP polynomial was built using the trace polynomial(evals) and the challenges(alphas).
def verify_Trace(idx, evals, alphas, dom, points):
  cp01 = FieldElement(evals[0] - 1)/(dom[idx] - points[0])
  cp01 *= alphas[0]

  cp02 = (FieldElement(evals[0]) - FieldElement(2338775057))/(dom[idx] - points[1022])
  cp02 *= alphas[1]

  denom = ((dom[idx]**1024 - 1)/((dom[idx] - points[1021])*(dom[idx] - points[1022])*(dom[idx] - points[1023])))
  cp03 = (FieldElement(evals[2]) - FieldElement(evals[1])*FieldElement(evals[1]) - FieldElement(evals[0])*FieldElement(evals[0]))/denom
  cp03 *= alphas[-1]

  cp = cp01 + cp02 + cp03

  assert cp == FieldElement(evals[3])

In [24]:
for query in range(number_of_queries):
  verify_Trace(idx[query], evals[query], alphas, domain, points)

In [25]:
#verify_fri verifies that CP_(i+1) was built using CP_i(x),CP_i(-x)(evals) and the challenges(betas).
def verify_fri(idx, evals, betas, dom = domain):
    j = 3 
    i = 0
    while j + 2 < len(evals):
        dom_len = len(dom) 
        idx = idx % dom_len#domain is 2 times smaller in current layer
        cp_x = FieldElement(evals[j])# CP(x)
        cp_neg_x = FieldElement(evals[j + 1])#CP(-x)
        even = (cp_x + cp_neg_x)/2#calculating the even degree part of polynomial
        odd = (cp_x - cp_neg_x)/(2* dom[idx])#calculating the odd degree part of polynomial
        next_pol = even + betas[i] * odd# The next layer's CP(x^2) can be computed by Even(x) + beta * Odd(x)
        assert FieldElement(evals[j + 2]) == next_pol,f'CP[{i + 1}] is not the next layer '
        i += 1
        j += 2
        dom = next_fri_domain(dom)
    assert FieldElement(last_value) == next_pol,f'The last layer is not the rigth constant'

In [26]:
for query in range(number_of_queries):
  verify_fri(idx[query], evals[query], betas, domain)