In [53]:
import itertools
from functools import reduce
from itertools import product

import operator

K = Fields()

q = 2**61 - 1
Fq = GF(q)
x = polygen(Fq, 'x')
Fq2.<u> = Fq.extension(x^2+1)
log_n = 4
n = 2**log_n
v = vector(Fq, [Fq.random_element() for _ in range(n)])

def get_bit(n, k):
    return (n >> k) & 1

def int_to_bits(n, bit_length=None):
    if bit_length is None:
        bit_length = n.bit_length() if n != 0 else 1
    return [(n >> i) & 1 for i in range(bit_length)]

def multilinear_extension(v, F, var_names=None):
    n = len(v)
    m = n.bit_length() - 1
    if 2 ** m != n:
        raise ValueError("Vector length must be a power of 2.")
    
    if var_names is None:
        var_names = [f'X{i}' for i in range(1, m+1)]
    R = PolynomialRing(F, var_names)
    X = R.gens()
    p = R.zero()
    
    for i in range(n):
        term = F(v[i]) 
        bits = int_to_bits(i,m) 
        
        for j in range(m):
            if bits[j] == 1:
                term *= X[j]
            else:
                term *= (1 - X[j])
        p += term
    return p

v_tilde = multilinear_extension(v, Fq)

for i, t in enumerate(v):
  assert t == v_tilde([j for j in int_to_bits(i, log_n)])


def multilinear_matrix_extension(M, F, var_names=None):
    d, n = M.dimensions()
    total_elements = d * n
    
    m = total_elements.bit_length() - 1
    if 2^m != total_elements:
        raise ValueError("Matrix dimensions must multiply to a power of 2 (d * n = 2^m).")
    
    # Flatten the matrix into a vector (row-major order)
    v = M.list()
    
    if var_names is None:
        var_names = [f'X{i}' for i in range(1, m+1)]
    R = PolynomialRing(F, var_names)
    X = R.gens()
    
    p = R.zero()
    for i in range(total_elements):
        term = F(v[i])  
        bits = int_to_bits(i,m) 
        
        for j in range(m):
            if bits[j] == 1:
                term *= X[j]
            else:
                term *= (1 - X[j])
        
        p += term
    
    return p

M = Matrix(Fq, [[ Fq.random_element() for _ in range(n)] for _ in range(n)])
M_tilde = multilinear_matrix_extension(M, Fq)

for i, row in enumerate(M):
  for j, e in enumerate(row):
    assert M[i][j] == M_tilde([y for y in int_to_bits(j, log_n)]+[x for x in int_to_bits(i, log_n)])

## TENSOR PRODUCT TEST
# Build multi-linear basis polynomials
var_names = [f'X{i}' for i in range(log_n)]
R = PolynomialRing(Fq2, var_names)
X = R.gens()
print(X)
# print([X[i] for i in range(log_n)])
basis = []
for i in range(n):
    bits = [(i >> j) & 1 for j in range(log_n)]
    poly = 1
    for b, x in zip(bits, [X[i] for i in range(log_n)]):
        poly *= x if b else (1 - x)
    basis.append(poly)

# print(basis)

tilde_f_r = sum(v[i] * basis[i] for i in range(n))
v_tilde2 = multilinear_extension(v, Fq2)
for i, t in enumerate(v):
  params = [j for j in int_to_bits(i, log_n)]
  assert tilde_f_r(params) == v_tilde(params)


## TENSOR PRODUCT BY PAPER DEFINITION
## compute at a random point r
r = vector(Fq2, [Fq2.random_element() for _ in range(log_n)])
# r_tilde = [r[0], 1 - r[0]]
# for i in range(1, log_n):
#     r_temp = []
#     for ri in r_tilde:
#         for rj in [r[i], 1 - r[i]]:
#             r_temp.append(ri * rj)
#     r_tilde = r_temp
# assert len(r_tilde) == n
# print(r_tilde)

def compute_ml_extension_tensor(v, r):
    n = len(v)
    assert len(r) == n.bit_length() - 1, "Length of r must match log_n"
    r_tilde = []
    # multilinear_extension and int_to_bits treat r as corresponding to the least significant bit of the index i.
    for i in range(n):
        term_product = Fq2(1)
        bits = int_to_bits(i, log_n)
        
        for j in range(log_n):
            if bits[j] == 1:
                term_product *= r[j]
            else:
                term_product *= (1 - r[j])
        r_tilde.append(term_product)

    return sum(v[i] * r_tilde[i] for i in range(n))

tilde_f_r = compute_ml_extension_tensor(v, r)
r_tuple = tuple(r) 

# Evaluate the multilinear extension at point r
v_tilde_at_r = v_tilde(*r_tuple) # Unpack the tuple into arguments
assert v_tilde_at_r == tilde_f_r


(X0, X1, X2, X3)


In [6]:
# 2.3 Rings and Modules
eta = 81
d = 54
q = 2**61 - 1
Fq = GF(q)
x = polygen(Fq, 'x')
Phi = x ** 54 + x ** 27 + 1
Rq.<u> = Fq.extension(Phi)
kappa = 16
## m = 2 ** 22 # TOO MUCH MEMORY
#m = 2 ** 8 # for testing purposes
#M = random_matrix(Rq, kappa, m)
a = Rq.random_element()

def cf(a):
    return a.list()

def cf_inv(a):
    return Rq(a)

def shift_matrix(Fq, Phi, d):
    F = matrix(Fq, d, d)
    c = cf(Phi)
    for i in range(d):
        for j in range(d):
            if i - 1 == j:
                F[i, j] = 1
            if j == d - 1:
                F[i, j] = -c[i]
    return F

def rot(a, Phi, Fq):
    cf_a = vector(Fq, a if isinstance(a, list) else cf(a))
    d = len(cf_a)
    F = shift_matrix(Fq, Phi, d)

    columns = [cf_a]
    F_power = F  # Start with F^1
    for _ in range(1, d):
        current = F_power * cf_a
        columns.append(current)
        F_power = F_power * F  # Compute next power incrementally
    rows = list(zip(*columns))
    M = matrix(Fq, len(cf_a), len(cf_a), sum([list(row) for row in rows], []))

    return M    # print([(F**i) * cf_a for i in range(1, d)])
    
# print(shift_matrix(Fq, Phi, d))
rot_a = rot(a, Phi, Fq)
b = Rq.random_element()
assert rot_a * vector(Fq, cf(b)) == vector(Fq,cf(a*b))

# F7 = GF(7)

# y = polygen(F7, 'y')
# Phi_test = y ** 3 + 1
# R7.<u> = F7.extension(Phi_test)

# a = 4*u**2 + 2*u + 3
# rot_a = rot(a, Phi_test, F7)
# print(rot_a)
# b = u**2 + 1
# print("cf(a*b)", cf(a*b))
# print("rot_a * vector(F7, cf(b)", rot_a * vector(F7, cf(b)))
# assert rot_a * vector(F7, cf(b)) == vector(F7, cf(a*b))

# 3.2 Neo’s solution, part-1: A matrix commitment scheme

def embed_Fq_to_Rq(z_i, d):
    """
    Embed a field element z_i ∈ Fq into Rq by using its bits
    as coefficients of a polynomial in Rq.
    """
    z_int = int(z_i)
    bits = [((z_int >> i) & 1) for i in range(d)]
    poly = sum(Fq(bits[i]) * u^i for i in range(len(bits)))
    return Rq(poly)

z_i = Fq.random_element()
z_i_embedded = embed_Fq_to_Rq(z_i, d)

print(f"Original Fq element: {z_i}")
print(f"Embedded Rq element: {z_i_embedded}")


Original Fq element: 1378334593571236180
Embedded Rq element: u^53 + u^47 + u^46 + u^44 + u^42 + u^33 + u^31 + u^28 + u^26 + u^25 + u^24 + u^22 + u^21 + u^20 + u^17 + u^13 + u^11 + u^8 + u^6 + u^4 + u^2


In [7]:
from random import randint


def decomp_b(z, d=None):
    F = z.base_ring()
    m = len(z)
    b = 2

    # Determine degree if not provided
    if d is None:
        # Find smallest d such that b^d > max(z)
        max_val = max([abs(int(zi)) for zi in z])
        d = 1
        while b**d <= max_val:
            d += 1

    # Initialize the result matrix
    # Z = matrix(F, d, m)

    # Perform b-ary decomposition
    cols = []
    for j in range(m):
        value = int(z[j])
        bits = [((value >> i) & 1) for i in range(d)]
        cols.append(bits)
    rows = list(zip(*cols))
    Z = matrix(F, d, m, sum([list(row) for row in rows], []))

    return Z


z_1 = Fq(123)
z_2 = Fq(567)
z_3 = Fq(890)
m = 3
z_1_embedded = embed_Fq_to_Rq(z_1, d)
z_2_embedded = embed_Fq_to_Rq(z_2, d)
z_3_embedded = embed_Fq_to_Rq(z_3, d)
z_vec = vector(Fq, [z_1, z_2, z_3])
print(z_vec)
Z = decomp_b(z_vec, d)
print("len(cf(z_1_embedded))", len(cf(z_1_embedded)))
print("len([Z[i, 0] for i in range(d)])", len([Z[i, 0] for i in range(d)]))
assert cf(z_1_embedded) == [Z[i, 0] for i in range(d)]
assert cf(z_2_embedded) == [Z[i, 1] for i in range(d)]
assert cf(z_3_embedded) == [Z[i, 2] for i in range(d)]

print("check sum", [Fq(sum([2**i * Z[i, col] for i in range(d)])) for col in range(m)])
assert list(z_vec) == [
    Fq(sum([2**i * Z[i, col] for i in range(d)])) for col in range(m)
]


def split_b(Z, k=None):
    F = Z.base_ring()
    d, m = Z.dimensions()
    b = 2

    # Determine depth if not provided
    if k is None:
        # Find smallest k such that b^k > max(Z)
        max_val = max([abs(int(Z[i, j])) for i in range(d) for j in range(m)])
        k = 1
        while b**k <= max_val:
            k += 1
    # Initialize the result matrices
    result = [matrix(F, d, m) for _ in range(k)]

    # Perform b-ary decomposition
    for i in range(d):
        for j in range(m):
            value = int(Z[i, j])
            bits = [((value >> i) & 1) for i in range(k)]
            for l in range(k):
                result[l][i, j] = F(bits[l])
    return result


max_val = 2**20  # low norm
test = matrix(Fq, d, m)
for i in range(d):
    for j in range(m):
        test[i, j] = Fq(randint(0, max_val - 1))
Z = split_b(test)
Z_restored = matrix(Fq, d, m)
for i, z in enumerate(Z):
    Z_restored += (2**i) * z

print(Z_restored)
assert test == Z_restored

(123, 567, 890)
len(cf(z_1_embedded)) 54
len([Z[i, 0] for i in range(d)]) 54
check sum [123, 567, 890]
[1047722  150796  211233]
[ 485043  643023   25831]
[ 339394  925526  662704]
[ 783104  773458 1011578]
[ 228050   87501   88773]
[ 188043  125530  232371]
[ 497316  129817  489122]
[ 951648  612855  773379]
[ 381264  503594  190942]
[ 637890  242907  414476]
[ 723484   48663  721572]
[ 869509  798124  749497]
[ 459476  726614  543995]
[ 146510  207434   25826]
[  97277 1008155  957948]
[1020778  829432  539724]
[ 351266   68112  282399]
[ 702936  857455  340792]
[ 531516  120035   57360]
[ 997933  879425 1005977]
[ 226323  769729  150862]
[  53641  677472  209401]
[ 865588  710635   97320]
[ 148095  201818  537074]
[ 202753  932330  364065]
[ 180784  650909   74555]
[ 332316  141858  249731]
[  84863  431254  218257]
[ 741441   21909  199260]
[ 923714  109021  943631]
[ 154807 1018065  993999]
[1040776     180 1010799]
[ 447093  758092  311498]
[ 882433  349964  743087]
[ 346188  281

In [21]:
import json

with open("../mimc_r1cs.json", "r") as f:
    r1cs = json.load(f)

# "Without loss of generality, assume that m = n and n, d · n are both powers of two"

z = vector(Fq, [int(w, base=16) for w in r1cs["witness"]])
print(z)
m = len(z)
if m & (m - 1) != 0:
    next_pow2 = 1 << int(m - 1).bit_length()
    print("next_pow2", next_pow2)
    z = vector(Fq, list(z) + [0] * (next_pow2 - m))
    m = len(z)
    print("m", m)


Z = decomp_b(z, d)
print(Z.dimensions()[0])
assert d == Z.dimensions()[0]
m = Z.dimensions()[1]
print("Z.dimensions()", Z.dimensions())
for i in range(Z.nrows()):
    for j in range(Z.ncols()):
        assert Z[i, j] != 0 or Z[i, j] != 1

# In particular, each column of a low-norm matrix gets mapped as the coefficients of a single ring element.
z_prime = vector(Rq, [cf_inv(list(Z.column(j))) for j in range(Z.ncols())])
print("z_prime", z_prime)
print(type(z_prime))
pp = random_matrix(Rq, kappa, m)
print("pp.dimensions()", pp.dimensions())
print("type(pp)", type(pp))
print("z_prime", type(z_prime))
c = pp * z_prime

rows = []
for x in c:
    rows.append(cf(x))
C_commitment = matrix(Fq, kappa, d, sum([list(row) for row in rows], [])).transpose()
print("C_commitment.dimensions()", C_commitment.dimensions())

(1, 16962, 291, 297666009, 5135631653277, 1808519078093316047, 984333915475778258, 1685566623314468743, 1642652937557983344, 1527641843700880957, 185943044122613080, 483079338772295503, 933456805013540041, 2197720272769521675, 1617993589254506848, 1761337032769114693, 375577925966181255, 1046385115991257376, 176654108289591058, 1372277634735931727, 799399426023973936, 226885281797858965, 728260780311052279, 893721199336006157, 55989593605212516, 438213488586031957, 334767213486906731, 953996857315648675, 727299636715805004, 1175995165663612982, 825778734511341010, 385884965558464826, 667143447155798967, 177614530649800786, 717186420647209007, 909848238614490367, 998442717425403927, 2243851849076574385, 384572664689713169, 179963909640428448, 549303530200489456, 716229922492233319, 50591820186508235, 110942631861744526, 252351407643527483, 738163373585794680, 1840812092701698002, 339780363503755214, 1464484043900980110, 710260083635686570, 1749002717549668216, 1214100501955449283, 10956

In [9]:
# CCS Reduction (ΠCCS).
# This reduction takes a new MCS(b,L) instance (your new computation) and k-1 existing ME(b,L) claims, reducing them to k new ME(b,L) claims
k = 12
T = 216
### Definition 17 (Matrix constraint system relation).


# "Without loss of generality, assume that m = n and n, d · m are both powers of two"
def extend_matrix(M, size):
    rows, cols = M.dimensions()

    # Check that size is sufficient for the matrix
    if size < rows or size < cols:
        raise ValueError(
            f"Target size {size} is too small for matrix with dimensions {rows}x{cols}"
        )

    # If the matrix already has the target dimensions, return it
    if rows == cols == size:
        return M

    # Create a new square matrix of zeros with the target size
    result = matrix(M.base_ring(), size, size)

    # Copy the original matrix into the top-left corner
    for i in range(rows):
        for j in range(cols):
            result[i, j] = M[i, j]

    return result


# check r1cs constraint
A = matrix(Fq, r1cs["A"])
B = matrix(Fq, r1cs["B"])
C = matrix(Fq, r1cs["C"])
x = vector(Fq, [int(w, base=16) for w in r1cs["witness"]])
print([int(w, base=16) for w in r1cs["witness"]])
print(vector(Fq, [int(w, base=16) for w in r1cs["witness"]]))
M0x = A * x
M1x = B * x
M2x = C * x
print(M0x)
print(M1x)
print(M2x)
hadamard = vector(Fq, [u * v for u, v in zip(M0x, M1x)])
print(hadamard)
assert hadamard == M2x, f"R1CS failed: {hadamard} != {M2x}"
print("R1CS constraints are satisfied!")

M0 = extend_matrix(A, m)
M1 = extend_matrix(B, m)
M2 = extend_matrix(C, m)


print("M0.dimensions()", M0.dimensions())
print("M1.dimensions()", M1.dimensions())
print("M2.dimensions()", M2.dimensions())
print("len(z)", len(z))

M0z = M0 * z
M1z = M1 * z
M2z = M2 * z
print("len(M0z)", len(M0z))


M0z_tilde = multilinear_extension(M0z, Fq2)
M1z_tilde = multilinear_extension(M1z, Fq2)
M2z_tilde = multilinear_extension(M2z, Fq2)
# Check if R1CS constraint holds: (A·z) ⊙ (B·z) = C·z
hadamard_product = vector(Fq, [M0z[i] * M1z[i] for i in range(len(M0z))])
print(hadamard_product)
print(M2z)
# Check if the constraint is satisfied
constraint_satisfied = hadamard_product == M2z
print("constraint_satisfied", constraint_satisfied)
F = M0z_tilde * M1z_tilde - M2z_tilde

bit_combinations = list(product([0, 1], repeat=7))

# Check that F is vanishing over the hypercube
results = []
for bits in bit_combinations:
    assert F(*bits) == 0

[1, 16962, 291, 297666009, 5135631653277, 1808519078093316047, 984333915475778258, 1685566623314468743, 1642652937557983344, 1527641843700880957, 185943044122613080, 483079338772295503, 933456805013540041, 2197720272769521675, 1617993589254506848, 1761337032769114693, 375577925966181255, 1046385115991257376, 176654108289591058, 1372277634735931727, 799399426023973936, 226885281797858965, 728260780311052279, 893721199336006157, 55989593605212516, 438213488586031957, 334767213486906731, 953996857315648675, 727299636715805004, 1175995165663612982, 825778734511341010, 385884965558464826, 667143447155798967, 177614530649800786, 717186420647209007, 909848238614490367, 998442717425403927, 2243851849076574385, 384572664689713169, 179963909640428448, 549303530200489456, 716229922492233319, 50591820186508235, 110942631861744526, 252351407643527483, 738163373585794680, 1840812092701698002, 339780363503755214, 1464484043900980110, 710260083635686570, 1749002717549668216, 1214100501955449283, 10956

In [33]:
# 3.3 Neo's solution, part-2: linear homomorphism for folding multilinear evaluation claims
# Test Procedure
# 1. **Start with a known correct witness $z$** for your R1CS.
# 2. **Compute $M_j z \in \mathbb{F}^m$** for your constraint matrix $M_j \in \mathbb{F}^{m \times n}$.
# 3. **Evaluate the multilinear extension** of $M_j z$ at point $r$:

#    $$
#    y_1 := \tilde{M_j z}(r) = \langle M_j z, \hat{r} \rangle
#    $$

# 4. **Decompose $z \rightarrow Z \in \mathbb{F}^{d \times n}$** using `Decomp_b(z)`.

# 5. **Compute** $Z M_j^\top \in \mathbb{F}^{d \times m}$

# 6. **Multiply with $\hat{r}$**:

#    $$
#    y_2 := (Z M_j^\top) \cdot \hat{r}
#    $$

# 7. **Recombine bits** (i.e., use the base $b$ powers):

#    $$
#    \boxed{
#    y := \sum_{i=0}^{d-1} b^i \cdot y_2^{(i)} \quad \text{(dot product per slice)}
#    }
#    $$

# 8. Then:

#    $$
#    \boxed{ y = y_1 }
#    \quad \text{(This is your invariant)}
#    $$
n = M0.ncols()
log_n = n.bit_length() - 1
r = vector(Fq2, [Fq2.random_element() for _ in range(log_n)])
r_tilde = [r[0], 1 - r[0]]
for i in range(1, log_n):
    r_temp = []
    for ri in r_tilde:
        for rj in [r[i], 1 - r[i]]:
            r_temp.append(ri * rj)
    r_tilde = r_temp
assert len(r_tilde) == n
# bit_combinations = list(product([0, 1], repeat=7))
# for r in bit_combinations:
#     y1 = direct_evaluation = M0z_tilde(r)
#     y2 = Z * (M0.transpose()) * vector(r)

In [None]:
## Definition 18 (Matrix evaluation relation).


class S_R1CS:
    def __init__(self, Fq2, A, B, C):
        self.M = [A, B, C]
        self.M_tilde = [multilinear_extension(m, Fq2) for m in self.M]
        self.f = M_tilde[0] * M_tilde[1] - M_tilde[2]


class ME:

    def __init__(self, Fq, Rq, s, d, kappa, pp=None):
        self.s = s
        self.d = d
        self.kappa = kappa
        self.Fq = Fq
        self.Rq = Rq
        self.pp = pp

    def commit(self, z, m_in=None):

        self.z = z

        # "Without loss of generality, assume that m = n and n, d · n are both powers of two"
        m = len(z)
        if m & (m - 1) != 0:
            next_pow2 = 1 << int(m - 1).bit_length()
            z = vector(self.Fq, list(z) + [0] * (next_pow2 - m))
            m = len(z)
        self.m = m
        if not m_in:
            self.m_in = m

        if not self.pp:
            self.pp = random_matrix(self.Rq, self.kappa, self.m)

        self.Z = decomp_b(z, d=self.d)
        self.X = self.Z[:, : self.m_in]
        assert all(self.Z.column(i) == self.X.column(i) for i in range(self.m_in))
        assert self.d == self.Z.dimensions()[0]

        for i in range(self.Z.nrows()):
            for j in range(self.Z.ncols()):
                assert self.Z[i, j] == 0 or self.Z[i, j] == 1

        # Each column of a low-norm matrix gets mapped as the coefficients of a single ring element.
        z_prime = vector(
            self.Rq, [cf_inv(list(self.Z.column(j))) for j in range(self.Z.ncols())]
        )
        print("z_prime", z_prime)

        self.c = self.pp * z_prime
        rows = []
        for x in self.c:
            rows.append(cf(x))
        self.C = matrix(
            self.Fq, self.kappa, self.d, sum([list(row) for row in rows], [])
        ).transpose()


me = ME(Fq, Rq, None, d, kappa, pp)
me.commit(z)
assert C_commitment == me.C