In [1]:
N_QUERY = 3

In [2]:
import pandas as pd
from itertools import product
import time
from poseidon import Poseidon, ARC, generate_cauchy_matrix
from field import FieldElement
import numpy as np
from functools import reduce
import constants

In [3]:
start_all = time.time()

In [4]:
import time
import math

def print_prover_success_banner(start_time, channel):
    # 1. Calculations
    duration = time.time() - start_time
    
    # Proof Size
    proof_size_bytes = sum(len(entry.encode('utf-8')) for entry in channel.proof)
    proof_size_kb = proof_size_bytes / 1024
    
    # Theoretical Soundness (The claim you are making)
    # Max degree is determined by the Poseidon constraint (degree 5)
    max_degree = (constants.POSEIDON_TRACE_EVAL_LENGTH - 1) * 5
    rho = max_degree / constants.STATE_EVAL_LENGTH
    p_cheat = rho ** N_QUERY
    security_bits = -math.log2(p_cheat) if p_cheat > 0 else 0
    
    # Expansion Factor (Blowup)
    blowup = constants.STATE_EVAL_LENGTH / constants.POSEIDON_TRACE_EVAL_LENGTH

    # 2. Colors & Art
    BLUE = "\033[94m"
    GREEN = "\033[92m"
    YELLOW = "\033[93m"
    MAGENTA = "\033[95m"
    CYAN = "\033[96m"
    RESET = "\033[0m"
    BOLD = "\033[1m"
    
    ascii_art = f"""
{BLUE}{BOLD}
  _____  _____   ____   ____  ______   _____ ______ _   _ 
 |  __ \|  __ \ / __ \ / __ \|  ____| / ____|  ____| \ | |
 | |__) | |__) | |  | | |  | | |__   | |  __| |__  |  \| |
 |  ___/|  _  /| |  | | |  | |  __|  | | |_ |  __| | . ` |
 | |    | | \ \| |__| | |__| | |     | |__| | |____| |\  |
 |_|    |_|  \_\\____/ \____/|_|      \_____|______|_| \_|
                                                          
{RESET}"""

    print(ascii_art)
    print(f"{BOLD}Status:{RESET}   {BLUE}PROOF GENERATED SUCCESSFULLY{RESET}")
    print("-" * 60)
    print(f"{BOLD}Computational Effort:{RESET}")
    print(f"  ‚è±Ô∏è  {YELLOW}Generation Time:{RESET}  {duration:.4f} seconds")
    print(f"  üìà {CYAN}Trace Length:{RESET}     {constants.POSEIDON_TRACE_EVAL_LENGTH} steps")
    print(f"  üîç {CYAN}Blowup Factor:{RESET}    {blowup:.1f}x (Domain size: {constants.STATE_EVAL_LENGTH})")
    print("-" * 60)
    print(f"{BOLD}Proof Artifacts:{RESET}")
    print(f"  üì¶ {GREEN}Proof Size:{RESET}       {proof_size_kb:.2f} KB")
    print(f"  üõ°Ô∏è  {MAGENTA}Claimed Security:{RESET} {security_bits:.1f} bits (P(Cheat) ‚âà {p_cheat:.1e})")
    print("-" * 60)

  print(ascii_art)


In [5]:
states = {'q0', 'q1'}
alphabet = {'a', 'b'}
transition = {
    'q0': {'a': 'q1', 'b': 'q0'},
    'q1': {'a': 'q0', 'b': 'q1'}
}
start_state = 'q0'
accept_states = {'q0'}

In [6]:
def generate_dfa_trace(input_string):
  trace = {'i':[], 's': [], 'c': []}
  current_state = start_state
  i = 0
  for symbol in input_string:
      if symbol not in alphabet:
          return False
      trace['i'].append(i)
      i+=1
      trace['s'].append(current_state)
      trace['c'].append(symbol)
      current_state = transition[current_state][symbol]

  return trace

In [7]:
word = "aaabababbbabba"
trace = generate_dfa_trace(word)

In [8]:
state_map = {"q0":1,"q1":2}
caracter_map = {'a': 3, 'b': 5}
transition_map = {("q0", 'a'): 7, ("q0", 'b'):9, ("q1", 'a'):11,("q1", 'b'):13 }

In [9]:
hash_input = [FieldElement(caracter_map[x]) for x in word]

In [10]:
phash, ptrace = Poseidon().hash(hash_input)

In [11]:
phash

array([-1161095951], dtype=object)

In [12]:
# @title
from itertools import dropwhile, starmap, zip_longest


def remove_trailing_elements(list_of_elements, element_to_remove):
    return list(dropwhile(lambda x: x == element_to_remove, list_of_elements[::-1]))[::-1]


def two_lists_tuple_operation(f, g, operation, fill_value):
    return list(starmap(operation, zip_longest(f, g, fillvalue=fill_value)))


def scalar_operation(list_of_elements, operation, scalar):
    return [operation(c, scalar) for c in list_of_elements]

In [13]:
# @title
import operator
from functools import reduce
try:
    from tqdm import tqdm
except ModuleNotFoundError:
    # tqdm is a wrapper for iterators implementing a progress bar. If it's
    # not available, simply return the iterator itself.
    tqdm = lambda x: x

def trim_trailing_zeros(p):
    """
    Removes zeros from the end of a list.
    """
    return remove_trailing_elements(p, FieldElement.zero())


def prod(values):
    """
    Computes a product.
    """
    len_values = len(values)
    if len_values == 0:
        return 1
    if len_values == 1:
        return values[0]
    return prod(values[:len_values // 2]) * prod(values[len_values // 2:])


def latex_monomial(exponent, coef, var):
    """
    Returns a string representation of a monomial as LaTeX.
    """
    if exponent == 0:
        return str(coef)
    if coef == 1:
        coef = ''
    if coef == -1:
        coef = '-'
    if exponent == 1:
        return f'{coef}{var}'
    return f'{coef}{var}^{{{exponent}}}'


class Polynomial:
    """
    Represents a polynomial over FieldElement.
    """

    @classmethod
    def X(cls):
        """
        Returns the polynomial x.
        """
        return cls([FieldElement.zero(), FieldElement.one()])

    def __init__(self, coefficients, var='x'):
        # Internally storing the coefficients in self.poly, least-significant (i.e. free term)
        # first, so $9 - 3x^2 + 19x^5$ is represented internally by the list  [9, 0, -3, 0, 0, 19].
        # Note that coefficients is copied, so the caller may freely modify the given argument.
        self.poly = remove_trailing_elements(coefficients, FieldElement.zero())
        self.var = var

    def _repr_latex_(self):
        """
        Returns a LaTeX representation of the Polynomial, for Jupyter.
        """
        if not self.poly:
            return '$0$'
        res = ['$']
        first = True
        for exponent, coef in enumerate(self.poly):
            if coef == 0:
                continue
            monomial = latex_monomial(exponent, coef, self.var)
            if first:
                first = False
                res.append(monomial)
                continue
            oper = '+'
            if monomial[0] == '-':
                oper = '-'
                monomial = monomial[1:]
            res.append(oper)
            res.append(monomial)
        res.append('$')
        return ' '.join(res)

    def __eq__(self, other):
        try:
            other = Polynomial.typecast(other)
        except AssertionError:
            return False
        return self.poly == other.poly

    @staticmethod
    def typecast(other):
        """
        Constructs a Polynomial from `FieldElement` or `int`.
        """
        if isinstance(other, int):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            other = Polynomial([other])
        assert isinstance(other, Polynomial), f'Type mismatch: Polynomial and {type(other)}.'
        return other

    def __add__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.add, FieldElement.zero()))

    __radd__ = __add__  # To support <int> + <Polynomial> (as in `1 + x + x**2`).

    def __sub__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.sub, FieldElement.zero()))

    def __rsub__(self, other):  # To support <int> - <Polynomial> (as in `1 - x + x**2`).
        return -(self - other)

    def __neg__(self):
        return Polynomial([]) - self

    def __mul__(self, other):
        other = Polynomial.typecast(other)
        pol1, pol2 = [[x.val for x in p.poly] for p in (self, other)]
        res = [0] * (self.degree() + other.degree() + 1)
        for i, c1 in enumerate(pol1):
            for j, c2 in enumerate(pol2):
                res[i + j] += c1 * c2
        res = [FieldElement(x) for x in res]
        return Polynomial(res)

    __rmul__ = __mul__  # To support <int> * <Polynomial>.

    def compose(self, other):
        """
        Composes this polynomial with `other`.
        Example:
        >>> f = X**2 + X
        >>> g = X + 1
        >>> f.compose(g) == (2 + 3*X + X**2)
        True
        """
        other = Polynomial.typecast(other)
        res = Polynomial([])
        for coef in self.poly[::-1]:
            res = (res * other) + Polynomial([coef])
        return res

    def qdiv(self, other):
        """
        Returns q, r the quotient and remainder polynomials respectively, such that
        f = q * g + r, where deg(r) < deg(g).
        * Assert that g is not the zero polynomial.
        """
        other = Polynomial.typecast(other)
        pol2 = trim_trailing_zeros(other.poly)
        assert pol2, 'Dividing by zero polynomial.'
        pol1 = trim_trailing_zeros(self.poly)
        if not pol1:
            return [], []
        rem = pol1
        deg_dif = len(rem) - len(pol2)
        quotient = [FieldElement.zero()] * (deg_dif + 1)
        g_msc_inv = pol2[-1].inverse()
        while deg_dif >= 0:
            tmp = rem[-1] * g_msc_inv
            quotient[deg_dif] = quotient[deg_dif] + tmp
            last_non_zero = deg_dif - 1
            for i, coef in enumerate(pol2, deg_dif):
                rem[i] = rem[i] - (tmp * coef)
                if rem[i] != FieldElement.zero():
                    last_non_zero = i
            # Eliminate trailing zeroes (i.e. make r end with its last non-zero coefficient).
            rem = rem[:last_non_zero + 1]
            deg_dif = len(rem) - len(pol2)
        return Polynomial(trim_trailing_zeros(quotient)), Polynomial(rem)

    def __truediv__(self, other):
        div, mod = self.qdiv(other)
        assert mod == 0, 'Polynomials are not divisible.'
        return div

    def __mod__(self, other):
        return self.qdiv(other)[1]

    @staticmethod
    def monomial(degree, coefficient):
        """
        Constructs the monomial coefficient * x**degree.
        """
        return Polynomial([FieldElement.zero()] * degree + [coefficient])

    @staticmethod
    def gen_linear_term(point):
        """
        Generates the polynomial (x-p) for a given point p.
        """
        return Polynomial([FieldElement.zero() - point, FieldElement.one()])

    def degree(self):
        """
        The polynomials are represented by a list so the degree is the length of the list minus the
        number of trailing zeros (if they exist) minus 1.
        This implies that the degree of the zero polynomial will be -1.
        """
        return len(trim_trailing_zeros(self.poly)) - 1

    def get_nth_degree_coefficient(self, n):
        """
        Returns the coefficient of x**n
        """
        if n > self.degree():
            return FieldElement.zero()
        else:
            return self.poly[n]

    def scalar_mul(self, scalar):
        """
        Multiplies polynomial by a scalar
        """
        return Polynomial(scalar_operation(self.poly, operator.mul, scalar))

    def eval(self, point):
        """
        Evaluates the polynomial at the given point using Horner evaluation.
        """
        point = FieldElement.typecast(point).val
        # Doing this with ints (as opposed to `FieldElement`s) speeds up eval significantly.
        val = 0
        for coef in self.poly[::-1]:
            val = (val * point + coef.val) % FieldElement.k_modulus
        return FieldElement(val)

    def __call__(self, other):
        """
        If `other` is an int or a FieldElement, evaluates the polynomial on `other` (in the field).
        If `other` is a polynomial, composes self with `other` as self(other(x)).
        """
        if isinstance(other, (int)):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            return self.eval(other)
        if isinstance(other, Polynomial):
            return self.compose(other)
        raise NotImplementedError()

    def __pow__(self, other):
        """
        Calculates self**other using repeated squaring.
        """
        assert other >= 0
        res = Polynomial([FieldElement(1)])
        cur = self
        while True:
            if other % 2 != 0:
                res *= cur
            other >>= 1
            if other == 0:
                break
            cur = cur * cur
        return res


def calculate_lagrange_polynomials(x_values):
    """
    Given the x_values for evaluating some polynomials, it computes part of the lagrange polynomials
    required to interpolate a polynomial over this domain.
    """
    lagrange_polynomials = []
    monomials = [Polynomial.monomial(1, FieldElement.one()) -
                 Polynomial.monomial(0, x) for x in x_values]
    numerator = prod(monomials)
    for j in tqdm(range(len(x_values))):
        # In the denominator, we have:
        # (x_j-x_0)(x_j-x_1)...(x_j-x_{j-1})(x_j-x_{j+1})...(x_j-x_{len(X)-1})
        denominator = prod([x_values[j] - x for i, x in enumerate(x_values) if i != j])
        # Numerator is a bit more complicated, since we need to compute a poly multiplication here.
        # Similarly to the denominator, we have:
        # (x-x_0)(x-x_1)...(x-x_{j-1})(x-x_{j+1})...(x-x_{len(X)-1})
        cur_poly, _ = numerator.qdiv(monomials[j].scalar_mul(denominator))
        lagrange_polynomials.append(cur_poly)
    return lagrange_polynomials


def interpolate_poly_lagrange(y_values, lagrange_polynomials):
    """
    :param y_values: y coordinates of the points.
    :param lagrange_polynomials: the polynomials obtained from calculate_lagrange_polynomials.
    :return: the interpolated poly/
    """
    poly = Polynomial([])
    for j, y_value in enumerate(y_values):
        poly += lagrange_polynomials[j].scalar_mul(y_value)
    return poly


def interpolate_poly(x_values, y_values):
    """
    Returns a polynomial of degree < len(x_values) that evaluates to y_values[i] on x_values[i] for
    all i.
    """
    assert len(x_values) == len(y_values)
    assert all(isinstance(val, FieldElement) for val in x_values),\
        'Not all x_values are FieldElement'
    lp = calculate_lagrange_polynomials(x_values)
    assert all(isinstance(val, FieldElement) for val in y_values),\
        'Not all y_values are FieldElement'
    return interpolate_poly_lagrange(y_values, lp)

In [14]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


import inspect
from hashlib import sha256


def serialize(obj):
    """
    Serializes an object into a string.
    """
    if isinstance(obj, (list, tuple)):
        return ','.join(map(serialize, obj))
    return obj._serialize_()


class Channel(object):
    """
    A Channel instance can be used by a prover or a verifier to preserve the semantics of an
    interactive proof system, while under the hood it is in fact non-interactive, and uses Sha256
    to generate randomness when this is required.
    It allows writing string-form data to it, and reading either random integers of random
    FieldElements from it.
    """

    def __init__(self):
        self.state = '0'
        self.proof = []

    def send(self, s):
        self.state = sha256((self.state + s).encode()).hexdigest()
        self.proof.append(f'{inspect.stack()[0][3]}:{s}')

    def receive_random_int(self, min, max, show_in_proof=True):
        """
        Emulates a random integer sent by the verifier in the range [min, max] (including min and
        max).
        """

        # Note that when the range is close to 2^256 this does not emit a uniform distribution,
        # even if sha256 is uniformly distributed.
        # It is, however, close enough for this tutorial's purposes.
        num = min + (int(self.state, 16) % (max - min + 1))
        self.state = sha256((self.state).encode()).hexdigest()
        if show_in_proof:
            self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return num

    def receive_random_field_element(self):
        """
        Emulates a random field element sent by the verifier.
        """
        num = self.receive_random_int(0, FieldElement.k_modulus - 1, show_in_proof=False)
        self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return FieldElement(num)

In [15]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


from hashlib import sha256
from math import log2, ceil


class MerkleTree(object):
    """
    A simple and naive implementation of an immutable Merkle tree.
    """

    def __init__(self, data):
        assert isinstance(data, list)
        assert len(data) > 0, 'Cannot construct an empty Merkle Tree.'
        num_leaves = 2 ** ceil(log2(len(data)))
        self.data = data + [FieldElement(0)] * (num_leaves - len(data))
        self.height = int(log2(num_leaves))
        self.facts = {}
        self.root = self.build_tree()

    def get_authentication_path(self, leaf_id):
        assert 0 <= leaf_id < len(self.data)
        node_id = leaf_id + len(self.data)
        cur = self.root
        decommitment = []
        # In a Merkle Tree, the path from the root to a leaf, corresponds to the the leaf id's
        # binary representation, starting from the second-MSB, where '0' means 'left', and '1' means
        # 'right'.
        # We therefore iterate over the bits of the binary representation - skipping the '0b'
        # prefix, as well as the MSB.
        for bit in bin(node_id)[3:]:
            cur, auth = self.facts[cur]
            if bit == '1':
                auth, cur = cur, auth
            decommitment.append(auth)
        return decommitment

    def build_tree(self):
        return self.recursive_build_tree(1)

    def recursive_build_tree(self, node_id):
        if node_id >= len(self.data):
            # A leaf.
            id_in_data = node_id - len(self.data)
            leaf_data = str(self.data[id_in_data])
            h = sha256(leaf_data.encode()).hexdigest()
            self.facts[h] = leaf_data
            return h
        else:
            # An internal node.
            left = self.recursive_build_tree(node_id * 2)
            right = self.recursive_build_tree(node_id * 2 + 1)
            h = sha256((left + right).encode()).hexdigest()
            self.facts[h] = (left, right)
            return h


def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
    return cur == root

In [16]:
X = Polynomial.X()

In [17]:
def next_power_of_two(n: int) -> int:
    if n <= 0:
        return 1
    # If n is already a power of two, return n
    if (n & (n - 1)) == 0:
        return n
    power = 1
    while power < n:
        power <<= 1
    return power

In [18]:
#g = FieldElement.generator() ** (3*2**26)
ptrace_domain_size = next_power_of_two(len(ptrace))
gp = FieldElement.generator() ** (3*2**30/(ptrace_domain_size))
x = [gp**i for i in range(len(trace['i']))]
z = [gp**i for i in range(len(ptrace))]

In [19]:
len(x)

14

In [20]:
state_polynomial = interpolate_poly(x, [FieldElement(state_map[s]) for s in trace['s']])
caracter_polynomial = interpolate_poly(x, [FieldElement(caracter_map[c]) for c in trace['c']])

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:00<00:00, 6602.23it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14/14 [00:00<00:00, 6867.87it/s]


In [21]:
# not sure if I should interpolate over points not already in the x array
transi_x = [FieldElement(state_map[s]*caracter_map[c]) for s,c in product(states, alphabet)]
transition_polynomial = interpolate_poly(transi_x, [FieldElement(state_map[transition[s][c]]) for (s,c) in product(states, alphabet)])

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:00<00:00, 12836.43it/s]


In [22]:
step_polynomial = state_polynomial(gp*X) - transition_polynomial(state_polynomial(X)*caracter_polynomial(X))

In [23]:
ptrace_polynomials = [interpolate_poly(z, [row[i] for row in ptrace]) for i in range(0,len(ptrace[0]))]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1691.02it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1725.76it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1139.20it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1723.27it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1738.45it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1740.50it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1737.50it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1719.38it/s]


In [24]:
w = FieldElement.generator()
#h = FieldElement.generator() ** (3*2**23)
h = FieldElement.generator() ** (3*2**30/(ptrace_domain_size*8))
#H = [h**i for i in range(128)] # 128 is the order of h
H = [h**i for i in range(ptrace_domain_size*8)]
# here also, not sure about the domain, should be big enough for both states and transitions
eval_domain = [w*h for h in H]

#gpp = FieldElement.generator() ** (3*2**30/(ptrace_domain_size*8))
#GP = [gpp**i for i in range(ptrace_domain_size*8)]
#peval_domain = [w*g for g in GP]

In [25]:
mds = generate_cauchy_matrix(8)

In [26]:
arc_polys = [interpolate_poly(z, [FieldElement(ARC[j%(8+57)][i]) for j in range(len(ptrace))]) for i in range(len(ARC[0]))]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1703.85it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1706.54it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1702.12it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1709.43it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1709.28it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1717.21it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1732.07it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1727.50it/s]


In [27]:
def hash_input_index(i,j):
    if (i % (8+57) == 0 and j < 7 and i//(8+57) < len(hash_input)//7):
        return hash_input[7*(i//(8+57))+j]
    else:
        return FieldElement(0)

In [28]:
hash_input_poly = [interpolate_poly(z, [hash_input_index(i,j) for i in range(len(z))]) for j in range(len(ptrace[0]))]
g_hash = FieldElement.generator() ** (3*2**30/(ptrace_domain_size*8))
G_H = [g_hash**i for i in range(ptrace_domain_size*8)]
heval_domain = [w*g for g in G_H]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1170.58it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1716.80it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1712.68it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1689.91it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1690.02it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1720.56it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1691.28it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 131/131 [00:00<00:00, 1729.69it/s]


In [29]:
state_eval = [state_polynomial(d) for d in eval_domain]
caracter_eval = [caracter_polynomial(d) for d in eval_domain]
transition_eval = [transition_polynomial(state*caracter) for state,caracter in product(state_eval,caracter_eval)]
ptrace_evals = [[p(x) for x in eval_domain] for p in ptrace_polynomials]
hash_evals = [[p(x) for x in heval_domain] for p in hash_input_poly]
arc_polys_eval = [[p(x) for x in eval_domain] for p in arc_polys]
state_merkle = MerkleTree(state_eval)
caracter_merkle = MerkleTree(caracter_eval)
transition_merkle = MerkleTree(transition_eval)
ptrace_merkle = MerkleTree([e for p in ptrace_evals for e in p])
hash_merkle = MerkleTree([e for p in hash_evals for e in p])
arc_merkle = MerkleTree([e for p in arc_polys_eval for e in p])

In [30]:
initial_constraint_p = (state_polynomial-state_map['q0'])/(X-1)
Z_G = Polynomial([FieldElement.one()])
for i in range(len(x)):
  Z_G = Z_G * (X-x[i])
ap = Polynomial([FieldElement.one()])
for c in alphabet:
    ap = ap * (caracter_polynomial - caracter_map[c])
caracter_constraint_p = ap/Z_G
ap = Polynomial([FieldElement.one()])
for f in accept_states:
  ap = ap * (state_polynomial - state_map[f])
accepting_constraint_p = ap/(X-x[-1])
step_constraint_p = step_polynomial/(Z_G/(X-x[-1]))

In [31]:
poseidon_full_round_polynomials = mds@[(ptrace_polynomials[j](X) + hash_input_poly[j](X) + arc_polys[j](X))**5 for j in range(0,8)]

In [32]:
poseidon_partial_round_polynomials = mds@[(ptrace_polynomials[j](X) + hash_input_poly[j](X) + arc_polys[j](X))**(5 if j == 0 else 1) for j in range(0,8)]

In [33]:
# TODO: set poseidon constraints haaa it's hell
# UPDATE: Done ;)
X_G = Polynomial([FieldElement.one()])
Z_G = Polynomial([FieldElement.one()])
r_f = 8
r_p = 57
for i in range(len(ptrace)-1):
    r = i%(8+57)
    if(r < r_f/2 or r >= r_f/2+r_p):
      X_G = X_G * (X-z[i])
    else:
        Z_G = Z_G * (X-z[i])
poseidon_constraint_full_round = [(ptrace_polynomials[j](gp*X) - poseidon_full_round_polynomials[j](X))/X_G for j in range(len(ptrace_polynomials))]
poseidon_constraint_partial_round = [(ptrace_polynomials[j](gp*X) - poseidon_partial_round_polynomials[j](X))/Z_G for j in range(len(ptrace_polynomials))]

In [34]:
poseidon_constraints_initial = [p/(X-1) for p in ptrace_polynomials]

In [35]:
poseidon_output_constraint = (ptrace_polynomials[0](X)-phash[0])/(X-z[-1])

In [36]:
channel = Channel()
channel.send(str(phash[0]))
channel.send(state_merkle.root)
channel.send(caracter_merkle.root)
channel.send(transition_merkle.root)
channel.send(ptrace_merkle.root)
channel.send(hash_merkle.root)
channel.send(arc_merkle.root)

In [37]:
ps = [poseidon_output_constraint, initial_constraint_p, caracter_constraint_p, accepting_constraint_p, step_constraint_p, *poseidon_constraints_initial, *poseidon_constraint_full_round, *poseidon_constraint_partial_round]
alphas = [channel.receive_random_field_element() for _ in range(len(ps))]
cp = reduce(lambda x, y: x+y,map(lambda a, p: a*p, alphas, ps),FieldElement(0))

In [38]:
cp_eval = [cp(x) for x in eval_domain]
cp_merkle = MerkleTree(cp_eval)
channel.send(cp_merkle.root)

In [39]:
def fold_domain(domain):
  return [x**2 for x in domain[:len(domain)//2]]

In [40]:
def fold(p, beta):
  odd = p.poly[1::2]
  even = p.poly[::2]
  o = Polynomial(odd)
  e = Polynomial(even)
  return e + beta*o

In [41]:
def next_fri_layer(poly, domain, beta):
    next_poly = fold(poly, beta)
    next_domain = fold_domain(domain)
    next_layer = [next_poly(x) for x in next_domain]
    return next_poly, next_domain, next_layer

In [42]:
def FriCommit(cp, domain, cp_eval, cp_merkle, channel):
    fri_polys = [cp]
    fri_domains = [domain]
    fri_layers = [cp_eval]
    fri_merkles = [cp_merkle]
    betas = []
    i = 0
    while fri_polys[-1].degree() > 0:
        beta = channel.receive_random_field_element()
        betas.append(beta)
        next_poly, next_domain, next_layer = next_fri_layer(fri_polys[-1], fri_domains[-1], beta)
        #print(i, beta)
        #assert(next_poly(X*X) == (fri_polys[-1](X) - fri_polys[-1](-X))/2 + beta*(fri_polys[-1](X) - fri_polys[-1](-X))/(2*X))
        fri_polys.append(next_poly)
        fri_domains.append(next_domain)
        fri_layers.append(next_layer)
        fri_merkles.append(MerkleTree(next_layer))
        channel.send(fri_merkles[-1].root)
        i+=1
    channel.send("FINISHED_FRI")
    channel.send(str(fri_polys[-1].poly[0]))
    return fri_polys, fri_domains, fri_layers, fri_merkles, betas

In [43]:
fri_polys, fri_domains, fri_layers, fri_merkles, betas = FriCommit(cp, eval_domain,cp_eval,cp_merkle,channel)

In [44]:
def decommit_on_fri_layers(idx, channel):
    prev_idx = None
    prev_sibidx = None
    p_idx = None
    i = 0
    for layer, merkle in zip(fri_layers[:-1], fri_merkles[:-1]):
        length = len(layer)
        idx = idx % length
        sib_idx = (idx + length // 2) % length
        channel.send(str(layer[idx]))
        channel.send(str(merkle.get_authentication_path(idx)))
        channel.send(str(layer[sib_idx]))
        channel.send(str(merkle.get_authentication_path(sib_idx)))
    channel.send(str(fri_layers[-1][0]))

In [45]:
def decommit_on_query(idx, channel): 
    assert idx + 8 < len(state_eval), f'query index: {idx} is out of range. Length of layer: {len(state_eval)}.'
    channel.send(str(state_eval[idx])) # f(x).
    channel.send(str(state_merkle.get_authentication_path(idx))) # auth path for f(x).
    assert(state_eval[idx + 8] == state_polynomial(gp*eval_domain[idx]))
    channel.send(str(state_eval[idx + 8])) # f(gx).
    channel.send(str(state_merkle.get_authentication_path(idx + 8))) # auth path for f(gx).
    #channel.send(str(state_eval[idx + 64])) # f(g^2x).
    #channel.send(str(state_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).
    assert idx + 8 < len(caracter_eval), f'query index: {idx} is out of range. Length of layer: {len(caracter_eval)}.'
    channel.send(str(caracter_eval[idx])) # f(x).
    channel.send(str(caracter_merkle.get_authentication_path(idx))) # auth path for f(x).
    #channel.send(str(caracter_eval[idx + 32])) # f(gx).
    #channel.send(str(caracter_merkle.get_authentication_path(idx + 32))) # auth path for f(gx).
    #channel.send(str(caracter_eval[idx + 64])) # f(g^2x).
    #channel.send(str(caracter_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).
    assert idx + 8 < len(transition_eval), f'query index: {idx} is out of range. Length of layer: {len(transition_eval)}.'
    transi_x = transition_polynomial(caracter_eval[idx]*state_eval[idx])
    channel.send(str(transi_x)) # f(x).
    transi_idx = len(state_eval)*idx+idx
    channel.send(str(transition_merkle.get_authentication_path(transi_idx))) # auth path for f(x).
    #channel.send(str(transition_eval[idx + 32])) # f(gx).
    #channel.send(str(transition_merkle.get_authentication_path(idx + 32))) # auth path for f(gx).
    #channel.send(str(transition_eval[idx + 64])) # f(g^2x).
    #channel.send(str(transition_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).


    #### POSEIDON ####
    for row in range(len(ptrace[0])):
        trace_x = ptrace_evals[row][idx]
        channel.send(str(trace_x))
        channel.send(str(ptrace_merkle.get_authentication_path(len(eval_domain)*row+idx)))
        assert(ptrace_polynomials[row](eval_domain[idx]) == trace_x)

        next_trace_x = ptrace_evals[row][idx+8]
        channel.send(str(next_trace_x))
        channel.send(str(ptrace_merkle.get_authentication_path(len(eval_domain)*row+idx+8)))
        assert(ptrace_polynomials[row](gp*eval_domain[idx]) == next_trace_x)

        hash_x = hash_evals[row][idx]
        channel.send(str(hash_x))
        channel.send(str(hash_merkle.get_authentication_path(len(eval_domain)*row+idx)))
        
        arc_x = arc_polys_eval[row][idx]
        channel.send(str(arc_x))
        channel.send(str(arc_merkle.get_authentication_path(len(eval_domain)*row+idx)))
        
    decommit_on_fri_layers(idx, channel)   

In [46]:
def decommit_fri(channel):
    for query in range(N_QUERY):
        # Get a random index from the verifier and send the corresponding decommitment.
        decommit_on_query(channel.receive_random_int(0, len(state_eval)-8), channel)

In [47]:
decommit_fri(channel)

In [48]:
# TODO
# - [x] Fix the prover (the domain size is too large for the automata trace so it must be spanned across the domain, not 100% sure)
# - [x] Fix the verification
# - [ ] It seems I'm already doing out of domain querying

In [49]:
proof_file = open("proof.txt", "w")
proof_file.write(str(channel.proof))

136848

In [50]:
print_prover_success_banner(start_all, channel)


[94m[1m
  _____  _____   ____   ____  ______   _____ ______ _   _ 
 |  __ \|  __ \ / __ \ / __ \|  ____| / ____|  ____| \ | |
 | |__) | |__) | |  | | |  | | |__   | |  __| |__  |  \| |
 |  ___/|  _  /| |  | | |  | |  __|  | | |_ |  __| | . ` |
 | |    | | \ \| |__| | |__| | |     | |__| | |____| |\  |
 |_|    |_|  \_\____/ \____/|_|      \_____|______|_| \_|

[0m
[1mStatus:[0m   [94mPROOF GENERATED SUCCESSFULLY[0m
------------------------------------------------------------
[1mComputational Effort:[0m
  ‚è±Ô∏è  [93mGeneration Time:[0m  29.5474 seconds
  üìà [96mTrace Length:[0m     131 steps
  üîç [96mBlowup Factor:[0m    15.6x (Domain size: 2048)
------------------------------------------------------------
[1mProof Artifacts:[0m
  üì¶ [92mProof Size:[0m       132.07 KB
  üõ°Ô∏è  [95mClaimed Security:[0m 5.0 bits (P(Cheat) ‚âà 3.2e-02)
------------------------------------------------------------
