In [1]:
import itertools
from functools import reduce
from itertools import product

import operator

K = Fields()

q = 2**61 - 1
Fq = GF(q)
x = polygen(Fq, 'x')
Fq2.<u> = Fq.extension(x^2+1)
log_n = 4
n = 2**log_n
v = vector(Fq, [Fq.random_element() for _ in range(n)])

def get_bit(n, k):
    return (n >> k) & 1

def int_to_bits(n, bit_length=None):
    if bit_length is None:
        bit_length = n.bit_length() if n != 0 else 1
    return [(n >> i) & 1 for i in range(bit_length)]

def multilinear_extension(v, F, var_names=None):
    n = len(v)
    m = n.bit_length() - 1
    if 2 ** m != n:
        raise ValueError("Vector length must be a power of 2.")
    
    if var_names is None:
        var_names = [f'X{i}' for i in range(1, m+1)]
    R = PolynomialRing(F, var_names)
    X = R.gens()
    p = R.zero()
    
    for i in range(n):
        term = F(v[i]) 
        bits = int_to_bits(i,m) 
        
        for j in range(m):
            if bits[j] == 1:
                term *= X[j]
            else:
                term *= (1 - X[j])
        p += term
    return p

v_tilde = multilinear_extension(v, Fq)

for i, t in enumerate(v):
  assert t == v_tilde([j for j in int_to_bits(i, log_n)])


def multilinear_matrix_extension(M, F, var_names=None):
    d, n = M.dimensions()
    total_elements = d * n
    
    m = total_elements.bit_length() - 1
    if 2^m != total_elements:
        raise ValueError("Matrix dimensions must multiply to a power of 2 (d * n = 2^m).")
    
    # Flatten the matrix into a vector (row-major order)
    v = M.list()
    
    if var_names is None:
        var_names = [f'X{i}' for i in range(1, m+1)]
    R = PolynomialRing(F, var_names)
    X = R.gens()
    
    p = R.zero()
    for i in range(total_elements):
        term = F(v[i])  
        bits = int_to_bits(i,m) 
        
        for j in range(m):
            if bits[j] == 1:
                term *= X[j]
            else:
                term *= (1 - X[j])
        
        p += term
    
    return p

M = Matrix(Fq, [[ Fq.random_element() for _ in range(n)] for _ in range(n)])
M_tilde = multilinear_matrix_extension(M, Fq)

for i, row in enumerate(M):
  for j, e in enumerate(row):
    assert M[i][j] == M_tilde([y for y in int_to_bits(j, log_n)]+[x for x in int_to_bits(i, log_n)])


# Build multi-linear basis polynomials
var_names = [f'X{i}' for i in range(log_n)]
R = PolynomialRing(Fq2, var_names)
X = R.gens()
print(X)
print([X[i] for i in range(log_n)])
basis = []
for i in range(n):
    bits = [(i >> j) & 1 for j in range(log_n)]
    poly = 1
    for b, x in zip(bits, [X[i] for i in range(log_n)]):
        poly *= x if b else (1 - x)
    basis.append(poly)

print(basis)

tilde_f_r = sum(v[i] * basis[i] for i in range(n))
v_tilde2 = multilinear_extension(v, Fq2)
for i, t in enumerate(v):
  params = [j for j in int_to_bits(i, log_n)]
  assert tilde_f_r(params) == v_tilde(params)


(X0, X1, X2, X3)
[X0, X1, X2, X3]
[X0*X1*X2*X3 - X0*X1*X2 - X0*X1*X3 - X0*X2*X3 - X1*X2*X3 + X0*X1 + X0*X2 + X1*X2 + X0*X3 + X1*X3 + X2*X3 - X0 - X1 - X2 - X3 + 1, -X0*X1*X2*X3 + X0*X1*X2 + X0*X1*X3 + X0*X2*X3 - X0*X1 - X0*X2 - X0*X3 + X0, -X0*X1*X2*X3 + X0*X1*X2 + X0*X1*X3 + X1*X2*X3 - X0*X1 - X1*X2 - X1*X3 + X1, X0*X1*X2*X3 - X0*X1*X2 - X0*X1*X3 + X0*X1, -X0*X1*X2*X3 + X0*X1*X2 + X0*X2*X3 + X1*X2*X3 - X0*X2 - X1*X2 - X2*X3 + X2, X0*X1*X2*X3 - X0*X1*X2 - X0*X2*X3 + X0*X2, X0*X1*X2*X3 - X0*X1*X2 - X1*X2*X3 + X1*X2, -X0*X1*X2*X3 + X0*X1*X2, -X0*X1*X2*X3 + X0*X1*X3 + X0*X2*X3 + X1*X2*X3 - X0*X3 - X1*X3 - X2*X3 + X3, X0*X1*X2*X3 - X0*X1*X3 - X0*X2*X3 + X0*X3, X0*X1*X2*X3 - X0*X1*X3 - X1*X2*X3 + X1*X3, -X0*X1*X2*X3 + X0*X1*X3, X0*X1*X2*X3 - X0*X2*X3 - X1*X2*X3 + X2*X3, -X0*X1*X2*X3 + X0*X2*X3, -X0*X1*X2*X3 + X1*X2*X3, X0*X1*X2*X3]


In [56]:
# 2.3 Rings and Modules
eta = 81
d = 54
q = 2**61 - 1
Fq = GF(q)
x = polygen(Fq, 'x')
Phi = x ** 54 + x ** 27 + 1
Rq.<u> = Fq.extension(Phi)
kappa = 16
## m = 2 ** 22 # TOO MUCH MEMORY
#m = 2 ** 8 # for testing purposes
#M = random_matrix(Rq, kappa, m)
a = Rq.random_element()

def cf(a):
    return a.list()

def cf_inv(a):
    return Rq(a)

def shift_matrix(Fq, Phi, d):
    F = matrix(Fq, d, d)
    c = cf(Phi)
    for i in range(d):
        for j in range(d):
            if i - 1 == j:
                F[i, j] = 1
            if j == d - 1:
                F[i, j] = -c[i]
    return F

def rot(a, Phi, Fq):
    cf_a = vector(Fq, cf(a))
    d = len(cf_a)
    F = shift_matrix(Fq, Phi, d)
    columns = [cf_a]
    F_power = F  # Start with F^1
    for _ in range(1, d):
        current = F_power * cf_a
        columns.append(current)
        F_power = F_power * F  # Compute next power incrementally
    M = matrix(Fq, len(cf_a), len(cf_a), sum([list(col) for col in columns], []))

    return M    # print([(F**i) * cf_a for i in range(1, d)])
    

#print(shift_matrix(Fq, Phi, d))
rot_a = rot(a, Phi, Fq)
b = Rq.random_element()
assert rot_a * vector(Fq, cf(b)) == cf(a*b)




AssertionError: 