In [35]:
from compiler.program import Program
from setup import Setup
from prover import Prover
from test.mini_poseidon import rc, mds, poseidon_hash
from utils import *

def prover_test():
    print("Beginning prover test")
    # powers should be 2^n so that we can use roots of unity for FFT
    # and should be bigger than len(coeffs) of polynomial to do KZG commitment
    # the value here is: powers = 4 * group_order
    # which is bigger than the order of quotient polynomial
    group_order = 8
    powers = group_order * 4
    setup = Setup.generate_srs(powers)

    program = Program(["e public", "c <== a * b", "e <== c * d"], group_order)
    assignments = {"a": 3, "b": 4, "c": 12, "d": 5, "e": 60}
    prover = Prover(setup, program)
    proof = prover.prove(assignments)
    print("Prover test success")
    return setup, proof, group_order

Setup

In [36]:
from utils import *
import py_ecc.bn128 as b
from curve import ec_lincomb, G1Point, G2Point
from compiler.program import CommonPreprocessedInput
from verifier import VerificationKey
from dataclasses import dataclass
from poly import Polynomial, Basis

@dataclass
class Setup(object):
    # https://github.com/sec-bit/learning-zkp/blob/develop/plonk-intro-cn/plonk-polycom.md#kzg10-%E6%9E%84%E9%80%A0
    #   ([1]₁, [x]₁, ..., [x^{d-1}]₁)
    # = ( G,    xG,  ...,  x^{d-1}G ), where G is a generator of G_1
    powers_of_x: list[G1Point]
    # [x]₂ = xH, where H is a generator of G_2
    X2: G2Point

    @classmethod
    def generate_srs(cls, powers: int):
        print("Start to generate structured reference string")
        # tau is a random number whatever you choose
        tau = 218313819403157342856071133

        # Initialize powers_of_x with 0 values
        powers_of_x = [0] * powers
        # powers_of_x[0] =  b.G1 * tau**0 = b.G1
        # powers_of_x[1] =  b.G1 * tau**1 = powers_of_x[0] * tau
        # powers_of_x[2] =  b.G1 * tau**2 = powers_of_x[1] * tau
        # ...
        # powers_of_x[i] =  b.G1 * tau**i = powers_of_x[i - 1] * tau
        powers_of_x[0] = b.G1

        for i in range(powers):
            if i > 0:
                powers_of_x[i] = b.multiply(powers_of_x[i - 1], tau)

        assert b.is_on_curve(powers_of_x[1], b.b)
        print("Generated G1 side, X^1 point: {}".format(powers_of_x[1]))

        X2 = b.multiply(b.G2, tau)
        assert b.is_on_curve(X2, b.b2)
        print("Generated G2 side, X^1 point: {}".format(X2))

        assert b.pairing(b.G2, powers_of_x[1]) == b.pairing(X2, b.G1)
        print("X^1 points checked consistent")
        print("Finished to generate structured reference string")

        return cls(powers_of_x, X2)

    # Encodes the KZG commitment that evaluates to the given values in the group
    def commit(self, values: Polynomial) -> G1Point:
        if (values.basis == Basis.LAGRANGE):
            # inverse FFT from Lagrange basis to monomial basis
            coeffs = values.ifft().values
        elif (values.basis == Basis.MONOMIAL):
            coeffs = values.values
        if len(coeffs) > len(self.powers_of_x):
            raise Exception("Not enough powers in setup")
        return ec_lincomb([(s, x) for s, x in zip(self.powers_of_x, coeffs)])

    # Generate the verification key for this program with the given setup
    def verification_key(self, pk: CommonPreprocessedInput) -> VerificationKey:
        return VerificationKey(
            pk.group_order,
            self.commit(pk.QM),
            self.commit(pk.QL),
            self.commit(pk.QR),
            self.commit(pk.QO),
            self.commit(pk.QC),
            self.commit(pk.S1),
            self.commit(pk.S2),
            self.commit(pk.S3),
            self.X2,
            Scalar.root_of_unity(pk.group_order),
        )


In [None]:
    # powers should be 2^n so that we can use roots of unity for FFT
    # and should be bigger than len(coeffs) of polynomial to do KZG commitment
    # the value here is: powers = 4 * group_order
    # which is bigger than the order of quotient polynomial
    group_order = 8
    powers = group_order * 4
    setup = Setup.generate_srs(powers)


Start to generate structured reference string
Generated G1 side, X^1 point: (13294353531659665076299264371299131321133377949180224052095139292042656767801, 7244526365924412580786759495774941482824109386590049888405102905649868841718)
Generated G2 side, X^1 point: ((19152636372783811233630472865897092822704646270638236258574538719932990329615, 5981775420279756813368727653284174372010905976217595528558124345993479956855), (15799937923252396087061091029963171915696870100807527253118347696384412331186, 9069635832515501441369801349489757510092424196796670011995129268707448587614))


In [8]:
from utils import *
from enum import Enum
from dataclasses import dataclass


class Column(Enum):
    LEFT = 1
    RIGHT = 2
    OUTPUT = 3

    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return NotImplemented

    @staticmethod
    def variants():
        return [Column.LEFT, Column.RIGHT, Column.OUTPUT]


@dataclass
class Cell:
    column: Column
    row: int

    def __key(self):
        return (self.row, self.column.value)

    def __hash__(self):
        return hash(self.__key())

    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.__key() < other.__key()
        return NotImplemented

    def __repr__(self) -> str:
        return "(" + str(self.row) + ", " + str(self.column.value) + ")"

    def __str__(self) -> str:
        return "(" + str(self.row) + ", " + str(self.column.value) + ")"

    # Outputs the label (an inner-field element) representing a given
    # (column, row) pair. Expects section = 1 for left, 2 right, 3 output
    def label(self, group_order: int) -> Scalar:
        assert self.row < group_order
        return Scalar.roots_of_unity(group_order)[self.row] * self.column.value


# Gets the key to use in the coeffs dictionary for the term for key1*key2,
# where key1 and key2 can be constant(''), a variable, or product keys
# Note that degrees higher than 2 are disallowed in the compiler, but we
# still allow them in the parser in case we find a way to compile them later
def get_product_key(key1, key2):
    members = sorted((key1 or "").split("*") + (key2 or "").split("*"))
    return "*".join([x for x in members if x])


def is_valid_variable_name(name: str) -> bool:
    return len(name) > 0 and name.isalnum() and name[0] not in "0123456789"


In [14]:

from typing import Optional
from dataclasses import dataclass


@dataclass
class GateWires:
    """Variable names for Left, Right, and Output wires."""

    L: Optional[str]
    R: Optional[str]
    O: Optional[str]

    def as_list(self) -> list[Optional[str]]:
        return [self.L, self.R, self.O]


@dataclass
class Gate:
    """Gate polynomial"""

    L: Scalar
    R: Scalar
    M: Scalar
    O: Scalar
    C: Scalar


@dataclass
class AssemblyEqn:
    """Assembly equation mapping wires to coefficients."""

    wires: GateWires
    coeffs: dict[Optional[str], int]

    def L(self) -> Scalar:
        return Scalar(-self.coeffs.get(self.wires.L, 0))

    def R(self) -> Scalar:
        if self.wires.R != self.wires.L:
            return Scalar(-self.coeffs.get(self.wires.R, 0))
        return Scalar(0)

    def C(self) -> Scalar:
        return Scalar(-self.coeffs.get("", 0))

    def O(self) -> Scalar:
        return Scalar(self.coeffs.get("$output_coeff", 1))

    def M(self) -> Scalar:
        if None not in self.wires.as_list():
            return Scalar(
                -self.coeffs.get(get_product_key(self.wires.L, self.wires.R), 0)
            )
        return Scalar(0)

    def gate(self) -> Gate:
        return Gate(self.L(), self.R(), self.M(), self.O(), self.C())


# Converts a arithmetic expression containing numbers, variables and {+, -, *}
# into a mapping of term to coefficient
#
# For example:
# ['a', '+', 'b', '*', 'c', '*', '5'] becomes {'a': 1, 'b*c': 5}
#
# Note that this is a recursive algo, so the input can be a mix of tokens and
# mapping expressions
#
def evaluate(exprs: list[str], first_is_negative=False) -> dict[Optional[str], int]:
    # Splits by + and - first, then *, to follow order of operations
    # The first_is_negative flag helps us correctly interpret expressions
    # like 6000 - 700 - 80 + 9 (that's 5229)
    if "+" in exprs:
        L = evaluate(exprs[: exprs.index("+")], first_is_negative)
        R = evaluate(exprs[exprs.index("+") + 1 :], False)
        return {x: L.get(x, 0) + R.get(x, 0) for x in set(L.keys()).union(R.keys())}
    elif "-" in exprs:
        L = evaluate(exprs[: exprs.index("-")], first_is_negative)
        R = evaluate(exprs[exprs.index("-") + 1 :], True)
        return {x: L.get(x, 0) + R.get(x, 0) for x in set(L.keys()).union(R.keys())}
    elif "*" in exprs:
        L = evaluate(exprs[: exprs.index("*")], first_is_negative)
        R = evaluate(exprs[exprs.index("*") + 1 :], first_is_negative)
        o = {}
        for k1 in L.keys():
            for k2 in R.keys():
                o[get_product_key(k1, k2)] = L[k1] * R[k2]
        return o
    elif len(exprs) > 1:
        raise Exception("No ops, expected sub-expr to be a unit: {}".format(exprs[1]))
    elif exprs[0][0] == "-":
        return evaluate([exprs[0][1:]], not first_is_negative)
    elif exprs[0].isnumeric():
        return {"": int(exprs[0]) * (-1 if first_is_negative else 1)}
    elif is_valid_variable_name(exprs[0]):
        return {exprs[0]: -1 if first_is_negative else 1}
    else:
        raise Exception("ok wtf is {}".format(exprs[0]))


# Converts an equation to a mapping of term to coefficient, and verifies that
# the operations in the equation are valid.
#
# Also outputs a triple containing the L and R input variables and the output
# variable
#
# Think of the list of (variable triples, coeffs) pairs as this language's
# version of "assembly"
#
# Example valid equations, and output:
# a === 9                      ([None, None, 'a'], {'': 9})
# b <== a * c                  (['a', 'c', 'b'], {'a*c': 1})
# d <== a * c - 45 * a + 987   (['a', 'c', 'd'], {'a*c': 1, 'a': -45, '': 987})
#
# Example invalid equations:
# 7 === 7                      # Can't assign to non-variable
# a <== b * * c                # Two times signs in a row
# e <== a + b * c * d          # Multiplicative degree > 2
#
def eq_to_assembly(eq: str) -> AssemblyEqn:
    tokens = eq.rstrip("\n").split(" ")
    if tokens[1] in ("<==", "==="):
        # First token is the output variable
        out = tokens[0]
        # Convert the expression to coefficient map form
        coeffs = evaluate(tokens[2:])
        # Handle the "-x === a * b" case
        if out[0] == "-":
            out = out[1:]
            coeffs["$output_coeff"] = -1
        # Check out variable name validity
        if not is_valid_variable_name(out):
            raise Exception("Invalid out variable name: {}".format(out))
        # Gather list of variables used in the expression
        variables = []
        for t in tokens[2:]:
            var = t.lstrip("-")
            if is_valid_variable_name(var) and var not in variables:
                variables.append(var)
        # Construct the list of allowed coefficients
        allowed_coeffs = variables + ["", "$output_coeff"]
        if len(variables) == 0:
            pass
        elif len(variables) == 1:
            variables.append(variables[0])
            allowed_coeffs.append(get_product_key(*variables))
        elif len(variables) == 2:
            allowed_coeffs.append(get_product_key(*variables))
        else:
            raise Exception("Max 2 variables, found {}".format(variables))
        # Check that only allowed coefficients are in the coefficient map
        for key in coeffs.keys():
            if key not in allowed_coeffs:
                raise Exception("Disallowed multiplication: {}".format(key))
        # Return output
        wires = variables + [None] * (2 - len(variables)) + [out]
        return AssemblyEqn(GateWires(wires[0], wires[1], wires[2]), coeffs)
    elif tokens[1] == "public":
        return AssemblyEqn(
            GateWires(tokens[0], None, None),
            {tokens[0]: -1, "$output_coeff": 0, "$public": True},
        )
    else:
        raise Exception("Unsupported op: {}".format(tokens[1]))


In [15]:
# A simple zk language, reverse-engineered to match https://zkrepl.dev/ output

from typing import Optional, Set
from poly import Polynomial, Basis


@dataclass
class CommonPreprocessedInput:
    """Common preprocessed input"""

    group_order: int
    # q_M(X) multiplication selector polynomial
    QM: Polynomial
    # q_L(X) left selector polynomial
    QL: Polynomial
    # q_R(X) right selector polynomial
    QR: Polynomial
    # q_O(X) output selector polynomial
    QO: Polynomial
    # q_C(X) constants selector polynomial
    QC: Polynomial
    # S_σ1(X) first permutation polynomial S_σ1(X)
    S1: Polynomial
    # S_σ2(X) second permutation polynomial S_σ2(X)
    S2: Polynomial
    # S_σ3(X) third permutation polynomial S_σ3(X)
    S3: Polynomial


class Program:
    constraints: list[AssemblyEqn]
    group_order: int

    def __init__(self, constraints: list[str], group_order: int):
        if len(constraints) > group_order:
            raise Exception("Group order too small")
        assembly = [eq_to_assembly(constraint) for constraint in constraints]
        self.constraints = assembly
        self.group_order = group_order

    def common_preprocessed_input(self) -> CommonPreprocessedInput:
        L, R, M, O, C = self.make_gate_polynomials()
        S = self.make_s_polynomials()
        return CommonPreprocessedInput(
            self.group_order,
            M,
            L,
            R,
            O,
            C,
            S[Column.LEFT],
            S[Column.RIGHT],
            S[Column.OUTPUT],
        )

    @classmethod
    def from_str(cls, constraints: str, group_order: int):
        lines = [line.strip() for line in constraints.split("\n")]
        return cls(lines, group_order)

    def coeffs(self) -> list[dict[Optional[str], int]]:
        return [constraint.coeffs for constraint in self.constraints]

    def wires(self) -> list[GateWires]:
        return [constraint.wires for constraint in self.constraints]

    def make_s_polynomials(self) -> dict[Column, Polynomial]:
        # For each variable, extract the list of (column, row) positions
        # where that variable is used
        variable_uses: dict[Optional[str], Set[Cell]] = {None: set()}
        for row, constraint in enumerate(self.constraints):
            for column, value in zip(Column.variants(), constraint.wires.as_list()):
                if value not in variable_uses:
                    variable_uses[value] = set()
                variable_uses[value].add(Cell(column, row))

        # Mark unused cells
        for row in range(len(self.constraints), self.group_order):
            for column in Column.variants():
                variable_uses[None].add(Cell(column, row))

        # For each list of positions, rotate by one.
        #
        # For example, if some variable is used in positions
        # (LEFT, 4), (LEFT, 7) and (OUTPUT, 2), then we store:
        #
        # at S[LEFT][7] the field element representing (LEFT, 4)
        # at S[OUTPUT][2] the field element representing (LEFT, 7)
        # at S[LEFT][4] the field element representing (OUTPUT, 2)

        S_values = {
            Column.LEFT: [Scalar(0)] * self.group_order,
            Column.RIGHT: [Scalar(0)] * self.group_order,
            Column.OUTPUT: [Scalar(0)] * self.group_order,
        }

        for _, uses in variable_uses.items():
            sorted_uses = sorted(uses)
            for i, cell in enumerate(sorted_uses):
                next_i = (i + 1) % len(sorted_uses)
                next_column = sorted_uses[next_i].column
                next_row = sorted_uses[next_i].row
                S_values[next_column][next_row] = cell.label(self.group_order)

        S = {}
        S[Column.LEFT] = Polynomial(S_values[Column.LEFT], Basis.LAGRANGE)
        S[Column.RIGHT] = Polynomial(S_values[Column.RIGHT], Basis.LAGRANGE)
        S[Column.OUTPUT] = Polynomial(S_values[Column.OUTPUT], Basis.LAGRANGE)

        return S

    # Get the list of public variable assignments, in order
    def get_public_assignments(self) -> list[Optional[str]]:
        coeffs = self.coeffs()
        o = []
        no_more_allowed = False
        for coeff in coeffs:
            if coeff.get("$public", False) is True:
                if no_more_allowed:
                    raise Exception("Public var declarations must be at the top")
                var_name = [x for x in list(coeff.keys()) if "$" not in str(x)][0]
                if coeff != {"$public": True, "$output_coeff": 0, var_name: -1}:
                    raise Exception("Malformatted coeffs: {}", format(coeffs))
                o.append(var_name)
            else:
                no_more_allowed = True
        return o

    # Generate the gate polynomials: L, R, M, O, C,
    # each a list of length `group_order`
    def make_gate_polynomials(
        self,
    ) -> tuple[Polynomial, Polynomial, Polynomial, Polynomial, Polynomial]:
        L = [Scalar(0) for _ in range(self.group_order)]
        R = [Scalar(0) for _ in range(self.group_order)]
        M = [Scalar(0) for _ in range(self.group_order)]
        O = [Scalar(0) for _ in range(self.group_order)]
        C = [Scalar(0) for _ in range(self.group_order)]
        for i, constraint in enumerate(self.constraints):
            gate = constraint.gate()
            L[i] = gate.L
            R[i] = gate.R
            M[i] = gate.M
            O[i] = gate.O
            C[i] = gate.C
        return (
            Polynomial(L, Basis.LAGRANGE),
            Polynomial(R, Basis.LAGRANGE),
            Polynomial(M, Basis.LAGRANGE),
            Polynomial(O, Basis.LAGRANGE),
            Polynomial(C, Basis.LAGRANGE),
        )

    # Attempts to "run" the program to fill in any intermediate variable
    # assignments, starting from the given assignments. Eg. if
    # `starting_assignments` contains {'a': 3, 'b': 5}, and the first line
    # says `c <== a * b`, then it fills in `c: 15`.
    def fill_variable_assignments(
        self, starting_assignments: dict[Optional[str], int]
    ) -> dict[Optional[str], int]:
        out = {k: Scalar(v) for k, v in starting_assignments.items()}
        out[None] = Scalar(0)
        for constraint in self.constraints:
            wires = constraint.wires
            coeffs = constraint.coeffs
            in_L = wires.L
            in_R = wires.R
            output = wires.O
            out_coeff = coeffs.get("$output_coeff", 1)
            product_key = get_product_key(in_L, in_R)
            if output is not None and out_coeff in (-1, 1):
                new_value = (
                    Scalar(
                        coeffs.get("", 0)
                        + out[in_L] * coeffs.get(in_L, 0)
                        + out[in_R] * coeffs.get(in_R, 0) * (1 if in_R != in_L else 0)
                        + out[in_L] * out[in_R] * coeffs.get(product_key, 0)
                    )
                    * out_coeff
                )  # should be / but equivalent for (1, -1)
                if output in out:
                    if out[output] != new_value:
                        raise Exception(
                            "Failed assertion: {} = {}".format(out[output], new_value)
                        )
                else:
                    out[output] = new_value
                    # print('filled in:', output, out[output])
        return {k: v.n for k, v in out.items()}


In [16]:
program = Program(["e public", "c <== a * b", "e <== c * d"], group_order)


In [20]:
assert program.group_order == 8

In [21]:
    assignments = {"a": 3, "b": 4, "c": 12, "d": 5, "e": 60}

In [23]:
from compiler.program import Program, CommonPreprocessedInput
from utils import *
from setup import *
from typing import Optional
from dataclasses import dataclass
from transcript import Transcript, Message1, Message2, Message3, Message4, Message5
from poly import Polynomial, Basis


@dataclass
class Proof:
    msg_1: Message1
    msg_2: Message2
    msg_3: Message3
    msg_4: Message4
    msg_5: Message5

    def flatten(self):
        proof = {}
        proof["a_1"] = self.msg_1.a_1
        proof["b_1"] = self.msg_1.b_1
        proof["c_1"] = self.msg_1.c_1
        proof["z_1"] = self.msg_2.z_1
        proof["W_t"] = self.msg_3.W_t
        proof["a_eval"] = self.msg_4.a_eval
        proof["b_eval"] = self.msg_4.b_eval
        proof["c_eval"] = self.msg_4.c_eval
        proof["ql_eval"] = self.msg_4.ql_eval
        proof["qr_eval"] = self.msg_4.qr_eval
        proof["qm_eval"] = self.msg_4.qm_eval
        proof["qo_eval"] = self.msg_4.qo_eval
        proof["qc_eval"] = self.msg_4.qc_eval
        proof["s1_eval"] = self.msg_4.s1_eval
        proof["s2_eval"] = self.msg_4.s2_eval
        proof["s3_eval"] = self.msg_4.s3_eval
        proof["z_eval"] = self.msg_4.z_eval
        proof["zw_eval"] = self.msg_4.zw_eval
        proof["t_eval"] = self.msg_4.t_eval
        proof["W_a"] = self.msg_5.W_a
        proof["W_a_quot"] = self.msg_5.W_a_quot
        proof["W_b"] = self.msg_5.W_b
        proof["W_b_quot"] = self.msg_5.W_b_quot
        proof["W_c"] = self.msg_5.W_c
        proof["W_c_quot"] = self.msg_5.W_c_quot
        proof["W_ql"] = self.msg_5.W_ql
        proof["W_ql_quot"] = self.msg_5.W_ql_quot
        proof["W_qr"] = self.msg_5.W_qr
        proof["W_qr_quot"] = self.msg_5.W_qr_quot
        proof["W_qm"] = self.msg_5.W_qm
        proof["W_qm_quot"] = self.msg_5.W_qm_quot
        proof["W_qo"] = self.msg_5.W_qo
        proof["W_qo_quot"] = self.msg_5.W_qo_quot
        proof["W_qc"] = self.msg_5.W_qc
        proof["W_qc_quot"] = self.msg_5.W_qc_quot
        proof["W_s1"] = self.msg_5.W_s1
        proof["W_s1_quot"] = self.msg_5.W_s1_quot
        proof["W_s2"] = self.msg_5.W_s2
        proof["W_s2_quot"] = self.msg_5.W_s2_quot
        proof["W_s3"] = self.msg_5.W_s3
        proof["W_s3_quot"] = self.msg_5.W_s3_quot
        proof["W_z"] = self.msg_5.W_z
        proof["W_z_quot"] = self.msg_5.W_z_quot
        proof["W_zw"] = self.msg_5.W_zw
        proof["W_zw_quot"] = self.msg_5.W_zw_quot
        proof["W_t"] = self.msg_5.W_t
        proof["W_t_quot"] = self.msg_5.W_t_quot
        return proof


@dataclass
class Prover:
    group_order: int
    setup: Setup
    program: Program
    pk: CommonPreprocessedInput

    def __init__(self, setup: Setup, program: Program):
        self.group_order = program.group_order
        self.setup = setup
        self.program = program
        self.pk = program.common_preprocessed_input()

    def prove(self, witness: dict[Optional[str], int]) -> Proof:
        # Initialise Fiat-Shamir transcript
        transcript = Transcript(b"plonk")

        # Collect fixed and public information
        # FIXME: Hash pk and PI into transcript
        public_vars = self.program.get_public_assignments()
        # Public input polynomial
        PI = Polynomial(
            [Scalar(-witness[v]) for v in public_vars]
            + [Scalar(0) for _ in range(self.group_order - len(public_vars))],
            Basis.LAGRANGE,
        )
        self.PI = PI

        # Round 1
        msg_1 = self.round_1(witness)
        self.beta, self.gamma = transcript.round_1(msg_1)

        # Round 2
        msg_2 = self.round_2()
        self.alpha, self.fft_cofactor = transcript.round_2(msg_2)

        # Round 3
        msg_3 = self.round_3()
        self.zeta = transcript.round_3(msg_3)

        # Round 4
        msg_4 = self.round_4()
        self.v = transcript.round_4(msg_4)

        # Round 5
        msg_5 = self.round_5()

        return Proof(msg_1, msg_2, msg_3, msg_4, msg_5)

    def round_1(
        self,
        witness: dict[Optional[str], int],
    ) -> Message1:
        # https://github.com/sec-bit/learning-zkp/blob/develop/plonk-intro-cn/plonk-arithmetization.md
        program = self.program
        setup = self.setup
        group_order = self.group_order

        if None not in witness:
            witness[None] = 0

        # Compute wire assignments
        A_values = [Scalar(0) for _ in range(group_order)]
        B_values = [Scalar(0) for _ in range(group_order)]
        C_values = [Scalar(0) for _ in range(group_order)]

        for i, gate_wires in enumerate(program.wires()):
            A_values[i] = Scalar(witness[gate_wires.L])
            B_values[i] = Scalar(witness[gate_wires.R])
            C_values[i] = Scalar(witness[gate_wires.O])

        self.A = Polynomial(A_values, Basis.LAGRANGE)
        self.B = Polynomial(B_values, Basis.LAGRANGE)
        self.C = Polynomial(C_values, Basis.LAGRANGE)

        a_1 = setup.commit(self.A)
        b_1 = setup.commit(self.B)
        c_1 = setup.commit(self.C)

        # Sanity check that witness fulfils gate constraints
        assert (
            self.A * self.pk.QL
            + self.B * self.pk.QR
            + self.A * self.B * self.pk.QM
            + self.C * self.pk.QO
            + self.PI
            + self.pk.QC
            == Polynomial([Scalar(0)] * group_order, Basis.LAGRANGE)
        )

        return Message1(a_1, b_1, c_1)

    def round_2(self) -> Message2:
        # https://github.com/sec-bit/learning-zkp/blob/develop/plonk-intro-cn/plonk-permutation.md#%E5%AE%8C%E6%95%B4%E7%9A%84%E7%BD%AE%E6%8D%A2%E5%8D%8F%E8%AE%AE
        group_order = self.group_order
        setup = self.setup

        Z_values = [Scalar(1)]
        roots_of_unity = Scalar.roots_of_unity(group_order)
        for i in range(group_order):
            Z_values.append(
                Z_values[-1]
                * self.rlc(self.A.values[i], roots_of_unity[i])
                * self.rlc(self.B.values[i], 2 * roots_of_unity[i])
                * self.rlc(self.C.values[i], 3 * roots_of_unity[i])
                / self.rlc(self.A.values[i], self.pk.S1.values[i])
                / self.rlc(self.B.values[i], self.pk.S2.values[i])
                / self.rlc(self.C.values[i], self.pk.S3.values[i])
            )
        # The last value is 1
        assert Z_values.pop() == 1

        # Sanity-check that Z was computed correctly
        for i in range(group_order):
            assert (
                self.rlc(self.A.values[i], roots_of_unity[i])
                * self.rlc(self.B.values[i], 2 * roots_of_unity[i])
                * self.rlc(self.C.values[i], 3 * roots_of_unity[i])
            ) * Z_values[i] - (
                self.rlc(self.A.values[i], self.pk.S1.values[i])
                * self.rlc(self.B.values[i], self.pk.S2.values[i])
                * self.rlc(self.C.values[i], self.pk.S3.values[i])
            ) * Z_values[
                (i + 1) % group_order
            ] == 0

        Z = Polynomial(Z_values, Basis.LAGRANGE)
        z_1 = setup.commit(Z)
        print("Permutation accumulator polynomial successfully generated")

        self.Z = Z
        return Message2(z_1)

    def round_3(self) -> Message3:
        # https://github.com/sec-bit/learning-zkp/blob/develop/plonk-intro-cn/plonk-constraints.md
        group_order = self.group_order
        setup = self.setup

        # Compute the quotient polynomial

        alpha = self.alpha

        roots_of_unity = Scalar.roots_of_unity(group_order)

        A_coeff, B_coeff, C_coeff, S1_coeff, S2_coeff, S3_coeff, Z_coeff, QL_coeff, QR_coeff, QM_coeff, QO_coeff, QC_coeff, PI_coeff = (
            x.ifft()
            for x in (
                self.A,
                self.B,
                self.C,
                self.pk.S1,
                self.pk.S2,
                self.pk.S3,
                self.Z,
                self.pk.QL,
                self.pk.QR,
                self.pk.QM,
                self.pk.QO,
                self.pk.QC,
                self.PI,
            )
        )

        L0_coeff = (
            Polynomial([Scalar(1)] + [Scalar(0)] * (group_order - 1), Basis.LAGRANGE)
        ).ifft()

        # x^8 - 1 coeffs are [-1, 0, 0, 0, 0, 0, 0, 0, 1]
        # which needs 9 points(n + 1) to determine the polynomial
        ZH_array = [Scalar(-1)] + [Scalar(0)] * (group_order - 1) + [Scalar(1)]
        ZH_coeff = Polynomial(ZH_array, Basis.MONOMIAL)

        gate_constraints_coeff = (
            A_coeff * QL_coeff
            + B_coeff * QR_coeff
            + A_coeff * B_coeff * QM_coeff
            + C_coeff * QO_coeff
            + PI_coeff
            + QC_coeff
        )

        normal_roots = Polynomial(
            roots_of_unity, Basis.LAGRANGE
        )

        roots_coeff = normal_roots.ifft()
        # z * w
        ZW = self.Z.shift(1)
        ZW_coeff = ZW.ifft()

        for i in range(group_order):
            assert (
                self.rlc(self.A.values[i], roots_of_unity[i])
                * self.rlc(self.B.values[i], 2 * roots_of_unity[i])
                * self.rlc(self.C.values[i], 3 * roots_of_unity[i])
            ) * self.Z.values[i] - (
                self.rlc(self.A.values[i], self.pk.S1.values[i])
                * self.rlc(self.B.values[i], self.pk.S2.values[i])
                * self.rlc(self.C.values[i], self.pk.S3.values[i])
            ) * ZW.values[
                i % group_order
            ] == 0

        permutation_grand_product_coeff = (
            (
                self.rlc(A_coeff, roots_coeff)
                * self.rlc(B_coeff, roots_coeff * Scalar(2))
                * self.rlc(C_coeff, roots_coeff * Scalar(3))
            )
            * Z_coeff
            - (
                self.rlc(A_coeff, S1_coeff)
                * self.rlc(B_coeff, S2_coeff)
                * self.rlc(C_coeff, S3_coeff)
            )
            * ZW_coeff
        )

        permutation_first_row_coeff = (Z_coeff - Scalar(1)) * L0_coeff

        all_constraints = (
            gate_constraints_coeff
            + permutation_grand_product_coeff * alpha
            + permutation_first_row_coeff * alpha**2
        )

        # quotient polynomial
        T_coeff = all_constraints / ZH_coeff

        print("Generated the quotient polynomial")

        W_t = setup.commit(T_coeff)

        self.A_coeff = A_coeff
        self.B_coeff = B_coeff
        self.C_coeff = C_coeff
        self.S1_coeff = S1_coeff
        self.S2_coeff = S2_coeff
        self.S3_coeff = S3_coeff
        self.Z_coeff = Z_coeff
        self.ZW_coeff = ZW_coeff
        self.QL_coeff = QL_coeff
        self.QR_coeff = QR_coeff
        self.QM_coeff = QM_coeff
        self.QO_coeff = QO_coeff
        self.QC_coeff = QC_coeff
        self.PI_coeff = PI_coeff
        self.T_coeff = T_coeff

        return Message3(W_t)

    def round_4(self) -> Message4:
        # https://github.com/sec-bit/learning-zkp/blob/develop/plonk-intro-cn/plonk-constraints.md
        group_order = self.group_order
        zeta = self.zeta

        a_eval = self.A_coeff.coeff_eval(zeta)
        b_eval = self.B_coeff.coeff_eval(zeta)
        c_eval = self.C_coeff.coeff_eval(zeta)
        s1_eval = self.S1_coeff.coeff_eval(zeta)
        s2_eval = self.S2_coeff.coeff_eval(zeta)
        s3_eval = self.S3_coeff.coeff_eval(zeta)
        root_of_unity = Scalar.root_of_unity(group_order)
        z_eval = self.Z_coeff.coeff_eval(zeta)
        zw_eval = self.Z_coeff.coeff_eval(zeta * root_of_unity)
        ql_eval = self.QL_coeff.coeff_eval(zeta)
        qr_eval = self.QR_coeff.coeff_eval(zeta)
        qm_eval = self.QM_coeff.coeff_eval(zeta)
        qo_eval = self.QO_coeff.coeff_eval(zeta)
        qc_eval = self.QC_coeff.coeff_eval(zeta)
        t_eval = self.T_coeff.coeff_eval(zeta)

        self.a_eval = a_eval
        self.b_eval = b_eval
        self.c_eval = c_eval
        self.ql_eval = ql_eval
        self.qr_eval = qr_eval
        self.qm_eval = qm_eval
        self.qo_eval = qo_eval
        self.qc_eval = qc_eval
        self.s1_eval = s1_eval
        self.s2_eval = s2_eval
        self.s3_eval = s3_eval
        self.z_eval = z_eval
        self.zw_eval = zw_eval
        self.t_eval = t_eval

        return Message4(
            a_eval,
            b_eval,
            c_eval,
            ql_eval,
            qr_eval,
            qm_eval,
            qo_eval,
            qc_eval,
            s1_eval,
            s2_eval,
            s3_eval,
            z_eval,
            zw_eval,
            t_eval
        )

    def round_5(self) -> Message5:
        W_a, W_a_quot = self.generate_commitment(self.A_coeff, self.a_eval)
        W_b, W_b_quot = self.generate_commitment(self.B_coeff, self.b_eval)
        W_c, W_c_quot = self.generate_commitment(self.C_coeff, self.c_eval)
        W_ql, W_ql_quot = self.generate_commitment(self.QL_coeff, self.ql_eval)
        W_qr, W_qr_quot = self.generate_commitment(self.QR_coeff, self.qr_eval)
        W_qm, W_qm_quot = self.generate_commitment(self.QM_coeff, self.qm_eval)
        W_qo, W_qo_quot = self.generate_commitment(self.QO_coeff, self.qo_eval)
        W_qc, W_qc_quot = self.generate_commitment(self.QC_coeff, self.qc_eval)
        W_s1, W_s1_quot = self.generate_commitment(self.S1_coeff, self.s1_eval)
        W_s2, W_s2_quot = self.generate_commitment(self.S2_coeff, self.s2_eval)
        W_s3, W_s3_quot = self.generate_commitment(self.S3_coeff, self.s3_eval)
        W_z, W_z_quot = self.generate_commitment(self.Z_coeff, self.z_eval)
        W_zw, W_zw_quot = self.generate_commitment(self.ZW_coeff, self.zw_eval)
        W_t, W_t_quot = self.generate_commitment(self.T_coeff, self.t_eval)

        print("Generated final quotient witness polynomials")
        return Message5(
            W_a, W_a_quot,
            W_b, W_b_quot,
            W_c, W_c_quot,
            W_ql, W_ql_quot,
            W_qr, W_qr_quot,
            W_qm, W_qm_quot,
            W_qo, W_qo_quot,
            W_qc, W_qc_quot,
            W_s1, W_s1_quot,
            W_s2, W_s2_quot,
            W_s3, W_s3_quot,
            W_z, W_z_quot,
            W_zw, W_zw_quot,
            W_t, W_t_quot,
        )

    def rlc(self, term_1, term_2):
        return term_1 + term_2 * self.beta + self.gamma

    def generate_commitment(self, coeff: Polynomial, eval: Scalar):
        setup = self.setup
        zeta = self.zeta
        # Polynomial for (X - zeta)
        ZH_zeta_coeff = Polynomial([-zeta, Scalar(1)], Basis.MONOMIAL)
        quot_coeff = (coeff - eval) / ZH_zeta_coeff
        # witness for polynomial itself
        w = setup.commit(coeff)
        # witness for quotient polynomial
        w_quot = setup.commit(quot_coeff)
        return w, w_quot


In [25]:
Proof

__main__.Proof

In [27]:
prover = Prover(setup, program)

In [28]:
proof = prover.prove(assignments)


Permutation accumulator polynomial successfully generated
Generated the quotient polynomial
Generated final quotient witness polynomials


Verify

In [29]:
    program = Program(["e public", "c <== a * b", "e <== c * d"], group_order)


In [31]:
    vk = setup.verification_key(program.common_preprocessed_input())


In [33]:
    public = [60]


In [34]:
    assert vk.verify_proof(group_order, proof, public)


Done KZG10 commitment check for a_eval polynomial
Done KZG10 commitment check for b_eval polynomial
Done KZG10 commitment check for c_eval polynomial
Done KZG10 commitment check for z_eval polynomial
Done KZG10 commitment check for zw_eval polynomial
Done KZG10 commitment check for t_eval polynomial
Done KZG10 commitment check for ql_eval polynomial
Done KZG10 commitment check for qr_eval polynomial
Done KZG10 commitment check for qm_eval polynomial
Done KZG10 commitment check for qo_eval polynomial
Done KZG10 commitment check for qc_eval polynomial
Done KZG10 commitment check for s1_eval polynomial
Done KZG10 commitment check for s2_eval polynomial
Done KZG10 commitment check for s3_eval polynomial
Done equation check for all constraints
