In [None]:
print("this is a test")

In [117]:
import pandas as pd
from itertools import product

In [68]:
states = {'q0', 'q1'}
alphabet = {'a', 'b'}
transition = {
    'q0': {'a': 'q1', 'b': 'q0'},
    'q1': {'a': 'q0', 'b': 'q1'}
}
start_state = 'q0'
accept_states = {'q0'}

In [69]:
def generate_dfa_trace(input_string):
  trace = {'i':[], 's': [], 'c': []}
  current_state = start_state
  i = 0
  for symbol in input_string:
      if symbol not in alphabet:
          return False
      trace['i'].append(i)
      i+=1
      trace['s'].append(current_state)
      trace['c'].append(symbol)
      current_state = transition[current_state][symbol]

  return trace

In [113]:
trace = generate_dfa_trace("aaabababbbabba")

In [71]:
state_map = {"q0":1,"q1":2}
caracter_map = {'a': 3, 'b': 5}
transition_map = {("q0", 'a'): 7, ("q0", 'b'):9, ("q1", 'a'):11,("q1", 'b'):13 }

In [72]:
# @title
from random import randint


class FieldElement:
    """
    Represents an element of F_(3 * 2**30 + 1).
    """
    k_modulus = 3 * 2**30 + 1
    generator_val = 5

    def __init__(self, val):
        self.val = val % FieldElement.k_modulus

    @staticmethod
    def zero():
        """
        Obtains the zero element of the field.
        """
        return FieldElement(0)

    @staticmethod
    def one():
        """
        Obtains the unit element of the field.
        """
        return FieldElement(1)

    def __repr__(self):
        # Choose the shorter representation between the positive and negative values of the element.
        return repr((self.val + self.k_modulus//2) % self.k_modulus - self.k_modulus//2)

    def __eq__(self, other):
        if isinstance(other, int):
            other = FieldElement(other)
        return isinstance(other, FieldElement) and self.val == other.val

    def __hash__(self):
        return hash(self.val)

    @staticmethod
    def generator():
        return FieldElement(FieldElement.generator_val)

    @staticmethod
    def typecast(other):
        if isinstance(other, int):
            return FieldElement(other)
        assert isinstance(other, FieldElement), f'Type mismatch: FieldElement and {type(other)}.'
        return other

    def __neg__(self):
        return self.zero() - self

    def __add__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val + other.val) % FieldElement.k_modulus)

    __radd__ = __add__

    def __sub__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val - other.val) % FieldElement.k_modulus)

    def __rsub__(self, other):
        return -(self - other)

    def __mul__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val * other.val) % FieldElement.k_modulus)

    __rmul__ = __mul__

    def __truediv__(self, other):
        other = FieldElement.typecast(other)
        return self * other.inverse()

    def __pow__(self, n):
        assert n >= 0
        cur_pow = self
        res = FieldElement(1)
        while n > 0:
            if n % 2 != 0:
                res *= cur_pow
            n = n // 2
            cur_pow *= cur_pow
        return res

    def inverse(self):
        t, new_t = 0, 1
        r, new_r = FieldElement.k_modulus, self.val
        while new_r != 0:
            quotient = r // new_r
            t, new_t = new_t, (t - (quotient * new_t))
            r, new_r = new_r, r - quotient * new_r
        assert r == 1
        return FieldElement(t)

    def is_order(self, n):
        """
        Naively checks that the element is of order n by raising it to all powers up to n, checking
        that the element to the n-th power is the unit, but not so for any k<n.
        """
        assert n >= 1
        h = FieldElement(1)
        for _ in range(1, n):
            h *= self
            if h == FieldElement(1):
                return False
        return h * self == FieldElement(1)

    def _serialize_(self):
        return repr(self.val)

    @staticmethod
    def random_element(exclude_elements=[]):
        fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        while fe in exclude_elements:
            fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        return fe

In [73]:
# @title
from itertools import dropwhile, starmap, zip_longest


def remove_trailing_elements(list_of_elements, element_to_remove):
    return list(dropwhile(lambda x: x == element_to_remove, list_of_elements[::-1]))[::-1]


def two_lists_tuple_operation(f, g, operation, fill_value):
    return list(starmap(operation, zip_longest(f, g, fillvalue=fill_value)))


def scalar_operation(list_of_elements, operation, scalar):
    return [operation(c, scalar) for c in list_of_elements]

In [110]:
# @title
import operator
from functools import reduce
try:
    from tqdm import tqdm
except ModuleNotFoundError:
    # tqdm is a wrapper for iterators implementing a progress bar. If it's
    # not available, simply return the iterator itself.
    tqdm = lambda x: x

def trim_trailing_zeros(p):
    """
    Removes zeros from the end of a list.
    """
    return remove_trailing_elements(p, FieldElement.zero())


def prod(values):
    """
    Computes a product.
    """
    len_values = len(values)
    if len_values == 0:
        return 1
    if len_values == 1:
        return values[0]
    return prod(values[:len_values // 2]) * prod(values[len_values // 2:])


def latex_monomial(exponent, coef, var):
    """
    Returns a string representation of a monomial as LaTeX.
    """
    if exponent == 0:
        return str(coef)
    if coef == 1:
        coef = ''
    if coef == -1:
        coef = '-'
    if exponent == 1:
        return f'{coef}{var}'
    return f'{coef}{var}^{{{exponent}}}'


class Polynomial:
    """
    Represents a polynomial over FieldElement.
    """

    @classmethod
    def X(cls):
        """
        Returns the polynomial x.
        """
        return cls([FieldElement.zero(), FieldElement.one()])

    def __init__(self, coefficients, var='x'):
        # Internally storing the coefficients in self.poly, least-significant (i.e. free term)
        # first, so $9 - 3x^2 + 19x^5$ is represented internally by the list  [9, 0, -3, 0, 0, 19].
        # Note that coefficients is copied, so the caller may freely modify the given argument.
        self.poly = remove_trailing_elements(coefficients, FieldElement.zero())
        self.var = var

    def _repr_latex_(self):
        """
        Returns a LaTeX representation of the Polynomial, for Jupyter.
        """
        if not self.poly:
            return '$0$'
        res = ['$']
        first = True
        for exponent, coef in enumerate(self.poly):
            if coef == 0:
                continue
            monomial = latex_monomial(exponent, coef, self.var)
            if first:
                first = False
                res.append(monomial)
                continue
            oper = '+'
            if monomial[0] == '-':
                oper = '-'
                monomial = monomial[1:]
            res.append(oper)
            res.append(monomial)
        res.append('$')
        return ' '.join(res)

    def __eq__(self, other):
        try:
            other = Polynomial.typecast(other)
        except AssertionError:
            return False
        return self.poly == other.poly

    @staticmethod
    def typecast(other):
        """
        Constructs a Polynomial from `FieldElement` or `int`.
        """
        if isinstance(other, int):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            other = Polynomial([other])
        assert isinstance(other, Polynomial), f'Type mismatch: Polynomial and {type(other)}.'
        return other

    def __add__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.add, FieldElement.zero()))

    __radd__ = __add__  # To support <int> + <Polynomial> (as in `1 + x + x**2`).

    def __sub__(self, other):
        other = Polynomial.typecast(other)
        return Polynomial(two_lists_tuple_operation(
            self.poly, other.poly, operator.sub, FieldElement.zero()))

    def __rsub__(self, other):  # To support <int> - <Polynomial> (as in `1 - x + x**2`).
        return -(self - other)

    def __neg__(self):
        return Polynomial([]) - self

    def __mul__(self, other):
        other = Polynomial.typecast(other)
        pol1, pol2 = [[x.val for x in p.poly] for p in (self, other)]
        res = [0] * (self.degree() + other.degree() + 1)
        for i, c1 in enumerate(pol1):
            for j, c2 in enumerate(pol2):
                res[i + j] += c1 * c2
        res = [FieldElement(x) for x in res]
        return Polynomial(res)

    __rmul__ = __mul__  # To support <int> * <Polynomial>.

    def compose(self, other):
        """
        Composes this polynomial with `other`.
        Example:
        >>> f = X**2 + X
        >>> g = X + 1
        >>> f.compose(g) == (2 + 3*X + X**2)
        True
        """
        other = Polynomial.typecast(other)
        res = Polynomial([])
        for coef in self.poly[::-1]:
            res = (res * other) + Polynomial([coef])
        return res

    def qdiv(self, other):
        """
        Returns q, r the quotient and remainder polynomials respectively, such that
        f = q * g + r, where deg(r) < deg(g).
        * Assert that g is not the zero polynomial.
        """
        other = Polynomial.typecast(other)
        pol2 = trim_trailing_zeros(other.poly)
        assert pol2, 'Dividing by zero polynomial.'
        pol1 = trim_trailing_zeros(self.poly)
        if not pol1:
            return [], []
        rem = pol1
        deg_dif = len(rem) - len(pol2)
        quotient = [FieldElement.zero()] * (deg_dif + 1)
        g_msc_inv = pol2[-1].inverse()
        while deg_dif >= 0:
            tmp = rem[-1] * g_msc_inv
            quotient[deg_dif] = quotient[deg_dif] + tmp
            last_non_zero = deg_dif - 1
            for i, coef in enumerate(pol2, deg_dif):
                rem[i] = rem[i] - (tmp * coef)
                if rem[i] != FieldElement.zero():
                    last_non_zero = i
            # Eliminate trailing zeroes (i.e. make r end with its last non-zero coefficient).
            rem = rem[:last_non_zero + 1]
            deg_dif = len(rem) - len(pol2)
        return Polynomial(trim_trailing_zeros(quotient)), Polynomial(rem)

    def __truediv__(self, other):
        div, mod = self.qdiv(other)
        assert mod == 0, 'Polynomials are not divisible.'
        return div

    def __mod__(self, other):
        return self.qdiv(other)[1]

    @staticmethod
    def monomial(degree, coefficient):
        """
        Constructs the monomial coefficient * x**degree.
        """
        return Polynomial([FieldElement.zero()] * degree + [coefficient])

    @staticmethod
    def gen_linear_term(point):
        """
        Generates the polynomial (x-p) for a given point p.
        """
        return Polynomial([FieldElement.zero() - point, FieldElement.one()])

    def degree(self):
        """
        The polynomials are represented by a list so the degree is the length of the list minus the
        number of trailing zeros (if they exist) minus 1.
        This implies that the degree of the zero polynomial will be -1.
        """
        return len(trim_trailing_zeros(self.poly)) - 1

    def get_nth_degree_coefficient(self, n):
        """
        Returns the coefficient of x**n
        """
        if n > self.degree():
            return FieldElement.zero()
        else:
            return self.poly[n]

    def scalar_mul(self, scalar):
        """
        Multiplies polynomial by a scalar
        """
        return Polynomial(scalar_operation(self.poly, operator.mul, scalar))

    def eval(self, point):
        """
        Evaluates the polynomial at the given point using Horner evaluation.
        """
        point = FieldElement.typecast(point).val
        # Doing this with ints (as opposed to `FieldElement`s) speeds up eval significantly.
        val = 0
        for coef in self.poly[::-1]:
            val = (val * point + coef.val) % FieldElement.k_modulus
        return FieldElement(val)

    def __call__(self, other):
        """
        If `other` is an int or a FieldElement, evaluates the polynomial on `other` (in the field).
        If `other` is a polynomial, composes self with `other` as self(other(x)).
        """
        if isinstance(other, (int)):
            other = FieldElement(other)
        if isinstance(other, FieldElement):
            return self.eval(other)
        if isinstance(other, Polynomial):
            return self.compose(other)
        raise NotImplementedError()

    def __pow__(self, other):
        """
        Calculates self**other using repeated squaring.
        """
        assert other >= 0
        res = Polynomial([FieldElement(1)])
        cur = self
        while True:
            if other % 2 != 0:
                res *= cur
            other >>= 1
            if other == 0:
                break
            cur = cur * cur
        return res


def calculate_lagrange_polynomials(x_values):
    """
    Given the x_values for evaluating some polynomials, it computes part of the lagrange polynomials
    required to interpolate a polynomial over this domain.
    """
    lagrange_polynomials = []
    monomials = [Polynomial.monomial(1, FieldElement.one()) -
                 Polynomial.monomial(0, x) for x in x_values]
    numerator = prod(monomials)
    for j in tqdm(range(len(x_values))):
        # In the denominator, we have:
        # (x_j-x_0)(x_j-x_1)...(x_j-x_{j-1})(x_j-x_{j+1})...(x_j-x_{len(X)-1})
        denominator = prod([x_values[j] - x for i, x in enumerate(x_values) if i != j])
        # Numerator is a bit more complicated, since we need to compute a poly multiplication here.
        # Similarly to the denominator, we have:
        # (x-x_0)(x-x_1)...(x-x_{j-1})(x-x_{j+1})...(x-x_{len(X)-1})
        cur_poly, _ = numerator.qdiv(monomials[j].scalar_mul(denominator))
        lagrange_polynomials.append(cur_poly)
    return lagrange_polynomials


def interpolate_poly_lagrange(y_values, lagrange_polynomials):
    """
    :param y_values: y coordinates of the points.
    :param lagrange_polynomials: the polynomials obtained from calculate_lagrange_polynomials.
    :return: the interpolated poly/
    """
    poly = Polynomial([])
    for j, y_value in enumerate(y_values):
        poly += lagrange_polynomials[j].scalar_mul(y_value)
    return poly


def interpolate_poly(x_values, y_values):
    """
    Returns a polynomial of degree < len(x_values) that evaluates to y_values[i] on x_values[i] for
    all i.
    """
    assert len(x_values) == len(y_values)
    assert all(isinstance(val, FieldElement) for val in x_values),\
        'Not all x_values are FieldElement'
    lp = calculate_lagrange_polynomials(x_values)
    assert all(isinstance(val, FieldElement) for val in y_values),\
        'Not all y_values are FieldElement'
    return interpolate_poly_lagrange(y_values, lp)

In [75]:
# @title
class MultivariatePolynomial:
    """
    Represents a multivariate polynomial over a Field.
    """

    def __init__(self, terms, variables):
        """
        Initializes a polynomial.
        `terms`: A dictionary mapping an exponent tuple to a coefficient.
                 e.g., {(2, 1): FieldElement(3)} for 3x^2y
        `variables`: A tuple or list of variable names, e.g., ('x', 'y').
                     The order of variables corresponds to the exponents in the tuple.
        """
        self.variables = tuple(variables)
        # Store only terms with non-zero coefficients
        self.terms = {
            exp: coeff for exp, coeff in terms.items()
            if not isinstance(coeff, FieldElement) or coeff != FieldElement.zero()
        }

        # Validate that all exponent tuples have the correct length
        for exp in self.terms:
            if len(exp) != len(self.variables):
                raise ValueError(
                    f"Exponent tuple {exp} has length {len(exp)}, but there are "
                    f"{len(self.variables)} variables."
                )

    # =================================================================
    # == NEW FUNCTION ADDED HERE ======================================
    # =================================================================
    @classmethod
    def from_univariate(cls, poly, all_vars):
        """
        Transforms a univariate Polynomial into a MultivariatePolynomial.

        Args:
            poly (Polynomial): The univariate polynomial to convert.
            all_vars (tuple): The variables for the new multivariate polynomial.
                              Must contain the variable from the original polynomial.
        """
        try:
            # Find the position of the old variable in the new variable list
            var_index = all_vars.index(poly.var)
        except ValueError:
            raise ValueError(
                f"Univariate polynomial variable '{poly.var}' not found in the "
                f"target variables {all_vars}."
            )

        new_terms = {}
        num_vars = len(all_vars)

        # Iterate through coefficients of the univariate polynomial
        for exponent, coeff in enumerate(poly.poly):
            if coeff != FieldElement.zero():
                # Create the new exponent tuple with the exponent in the correct spot
                exp_tuple = [0] * num_vars
                exp_tuple[var_index] = exponent
                new_terms[tuple(exp_tuple)] = coeff

        return cls(new_terms, all_vars)
    # =================================================================

    @classmethod
    def typecast(cls, other, variables):
        """
        Constructs a constant MultivariatePolynomial from a FieldElement or int.
        """
        if isinstance(other, MultivariatePolynomial):
            if other.variables != variables:
                raise TypeError(
                    f"Polynomial variable mismatch: {other.variables} vs {variables}"
                )
            return other

        coeff = FieldElement.typecast(other)
        if coeff == FieldElement.zero():
            return cls({}, variables)

        # A constant is a term with all exponents being zero
        zero_exp = tuple([0] * len(variables))
        return cls({zero_exp: coeff}, variables)

    def __eq__(self, other):
        try:
            other = self.typecast(other, self.variables)
        except (TypeError, ValueError):
            return False
        return self.variables == other.variables and self.terms == other.terms

    def __add__(self, other):
        other = self.typecast(other, self.variables)

        new_terms = defaultdict(FieldElement.zero)
        for exp, coeff in self.terms.items():
            new_terms[exp] += coeff
        for exp, coeff in other.terms.items():
            new_terms[exp] += coeff

        return MultivariatePolynomial(new_terms, self.variables)

    __radd__ = __add__

    def __sub__(self, other):
        return self + (-other)

    def __rsub__(self, other):
        return -(self - other)

    def __neg__(self):
        return MultivariatePolynomial(
            {exp: -coeff for exp, coeff in self.terms.items()}, self.variables
        )

    def __mul__(self, other):
        other = self.typecast(other, self.variables)

        new_terms = defaultdict(FieldElement.zero)
        if not self.terms or not other.terms:
             return MultivariatePolynomial({}, self.variables)

        for exp1, coeff1 in self.terms.items():
            for exp2, coeff2 in other.terms.items():
                new_exp = tuple(e1 + e2 for e1, e2 in zip(exp1, exp2))
                new_terms[new_exp] += coeff1 * coeff2

        return MultivariatePolynomial(new_terms, self.variables)

    __rmul__ = __mul__

    def __pow__(self, exponent):
        if not isinstance(exponent, int) or exponent < 0:
            raise TypeError("Exponent must be a non-negative integer.")

        res = self.typecast(1, self.variables) # Multiplicative identity
        base = self

        while exponent > 0:
            if exponent % 2 == 1:
                res *= base
            base *= base
            exponent //= 2

        return res

    def total_degree(self):
        """
        Returns the total degree of the polynomial, which is the maximum
        sum of exponents in any single term. Returns -1 for the zero polynomial.
        """
        if not self.terms:
            return -1
        return max(sum(exp) for exp in self.terms)

    def eval(self, point_map):
        """
        Evaluates the polynomial at a given point.
        `point_map`: A dictionary mapping variable names to values,
                     or a list/tuple of values in the same order as self.variables.
        """
        if isinstance(point_map, (list, tuple)):
            if len(point_map) != len(self.variables):
                raise ValueError("Incorrect number of values for evaluation.")
            point_map = dict(zip(self.variables, point_map))

        total = FieldElement.zero()
        for exp, coeff in self.terms.items():
            term_val = coeff
            for i, var_name in enumerate(self.variables):
                val = FieldElement.typecast(point_map.get(var_name))
                if val is None:
                    raise ValueError(f"Value for variable '{var_name}' not provided.")
                # Use Python's built-in pow for integer exponentiation
                term_val *= (val ** exp[i])
            total += term_val
        return total

    __call__ = eval

    def _repr_term_latex(self, exp, coeff):
        """Helper to generate LaTeX for a single term."""

        # Handle the coefficient
        coeff_val = coeff.val
        if 'k_modulus' in globals().get('FieldElement', {}).__dict__:
             # Make negative coefficients look nicer, e.g., (p-1) becomes -1
            if coeff_val > FieldElement.k_modulus / 2:
                coeff_val -= FieldElement.k_modulus

        if abs(coeff_val) == 1 and any(e > 0 for e in exp):
            s_coeff = '-' if coeff_val == -1 else ''
        else:
            s_coeff = str(coeff_val)

        # Handle the variables
        parts = []
        for i, e in enumerate(exp):
            if e == 0:
                continue
            var = self.variables[i]
            if e == 1:
                parts.append(var)
            else:
                parts.append(f'{var}^{{{e}}}')

        s_vars = ' '.join(parts)

        # Join coefficient and variables
        if s_vars and s_coeff:
            return f'{s_coeff} {s_vars}'
        return s_coeff or s_vars or str(coeff_val) # Fallback for constant term

    def _repr_latex_(self):
        """Returns a LaTeX representation of the Polynomial for Jupyter."""
        if not self.terms:
            return "$0$"

        # Sort terms by total degree (desc), then lexicographically (desc)
        sorted_terms = sorted(
            self.terms.items(),
            key=lambda item: (sum(item[0]), item[0]),
            reverse=True
        )

        parts = []
        first = True
        for exp, coeff in sorted_terms:
            term_str = self._repr_term_latex(exp, coeff)

            # Add sign for non-first terms
            if not first:
                if term_str.startswith('-'):
                    parts.append(f'- {term_str[1:].strip()}')
                else:
                    parts.append(f'+ {term_str}')
            else:
                parts.append(term_str)
            first = False

        return f"$ {' '.join(parts)} $"

    def __repr__(self):
        """Returns a string representation of the polynomial."""
        if not self.terms:
            return "0"

        # Use the LaTeX representation but remove the math delimiters
        latex_repr = self._repr_latex_()
        return latex_repr[2:-2].strip()

In [112]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


import inspect
from hashlib import sha256


def serialize(obj):
    """
    Serializes an object into a string.
    """
    if isinstance(obj, (list, tuple)):
        return ','.join(map(serialize, obj))
    return obj._serialize_()


class Channel(object):
    """
    A Channel instance can be used by a prover or a verifier to preserve the semantics of an
    interactive proof system, while under the hood it is in fact non-interactive, and uses Sha256
    to generate randomness when this is required.
    It allows writing string-form data to it, and reading either random integers of random
    FieldElements from it.
    """

    def __init__(self):
        self.state = '0'
        self.proof = []

    def send(self, s):
        self.state = sha256((self.state + s).encode()).hexdigest()
        self.proof.append(f'{inspect.stack()[0][3]}:{s}')

    def receive_random_int(self, min, max, show_in_proof=True):
        """
        Emulates a random integer sent by the verifier in the range [min, max] (including min and
        max).
        """

        # Note that when the range is close to 2^256 this does not emit a uniform distribution,
        # even if sha256 is uniformly distributed.
        # It is, however, close enough for this tutorial's purposes.
        num = min + (int(self.state, 16) % (max - min + 1))
        self.state = sha256((self.state).encode()).hexdigest()
        if show_in_proof:
            self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return num

    def receive_random_field_element(self):
        """
        Emulates a random field element sent by the verifier.
        """
        num = self.receive_random_int(0, FieldElement.k_modulus - 1, show_in_proof=False)
        self.proof.append(f'{inspect.stack()[0][3]}:{num}')
        return FieldElement(num)

In [116]:
# @title
###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


from hashlib import sha256
from math import log2, ceil


class MerkleTree(object):
    """
    A simple and naive implementation of an immutable Merkle tree.
    """

    def __init__(self, data):
        assert isinstance(data, list)
        assert len(data) > 0, 'Cannot construct an empty Merkle Tree.'
        num_leaves = 2 ** ceil(log2(len(data)))
        self.data = data + [FieldElement(0)] * (num_leaves - len(data))
        self.height = int(log2(num_leaves))
        self.facts = {}
        self.root = self.build_tree()

    def get_authentication_path(self, leaf_id):
        assert 0 <= leaf_id < len(self.data)
        node_id = leaf_id + len(self.data)
        cur = self.root
        decommitment = []
        # In a Merkle Tree, the path from the root to a leaf, corresponds to the the leaf id's
        # binary representation, starting from the second-MSB, where '0' means 'left', and '1' means
        # 'right'.
        # We therefore iterate over the bits of the binary representation - skipping the '0b'
        # prefix, as well as the MSB.
        for bit in bin(node_id)[3:]:
            cur, auth = self.facts[cur]
            if bit == '1':
                auth, cur = cur, auth
            decommitment.append(auth)
        return decommitment

    def build_tree(self):
        return self.recursive_build_tree(1)

    def recursive_build_tree(self, node_id):
        if node_id >= len(self.data):
            # A leaf.
            id_in_data = node_id - len(self.data)
            leaf_data = str(self.data[id_in_data])
            h = sha256(leaf_data.encode()).hexdigest()
            self.facts[h] = leaf_data
            return h
        else:
            # An internal node.
            left = self.recursive_build_tree(node_id * 2)
            right = self.recursive_build_tree(node_id * 2 + 1)
            h = sha256((left + right).encode()).hexdigest()
            self.facts[h] = (left, right)
            return h


def verify_decommitment(leaf_id, leaf_data, decommitment, root):
    leaf_num = 2 ** len(decommitment)
    node_id = leaf_id + leaf_num
    cur = sha256(str(leaf_data).encode()).hexdigest()
    for bit, auth in zip(bin(node_id)[3:][::-1], decommitment[::-1]):
        if bit == '0':
            h = cur + auth
        else:
            h = auth + cur
        cur = sha256(h.encode()).hexdigest()
    return cur == root

In [77]:
X = Polynomial.X()

In [78]:
g = FieldElement.generator() ** (3*2**26)
x = [g**i for i in range(len(trace['i']))]

100%|██████████| 14/14 [00:00<00:00, 8089.30it/s]


In [79]:
state_polynomial = interpolate_poly(x, [FieldElement(state_map[s]) for s in trace['s']])
caracter_polynomial = interpolate_poly(x, [FieldElement(caracter_map[c]) for c in trace['c']])

100%|██████████| 14/14 [00:00<00:00, 6804.20it/s]


In [119]:
# not sure if I should interpolate over points not already in the x array
transi_x = [state_polynomial(state_map[s])*caracter_polynomial(caracter_map[c]) for s,c in product(states, alphabet)]
transition_polynomial = interpolate_poly(transi_x, [FieldElement(state_map[transition[s][c]]) for (s,c) in product(states, alphabet)])

100%|██████████| 4/4 [00:00<00:00, 10161.85it/s]


In [90]:
step_polynomial = state_polynomial(g*X) - transition_polynomial(state_polynomial(X)*caracter_polynomial(X))

In [120]:
w = FieldElement.generator()
h = FieldElement.generator() ** (3*2**23)
H = [h**i for i in range(16*8)]
# here also, not sure about the domain, should be big enough for both states and transitions
eval_domain = [w*h for h in H]

In [122]:
state_eval = [state_polynomial(d) for d in eval_domain]
caracter_eval = [caracter_polynomial(d) for d in eval_domain]
transition_eval = [transition_polynomial(d) for d in eval_domain]
state_merkle = MerkleTree(state_eval)
caracter_merkle = MerkleTree(caracter_eval)
transition_merkle = MerkleTree(transition_eval)

In [123]:
channel = Channel()

In [124]:
channel.send(state_merkle.root)
channel.send(caracter_merkle.root)
channel.send(transition_merkle.root)

In [97]:
initial_constraint_p = (state_polynomial-state_map['q0'])/(X-1)
dp = Polynomial([FieldElement.one()])
ap = Polynomial([FieldElement.one()])
for i in range(len(x)):
  ap = ap * (caracter_polynomial - caracter_map[trace['c'][i]])
  dp = dp * (X-x[i])
caracter_constraint_p = ap/dp
ap = Polynomial([FieldElement.one()])
for f in accept_states:
  ap = ap * (state_polynomial - state_map[f])
accepting_constraint_p = ap/(X-x[-1])
dp = Polynomial([FieldElement.one()])
ap = Polynomial([FieldElement.one()])
i = 0
for s,c in product(states,alphabet):
  ap = ap * (transition_polynomial - state_map[transition[s][c]])
  dp = dp * (X-transi_x[i])
  i+=1
transition_constraint_p = ap/dp
dp = Polynomial([FieldElement.one()])
for i in range(len(trace['i'])):
  dp = dp * x[i]
step_constraint_p = step_polynomial/dp

In [125]:
p0 = initial_constraint_p
p1 = caracter_constraint_p
p2 = accepting_constraint_p
p3 = transition_constraint_p
p4 = step_constraint_p

a0 = channel.receive_random_field_element()
a1 = channel.receive_random_field_element()
a2 = channel.receive_random_field_element()
a3 = channel.receive_random_field_element()
a4 = channel.receive_random_field_element()

cp = a0*p0 + a1*p1 + a2*p2 + a3*p3 + a4*p4

In [126]:
cp_eval = [cp(x) for x in eval_domain]
cp_merkle = MerkleTree(cp_eval)
channel.send(cp_merkle.root)

In [106]:
def fold_domain(domain):
  return [x**2 for x in domain[:len(domain)//2]]

In [107]:
def fold(p, beta):
  odd = p.poly[1::2]
  even = p.poly[::2]
  o = Polynomial(odd)
  e = Polynomial(even)
  return e + beta*o

In [135]:
def next_fri_layer(poly, domain, beta):
    next_poly = fold(poly, beta)
    next_domain = fold_domain(domain)
    print(next_poly.degree())
    next_layer = [next_poly(x) for x in next_domain]
    return next_poly, next_domain, next_layer

In [133]:
def FriCommit(cp, domain, cp_eval, cp_merkle, channel):
    fri_polys = [cp]
    fri_domains = [domain]
    fri_layers = [cp_eval]
    fri_merkles = [cp_merkle]
    while fri_polys[-1].degree() > 0:
        beta = channel.receive_random_field_element()
        next_poly, next_domain, next_layer = next_fri_layer(fri_polys[-1], fri_domains[-1], beta)
        fri_polys.append(next_poly)
        fri_domains.append(next_domain)
        fri_layers.append(next_layer)
        print("layer", next_layer)
        fri_merkles.append(MerkleTree(next_layer))
        channel.send(fri_merkles[-1].root)
    channel.send(str(fri_polys[-1].poly[0]))
    return fri_polys, fri_domains, fri_layers, fri_merkles

In [136]:
fri_polys, fri_domains, fri_layers, fri_merkles = FriCommit(cp, eval_domain,cp_eval,cp_merkle,channel)

84
layer [-445678153, 984159111, -1508123738, -1198096378, 30116167, -861776134, 578954687, 1177941614, 739531732, -856639300, -1210110518, 752851358, 1546672125, -261680807, -431551355, 178991872, 6459272, -110167191, 788364583, -1576320180, 349393331, 626661471, 1239262201, -803445603, -505141053, 1147211764, -1177514387, 1178002675, -1237804822, 448940114, 1083198243, 458092431, -454798951, -29687943, 825211632, 762199749, 1432988855, 1556055642, 798958081, -921756726, 1097276008, -388675182, -293365110, 219864621, 1370006043, -229473443, 815302057, -1349507661, 1498046425, -1430791521, -135420712, 283927077, 208749796, 894223898, -790944870, -269019055, -1573200084, 1524444056, 1478739913, -1544810060, 614194946, 1338477191, 1308063390, 1263217797]
42
layer [-317571566, -1510872841, 1456525763, -677924651, -461727817, -1508461661, -169449806, 1327187257, -79265166, 325058912, 196010124, -421230530, 702768604, -338034894, -694630714, -368104929, 553045718, -1405977228, 1297238674, -

AssertionError: Cannot construct an empty Merkle Tree.

In [128]:
def decommit_on_fri_layers(idx, channel):
    for layer, merkle in zip(fri_layers[:-1], fri_merkles[:-1]):
        length = len(layer)
        idx = idx % length
        sib_idx = (idx + length // 2) % length
        channel.send(str(layer[idx]))
        channel.send(str(merkle.get_authentication_path(idx)))
        channel.send(str(layer[sib_idx]))
        channel.send(str(merkle.get_authentication_path(sib_idx)))
    channel.send(str(fri_layers[-1][0]))