In [35]:
import math
import random
import sympy


p = 2**256 - 1      # Mersenne
F = lambda x: x % p

# Fake bilinear group
g = 9
def Gexp(x):
    return pow(g, x % p, p)

def pairing_exp(a, b):
    return (a * b) % p


In [36]:
'''
witness: x
output: x^2 - 25 = 0

We should proof that we know x as x^2 = 25
w0 = 1 (constant), w1 = x, w2 = x^2
'''

# I. Arithmetization (R1CS Rank 1 Constraint System)

A = [
    [0, 0, 1],  # A1: x
]
B = [
    [1, 0, 0],  # B1: x
]
C = [
    [0, 0, 0],  # C1: x^2
]
K = [25]

N_CONSTRAINTS = 1
N_VARS = 3

print("R1CS ")
print(f"- {N_CONSTRAINTS} constraints")
print(f"- {N_VARS} vars -> vector")
print("A:", A)
print("B:", B)
print("C:", C)
print("K:", K)

R1CS 
- 1 constraints
- 3 vars -> vector
A: [[0, 0, 1]]
B: [[1, 0, 0]]
C: [[0, 0, 0]]
K: [25]


In [37]:
def check_r1cs(x, nb_constraint=2, nb_var=3):
    w = [1, F(x), F(x*x)]

    for i in range(nb_constraint):
        A_dot = sum(A[i][j] * w[j] for j in range(nb_var))
        B_dot = sum(B[i][j] * w[j] for j in range(nb_var))
        C_dot = sum(C[i][j] * w[j] for j in range(nb_var)) + K[i]

        print(f"Constraint {i+1}:  {A_dot} * {B_dot} =? {C_dot}")

        if A_dot * B_dot != C_dot:
            print("R1CS invalid")
            return False
    
    print("R1CS valid")
    return True

check_r1cs(4, N_CONSTRAINTS, N_VARS)
print("-----")
check_r1cs(5, N_CONSTRAINTS, N_VARS)
print("-----")
check_r1cs(7, N_CONSTRAINTS, N_VARS)

Constraint 1:  16 * 1 =? 25
R1CS invalid
-----
Constraint 1:  25 * 1 =? 25
R1CS valid
-----
Constraint 1:  49 * 1 =? 25
R1CS invalid


False

In [38]:
# II. Constraints to polynomial ( R1CS -> QAP Quadratic Arithmetic Program) 
# Transform linear constraints to a unique polynomial, verifiable with an single secret point

def lagrange_basis(i, n):
    X = sympy.symbols('X')
    basis = 1
    for j in range(1, n+1):
        if j != i:
            basis *= (X - j) / (i - j)
    return sympy.simplify(basis)

L = [lagrange_basis(i+1, N_CONSTRAINTS) for i in range(N_CONSTRAINTS)]

print("Constructed basic Lagrange polynomials.")

def poly_from_R1CS_vector(vec):
    """Build a polynomial (N_CONSTRAINTS-1) from L_i."""
    X = sympy.symbols('X')
    poly = sum(vec[i] * L[i] for i in range(N_CONSTRAINTS))
    return sympy.simplify(poly)

A_poly = [poly_from_R1CS_vector([A[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
B_poly = [poly_from_R1CS_vector([B[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
C_poly = [poly_from_R1CS_vector([C[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
K_poly = poly_from_R1CS_vector(K)

print("Polynomials A(t), B(t), C(t), K(t) generated.")

Constructed basic Lagrange polynomials.
Polynomials A(t), B(t), C(t), K(t) generated.


In [39]:
# III. Parameters setup

X = sympy.symbols('X')
tau = random.randint(1, p-1)
Z_poly = X - 1 

print("tau =", tau)

Q_poly = sum(A_poly[j]*B_poly[j] for j in range(N_VARS)) - sum(C_poly[j] for j in range(N_VARS)) - K_poly
H_poly = sympy.simplify(Q_poly / Z_poly)  # division polynomial

tau = 78290775536229259823666172122755423195876996549271692832388795852805491782866


In [40]:
# IV. Proof gen

def proof(x):
    w = [1, F(x), F(x*x)]

    def eval_poly(poly_list):
        return F(sum(F(poly_list[j].subs(X, tau)) * w[j] for j in range(N_VARS)))

    At = F(int(eval_poly(A_poly)))
    Bt = F(int(eval_poly(B_poly)))
    Ct = F(int(eval_poly(C_poly)))
    Kt = F(int(K_poly.subs(X, tau)))
    Ht = F(int(H_poly.subs(X, tau)))
    Zt = F(int(Z_poly.subs(X, tau)))

    # Blinding factors
    r_blinding = random.randint(1, p-1)
    s_blinding = random.randint(1, p-1)

    # commitments
    A_exp_proof = F(At + r_blinding * Zt)
    B_exp_proof = F(Bt + s_blinding * Zt)
    C_exp_proof = F((Ct + Kt + Ht * Zt) + At * s_blinding * Zt + Bt * r_blinding * Zt + r_blinding * s_blinding * Zt * Zt)

    return {
        "piA": Gexp(A_exp_proof),
        "A_exp_proof": A_exp_proof,
        "piB": Gexp(B_exp_proof),
        "B_exp_proof": B_exp_proof,
        "piC": Gexp(C_exp_proof),
        "C_exp_proof": C_exp_proof
    }

In [46]:
# Verification

def verify(proof):
    # Extract fields
    A_exp = proof["A_exp_proof"]
    B_exp = proof["B_exp_proof"]
    C_exp = proof["C_exp_proof"]

    # Verification equation for blinded proof
    lhs = pairing_exp(A_exp, B_exp)
    rhs = C_exp

    #print("LHS =", lhs)
    #print("RHS =", rhs)

    return lhs == rhs

In [50]:
print("Testing correct x = 5")
proof_correct = proof(5)
print(proof_correct)
print("Valid proof (correct x=5)?", verify(proof_correct))

print("\nTesting incorrect x = 9")
proof_incorrect_1 = proof(9)
print("Valid proof (incorrect x=9)?", verify(proof_incorrect_1))

for i in range (0,100000):
    proof_incorrect_2 = proof(10)
    if (verify(proof_incorrect_2) == True):
        print("Valid proof (incorrect x=10)?", verify(proof_incorrect_2))

Testing correct x = 5
{'piA': 104835999295654603657476229355715554111653244462919466445262817262748520923261, 'A_exp_proof': 90937934361344749437993845197280967999885197418327024621343922331380328137290, 'piB': 30930440501507156535753813917260247645972520006724428610626499322638402318131, 'B_exp_proof': 95807759777643543876207242121310044217442499406788488409539110162259977747066, 'piC': 68777713653840400377570165569686295810386344953909209045822755228490327080456, 'C_exp_proof': 11336948337370472737543263751665784656970936244600627319805052248838326781740}
Valid proof (correct x=5)? True

Testing incorrect x = 9
Valid proof (incorrect x=9)? False
