In [10]:
import numpy as np
from py_ecc.bn128 import G1, G2, pairing, multiply, add, eq
from py_ecc.bn128.bn128_curve import b, b2
import numpy as np
import galois
import functools
import galois 
import numpy as np
from py_ecc.bn128 import curve_order
from typing import Callable
from py_ecc.bn128 import eq, neg, pairing, final_exponentiate, FQ12
from functools import partial
from typing import List, Tuple, Union
# Use the curve_order from py_ecc.bn128
GF = galois.GF(curve_order, primitive_element=5, verify=False)


In [11]:
def mat_to_finite_field(mat : np.array) -> np.array:
    return (mat + curve_order) % curve_order

In [12]:
# Define the matrices
A = np.array([[0,0,3,0,0,0],
               [0,0,0,0,1,0],
               [0,0,1,0,0,0]])

B = np.array([[0,0,1,0,0,0],
               [0,0,0,1,0,0],
               [0,0,0,5,0,0]])

C = np.array([[0,0,0,0,1,0],
               [0,0,0,0,0,1],
               [-3,1,1,2,0,-1]])

# pick values for x and y
x = 100
y = 100

# this is our orignal formula
out = 3 * x * x * y + 5 * x * y - x- 2*y + 3
# the witness vector with the intermediate variables inside
v1 = 3*x*x
v2 = v1 * y
w = np.array([1, out, x, y, v1, v2])

F_A = GF(mat_to_finite_field(A))
F_B = GF(mat_to_finite_field(B))
F_C = GF(mat_to_finite_field(C))
F_W = GF(mat_to_finite_field(w))
n_rows = A.shape[0]


t = galois.Poly([1], field=GF)
for i in range(1, n_rows + 1):
    t *= galois.Poly([1, curve_order - i], field=GF)

def t_fn(tao : int, t : galois.Poly):
    return int(t(tao))


print(A.shape)
print(B.shape)
print(C.shape)
print(w.shape)
# do the trusted setup



(3, 6)
(3, 6)
(3, 6)
(6,)


In [13]:
#some helper functions

def poly_to_np(poly : galois.Poly, n_rows : int)-> np.array:
    elms = [int(coeff) for coeff in poly.coeffs]
    if len(elms) < n_rows:
        elms = elms + [0] * (n_rows - len(elms)) # pad with zeros if not enough elements
    return np.array(elms)

def pol_interp(mat : np.array) -> np.array:
    n_rows = mat.shape[0]
    def interpolate_column(col):
        xs = GF(np.arange(1, n_rows + 1))
        return galois.lagrange_poly(xs, col)
    
    interpolated = np.apply_along_axis(interpolate_column, 0, mat)
    return interpolated



In [37]:

def compute_inverse(scalar: int, curve_order: int) -> int:
    """
    Compute the multiplicative inverse of a scalar over a field with order curve_order.
    
    Args:
    scalar (int): The scalar to invert.
    curve_order (int): The order of the field.
    
    Returns:
    int: The multiplicative inverse of the scalar.
    """
    #check curve_order is a prime
    if not galois.is_prime(curve_order):
        raise ValueError("Curve order must be a prime number")
    assert scalar < curve_order
    
    return pow(scalar, -1, curve_order)



def eval_poly(coeffs: list[int], point: list):
    val = None    
    for i, coeff in enumerate(coeffs):
        if val is None:
            val = multiply(point[i], int(coeff))
        else:
            val = add(val, multiply(point[i], int(coeff)))
    
    return val




def inner_product_polynomials_with_witness(polys, witness):
    mul_ = lambda x, y: x * y
    sum_ = lambda x, y: x + y
    return functools.reduce(sum_, map(mul_, polys, witness))

In [83]:



# do the trusted setup
def trusted_setup(
    tao : int,
    alpha : int,
    beta : int,
    gamma : int,
    delta : int,
    t: Callable[[int], int],
    n_rows : int,
    n_cols : int,
    u_polys : np.array,
    v_polys : np.array,
    w_polys : np.array,
    l : int, # the number of elements to use for the public parameters
) -> tuple[
    list[G1],
    list[G2],
    list[G1],
    list[G1], # private psi
    list[G1], # public psi
    G1,
    G2,
    G2, # gamma * G2
    G2, # delta * G2
]:
    '''
    $$
    \begin{align*}[\alpha]_1&=\alpha G_1 \\ [\beta]_2&=\beta G2 \\  \mathbf{\Psi}&=(w_i(\tau)+\alpha v_i(\tau)+\beta u_i(\tau))G_1\vert_{i=1}^m\end{align*}
    $$
    '''
    assert l <= n_cols
    # [tao^dG1, tao^(d-1)G1, ..., taoG1, G1]
    
    
    srs1_int = [tao**i for i in range(n_rows-1, -1, -1)]
    srs1 = [multiply(G1, integer) for integer in srs1_int]
    
    # [tao^dG2, tao^(d-1)G2, ..., taoG2, G2]
    srs2_int = [tao**i for i in range(n_rows-1, -1, -1)]
    srs2 = [multiply(G2, integer) for integer in srs2_int]
    
    gamma_g2 = multiply(G2, gamma)
    delta_g2 = multiply(G2, delta)
    
    gamma_inv = compute_inverse(gamma, curve_order) # γ 
    delta_inv = compute_inverse(delta, curve_order) # δ 
    
    print(delta_inv)
    print(gamma_inv)
    print(curve_order)
    #[\dots,\tau^2t(\tau)G_1,\tau t(\tau)G_1,t(\tau)G_1]
    srs3_int = [((tao**i)*t(tao) * delta_inv) for i in range(n_rows-1, -1, -1)]
    srs3 = [multiply(G1, integer) for integer in srs3_int]
    
    
    #\begin{align*}[\alpha]_1&=\alpha G_1 \\ [\beta]_2&=\beta G2 \\  \mathbf{\Psi}&=(w_i(\tau)+\alpha v_i(\tau)+\beta u_i(\tau))G_1\vert_{i=1}^m\end{align*}
    
    public_thetas = [
        (w_polys[i](tao) + alpha * v_polys[i](tao) + beta * u_polys[i](tao) * gamma_inv)
        for i in range(l)
    ]
    private_thetas = [
        (w_polys[i](tao) + alpha * v_polys[i](tao) + beta * u_polys[i](tao) * delta_inv)
        for i in range(l, n_cols)
    ]
    
    private_psi = [multiply(G1, int(private_thetas[i])) for i in range(l-n_cols)]
    public_psi = [multiply(G1, int(public_thetas[i])) for i in range(l)]

    A_alpha = multiply(G1, alpha)
    B_beta = multiply(G2, beta)
    
    return srs1, srs2, srs3, private_psi, public_psi, A_alpha, B_beta, gamma_g2, delta_g2

u_polys = pol_interp(GF(mat_to_finite_field(A)))
v_polys = pol_interp(GF(mat_to_finite_field(B)))
w_polys = pol_interp(GF(mat_to_finite_field(C)))


(
    srs1, 
    srs2, 
    srs3, 
    private_psi, 
    public_psi, 
    A_alpha, 
    B_beta,
    gamma_g2,
    delta_g2,
) = trusted_setup(
    tao=2, 
    alpha=1, 
    beta=1, 
    gamma=1,
    delta=1,
    l = 3,
    t=partial(t_fn, t=t), 
    n_rows=A.shape[0],
    n_cols=A.shape[1],
    u_polys=u_polys,
    v_polys=v_polys,
    w_polys=w_polys,
)



1
1
21888242871839275222246405745257275088548364400416034343698204186575808495617


In [56]:
def compute_H(
    u_polys : np.array,
    v_polys : np.array,
    w_polys : np.array,
    witness : np.array,
    n_rows : int,
) -> np.array:
    term1 = inner_product_polynomials_with_witness(u_polys, witness)
    term2 = inner_product_polynomials_with_witness(v_polys, witness)
    term3 = inner_product_polynomials_with_witness(w_polys, witness)
    t = galois.Poly([1], field=GF)
    for i in range(1, n_rows + 1):
        t *= galois.Poly([1, curve_order - i], field=GF)

    h = (term1 * term2 - term3) // t
    return h

In [85]:


def prover(
    witness : np.array,
    srs1 : list[G1],
    srs2 : list[G2],
    psi_public : list[G1],
    A_alpha : G1,
    u_polys : np.array,
    v_polys : np.array,
    w_polys : np.array,
    B_beta : G2,
    n_rows : int,
    l : int, # the start index for witness
) -> tuple[
    G1, 
    G2, 
    G1, 
    np.array,
]:
    
    u_scaled = inner_product_polynomials_with_witness(u_polys, witness)
    v_scaled = inner_product_polynomials_with_witness(v_polys, witness)
    
    
    H = compute_H(u_polys, v_polys, w_polys, witness, n_rows)

    poly_h = poly_to_np(H, n_rows)
    poly_A = poly_to_np(u_scaled, n_rows)
    poly_B = poly_to_np(v_scaled, n_rows)
    
    A = add(eval_poly(poly_A, srs1), A_alpha)
    B = add(eval_poly(poly_B, srs2), B_beta)
    
    accum = eval_poly(poly_h, srs3)
    
    # to compute c only the l+1:m elements are used
    C = add(eval_poly(witness[l:], psi_public), accum) 

    return A, B, C, witness[:l]
    
PROVER_A, PROVER_B, PROVER_C, PUBLIC_WITNESS = prover(
    witness=F_W,
    srs1=srs1,
    srs2=srs2,
    psi_public=public_psi,
    u_polys=u_polys,
    v_polys=v_polys,
    w_polys=w_polys,
    A_alpha=A_alpha,
    B_beta=B_beta,
    n_rows=A.shape[0],
    l=3,
)

In [45]:


def verifier(
    A : G1,
    B : G2,
    C : G1,
    public_witness : List[int], 
    public_psi : List, # List[G1] the public psi
    Beta : G2,
    Alpha : G1,
    Gamma : G2, #  γ  
    Delta : G2, # δ 
) -> bool:
    
    
    c_pair = pairing(Delta, C)
    X = eval_poly(public_witness, public_psi)
    x_pair = pairing(Gamma, X)
    return eq(FQ12.one(), final_exponentiate(pairing(B, neg(A)) * pairing(Beta, Alpha) * c_pair * x_pair))





In [46]:
verifier(
    A=PROVER_A,
    B=PROVER_B,
    C=PROVER_C,
    public_witness=PUBLIC_WITNESS,
    public_psi=public_psi,
    Beta=B_beta,
    Alpha=A_alpha,
    Gamma=gamma_g2,
    Delta=delta_g2,
)

False