<a href="https://colab.research.google.com/github/Haykhovhannisyan1/Stark101_Verifier/blob/TheVerifier/Stark101_verifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import os

In [2]:
import time
import inspect
from hashlib import sha256
from math import log2, ceil
import operator
from functools import reduce
from itertools import dropwhile, starmap, zip_longest
from random import randint


In [3]:
###Chanel###
def serialize(obj):
    """
    Serializes an object into a string.
    """
    if isinstance(obj, (list, tuple)):
        return ','.join(map(serialize, obj))
    return obj._serialize_()


class Channel(object):
    """
    A Channel instance can be used by a prover or a verifier to preserve the semantics of an
    interactive proof system, while under the hood it is in fact non-interactive, and uses Sha256
    to generate randomness when this is required.
    It allows writing string-form data to it, and reading either random integers of random
    FieldElements from it.
    """

    def __init__(self):
        self.state = '0'
        self.proof = []

    def send(self, s):
        self.state = sha256((self.state + s).encode()).hexdigest()
        self.proof.append(f'{inspect.stack()[0][3]}:{s}')

    def receive_random_int(self, min, max, show_in_proof=True):
        """
        Emulates a random integer sent by the verifier in the range [min, max] (including min and
        max).
        """

        # Note that when the range is close to 2^256 this does not emit a uniform distribution,
        # even if sha256 is uniformly distributed.
        # It is, however, close enough for this tutorial's purposes.
        num = min + (int(self.state, 16) % (max - min + 1))
        self.state = sha256((self.state).encode()).hexdigest()
        if show_in_proof:
            self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return num

    def receive_random_field_element(self):
        """
        Emulates a random field element sent by the verifier.
        """
        num = self.receive_random_int(0, FieldElement.k_modulus - 1, show_in_proof=False)
        self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return FieldElement(num)

In [4]:
###MerkleTree
class MerkleTree(object):
    """
    A simple and naive implementation of an immutable Merkle tree.
    """

    def __init__(self, data):
        assert isinstance(data, list)
        assert len(data) > 0, 'Cannot construct an empty Merkle Tree.'
        num_leaves = 2 ** ceil(log2(len(data)))
        self.data = data + [FieldElement(0)] * (num_leaves - len(data))
        self.height = int(log2(num_leaves))
        self.facts = {}
        self.root = self.build_tree()

    def get_authentication_path(self, leaf_id):
        assert 0 <= leaf_id < len(self.data)
        node_id = leaf_id + len(self.data)
        cur = self.root
        decommitment = []
        # In a Merkle Tree, the path from the root to a leaf, corresponds to the the leaf id's
        # binary representation, starting from the second-MSB, where '0' means 'left', and '1' means
        # 'right'.
        # We therefore iterate over the bits of the binary representation - skipping the '0b'
        # prefix, as well as the MSB.
        for bit in bin(node_id)[3:]:
            cur, auth = self.facts[cur]
            if bit == '1':
                auth, cur = cur, auth
            decommitment.append(auth)
        return decommitment

    def build_tree(self):
        return self.recursive_build_tree(1)

    def recursive_build_tree(self, node_id):
        if node_id >= len(self.data):
            # A leaf.
            id_in_data = node_id - len(self.data)
            leaf_data = str(self.data[id_in_data])
            h = sha256(leaf_data.encode()).hexdigest()
            self.facts[h] = leaf_data
            return h
        else:
            # An internal node.
            left = self.recursive_build_tree(node_id * 2)
            right = self.recursive_build_tree(node_id * 2 + 1)
            h = sha256((left + right).encode()).hexdigest()
            self.facts[h] = (left, right)
            return h


def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    #print(cur)
    #print(decommitment)
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
        #print(cur) 
    return cur == root

In [5]:
def remove_trailing_elements(list_of_elements, element_to_remove):
    return list(dropwhile(lambda x: x == element_to_remove, list_of_elements[::-1]))[::-1]


def two_lists_tuple_operation(f, g, operation, fill_value):
    return list(starmap(operation, zip_longest(f, g, fillvalue=fill_value)))


def scalar_operation(list_of_elements, operation, scalar):
    return [operation(c, scalar) for c in list_of_elements]

In [6]:
###FieldElement
class FieldElement:
    """
    Represents an element of F_(3 * 2**30 + 1).
    """
    k_modulus = 3 * 2**30 + 1
    generator_val = 5

    def __init__(self, val):
        self.val = val % FieldElement.k_modulus

    @staticmethod
    def zero():
        """
        Obtains the zero element of the field.
        """
        return FieldElement(0)

    @staticmethod
    def one():
        """
        Obtains the unit element of the field.
        """
        return FieldElement(1)

    def __repr__(self):
        # Choose the shorter representation between the positive and negative values of the element.
        return repr((self.val + self.k_modulus//2) % self.k_modulus - self.k_modulus//2)

    def __eq__(self, other):
        if isinstance(other, int):
            other = FieldElement(other)
        return isinstance(other, FieldElement) and self.val == other.val

    def __hash__(self):
        return hash(self.val)

    @staticmethod
    def generator():
        return FieldElement(FieldElement.generator_val)

    @staticmethod
    def typecast(other):
        if isinstance(other, int):
            return FieldElement(other)
        assert isinstance(other, FieldElement), f'Type mismatch: FieldElement and {type(other)}.'
        return other

    def __neg__(self):
        return self.zero() - self

    def __add__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val + other.val) % FieldElement.k_modulus)

    __radd__ = __add__

    def __sub__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val - other.val) % FieldElement.k_modulus)

    def __rsub__(self, other):
        return -(self - other)

    def __mul__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val * other.val) % FieldElement.k_modulus)

    __rmul__ = __mul__

    def __truediv__(self, other):
        other = FieldElement.typecast(other)
        return self * other.inverse()

    def __pow__(self, n):
        assert n >= 0
        cur_pow = self
        res = FieldElement(1)
        while n > 0:
            if n % 2 != 0:
                res *= cur_pow
            n = n // 2
            cur_pow *= cur_pow
        return res

    def inverse(self):
        t, new_t = 0, 1
        r, new_r = FieldElement.k_modulus, self.val
        while new_r != 0:
            quotient = r // new_r
            t, new_t = new_t, (t - (quotient * new_t))
            r, new_r = new_r, r - quotient * new_r
        assert r == 1
        return FieldElement(t)

    def is_order(self, n):
        """
        Naively checks that the element is of order n by raising it to all powers up to n, checking
        that the element to the n-th power is the unit, but not so for any k<n.
        """
        assert n >= 1
        h = FieldElement(1)
        for _ in range(1, n):
            h *= self
            if h == FieldElement(1):
                return False
        return h * self == FieldElement(1)

    def _serialize_(self):
        return repr(self.val)

    @staticmethod
    def random_element(exclude_elements=[]):
        fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        while fe in exclude_elements:
            fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        return fe

In [7]:
g = FieldElement.generator() ** (3 * 2 ** 20)
points = [g ** i for i in range(1024)]
h_gen = FieldElement.generator() ** ((2 ** 30 * 3) // 8192)
h = [h_gen ** i for i in range(8192)]
domain = [FieldElement.generator() * x for x in h]

In [8]:
# Auth_path is given as 1 big str, we need to make it an array of strings 
def get_path(path):
  non_string_path = []
  i = 5 # as path starts with "send:"
  while i < len(path):
    if path[i] == "'": #start of the i_th part
      i += 1
      j = i
      while path[i] != "'": # end of i_th part
        i += 1
      non_string_path.append(path[j:i])
      i += 1 #Now we are at the semicolon or reached the end 

    if path[i] == ']':
      assert i + 1 == len(path),f'smth wrong' #There should not be any other ']'. 
   
    i += 1

  return non_string_path

In [9]:
roots = []
evals = []
auth_paths = []
idx = []
beta = []

In [10]:
proof = np.load("/content/The proof.npy")

In [None]:
proof

In [12]:
roots.append(proof[0][5:])#The root of Trace polynomial

In [13]:
for i in range(4, 25, 2):
  roots.append(proof[i][5:])#The roots of CP0,CP1,...,CP10

In [14]:
for i in range(5, 25, 2):
  beta.append(int(proof[i][29:]))# The betas send by the verifier

In [15]:
last_value = proof[25][5:]#The value of the last layer

In [16]:
idx.append(int(proof[26][19:]))#The first query's index

In [17]:
evals.append(int(proof[27][5:]))#f_eval(x)
evals.append(int(proof[29][5:]))#f_eval(gx)
evals.append(int(proof[31][5:]))#f_eval(g^2x)


In [18]:
auth_paths.append(get_path(proof[28]))#auth path for f_eval(x)
auth_paths.append(get_path(proof[30]))#auth path for f_eval(gx)
auth_paths.append(get_path(proof[32]))#auth path for f_eval(g^2x)

In [19]:
for i in range(33, 74, 2):
  evals.append(int(proof[i][5:]))#The CP_i(x) and CP_i(-x)
  #The proof[73] is the last value

In [20]:
for i in range(34, 74, 2):
  auth_paths.append(get_path(proof[i]))#auth paths for  CP_i(x) and CP_i(-x)

In [21]:
def verify_paths(idx, evals, auth_paths, roots, length):
    assert verify_decommitment(idx, evals[0], auth_paths[0], roots[0]), f'f(x) is not from the data'
    assert verify_decommitment(idx + 8, evals[1], auth_paths[1], roots[0]), f'f(gx) is not from the data'
    assert verify_decommitment(idx + 16, evals[2], auth_paths[2], roots[0]), f'f(g^2x) is not from the data'
    j = 3
    for i in range(1, len(roots), 1):
        sib_idx = (idx + length // 2) % length
        if j != len(evals) -1:
            assert verify_decommitment(idx, evals[j], auth_paths[j], roots[i]), f'CP[{i-1}](x) is not from the data'
            assert verify_decommitment(sib_idx, evals[j + 1], auth_paths[j + 1], roots[i]), f'CP[{i-1}](-x) is not from the data'
        length = length//2
        j += 2
    #print(evals[j - 2])        

In [22]:
verify_paths(idx[0], evals, auth_paths, roots, len(domain))

In [23]:
cp01 = FieldElement(evals[0] - 1)/(domain[idx[0]] - points[0])
cp01 *= FieldElement(int(proof[1][29:]))

In [24]:
cp02 = (FieldElement(evals[0]) - FieldElement(2338775057))/(domain[idx[0]] - points[1022])
cp02 *= FieldElement(int(proof[2][29:]))

In [25]:
denom = ((domain[idx[0]]**1024 - 1)/((domain[idx[0]] - points[1021])*(domain[idx[0]] - points[1022])*(domain[idx[0]] - points[1023])))

In [26]:
cp03 = (FieldElement(evals[2]) - FieldElement(evals[1])*FieldElement(evals[1]) - FieldElement(evals[0])*FieldElement(evals[0]))/denom
cp03 *= FieldElement(int(proof[3][29:]))

In [27]:
cp = cp01 + cp02 + cp03

In [None]:
assert cp == FieldElement(evals[3])

In [29]:
def next_fri_domain(domain):
    return [x ** 2 for x in domain[:len(domain) // 2]]


In [30]:
def verify_fri(query_idx,evals,beta, dom = domain):
    j = 3 
    i = 0
    while j + 2 < len(evals):
        dom_len = len(dom) 
        query_idx = query_idx % dom_len
        cp_x = FieldElement(evals[j])# CP(x)
        cp_neg_x = FieldElement(evals[j + 1])#CP(-x)
        even = (cp_x + cp_neg_x)/2
        odd = (cp_x - cp_neg_x)/(2* dom[query_idx])
        b_i = FieldElement(beta[i])
        next_pol = even + b_i * odd# The next layer's CP(x^2) can be computed by Even(x) + beta * Odd(x)
        assert FieldElement(evals[j + 2]) == next_pol,f'CP[{i + 1}] is not the next layer '
        i += 1
        j += 2
        dom = next_fri_domain(dom)

In [None]:
verify_fri(idx[0], evals, beta, domain)