In [None]:
from py_ecc.bn128 import G1, G2, multiply, pairing, add, neg, eq, curve_order
from galois import Poly, GF
import random

In [2]:
field = GF(curve_order) # Define the finite field based on the curve order

In [None]:
class KZGCommitment:
    # Setup a KZG Commitment Scheme with a trusted setup phase.     
    def __init__(self, max_degree, field):
        self.max_degree = max_degree
        self.field = field
        self.srs = random.randint(1, curve_order - 1)
        self.srs_g1 = [multiply(G1, self.srs**i) for i in range(max_degree + 1)]
        self.srs_g2 = [G2, multiply(G2, self.srs)]
        #print(self.srs_g1)
        #print(self.srs_g2)
        
    # Commit to a polynomial using the SRS.  Returns a point in G1.  The commitment is a linear combination of the SRS points weighted by the coefficients of the polynomial.  If the coefficient is zero, we skip that term.  We use the multiply function from py_ecc.bn128 to compute the scalar multiplication of the SRS point with the coefficient.  We use the add function from py_ecc.bn128 to add the commitments for each term.  The commitment is returned in G1.  If the polynomial has degree greater than max_degree, we raise a ValueError. 
    def commit(self, poly: Poly):
        coeffs = poly.coefficients()       
        commitment = None
        # Compute the commitment to the polynomial using the SRS
        # Reverse the coefficients as they are stored from highest degree to lowest in Poly object. 
        for coeff, s in zip(coeffs[::-1], self.srs_g1):
            if coeff == 0:
                continue
            term_commitment = multiply(s, int(coeff))
            commitment = add(commitment, term_commitment) if commitment else term_commitment
        # Return the commitment in G1
        # srs = [[1]_1, [s]_1, ..., [s^n]_1]
        # [f(s)]_1 = a_0*[1]_1 + a_1*[s]_1 + ... + a_n*[s^n]_1
        return commitment
    
    def open(self, poly:Poly, z_point):
        f_z = poly(z_point)
        numberator = poly - field(f_z)        
        denominator = Poly([1, -z_point], field=field)
        quotient = numberator //  denominator
        proof = self.commit(quotient)
        return f_z, proof
    
    def verify(self, commitment, z_point, poly:Poly):
        # q(s) = [f(s) - f(z)] / (s - z) in G1
        # [f(s)]_1 - [f(z)]_1 = [q(s)]_1 * (s - z)
        #e([f(s)]_1 - [f(z)]_1, G2) == e([q(s)]_1, [s]_2 - [z]_2)
        f_z, proof = self.open(poly, z_point)
        f_z_g1 = self.commit(Poly([f_z], field=field))        
        s = self.srs
        lhs = add(commitment, neg(f_z_g1))
        lhs_pairing = pairing(self.srs_g2[0], lhs)
        rhs_pairing = pairing(
            add(
                self.srs_g2[1],
                multiply(neg(self.srs_g2[0]), int(z_point))
            ),
            proof
        )
        # Debugging line to see the pairings for comparison. Can be removed in production code.
        #print(f"lhs_pairing: {lhs_pairing}") 
        #print(f"rhs_pairing: {rhs_pairing}") 
        return lhs_pairing == rhs_pairing

In [173]:
kzg = KZGCommitment(3, field)
poly = Poly([1, 2, 3, 0], field=field) # f(x) = x^3 + 2x^2 + 3x + 0
print(poly.coefficients())

[1 2 3 0]


In [174]:
commitment = kzg.commit(poly)
print(commitment)

(9953335635914614939829934415496832188241572235654832090561615548078357654700, 17206522999725458237328610884307598504359551828740947891776895448968538333577)


In [175]:
z_point = field.Random(1)[0]
f_z, proof = kzg.open(poly, z_point)
print(f"f(z) = {f_z}")
print(proof)

f(z) = 15660262963832404716596791646158523948812039572817471521199125877306333397077
(7078526594979175394119242622319255889785551335103852469310179285143485300647, 7205735234049035568519808979835918775081647893054951404282672896406027511544)


In [176]:
kzg.verify(commitment, z_point, poly)

lhs_pairing: (17940726559316152136534235094918619837659637812623765132091811630288608823725, 14532728522988339825372655117465070763600287613277451252610271669628604621895, 21750323611109516835134662252215185355828107691959271561310235113325864279674, 8993021742618597535838162355330249368881480955501870426602755490251716064087, 5923848068248490578184152416742075875201949202674107976583872588642462592295, 5351716301210347756950399198621338277949444063120802284132636617406280799565, 18753758207359458361106877844129195357729447992527245916269151106934417808464, 1460106484377835192317078647277700548086969733587011786488201252088323141894, 462606270038137723191687166701332657554580472112838359195529429572919915087, 8014420729298777074246438256114908870692403611713227623431564701703189355099, 1831253566955229352198209913268943194401786231369425401843794404050798253496, 1477998165994918686176343221846286474189090168578955737444751540303209919772)
rhs_pairing: (179407265593161521365342350949186

True

In [183]:
import json
from collections import defaultdict

with open("r1cs.json") as f:
    r1cs_data = json.load(f)
    
with open("witness.json") as f:
    witness_data = json.load(f)

In [None]:
import importlib
importlib.reload(logging)
#logging.basicConfig(level=logging.DEBUG)  # 전역 설정
r1cs_json_path = 'r1cs.json'
witness_json_path = 'witness.json'

plonkIoP = PlonkIOP(r1cs_json_path, witness_json_path, field)
plonkIoP.verify_R1CS()
plonkIoP.permutation_check()

Omega elements for z(x): [GF(1,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617), GF(4407920970296243842393367215006156084916469457145843978461,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617), GF(21888242871839275217838484774961031246154997185409878258781734729429964517155,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617)]
Omega elements for T(x): [GF(1,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617), GF(7890059333988994465574740005840865221433745984419803513342428278253292184207,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617), GF(21753035119881904180964963008150682033897938407632563941302855815390755029584,
   order=21888242871839275222246405745257275088548364400416034343698204186575808495617), GF(4833374017787122887595405489557386282696242896932280905628839833224856877419,
   order=2188

False

In [523]:
import json
from collections import defaultdict

def extract_copy_cycles(r1cs_json_path, sym_path):
    with open(r1cs_json_path) as f:
        r1cs_data = json.load(f)
    with open(sym_path) as f:
        sym_lines = f.readlines()

    # wire_idx -> label 매핑
    wire_labels = {}
    for line in sym_lines:
        parts = line.strip().split(",")
        if len(parts) >= 4:
            wire_id = int(parts[1])
            label = parts[3]
            wire_labels[wire_id] = label

    appearances = defaultdict(list)  # wire 인덱스가 constraint 어디에 등장하는지 저장
    constraints = r1cs_data["constraints"]

    for i, c in enumerate(constraints):
        A_terms, B_terms, C_terms = c
        for wire_idx in list(A_terms.keys()) + list(B_terms.keys()) + list(C_terms.keys()):
            appearances[int(wire_idx)].append(i)

    # 복사 cycle 찾기: 두 번 이상 나타나는 wire는 cycle!
    cycles = {wire: cons for wire, cons in appearances.items() if len(cons) > 1}

    print("=== Copy Constraints (Permutation Cycles) ===")
    for wire, cons_list in cycles.items():
        label = wire_labels.get(wire, f"wire_{wire}")
        print(f"Wire {wire} ({label}) appears in constraints: {cons_list}  -> Permutation cycle")

    return cycles

# 사용 예시:
r1cs_json_path = "r1cs.json"
sym_path = "circuit.sym"
extract_copy_cycles(r1cs_json_path, sym_path)

=== Copy Constraints (Permutation Cycles) ===
Wire 1 (main.c7) appears in constraints: [0, 1]  -> Permutation cycle
Wire 6 (main.c3) appears in constraints: [1, 2]  -> Permutation cycle
Wire 7 (main.c4) appears in constraints: [1, 2]  -> Permutation cycle


{1: [0, 1], 6: [1, 2], 7: [1, 2]}

In [557]:
import json
import logging
from galois import lagrange_poly
from collections import defaultdict
import random

class PlonkIOP:
    def __init__(self, r1cs_json_path, witness_json_path, sym_path, field):
        self.logger = logging.getLogger("PlonkIOP")
        self.logger.setLevel(logging.DEBUG)

        self.r1cs_json_path = r1cs_json_path
        self.witness_json_path = witness_json_path
        self.sym_path = sym_path

        self.field = field

        with open(self.r1cs_json_path) as f:
            r1cs_data = json.load(f)

        with open(self.witness_json_path) as f:
            witness_data = json.load(f)

        self.n_constraints = r1cs_data["nConstraints"]
        self.constraints = r1cs_data["constraints"]

        self.witness = [self.field(int(x)) for x in witness_data['witness']]

        self.a_vals = [self.field(0)] * self.n_constraints
        self.b_vals = [self.field(0)] * self.n_constraints
        self.c_vals = [self.field(0)] * self.n_constraints

        for i, c in enumerate(self.constraints):
            A_terms, B_terms, C_terms = c

            a_wire = int(list(A_terms.keys())[0])
            b_wire = int(list(B_terms.keys())[0])
            c_wire = int(list(C_terms.keys())[0])

            self.a_vals[i] = self.witness[a_wire]
            self.b_vals[i] = self.witness[b_wire]
            self.c_vals[i] = self.witness[c_wire]
            
        # permutation_check 내부나 그 전에
        omega = self._nth_root_of_unity(self.n_constraints)
        domain_points = [omega**i for i in range(self.n_constraints)]

        self.a_poly = galois.lagrange_poly(field(domain_points), field(self.a_vals))
        self.b_poly = galois.lagrange_poly(field(domain_points), field(self.b_vals))
        self.c_poly = galois.lagrange_poly(field(domain_points), field(self.c_vals))
        


    def _nth_root_of_unity(self, order):
        g = self.field.primitive_element
        return g ** ((self.field.order - 1) // order)

    def build_sigma_cycles_position_based(self):
        constraints = self.constraints
        n_positions = 3 * self.n_constraints

        wire_positions = defaultdict(list)  # wire_id -> [position indices]

        for i, constraint in enumerate(constraints):
            A_terms, B_terms, C_terms = constraint
            a_wire = int(list(A_terms.keys())[0])
            b_wire = int(list(B_terms.keys())[0])
            c_wire = int(list(C_terms.keys())[0])

            wire_positions[a_wire].append(i)  # A column
            wire_positions[b_wire].append(i + self.n_constraints)  # B column
            wire_positions[c_wire].append(i + 2 * self.n_constraints)  # C column

        sigma = list(range(n_positions))

        for wire, positions in wire_positions.items():
            if len(positions) > 1:
                for idx, pos in enumerate(positions):
                    next_pos = positions[(idx + 1) % len(positions)]
                    sigma[pos] = next_pos

        print("=== Sigma (position based) ===")
        for wire, positions in wire_positions.items():
            print(f"Wire {wire} -> positions {positions}")
        print(f"sigma: {sigma}")

        return sigma

    def permutation_check(self):
        omega = self._nth_root_of_unity(self.n_constraints)
        omega_elements = [omega ** i for i in range(self.n_constraints)]

        sigma = self.build_sigma_cycles_position_based()

        all_wires = self.a_vals + self.b_vals + self.c_vals  # length = 3 * n_constraints

        beta = self.field(random.randint(1, self.field.order - 1))
        gamma = self.field(random.randint(1, self.field.order - 1))

        z = [self.field(1)]

        for i in range(3 * self.n_constraints):
            # 도메인 인덱스: position i -> omega^(i mod n_constraints)
            k_i = omega_elements[i % self.n_constraints]

            left = all_wires[i] + beta * k_i + gamma

            sigma_i = sigma[i]
            right = all_wires[i] + beta * omega_elements[sigma_i % self.n_constraints] + gamma

            ratio = left / right
            z_next = z[-1] * ratio
            z.append(z_next)

        if z[0] != z[-1]:
            print(f"❌ Permutation check failed: z[last]={z[-1]} != z[0]={z[0]}")
            return False

        # ✅ 여기서 z_poly 보간 후 저장
        omega_points = [self._nth_root_of_unity(3 * self.n_constraints) ** i for i in range(3 * self.n_constraints)]
        z_values = z[:-1]
        self.z_poly = galois.lagrange_poly(field(omega_points), field(z_values))
        print("✅ Permutation check passed!")
        
        return True


In [561]:
r1cs_json_path = "r1cs.json"
witness_json_path = "witness.json"
sym_path = "circuit.sym"

# Plonk IOP 초기화 및 값 준비
plonk = PlonkIOP(r1cs_json_path, witness_json_path, sym_path, field)
plonk.permutation_check()

=== Sigma (position based) ===
Wire 4 -> positions [0]
Wire 5 -> positions [3]
Wire 1 -> positions [6, 4]
Wire 6 -> positions [1, 2]
Wire 2 -> positions [7]
Wire 8 -> positions [5]
Wire 3 -> positions [8]
sigma: [0, 2, 1, 3, 6, 5, 4, 7, 8]
✅ Permutation check passed!


True

In [562]:
a_poly = plonk.a_poly
b_poly = plonk.b_poly
c_poly = plonk.c_poly

# (임시) z(x), t(x) 다항식 예시 생성
z_poly = plonk.z_poly
t_poly = Poly([3, 1, 5, 2], field=field)  # quotient polynomial 예시

In [563]:
kzg = KZGCommitment(max_degree=8, field=field)

# 커밋
A_commit = kzg.commit(a_poly)
B_commit = kzg.commit(b_poly)
C_commit = kzg.commit(c_poly)
Z_commit = kzg.commit(z_poly)
T_commit = kzg.commit(t_poly)

print("A_commit:", A_commit)
print("B_commit:", B_commit)
print("C_commit:", C_commit)
print("Z_commit:", Z_commit)
print("T_commit:", T_commit)

A_commit: (5910183705316751616995749724008520562521578563449927414090424708688402025652, 4504485863993189657862934273756548325484781779293279521271319225348212381585)
B_commit: (491067480121328607198909758153803191523859306715026407378771205402105980909, 17933938912707421612816376829822084521712755707182886911746068383613659021963)
C_commit: (997075362124060219481453790721810560750741315913135409048576040434287234865, 21497254753949712894924453717195090957458165701151629910036267031614847844822)
Z_commit: (11247530580126195993228939701689815266224513039909380554751059641383213475691, 16276561646379957016155221652343598932590363860406403784315535909152333021240)
T_commit: (2589649451993758842503236463147753290543521282660313931707285880761333187940, 15534155914957630747453376943612282615429422729511118280111756391622539814016)


In [564]:
# opening 검증
z_eval_point = field(10)  # 검증 시 임의 z
f_z, proof = kzg.open(a_poly, z_eval_point)
print(f"a({z_eval_point}) = {f_z}, proof:", proof)

valid = kzg.verify(A_commit, z_eval_point, a_poly)
print("KZG Opening Verification (a_poly):", "✅" if valid else "❌")

a(10) = 75, proof: (8258549238801558924241141376363426934833709269883464302700379019768292397731, 14697048764916141786618311567587382009783153601658441108286867704736960756925)
lhs_pairing: (20329691272564511523752731035173542542607208157777175497202429142563446487704, 5082048087934796953914443073661354727846184653297341122789884930985196371768, 15919773038250229745616625809260371114312770809553191710459754423436137179460, 16905944932899247524585145959299162920066248253911486684734721651554841819015, 9535213782531931961061550421848942398722769482902292600656452337019856490224, 10023837438745168664876764878708458356140673377246135468160190501480657024421, 11521463981965687886253223426547887883685265440508689808060384119876010327725, 6912704858511725265204886585160841508708140529034448462684256871588014353137, 6467194555916663594897590116134088979656329805894628598752915633321887998402, 16390418457850157753033595149719674986501173050549752185477655094978944596021, 13344726666741443988571