In [1]:
import pandas as pd
from itertools import product
import time

In [2]:
start_all = time.time()

In [3]:
states = {'q0', 'q1'}
alphabet = {'a', 'b'}
transition = {
    'q0': {'a': 'q1', 'b': 'q0'},
    'q1': {'a': 'q0', 'b': 'q1'}
}
start_state = 'q0'
accept_states = {'q0'}

In [4]:
def generate_dfa_trace(input_string):
  trace = {'i':[], 's': [], 'c': []}
  current_state = start_state
  i = 0
  for symbol in input_string:
      if symbol not in alphabet:
          return False
      trace['i'].append(i)
      i+=1
      trace['s'].append(current_state)
      trace['c'].append(symbol)
      current_state = transition[current_state][symbol]

  return trace

In [5]:
trace = generate_dfa_trace("aaabababbbabba")

In [6]:
state_map = {"q0":1,"q1":2}
caracter_map = {'a': 3, 'b': 5}
transition_map = {("q0", 'a'): 7, ("q0", 'b'):9, ("q1", 'a'):11,("q1", 'b'):13 }

In [7]:
# @title
from random import randint


class FieldElement:
    """
    Represents an element of F_(3 * 2**30 + 1).
    """
    k_modulus = 3 * 2**30 + 1
    generator_val = 5

    def __init__(self, val):
        self.val = val % FieldElement.k_modulus

    @staticmethod
    def zero():
        """
        Obtains the zero element of the field.
        """
        return FieldElement(0)

    @staticmethod
    def one():
        """
        Obtains the unit element of the field.
        """
        return FieldElement(1)

    def __repr__(self):
        # Choose the shorter representation between the positive and negative values of the element.
        return repr((self.val + self.k_modulus//2) % self.k_modulus - self.k_modulus//2)

    def __eq__(self, other):
        if isinstance(other, int):
            other = FieldElement(other)
        return isinstance(other, FieldElement) and self.val == other.val

    def __hash__(self):
        return hash(self.val)

    @staticmethod
    def generator():
        return FieldElement(FieldElement.generator_val)

    @staticmethod
    def typecast(other):
        if isinstance(other, int):
            return FieldElement(other)
        assert isinstance(other, FieldElement), f'Type mismatch: FieldElement and {type(other)}.'
        return other

    def __neg__(self):
        return self.zero() - self

    def __add__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val + other.val) % FieldElement.k_modulus)

    __radd__ = __add__

    def __sub__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val - other.val) % FieldElement.k_modulus)

    def __rsub__(self, other):
        return -(self - other)

    def __mul__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val * other.val) % FieldElement.k_modulus)

    __rmul__ = __mul__

    def __truediv__(self, other):
        other = FieldElement.typecast(other)
        return self * other.inverse()

    def __pow__(self, n):
        assert n >= 0
        cur_pow = self
        res = FieldElement(1)
        while n > 0:
            if n % 2 != 0:
                res *= cur_pow
            n = n // 2
            cur_pow *= cur_pow
        return res

    def inverse(self):
        t, new_t = 0, 1
        r, new_r = FieldElement.k_modulus, self.val
        while new_r != 0:
            quotient = r // new_r
            t, new_t = new_t, (t - (quotient * new_t))
            r, new_r = new_r, r - quotient * new_r
        assert r == 1
        return FieldElement(t)

    def is_order(self, n):
        """
        Naively checks that the element is of order n by raising it to all powers up to n, checking
        that the element to the n-th power is the unit, but not so for any k<n.
        """
        assert n >= 1
        h = FieldElement(1)
        for _ in range(1, n):
            h *= self
            if h == FieldElement(1):
                return False
        return h * self == FieldElement(1)

    def _serialize_(self):
        return repr(self.val)

    @staticmethod
    def random_element(exclude_elements=[]):
        fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        while fe in exclude_elements:
            fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        return fe

In [8]:
# @title
from itertools import dropwhile, starmap, zip_longest


def remove_trailing_elements(list_of_elements, element_to_remove):
    return list(dropwhile(lambda x: x == element_to_remove, list_of_elements[::-1]))[::-1]


def two_lists_tuple_operation(f, g, operation, fill_value):
    return list(starmap(operation, zip_longest(f, g, fillvalue=fill_value)))


def scalar_operation(list_of_elements, operation, scalar):
    return [operation(c, scalar) for c in list_of_elements]

In [9]:
# @title
import operator
from functools import reduce
try:
    from tqdm import tqdm
except ModuleNotFoundError:
    # tqdm is a wrapper for iterators implementing a progress bar. If it's
    # not available, simply return the iterator itself.
    tqdm = lambda x: x

def trim_trailing_zeros(p):
    """
    Removes zeros from the end of a list.
    """
    return remove_trailing_elements(p, FieldElement.zero())


def prod(values):
    """
    Computes a product.
    """
    len_values = len(values)
    if len_values == 0:
        return 1
    if len_values == 1:
        return values[0]
    return prod(values[:len_values // 2]) * prod(values[len_values // 2:])


def latex_monomial(exponent, coef, var):
    """
    Returns a string representation of a monomial as LaTeX.
    """
    if exponent == 0:
        return str(coef)
    if coef == 1:
        coef = ''
    if coef == -1:
        coef = '-'
    if exponent == 1:
        return f'{coef}{var}'
    return f'{coef}{var}^{{{exponent}}}'


class Polynomial:
    """
    Represents a polynomial over FieldElement.
    """

    @classmethod
    def X(cls):
        """
        Returns the polynomial x.
        """
        return cls([FieldElement.zero(), FieldElement.one()])

    def __init__(self, coefficients, var='x'):
        # Internally storing the coefficients in self.poly, least-significant (i.e. free term)
        # first, so $9 - 3x^2 + 19x^5$ is represented internally by the list  [9, 0, -3, 0, 0, 19].
        # Note that coefficients is copied, so the caller may freely modify the given argument.
        self.poly = remove_trailing_elements(coefficients, FieldElement.zero())
        self.var = var

    def _repr_latex_(self):
        """
        Returns a LaTeX representation of the Polynomial, for Jupyter.
        """
        if not self.poly:
            return '$0$'
        res = ['$']
        first = True
        for exponent, coef in enumerate(self.poly):
            if coef == 0:
                continue
            monomial = latex_monomial(exponent, coef, self.var)
            if first:
                first = False
                res.append(monomial)
                continue
            oper = '+'
            if monomial[0] == '-':
                oper = '-'
                monomial = monomial[1:]
            res.append(oper)
            res.append(monomial)
        res.append('$')
        return ' '.join(res)

    def __eq__(self, other):
        try:
            other = Polynomial.typecast(other)
        except AssertionError:
            return False
        return self.poly == other.poly

    @staticmethod
    def typecast(other):
        """
        Constructs a Polynomial from `FieldElement` or `int`.
        """
        if isinstance(other, int):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            other = Polynomial([other])
        assert isinstance(other, Polynomial), f'Type mismatch: Polynomial and {type(other)}.'
        return other

    def __add__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.add, FieldElement.zero()))

    __radd__ = __add__  # To support <int> + <Polynomial> (as in `1 + x + x**2`).

    def __sub__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.sub, FieldElement.zero()))

    def __rsub__(self, other):  # To support <int> - <Polynomial> (as in `1 - x + x**2`).
        return -(self - other)

    def __neg__(self):
        return Polynomial([]) - self

    def __mul__(self, other):
        other = Polynomial.typecast(other)
        pol1, pol2 = [[x.val for x in p.poly] for p in (self, other)]
        res = [0] * (self.degree() + other.degree() + 1)
        for i, c1 in enumerate(pol1):
            for j, c2 in enumerate(pol2):
                res[i + j] += c1 * c2
        res = [FieldElement(x) for x in res]
        return Polynomial(res)

    __rmul__ = __mul__  # To support <int> * <Polynomial>.

    def compose(self, other):
        """
        Composes this polynomial with `other`.
        Example:
        >>> f = X**2 + X
        >>> g = X + 1
        >>> f.compose(g) == (2 + 3*X + X**2)
        True
        """
        other = Polynomial.typecast(other)
        res = Polynomial([])
        for coef in self.poly[::-1]:
            res = (res * other) + Polynomial([coef])
        return res

    def qdiv(self, other):
        """
        Returns q, r the quotient and remainder polynomials respectively, such that
        f = q * g + r, where deg(r) < deg(g).
        * Assert that g is not the zero polynomial.
        """
        other = Polynomial.typecast(other)
        pol2 = trim_trailing_zeros(other.poly)
        assert pol2, 'Dividing by zero polynomial.'
        pol1 = trim_trailing_zeros(self.poly)
        if not pol1:
            return [], []
        rem = pol1
        deg_dif = len(rem) - len(pol2)
        quotient = [FieldElement.zero()] * (deg_dif + 1)
        g_msc_inv = pol2[-1].inverse()
        while deg_dif >= 0:
            tmp = rem[-1] * g_msc_inv
            quotient[deg_dif] = quotient[deg_dif] + tmp
            last_non_zero = deg_dif - 1
            for i, coef in enumerate(pol2, deg_dif):
                rem[i] = rem[i] - (tmp * coef)
                if rem[i] != FieldElement.zero():
                    last_non_zero = i
            # Eliminate trailing zeroes (i.e. make r end with its last non-zero coefficient).
            rem = rem[:last_non_zero + 1]
            deg_dif = len(rem) - len(pol2)
        return Polynomial(trim_trailing_zeros(quotient)), Polynomial(rem)

    def __truediv__(self, other):
        div, mod = self.qdiv(other)
        assert mod == 0, 'Polynomials are not divisible.'
        return div

    def __mod__(self, other):
        return self.qdiv(other)[1]

    @staticmethod
    def monomial(degree, coefficient):
        """
        Constructs the monomial coefficient * x**degree.
        """
        return Polynomial([FieldElement.zero()] * degree + [coefficient])

    @staticmethod
    def gen_linear_term(point):
        """
        Generates the polynomial (x-p) for a given point p.
        """
        return Polynomial([FieldElement.zero() - point, FieldElement.one()])

    def degree(self):
        """
        The polynomials are represented by a list so the degree is the length of the list minus the
        number of trailing zeros (if they exist) minus 1.
        This implies that the degree of the zero polynomial will be -1.
        """
        return len(trim_trailing_zeros(self.poly)) - 1

    def get_nth_degree_coefficient(self, n):
        """
        Returns the coefficient of x**n
        """
        if n > self.degree():
            return FieldElement.zero()
        else:
            return self.poly[n]

    def scalar_mul(self, scalar):
        """
        Multiplies polynomial by a scalar
        """
        return Polynomial(scalar_operation(self.poly, operator.mul, scalar))

    def eval(self, point):
        """
        Evaluates the polynomial at the given point using Horner evaluation.
        """
        point = FieldElement.typecast(point).val
        # Doing this with ints (as opposed to `FieldElement`s) speeds up eval significantly.
        val = 0
        for coef in self.poly[::-1]:
            val = (val * point + coef.val) % FieldElement.k_modulus
        return FieldElement(val)

    def __call__(self, other):
        """
        If `other` is an int or a FieldElement, evaluates the polynomial on `other` (in the field).
        If `other` is a polynomial, composes self with `other` as self(other(x)).
        """
        if isinstance(other, (int)):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            return self.eval(other)
        if isinstance(other, Polynomial):
            return self.compose(other)
        raise NotImplementedError()

    def __pow__(self, other):
        """
        Calculates self**other using repeated squaring.
        """
        assert other >= 0
        res = Polynomial([FieldElement(1)])
        cur = self
        while True:
            if other % 2 != 0:
                res *= cur
            other >>= 1
            if other == 0:
                break
            cur = cur * cur
        return res


def calculate_lagrange_polynomials(x_values):
    """
    Given the x_values for evaluating some polynomials, it computes part of the lagrange polynomials
    required to interpolate a polynomial over this domain.
    """
    lagrange_polynomials = []
    monomials = [Polynomial.monomial(1, FieldElement.one()) -
                 Polynomial.monomial(0, x) for x in x_values]
    numerator = prod(monomials)
    for j in tqdm(range(len(x_values))):
        # In the denominator, we have:
        # (x_j-x_0)(x_j-x_1)...(x_j-x_{j-1})(x_j-x_{j+1})...(x_j-x_{len(X)-1})
        denominator = prod([x_values[j] - x for i, x in enumerate(x_values) if i != j])
        # Numerator is a bit more complicated, since we need to compute a poly multiplication here.
        # Similarly to the denominator, we have:
        # (x-x_0)(x-x_1)...(x-x_{j-1})(x-x_{j+1})...(x-x_{len(X)-1})
        cur_poly, _ = numerator.qdiv(monomials[j].scalar_mul(denominator))
        lagrange_polynomials.append(cur_poly)
    return lagrange_polynomials


def interpolate_poly_lagrange(y_values, lagrange_polynomials):
    """
    :param y_values: y coordinates of the points.
    :param lagrange_polynomials: the polynomials obtained from calculate_lagrange_polynomials.
    :return: the interpolated poly/
    """
    poly = Polynomial([])
    for j, y_value in enumerate(y_values):
        poly += lagrange_polynomials[j].scalar_mul(y_value)
    return poly


def interpolate_poly(x_values, y_values):
    """
    Returns a polynomial of degree < len(x_values) that evaluates to y_values[i] on x_values[i] for
    all i.
    """
    assert len(x_values) == len(y_values)
    assert all(isinstance(val, FieldElement) for val in x_values),\
        'Not all x_values are FieldElement'
    lp = calculate_lagrange_polynomials(x_values)
    assert all(isinstance(val, FieldElement) for val in y_values),\
        'Not all y_values are FieldElement'
    return interpolate_poly_lagrange(y_values, lp)

In [10]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


import inspect
from hashlib import sha256


def serialize(obj):
    """
    Serializes an object into a string.
    """
    if isinstance(obj, (list, tuple)):
        return ','.join(map(serialize, obj))
    return obj._serialize_()


class Channel(object):
    """
    A Channel instance can be used by a prover or a verifier to preserve the semantics of an
    interactive proof system, while under the hood it is in fact non-interactive, and uses Sha256
    to generate randomness when this is required.
    It allows writing string-form data to it, and reading either random integers of random
    FieldElements from it.
    """

    def __init__(self):
        self.state = '0'
        self.proof = []

    def send(self, s):
        self.state = sha256((self.state + s).encode()).hexdigest()
        self.proof.append(f'{inspect.stack()[0][3]}:{s}')

    def receive_random_int(self, min, max, show_in_proof=True):
        """
        Emulates a random integer sent by the verifier in the range [min, max] (including min and
        max).
        """

        # Note that when the range is close to 2^256 this does not emit a uniform distribution,
        # even if sha256 is uniformly distributed.
        # It is, however, close enough for this tutorial's purposes.
        num = min + (int(self.state, 16) % (max - min + 1))
        self.state = sha256((self.state).encode()).hexdigest()
        if show_in_proof:
            self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return num

    def receive_random_field_element(self):
        """
        Emulates a random field element sent by the verifier.
        """
        num = self.receive_random_int(0, FieldElement.k_modulus - 1, show_in_proof=False)
        self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return FieldElement(num)

In [11]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


from hashlib import sha256
from math import log2, ceil


class MerkleTree(object):
    """
    A simple and naive implementation of an immutable Merkle tree.
    """

    def __init__(self, data):
        assert isinstance(data, list)
        assert len(data) > 0, 'Cannot construct an empty Merkle Tree.'
        num_leaves = 2 ** ceil(log2(len(data)))
        self.data = data + [FieldElement(0)] * (num_leaves - len(data))
        self.height = int(log2(num_leaves))
        self.facts = {}
        self.root = self.build_tree()

    def get_authentication_path(self, leaf_id):
        assert 0 <= leaf_id < len(self.data)
        node_id = leaf_id + len(self.data)
        cur = self.root
        decommitment = []
        # In a Merkle Tree, the path from the root to a leaf, corresponds to the the leaf id's
        # binary representation, starting from the second-MSB, where '0' means 'left', and '1' means
        # 'right'.
        # We therefore iterate over the bits of the binary representation - skipping the '0b'
        # prefix, as well as the MSB.
        for bit in bin(node_id)[3:]:
            cur, auth = self.facts[cur]
            if bit == '1':
                auth, cur = cur, auth
            decommitment.append(auth)
        return decommitment

    def build_tree(self):
        return self.recursive_build_tree(1)

    def recursive_build_tree(self, node_id):
        if node_id >= len(self.data):
            # A leaf.
            id_in_data = node_id - len(self.data)
            leaf_data = str(self.data[id_in_data])
            h = sha256(leaf_data.encode()).hexdigest()
            self.facts[h] = leaf_data
            return h
        else:
            # An internal node.
            left = self.recursive_build_tree(node_id * 2)
            right = self.recursive_build_tree(node_id * 2 + 1)
            h = sha256((left + right).encode()).hexdigest()
            self.facts[h] = (left, right)
            return h


def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
    return cur == root

In [12]:
X = Polynomial.X()

In [13]:
g = FieldElement.generator() ** (3*2**26)
x = [g**i for i in range(len(trace['i']))]

In [14]:
state_polynomial = interpolate_poly(x, [FieldElement(state_map[s]) for s in trace['s']])
caracter_polynomial = interpolate_poly(x, [FieldElement(caracter_map[c]) for c in trace['c']])

100%|██████████| 14/14 [00:00<00:00, 6852.64it/s]
100%|██████████| 14/14 [00:00<00:00, 6859.04it/s]


In [15]:
# not sure if I should interpolate over points not already in the x array
transi_x = [state_polynomial(state_map[s])*caracter_polynomial(caracter_map[c]) for s,c in product(states, alphabet)]
transition_polynomial = interpolate_poly(transi_x, [FieldElement(state_map[transition[s][c]]) for (s,c) in product(states, alphabet)])

100%|██████████| 4/4 [00:00<00:00, 23431.87it/s]


In [16]:
step_polynomial = state_polynomial(g*X) - transition_polynomial(state_polynomial(X)*caracter_polynomial(X))

In [17]:
def next_power_of_two(n: int) -> int:
    if n <= 0:
        return 1
    # If n is already a power of two, return n
    if (n & (n - 1)) == 0:
        return n
    power = 1
    while power < n:
        power <<= 1
    return power

In [18]:
w = FieldElement.generator()
h = FieldElement.generator() ** (3*2**23)
H = [h**i for i in range(next_power_of_two(len(trace['i']))*32)]
# here also, not sure about the domain, should be big enough for both states and transitions
eval_domain = [w*h for h in H]

In [19]:
len(trace['i'])

14

In [20]:
state_eval = [state_polynomial(d) for d in eval_domain]
caracter_eval = [caracter_polynomial(d) for d in eval_domain]
transition_eval = [transition_polynomial(d) for d in eval_domain]
state_merkle = MerkleTree(state_eval)
caracter_merkle = MerkleTree(caracter_eval)
transition_merkle = MerkleTree(transition_eval)

In [21]:
channel = Channel()

In [22]:
channel.send(state_merkle.root)
channel.send(caracter_merkle.root)
channel.send(transition_merkle.root)

In [23]:
initial_constraint_p = (state_polynomial-state_map['q0'])/(X-1)
dp = Polynomial([FieldElement.one()])
ap = Polynomial([FieldElement.one()])
for i in range(len(x)):
  ap = ap * (caracter_polynomial - caracter_map[trace['c'][i]])
  dp = dp * (X-x[i])
caracter_constraint_p = ap/dp
ap = Polynomial([FieldElement.one()])
for f in accept_states:
  ap = ap * (state_polynomial - state_map[f])
accepting_constraint_p = ap/(X-x[-1])
dp = Polynomial([FieldElement.one()])
ap = Polynomial([FieldElement.one()])
i = 0
for s,c in product(states,alphabet):
  ap = ap * (transition_polynomial - state_map[transition[s][c]])
  dp = dp * (X-transi_x[i])
  i+=1
transition_constraint_p = ap/dp
dp = Polynomial([FieldElement.one()])
for i in range(len(trace['i'])):
  dp = dp * x[i]
step_constraint_p = step_polynomial/dp

In [24]:
p0 = initial_constraint_p
p1 = caracter_constraint_p
p2 = accepting_constraint_p
p3 = transition_constraint_p
p4 = step_constraint_p

a0 = channel.receive_random_field_element()
a1 = channel.receive_random_field_element()
a2 = channel.receive_random_field_element()
a3 = channel.receive_random_field_element()
a4 = channel.receive_random_field_element()

cp = a0*p0 + a1*p1 + a2*p2 + a3*p3 + a4*p4

In [25]:
cp_eval = [cp(x) for x in eval_domain]
cp_merkle = MerkleTree(cp_eval)
channel.send(cp_merkle.root)

In [26]:
def fold_domain(domain):
  return [x**2 for x in domain[:len(domain)//2]]

In [27]:
def fold(p, beta):
  odd = p.poly[1::2]
  even = p.poly[::2]
  o = Polynomial(odd)
  e = Polynomial(even)
  return e + beta*o

In [28]:
def next_fri_layer(poly, domain, beta):
    next_poly = fold(poly, beta)
    next_domain = fold_domain(domain)
    next_layer = [next_poly(x) for x in next_domain]
    return next_poly, next_domain, next_layer

In [29]:
def FriCommit(cp, domain, cp_eval, cp_merkle, channel):
    fri_polys = [cp]
    fri_domains = [domain]
    fri_layers = [cp_eval]
    fri_merkles = [cp_merkle]
    while fri_polys[-1].degree() > 0:
        beta = channel.receive_random_field_element()
        next_poly, next_domain, next_layer = next_fri_layer(fri_polys[-1], fri_domains[-1], beta)
        fri_polys.append(next_poly)
        fri_domains.append(next_domain)
        fri_layers.append(next_layer)
        fri_merkles.append(MerkleTree(next_layer))
        channel.send(fri_merkles[-1].root)
    channel.send(str(fri_polys[-1].poly[0]))
    return fri_polys, fri_domains, fri_layers, fri_merkles

In [30]:
fri_polys, fri_domains, fri_layers, fri_merkles = FriCommit(cp, eval_domain,cp_eval,cp_merkle,channel)

In [31]:
def decommit_on_fri_layers(idx, channel):
    for layer, merkle in zip(fri_layers[:-1], fri_merkles[:-1]):
        length = len(layer)
        idx = idx % length
        sib_idx = (idx + length // 2) % length
        channel.send(str(layer[idx]))
        channel.send(str(merkle.get_authentication_path(idx)))
        channel.send(str(layer[sib_idx]))
        channel.send(str(merkle.get_authentication_path(sib_idx)))
    channel.send(str(fri_layers[-1][0]))

In [32]:
def decommit_on_query(idx, channel): 
    assert idx + 16 < len(state_eval), f'query index: {idx} is out of range. Length of layer: {len(state_eval)}.'
    channel.send(str(state_eval[idx])) # f(x).
    channel.send(str(state_merkle.get_authentication_path(idx))) # auth path for f(x).
    channel.send(str(state_eval[idx + 32])) # f(gx).
    channel.send(str(state_merkle.get_authentication_path(idx + 32))) # auth path for f(gx).
    channel.send(str(state_eval[idx + 64])) # f(g^2x).
    channel.send(str(state_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).
    assert idx + 16 < len(caracter_eval), f'query index: {idx} is out of range. Length of layer: {len(caracter_eval)}.'
    channel.send(str(caracter_eval[idx])) # f(x).
    channel.send(str(caracter_merkle.get_authentication_path(idx))) # auth path for f(x).
    channel.send(str(caracter_eval[idx + 32])) # f(gx).
    channel.send(str(caracter_merkle.get_authentication_path(idx + 32))) # auth path for f(gx).
    channel.send(str(caracter_eval[idx + 64])) # f(g^2x).
    channel.send(str(caracter_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).
    assert idx + 16 < len(transition_eval), f'query index: {idx} is out of range. Length of layer: {len(transition_eval)}.'
    channel.send(str(transition_eval[idx])) # f(x).
    channel.send(str(transition_merkle.get_authentication_path(idx))) # auth path for f(x).
    channel.send(str(transition_eval[idx + 32])) # f(gx).
    channel.send(str(transition_merkle.get_authentication_path(idx + 32))) # auth path for f(gx).
    channel.send(str(transition_eval[idx + 64])) # f(g^2x).
    channel.send(str(transition_merkle.get_authentication_path(idx + 64))) # auth path for f(g^2x).
    decommit_on_fri_layers(idx, channel)   

In [33]:
def decommit_fri(channel):
    for query in range(3):
        # Get a random index from the verifier and send the corresponding decommitment.
        decommit_on_query(channel.receive_random_int(0, len(state_eval)-64), channel)

In [34]:
decommit_fri(channel)
print(channel.proof)

['send:79a627eb50b7f2fb341a18e71d0035e5252a2143aa9c180e8208c056e4d11a7d', 'send:0d351de759ca5ce67f68d09888f4447adf20f17d08de9cec641e4a3980ed5141', 'send:4244e876ef16cc3946c4584763133bc75d36b6844115988412be231a5f87df95', 'receive_random_field_element:1629565560', 'receive_random_field_element:809989499', 'receive_random_field_element:829155219', 'receive_random_field_element:1919549043', 'receive_random_field_element:1918767788', 'send:6d3911c3b6f7b098ec2c2d7c314a44d56296b6fc1086c2408732da47846172d4', 'receive_random_field_element:1909031858', 'send:094bd03a1926ef2355f3407e947f131083f1188bd06fa4c957f96cf5952bcece', 'receive_random_field_element:57974712', 'send:ee31f997c8572ba53e3f6935ce2c87b4e7cc942951bf6b53600f87656803c5d9', 'receive_random_field_element:1154397900', 'send:e6f1569b3ff53013e373f7bd2271464543fe3b53463fc9452c57ddc718a3959c', 'receive_random_field_element:2572294563', 'send:6eb3c51b794212eef6548feb06d7d541acae2853d87c1e7045e5b723755fe953', 'receive_random_field_element:91

In [35]:
print(f'Overall time: {time.time() - start_all}s')

Overall time: 1.8473024368286133s


# Verification time

In [64]:
proof = channel.proof
c = Channel()
current_channel_idx = 0
state_merkle = proof[current_channel_idx][5:]
c.send(state_merkle)
current_channel_idx+=1
caracter_merkle = proof[current_channel_idx][5:]
c.send(caracter_merkle)
current_channel_idx+=1
transition_merkle = proof[current_channel_idx][5:]
c.send(transition_merkle)
current_channel_idx+=1
a0 = FieldElement(int(proof[current_channel_idx][len('receive_random_field_element:'):]))
assert(c.receive_random_field_element() == a0)
current_channel_idx+=1
a1 = FieldElement(int(proof[current_channel_idx][len('receive_random_field_element:'):]))
assert(c.receive_random_field_element() == a1)
current_channel_idx+=1
a2 = FieldElement(int(proof[current_channel_idx][len('receive_random_field_element:'):]))
assert(c.receive_random_field_element() == a2)
current_channel_idx+=1
a3 = FieldElement(int(proof[current_channel_idx][len('receive_random_field_element:'):]))
assert(c.receive_random_field_element() == a3)
current_channel_idx+=1
a4 = FieldElement(int(proof[current_channel_idx][len('receive_random_field_element:'):]))
assert(c.receive_random_field_element() == a4)
current_channel_idx+=1
cp_merkle = 