# GKR for data-parallel binary circuit


Succinct arguments (often referred to as ZK, a term misused by the industry) allow a prover to convince a verifier that a given statement is true using an extremely short proof. A major bottleneck is the prover's overhead, which significantly slows down computation compared to native execution. The slowdown usually comes from following reasons:
- The need for $N \log{N}$ algorithms, like Number Theoretical Transforms (NTT) for polynomials; or expansive linear algorithms such as Multi-Scalar Multiplications (MSM) for elliptic curve operations.
- The use of large field operations, such as the scalar field of BLS12-381, BN254 curves, etc.


Justin Thaler proposed the use of Mersenne prime field in his [paper](https://eprint.iacr.org/2013/351.pdf) in 2013 section 6.2. The idea is to use a smaller field that has faster arithmetic to replace larger field. A later [paper](https://eprint.iacr.org/2019/1482.pdf) in 2019 by Zhang et al. proposed the use of degree-2 extension field of Mersenne prime field to enable NTT and eventually leads to their adoptions in FRI polynomial commitment schemes.

However, none of these papers works well on binary circuits. The reason is that using Mersenee prime $2^{31} - 1$ to represent binary circuits is not efficient, wasted 30 bits of each field element. Or some people argue that we can use lookup table to speed up the computation, but each lookup operation costs $O(\log{N})$ field operations, which is not efficient for binary circuits.

In this paper, we propose a new method to speed up the computation of binary circuits. We use the GKR protocol to prove the satisfiability of a data-parallel binary circuit. We show that, regardless of the field choice, we can reduce the cost of $N$ binary operations to $\frac{N}{\log{N}}$ field operations for the prover, while keeping the GKR verifier unchanged. This is purely an algorithmic improvement, and does not require any protocol-level changes. Note that we will require a proper binary polynomial commitment scheme to work with GKR protocol.

## Our Improved Sumcheck Technique

Our main contribution is a novel approach to performing sumchecks that significantly reduces the prover's workload. This method leverages the nature that certain values are binary in our protocol. We pack multiple binary bits—specifically, $\frac{\log(N)}{3}$ bits—into a single field element during protocol execution. This packing is efficient because each gate in our system has two inputs, limiting the number of possible packed values to $N^{2/3}$, where N relates to the problem size.

To optimize further, we create a precomputed lookup table of size $N^{2/3}$ before running the main protocol. This table contains all possible evaluations of the packed values. During the sumcheck, we evaluate the sum in segments, with each segment sized at $\frac{\log(N)}{3}$ bits. For each segment, we can quickly retrieve the required value from our precomputed table, eliminating the need for on-the-fly calculations.

This approach combines clever bit packing, precomputation, and segmented evaluation to significantly speed up the sumcheck process, reducing the overall computational burden on the prover.

## Choice of the field
The technique itself is field agnostic, but we highly recommend following fields for performance or compatibility:
- Merseene prime fields: $GF(2^{31} - 1)$ for circuit descriptions and witnesses, with extension field $GF((2^{31} - 1)^3)$ for random challenges.
- Binary field: $GF(2)$ for circuit descriptions and witnesses, with extension field $GF(2^{128})$ for random challenges.
- BN254 for Ethereum compatibility.

In [154]:
# For demonstration purposes, we use a Mersenne prime modulo:

mod = 2**127 - 1 # Mersenne prime for demo

# Arithmetics in these fields usually involves at least 9 32bit multiplications and even more additions. 
# It's crucial to note that GF(2) and GF(2^128) are more recommended for practical use, especially in binary circuit applications.


## Preliminary
### Computation model
In this paper, we consider the RAM model for computation, where the cost of field operations is not constant, and the security parameter $\lambda = \omega(\log{N})$. The cost of field multiplication is $O((\frac{\lambda}{\log{N}})^2)$ and the cost of field addition is $O(\frac{\lambda}{\log{N}})$. We emphasize the computation model because we don't think "linear" number of field operations is a proper "linear" time algorithm. Our ultimate goal in this paper is to achieve $O(\frac{N}{\log{N}})$ field operations given $N$ is the boolean circuit size.

In [155]:
security_parameter_lambda = 128 # concrete security parameter


### Bit-level data parallel Layered Circuit

We consider a layered circuit with $D$ layers, where each layer consists of $|C_i|$ bit gates. To capture the computation we use today, we assume that each layer is divided into $\frac{3|C_i|}{\log{N}}$ segments, where each segment has a length of $\frac{\log{N}}{3}$. Each segment performs the same computation. This structure captures the fact that most programming languages today are 32/64-bit integer-oriented, and the computation is usually done in 8-bit, 32-bit, or 64-bit chunks. 

We use $V_i(g)$ to denote the value of gate $g$ in layer $i$. The index $g$ is composed of two indices $(p, b)$, where $p$ is the index of the segment and $b$ is the index of the bit within the segment.

In [156]:
segment_size = 8 # concrete parameter for \frac{\log{N}}{3}
lg_number_of_segments = 20
number_of_segments = 2**lg_number_of_segments # concrete example for \frac{3|C_i|}{\log{N}}

In [157]:
# raise to the power of n via square-and-multiply algorithm
def pow(x, n):
    if n == 0:
        return 1
    result = 1
    while n > 0:
        if n % 2 == 1:
            result = result * x % mod
        x = x * x % mod
        n = n // 2
    return result

# find inverse of x via Fermat's little theorem    
def inv(x):
    return pow(x, mod - 2)


### Modified Multi-linear extension
We use a modified multi-linear extension for the GKR protocol. Let $\tilde{p} = \{p_0, p_1, ..., p_{l-1}\}$, where $2^l$ is the number of segments, and $\tilde{b}$ is a univariate variable. The modified multi-linear extension of $V$ is the function $\tilde{V}: [0, 1]^{l}\times [0, \frac{\log{N}}{3} - 1]\rightarrow \mathbb{F}$ defined as follows:

$$\tilde{V}(\tilde{p}, \tilde{b}) := \sum\limits_{p, b}V(p, b) eq\_p\_p(p, \tilde{p}) eq\_b\_b(b, \tilde{b})$$ 

where $eq\_p\_p(p, \tilde{p}) := \prod\limits_{i=0}^{l-1}(p[i] * \tilde{p}[i] + (1 - p[i]) * (1 - \tilde{p}[i]))$ and $eq\_b\_b(b, \tilde{b}) = 1 - \frac{\prod\limits_{i\neq b}(i - \tilde{b})}{\prod\limits_{i\neq b}(i - b)}$. Here $\tilde{V}$ is multi-linear on $\tilde{p}$ and low-degree on $\tilde{b}$.

For the exact formal definition, please refer to the following Python code.

In [158]:
# Note: The eq functions and the tilde_V function presented below are for definitional clarity and are not optimized for computational efficiency. 

def eq_p_p(p, tilde_p):
    # p is an array of boolean values, tilde_p is an array of field elements (integer % mod)
    prod = 1
    for i in range(len(p)):
        prod = prod * (p[i] * tilde_p[i] + (1 - p[i]) * (1 - tilde_p[i])) % mod
    return prod

def eq_b_b(b, tilde_b):
    # 0 <= b < segment_size, tilde_b is an integer % mod
    lagrange_term_denominator = 1
    lagrange_term_numerator = 1
    for i in range(segment_size):
        if i != b:
            lagrange_term_denominator = lagrange_term_denominator * (i - b) % mod
            lagrange_term_numerator = lagrange_term_numerator * (i - tilde_b) % mod
    return (lagrange_term_numerator * inv(lagrange_term_denominator) % mod + mod - 1) % mod

# The explicit definition of \tidle{V} is here
def tilde_V(V, tilde_p, tilde_b):
    result = 0
    for p in range(number_of_segments):
        for b in range(segment_size):
            result = result + V[p][b] * eq_p_p(tilde_p, p) * eq_b_b(tilde_b, b) # here V[p][b] is binary value
            result = result % mod
    return result % mod



## Problem Description
We're proving the satisfiability of a data-parallel binary circuit with GKR, for example, the parallel Keccak circuits. We use $B = \frac{\log{N}}{3}$ to denote the number of parallel bits, and $N$ to denote the size of each independent circuit. As a result, the whole circuit has a size of $BN$.

Recall that the GKR protocol consists of several sumchecks to verify the correct evaluation of the layers, and the problem is essentially to perform sumchecks for a data-parallel binary layer. The following is the relation between multilinear extensions of different layers:

$$
\tilde{V}_{i+1}(\tilde{p}, \tilde{b}) = \sum_{b,x,y} eq\_b\_b(b, \tilde{b}) \left(Mul(\tilde{p}, x, y) \tilde{V}_i(x, b) \tilde{V}_i(y, b) + Add(\tilde{p}, x, y)(\tilde{V}_i(x, b) + \tilde{V}_i(y, b))\right)
$$

Here, $b$ is a selector of the bits, which ranges from $0$ to $B - 1$ in our case. We name the function $Mul$ as a wiring predicate; similarly, there is a wiring predicate for $Add$. The definitions of $Mul$ and $Add$ are in the code below. We explicitly define our arithmetic as follows:

In [159]:
# Constants for gate types
mul_type = 0  # Multiplication gate
add_type = 1  # Addition gate

class Layer:
    """Represents a single layer in the arithmetic circuit."""
    def __init__(self):
        self.gates = []   # List to store gates in this layer
        self.input = []   # Input values for this layer
        self.output = []  # Output values after evaluation

    def insert_gate(self, gate_type, input0, input1):
        """Adds a new gate to the layer."""
        self.gates.append([gate_type, input0, input1])

    def evaluate(self, input):
        """Evaluates all gates in the layer given an input."""
        result = []
        for ty, a, b in self.gates:
            for bit_index in range(segment_size):
                if ty == mul_type:
                    # Multiplication gate: multiply corresponding bits
                    result.append(input[a * segment_size + bit_index] * input[b * segment_size + bit_index] % mod)
                elif ty == add_type:
                    # Addition gate: add corresponding bits
                    result.append((input[a * segment_size + bit_index] + input[b * segment_size + bit_index]) % mod)
        self.output = result
        return result

    def wiring_predicate_mul(self, tilde_p, tilde_x, tilde_y):
        """Computes the wiring predicate for multiplication gates.
        This corresponds to the Mul function in the paper."""
        result = 0
        for i, (ty, a, b) in enumerate(self.gates):
            if ty == mul_type:
                # Sum up the contributions of each multiplication gate
                result = result + eq_p_p(i, tilde_p) * eq_p_p(a, tilde_x) * eq_p_p(b, tilde_y)
                result = result % mod
        return result

    def wiring_predicate_add(self, tilde_p, tilde_x, tilde_y):
        """Computes the wiring predicate for addition gates.
        This corresponds to the Add function (omitted in the paper)."""
        result = 0
        for i, (ty, a, b) in enumerate(self.gates):
            if ty == add_type:
                # Sum up the contributions of each addition gate
                result = result + eq_p_p(i, tilde_p) * eq_p_p(a, tilde_x) * eq_p_p(b, tilde_y)
                result = result % mod
        return result

class LayeredArithmeticCircuit:
    """Represents the entire layered arithmetic circuit."""
    def __init__(self):
        self.layers = []       # List to store all layers in the circuit
        self.layer_values = [] # Stores intermediate values after each layer evaluation

    def insert_layer(self, layer):
        """Adds a new layer to the circuit."""
        self.layers.append(layer)

    def evaluate(self, input):
        """Evaluates the entire circuit given an input."""
        for layer in self.layers:
            input = layer.evaluate(input)
        return input

# Note: The `eq_p_p` function, `mod`, and `segment_size` variables are not defined in this snippet.
# They should be defined elsewhere in the codebase.

## GKR protocol
To evaluate $\tilde{V}$ at some random point $\tilde{b}$, we have:

$$
\tilde{V}(x, \tilde{b}) = \sum_{b\in \{0..(B-1)\}} eq\_b\_b(b, \tilde{b}) V(x, b)
$$

where $eq\_b\_b$ is a lagrange polynomial. 

We found that the evaluation here is determined by the $B$ values of $V(x, b)$. Additionally, if $x$ is in the binary-hypercube, $V(x, b)$ refers to some input value of this layer, which is binary. As a result, there are at most $2^B$ different values of the array $[V(x, 0), V(x, 1), ..., V(x, B-1)]$.

Let's check how we utilize this in the GKR protocol.

The GKR protocol is an interactive protocol between a prover and a verifier on the circuit defined above. Let $d = len(layers)$ be the depth of the circuit, where layer 0 is the input layer and layer $d - 1$ is the output layer.

The first message is sent from the prover to the verifier, containing the whole output layer. In practice, the whole output layer is usually an all-zero array, so it can be a very simple message.

The protocol then works its way in iterations from the output layer to the input layer. Each iteration uses a sumcheck protocol to check the relationship between layer $i + 1$ and layer $i$, reducing a claim on layer $i + 1$ to a claim on layer $i$.


Concretely, the whole protocol starts with a claim on the output layer. The verifier cannot verify the claim without evaluating the circuit herself. However, evaluating the circuit is not succinct, so we want to find another way out. We can use the sumcheck protocol to perform a sumcheck on the following equation:

$$
\tilde{V}_{d-1}(\tilde{p}, \tilde{b}) = \sum_{b,x,y} eq\_b\_b(b, \tilde{b}) (Mul(\tilde{p}, x, y) \tilde{V}_{d-2}(x, b) \tilde{V}_{d-2}(y, b) + Add(\tilde{p}, x, y)(\tilde{V}_{d-2}(x, b) + \tilde{V}_{d-2}(y, b)))
$$

where $\tilde{V}_{d-1}(\tilde{p}, \tilde{b})$ is a claim from the output layer. By using a sumcheck protocol, it can reduce the claim from $\tilde{V}_{d-1}(\tilde{p}, \tilde{b})$ to $\tilde{V}_{d-2}(\tilde{p}', \tilde{b}')$, where $\tilde{p}', \tilde{b}'$ are fresh randomness generated by the sumcheck protocol.

We repeat the process until we reach the input layer, then we can send the whole input layer to the verifier or reveal the claim by using a polynomial commitment protocol.

## How to perform the sumcheck protocol efficiently?

For simplicity, we only consider the $Mul$ gate.

We would like to perform sumcheck on the variable $b$ first, i.e.
$$
\tilde{V}_{i+1}(\tilde{p}, \tilde{b}) = \sum_{x, y} Mul(\tilde{p}, x, y) \sum_b eq\_b\_b(b, \tilde{b})  \tilde{V}_i(x, b) \tilde{V}_i(y, b)
$$

Here, $\tilde{b}$ and $\tilde{p}$ are random queries sampled by the previous sumcheck protocol; they are now constant values.

### Observation
Let $g_{x, y}(b) = eq\_b\_b(b, \tilde{b}) \tilde{V}_i(x, b) \tilde{V}_i(y, b)$ be a polynomial of degree $3(B-1)$, uniquely determined by $\tilde{V}_i(x, b) \tilde{V}_i(y, b)$, where $0\le b < B$. Since we are dealing with a boolean circuit, the total number of different possible combinations of $\tilde{V}_i(x, b) \tilde{V}_i(y, b)$ is $2^{2B}$.

We can pre-compute all possible combinations of $\tilde{V}_i(x, b) \tilde{V}_i(y, b)$ into a pre-computed table, and the polynomial $g_{x, y}(b)$ can be computed by simply copying a value from the table. The pre-computation is defined by following code:

In [160]:
class BPolynomial:
    """
    Represents a polynomial defined by its evaluations.
    This is used to efficiently handle polynomials in the pre-computation process.
    """
    def __init__(self):
        self.evaluations = []  # The polynomial is defined by its evaluations at points 0 to B-1

    def insert_evaluation(self, evaluation):
        """Add an evaluation point to the polynomial."""
        self.evaluations.append(evaluation)

    def __mul__(self, other):
        """
        Multiply this polynomial with another.
        Note: Implementation is omitted for simplicity.
        In practice, this would perform polynomial multiplication.
        """
        pass  # Implement polynomial multiplication here

def pre_computation(tilde_b, B):
    """
    Perform pre-computation for the sumcheck optimization.
    
    Args:
    tilde_b: The random query point for b.
    B: The number of parallel bits (usually log(N)/3).
    
    Returns:
    A pre-computed table of polynomials for all possible combinations of Vᵢ(b,x) and Vᵢ(b,y).
    """
    eq_polynomial = BPolynomial()
    preprocessed_table = []
    
    # Compute the eq_b_b polynomial
    for b in range(B):
        eq_polynomial.insert_evaluation(eq_b_b(b, tilde_b))

    # Iterate over all possible combinations of Vᵢ(b,x) and Vᵢ(b,y)
    for vx in range(2**B):  # 2⁸ combinations for x
        vx_polynomial = BPolynomial()
        for b in range(B):
            # Convert integer to binary representation
            vx_polynomial.insert_evaluation((vx >> b) % 2)
        
        for vy in range(2**B):  # 2⁸ combinations for y
            vy_polynomial = BPolynomial()
            for b in range(B):
                # Convert integer to binary representation
                vy_polynomial.insert_evaluation((vy >> b) % 2)
            
            # Compute g_{x,y}(b) = eq_b_b(b, tilde_b) * Vᵢ(b,x) * Vᵢ(b,y)
            overall_polynomial = eq_polynomial * vx_polynomial * vy_polynomial
            
            # Store the computed polynomial and a placeholder for future use
            preprocessed_table.append([overall_polynomial, 0])

    return preprocessed_table
# 

overall, the preprocessed table has $2^{16}$ elements, each element is computed by 2 polynomials multiplications, so it has $2^{17}$ polynomial multiplications in total. The whole array takes $2^{16} * 128 * 3 * B$ bits or memory = 24MB. Here we have $2**16$ polynomials, each polynomial has degree $3B$, and each coefficient is $128$ bits this thing can fit into CPU L3 cache, so it's very fast to access

For the first round of sumcheck, we can first process the variable $b$. The prover is required to compute the polynomial:

$$
f(b) =  \sum_{x, y} Mul(\tilde{p}, x, y) eq\_b\_b(b, \tilde{b}) \tilde{V}_i(x, b) \tilde{V}_i(y, b) = \sum_{x, y} Mul(\tilde{p}, x, y) g_{x, y}(b)
$$

And send the polynomial to the verifier. The verifier only needs to check:

$$\sum_{b}f(b) == \tilde{V}_{i+1}(\tilde{p}, \tilde{b})$$

Here, $f(b)$ is a polynomial of degree $3(B-1)$ over the variable $b$.

### How to compute $f(b)$?

To do so, we will use the pre-computed table. We proceed as follows:


In [161]:
def compute_f_b(preprocessed_table, V_i: Layer, tilde_p):
    """
    Compute the polynomial f(b) for the sumcheck protocol.

    This function calculates f(b) = Σ_{x,y} tilde_Mul(tilde_p, x, y) * g_{x,y}(b)
    using the pre-computed table and the current layer's information.

    Args:
    preprocessed_table: The pre-computed table containing g_{x,y}(b) polynomials.
    V_i: The current layer in the arithmetic circuit.
    tilde_p: Random query point for p.

    Returns:
    BPolynomial: The resulting f(b) polynomial.
    """
    result = BPolynomial()

    # Iterate through all gates in the current layer
    for i in range(len(V_i.gates)):
        if V_i.gates[i][0] == mul_type:
            vx = 0
            vy = 0
            # Compute the coefficient for this gate
            mul_coefficient = eq_p_p(i, tilde_p)

            # Convert binary representations of inputs to integers
            for b in range(segment_size):
                value_x = V_i.output[V_i.gates[i][1] * segment_size + b]
                value_y = V_i.output[V_i.gates[i][2] * segment_size + b]
                vx = vx + value_x * (1 << b)
                vy = vy + value_y * (1 << b)

            # Retrieve the pre-computed g_{x,y}(b) polynomial
            gxy = preprocessed_table[vx * 2**segment_size + vy]

            # Add the contribution of this gate to the result
            result = result + gxy[0] * mul_coefficient

        if V_i.gates[i][0] == add_type:
            # TODO: Implement the computation for addition gates
            # This part is omitted for simplicity in the current implementation
            pass

    return result

# Note: This function assumes the existence of:
# - BPolynomial class for polynomial operations
# - eq_p_p function for computing equality predicates
# - mul_type and add_type constants for gate types
# - segment_size constant for the number of bits per segment

The process of computing $f(b)$ takes $\frac{3|C_i|}{\log{N}}$ polynomial additions/multiplications and random memory accesses. The number of field additions/multiplications is $\frac{3|C_i|}{\log{N}} * \frac{\log{N}}{3} = |C_i|$, which is linear in the size of the circuit layer. 

Can we do even better? Yes, check the following code:

In [162]:
def compute_f_b_optimized(preprocessed_table, V_i: Layer, tilde_p):
    """
    Compute the polynomial f(b) for the sumcheck protocol with optimization.

    This function calculates f(b) = Σ_{x,y} tilde_Mul(tilde_p, x, y) * g_{x,y}(b)
    using an optimized method that reduces the number of polynomial operations.

    Args:
    preprocessed_table: The pre-computed table containing g_{x,y}(b) polynomials and accumulators.
    V_i: The current layer in the arithmetic circuit.
    tilde_p: Random query point for p.

    Returns:
    BPolynomial: The resulting f(b) polynomial.
    """
    result = BPolynomial()

    # First pass: Accumulate coefficients in the preprocessed table
    for i in range(len(V_i.gates)):
        if V_i.gates[i][0] == mul_type:
            vx = 0
            vy = 0
            mul_coefficient = eq_p_p(i, tilde_p)

            # Convert binary representations of inputs to integers
            for b in range(segment_size):
                value_x = V_i.output[V_i.gates[i][1] * segment_size + b]
                value_y = V_i.output[V_i.gates[i][2] * segment_size + b]
                vx = vx + value_x * (1 << b)
                vy = vy + value_y * (1 << b)

            # Accumulate the coefficient in the preprocessed table
            # This is the key optimization: we're accumulating coefficients
            # instead of performing polynomial multiplications for each gate
            preprocessed_table[vx * 2**segment_size + vy][1] += mul_coefficient
            preprocessed_table[vx * 2**segment_size + vy][1] %= mod

        if V_i.gates[i][0] == add_type:
            # TODO: Implement the computation for addition gates
            # This part is omitted for simplicity in the current implementation
            pass

    # Second pass: Compute the final result
    for i in range(len(preprocessed_table)):
        # Multiply each pre-computed polynomial by its accumulated coefficient
        # and add to the result
        result = result + preprocessed_table[i][0] * preprocessed_table[i][1]

    return result

# Note: This function assumes the existence of:
# - BPolynomial class for polynomial operations
# - eq_p_p function for computing equality predicates
# - mul_type and add_type constants for gate types
# - segment_size constant for the number of bits per segment
# - mod constant for modular arithmetic
#
# Optimization details:
# - The function uses the second element of each preprocessed table entry as an accumulator
# - This reduces the number of polynomial operations from O(|C_i|) to O(2^(2*segment_size))
# - For large circuits, this can significantly reduce computation time

The optimized process now takes $\frac{|C_i|}{B}$ field additions and $2^{2B}$ polynomial additions, which overall amounts to $\frac{|C_i|}{B}$ field additions and $2^{2B} * (3(B - 1))$ field multiplications/additions. If $B = 8$, then

In [163]:
B = segment_size
cost = 2**(2 * B) * (3 * (B - 1))
import math
print("The second cost is", cost, ", log(cost) is", math.log2(cost))

The second cost is 1376256 , log(cost) is 20.39231742277876


## Next step: 
After the sumcheck of variable $b$, we will receive a random challenge $\tilde{b}'$ and proceed with the following sumcheck 
$$
f(\tilde{b}') = \sum_{x, y} Mul(\tilde{p}, x, y) \tilde{V}_i(x, \tilde{b}') \tilde{V}_i(y, \tilde{b}')
$$

Starting from here, we proceed with a normal sumcheck, which takes $O(\frac{|C_i|}{B})$ field operations.

## In summary
We pay an additonal cost of $O(B^2 * 2^{2B})$ field multiplications/additions for pre-computation and compute_f_b_optimized, we reduced the sumcheck problem of size $|C_i|$ to $\frac{|C_i|}{B}$ field operations.