In [118]:
import os
import sys
module_path = os.path.abspath(os.path.join('./sage-snark'))
sys.path.insert(0, module_path)

from IPython.core.display import SVG

from spartan.utils import eq_bit_decomp, hadamard_product, hypercube_sum, bit_decomp_dict

In [119]:
from sage.matrix.all import Matrix
from sage.rings.integer import Integer
from sage.rings.polynomial.all import PolynomialRing, Polynomial
from sage.modules.all import vector
from sage.rings.rational_field import QQ
from typing import List, Any

In [120]:
# Sample R1CS Instance. Taken from: https://emirsoyturk.medium.com/hello-arithmetization-55e57c8e5471

AL = [
     [0, 1, 0, 0, 0, 0],
     [0, 0, 0, 1, 0, 0],
     [0, 1, 0, 0, 1, 0],
     [5, 0, 0, 0, 0, 1]
    ]

BL = [
     [0, 1, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0]
    ]

CL = [
     [0, 0, 0, 1, 0, 0],
     [0, 0, 0, 0, 1, 0],
     [0, 0, 0, 0, 0, 1],
     [0, 0, 1, 0, 0, 0]
    ]

WL = [1, 3, 35, 9, 27, 30]

In [121]:
def matrix_multilinearize(mat : Matrix, PolyRing : PolynomialRing = None) -> Polynomial:
    ncols_bitlen = Integer(mat.ncols()).bit_length();
    nrows_bitlen = Integer(mat.nrows()).bit_length()
    row_vars = [f"X{i}" for i in range(nrows_bitlen)]
    col_vars = [f"Y{i}" for i in range(ncols_bitlen)]
    names = row_vars + col_vars
    var_count = ncols_bitlen + nrows_bitlen
    PolyRing = PolyRing or PolynomialRing(mat.base_ring(), var_count, names)
    x_vars = list(PolyRing.gens()[:nrows_bitlen])
    y_vars = list(PolyRing.gens()[nrows_bitlen:])

    poly = PolyRing(0)

    x_lagrange_basis = [eq_bit_decomp(PolyRing, i, x_vars) for i in range(2**nrows_bitlen)]
    y_lagrange_basis = [eq_bit_decomp(PolyRing, j, y_vars) for j in range(2**ncols_bitlen)]

    # print(f"X-Lagrange: {x_lagrange_basis}")
    # print(f"Y-Lagrange: {y_lagrange_basis}")

    for i in range(mat.nrows()):
        for j in range(mat.ncols()):
            poly += mat[i][j] * x_lagrange_basis[i] * y_lagrange_basis[j]

    return poly

m = Matrix(QQ, [[7,2,2],[2,3,4],[3,1,1],[7,4,5],[8,8,4]])
x1 = matrix_multilinearize(m)
x2 = matrix_multilinearize(m, x1.parent())

assert x1 == x2

In [122]:
def vec_multilinearize(vec : List[Any], gens : List[Any]) -> Polynomial:
    assert len(vec) <= 2**len(gens)
    PolyRing = gens[0].parent()
    result = PolyRing(0)

    for (i,v) in enumerate(vec):
        basis = eq_bit_decomp(PolyRing, i, gens)
        result += v*basis
    return result

In [123]:
def compute_y_sum(poly, y_dim):
    gens = poly.parent().gens()
    skip_len = len(gens) - Integer(y_dim).bit_length()
    assert skip_len >= 0

    skip_vars = gens[:skip_len]
    return hypercube_sum(poly, skip_vars)

In [124]:
P = 15*(2**27) + 1;
assert is_prime(P)
Fp = GF(P)

A = Matrix(Fp, AL)
B = Matrix(Fp, BL)
C = Matrix(Fp, CL)
w = vector(Fp, WL);
aw = A*w;
bw = B*w;
cw = C*w;

# print(f"W = {w}")
# print(f"A*W = {aw}")
# print(f"B*W = {bw}")
# print(f"C*W = {cw}")

assert list(cw) == hadamard_product(aw, bw)

In [125]:
Axy = matrix_multilinearize(A);
Bxy = matrix_multilinearize(B, Axy.parent())
Cxy = matrix_multilinearize(C, Axy.parent())

xgens_count = Integer(A.nrows()).bit_length()
xgens = Axy.parent().gens()[:xgens_count];
ygens = Axy.parent().gens()[xgens_count:];

Wy = vec_multilinearize(w, ygens)

Ax = compute_y_sum(Axy*Wy, 6)
Bx = compute_y_sum(Bxy*Wy, 6)
Cx = compute_y_sum(Cxy*Wy, 6)

expected = Ax*Bx - Cx

for i in range(A.nrows()):
    d = bit_decomp_dict(i, xgens)
    e = expected.subs(d)
    assert e == 0


![Spartan Attempt 0](./spartan/Spartan-0.drawio.svg)