In [2]:
import math
import random
import sympy


p = 2**127 - 1      # Mersenne
F = lambda x: x % p

# Fake bilinear group
g = 9
def Gexp(x):
    return pow(g, x % p, p)

def pairing_exp(a, b):
    return (a * b) % p


In [3]:
'''
witness: x
output: x^2 - 25 = 0

We should proof that we know x as x^2 = 25
w0 = 1 (constant), w1 = x, w2 = x^2
'''

# I. Arithmetization (R1CS Rank 1 Constraint System)

A = [
    [0, 0, 1],  # A1: x
]
B = [
    [1, 0, 0],  # B1: x
]
C = [
    [0, 0, 0],  # C1: x^2
]
K = [25]

N_CONSTRAINTS = 1
N_VARS = 3

print("R1CS ")
print(f"- {N_CONSTRAINTS} constraints")
print(f"- {N_VARS} vars -> vector")
print("A:", A)
print("B:", B)
print("C:", C)
print("K:", K)

R1CS 
- 1 constraints
- 3 vars -> vector
A: [[0, 0, 1]]
B: [[1, 0, 0]]
C: [[0, 0, 0]]
K: [25]


In [4]:
def check_r1cs(x, nb_constraint=2, nb_var=3):
    w = [1, F(x), F(x*x)]

    for i in range(nb_constraint):
        A_dot = sum(A[i][j] * w[j] for j in range(nb_var))
        B_dot = sum(B[i][j] * w[j] for j in range(nb_var))
        C_dot = sum(C[i][j] * w[j] for j in range(nb_var)) + K[i]

        print(f"Constraint {i+1}:  {A_dot} * {B_dot} =? {C_dot}")

        if A_dot * B_dot != C_dot:
            print("R1CS invalid")
            return False
    
    print("R1CS valid")
    return True

check_r1cs(4, N_CONSTRAINTS, N_VARS)
print("-----")
check_r1cs(5, N_CONSTRAINTS, N_VARS)
print("-----")
check_r1cs(7, N_CONSTRAINTS, N_VARS)

Constraint 1:  16 * 1 =? 25
R1CS invalid
-----
Constraint 1:  25 * 1 =? 25
R1CS valid
-----
Constraint 1:  49 * 1 =? 25
R1CS invalid


False

In [5]:
# II. Constraints to polynomial ( R1CS -> QAP Quadratic Arithmetic Program) 
# Transform linear constraints to a unique polynomial, verifiable with an single secret point

def lagrange_basis(i, n):
    X = sympy.symbols('X')
    basis = 1
    for j in range(1, n+1):
        if j != i:
            basis *= (X - j) / (i - j)
    return sympy.simplify(basis)

L = [lagrange_basis(i+1, N_CONSTRAINTS) for i in range(N_CONSTRAINTS)]

print("Constructed basic Lagrange polynomials.")

def poly_from_R1CS_vector(vec):
    """Build a polynomial (N_CONSTRAINTS-1) from L_i."""
    X = sympy.symbols('X')
    poly = sum(vec[i] * L[i] for i in range(N_CONSTRAINTS))
    return sympy.simplify(poly)

A_poly = [poly_from_R1CS_vector([A[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
B_poly = [poly_from_R1CS_vector([B[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
C_poly = [poly_from_R1CS_vector([C[i][j] for i in range(N_CONSTRAINTS)]) for j in range(N_VARS)]
K_poly = poly_from_R1CS_vector(K)

print("Polynomials A(t), B(t), C(t), K(t) generated.")

Constructed basic Lagrange polynomials.
Polynomials A(t), B(t), C(t), K(t) generated.


In [8]:
# III. Parameters setup

tau = random.randint(1, p-1)
Z_poly = (sympy.symbols('X') - 1)*(sympy.symbols('X') - 2)

print("tau =", tau)

X = sympy.symbols('X')
Q_poly = sum(A_poly[j]*B_poly[j] for j in range(N_VARS)) - sum(C_poly[j] for j in range(N_VARS)) - K_poly
H_poly = sympy.simplify(Q_poly / Z_poly)  # division polynomiale


tau = 16385464519465675185522296871976395680


In [9]:
# IV. Proof gen

def gen_proof(x):
    w = [1, F(x), F(x*x)]

    def eval_poly(poly_list):
        return F(sum(F(poly_list[j].subs(X, tau)) * w[j] for j in range(N_VARS)))

    At = F(int(eval_poly(A_poly)))
    Bt = F(int(eval_poly(B_poly)))
    Ct = F(int(eval_poly(C_poly)))
    Kt = F(int(K_poly.subs(X, tau)))
    Ht = F(int(H_poly.subs(X, tau)))
    Zt = F(int(Z_poly.subs(X, tau)))


    # commitments
    piA = Gexp(At)
    piB = Gexp(Bt)
    piC = Gexp(Ht)

    public_input = F(Ct + Kt)

    return {
        "piA": piA, "Aexp": At,
        "piB": piB, "Bexp": Bt,
        "piC": piC, "Hexp": Ht,
        "public": public_input,
        "Zexp": Zt
    }


In [10]:
# Verification

def verify_proof(proof):
    lhs = pairing_exp(proof["Aexp"], proof["Bexp"])
    rhs = F(proof["Hexp"] * proof["Zexp"] + proof["public"])

    print("LHS =", lhs)
    print("RHS =", rhs)
    return lhs == rhs


In [14]:
print("Testing correct x = 5")
proof = gen_proof(5)
print(proof)
print("Valid proof ?", verify_proof(proof))

print("\nTesting wrong x = 9")
proof = gen_proof(9)
print(proof)
print("Valid proof ?", verify_proof(proof))

print("\nTesting wrong x = 10")
proof = gen_proof(10)
print(proof)
print("Valid proof ?", verify_proof(proof))


Testing correct x = 5
{'piA': 717897987691852588770249, 'Aexp': 25, 'piB': 9, 'Bexp': 1, 'piC': 1, 'Hexp': 0, 'public': 25, 'Zexp': 125078977634435462616698735873672346141}
LHS = 25
RHS = 25
Valid proof ? True

Testing wrong x = 9
{'piA': 87976821907457075189336335121614927364, 'Aexp': 81, 'piB': 9, 'Bexp': 1, 'piC': 1, 'Hexp': 0, 'public': 25, 'Zexp': 125078977634435462616698735873672346141}
LHS = 81
RHS = 25
Valid proof ? False

Testing wrong x = 10
{'piA': 10810968933129975378600013865352026249, 'Aexp': 100, 'piB': 9, 'Bexp': 1, 'piC': 1, 'Hexp': 0, 'public': 25, 'Zexp': 125078977634435462616698735873672346141}
LHS = 100
RHS = 25
Valid proof ? False
