## lookup 要解决的问题
证明一个 f 表中的元素都在 t 表中。

可以用在很多场景中，比如加法的结果证明，hash 的结果证明，值的范围证明。

## cq 要解决的问题
- 让查找的复杂度与大表无关，这一点通过 pre-computation 来实现，算法是 FK
- lookup 的复杂度 O(nlogn)，pre-computation 的复杂度O(NlogN)
- 支持加法同态，有利于 accumulation

## cq 基本原理：
- 要证明 F 中的所有元素都在 T 中，转化为证明两个集合拥有一样的元素
- 转化为，用这些元素构建一个多项式，他们拥有相同的根 $f(x) = (x-a_1)(x-a_2)(x-a_3)(x-a_4)...(x-a_n)$, $g(x) = (x-b_1)(x-b_2)(x-b_3)(x-b_4)...(x-b_n)$
- 其中 a_i 和 b_i 在集合中出现的位置无所谓，一个元素也可以出现多次
- 接下来要证明 f(x) = g(x)
- 使用对数导数法可以简化计算，$f(x) = (x-a_1)(x-a_2)(x-a_3)(x-a_4)...(x-a_n)$ 的对数导数结果为 $1/(x-a_1) + 1/(x-a_2) + 1/(x-a_3) + 1/(x-a_4) + ... + 1/(x-a_n)$, 如果某个元素出现多次也可以处理，则分子上的值是对应的元素出现的次数
- 举例来说 a = {1, 2, 3, 4}, b = {1, 2, 1, 3}, 构造方式为 $A(x) = 2/(x-1) + 1/(x-2) + 1/(x-3) + 0/(x-4)$, $B(x) = 1/(x-1) + 1/(x-2) + 1/(x-1) + 1/(x-3)$, 则只需要证明 A(x) = B(x)
- 证明的方式是利用 sumcheck

## pre-computation 的原理
- 算法叫做 FK
- T(X): 公开表格元素构造出的多项式
- 目标是利用 FK 算法计算出在所有单位根处的 T(X) 的商多项式的 KZG 承诺

这样就可以让 prover 在 O(nlogn) 的复杂度下计算出 A(X) 的商多项式。如果不用这个方法，时间复杂度为 $O(n^2)$

In [1]:
# public table
table = [1, 2, 3, 4]
# values to lookup
witness = [1, 2, 1, 3]

要证明 witness 的元素都在 table 中，其中 table 是公开的。使用数学公式构造，构造方式为 $A(x) = 2/(x-1) + 1/(x-2) + 1/(x-3) + 0/(x-4)$, $B(x) = 1/(x-1) + 1/(x-2) + 1/(x-1) + 1/(x-3)$，如果现实中使用，是利用多项式来计算，多项式是通过多个点插值来得到的，所以需要计算多个点，方式是利用一个随机数 $\beta$，把上面的数学表达形式转换成点值表达形式，假设 $\beta = 10$，则 A(X) 是 $A(\beta)$ 代表的几个值 {2/(10-1), 1/(10-2), 1/(10-3), 0/(10-4)}，B(X) 是 $B(\beta)$ 代表的几个值 {1/(10-1), 1/(10-2), 1/(10-1), 1/(10-3)}

这样我们就可以将这几个值插值成多项式，利用多项式的性质来证明 A(X) 和 B(X) 相等。这需要我们证明三个地方：
- A(X) 是正确构造了。利用 KZG 承诺
- B(X) 是正确构造了。利用 KZG 承诺
- A(X) 和 B(X) 公式中得到的求和结果是相等的。利用 sumcheck

由于 FK 算法只是针对计算 A(X) 商多项式的优化，不影响整体的流程，所以我们先讲不带 FK 算法的流程，之后利用 FK 算法把相关代码做一个替换就好了。

先证明 A(X) 被正确构造，我们需要用到 KZG 承诺。首先需要生成 SRS（Structured Reference String)

In [2]:
import py_ecc.bn128 as b

def generate_srs(powers: int, tau: int):
    print("Start to generate structured reference string")

    # Initialize powers_of_x with 0 values
    powers_of_x = [0] * powers
    # powers_of_x[0] =  b.G1 * tau**0 = b.G1
    # powers_of_x[1] =  b.G1 * tau**1 = powers_of_x[0] * tau
    # powers_of_x[2] =  b.G1 * tau**2 = powers_of_x[1] * tau
    # ...
    # powers_of_x[i] =  b.G1 * tau**i = powers_of_x[i - 1] * tau
    powers_of_x[0] = b.G1

    for i in range(1, powers):
        powers_of_x[i] = b.multiply(powers_of_x[i - 1], tau)

    assert b.is_on_curve(powers_of_x[1], b.b)
    print("Generated G1 side, X^1 point: {}".format(powers_of_x[1]))

    powers_of_x2 = [0] * (powers + 1)
    powers_of_x2[0] = b.G2
    for i in range(1, powers + 1):
        powers_of_x2[i] = b.multiply(powers_of_x2[i - 1], tau)

    assert b.is_on_curve(powers_of_x2[1], b.b2)
    print("Generated G2 side, X^1 point: {}".format(powers_of_x2[1]))

    # assert b.pairing(b.G2, powers_of_x[1]) == b.pairing(powers_of_x2[1], b.G1)
    print("X^1 points checked consistent")
    print("Finished to generate structured reference string")
    return (powers_of_x, powers_of_x2)


In [3]:
from src.common_util.curve import Scalar
from py_ecc.fields.field_elements import FQ as Field
import py_ecc.bn128 as b

PRIMITIVE_ROOT = 5

class Scalar(Field):
    field_modulus = b.curve_order

    # Gets the first root of unity of a given group order
    @classmethod
    def root_of_unity(cls, group_order: int):
        assert (cls.field_modulus - 1) % group_order == 0
        return Scalar(PRIMITIVE_ROOT) ** ((cls.field_modulus - 1) // group_order)

    # Gets the full list of roots of unity of a given group order
    @classmethod
    def roots_of_unity(cls, group_order: int):
        o = [Scalar(1), cls.root_of_unity(group_order)]
        while len(o) < group_order:
            o.append(o[-1] * o[1])
        return o

def is_power_of_two(n):
    """
    Check if a given number is a power of two.

    :param n: The number to be checked.
    :return: True if n is a power of two, False otherwise.
    """
    if n <= 0:
        return False
    else:
        return (n & (n - 1)) == 0

def fft(values: list[Scalar], inv=False):
    def _fft(vals, modulus, roots_of_unity):
        if len(vals) == 1:
            return vals
        L = _fft(vals[::2], modulus, roots_of_unity[::2])
        R = _fft(vals[1::2], modulus, roots_of_unity[::2])
        o = [0] * len(vals)
        for i, (x, y) in enumerate(zip(L, R)):
            y_times_root = y * roots_of_unity[i]
            o[i] = (x + y_times_root) % modulus
            o[i + len(L)] = (x - y_times_root) % modulus
        return o

    assert is_power_of_two(len(values)), "fft: values length should be powers of 2"
    roots = [x.n for x in Scalar.roots_of_unity(len(values))]
    o, nvals = Scalar.field_modulus, [x.n for x in values]
    if inv:
        # Inverse FFT
        invlen = Scalar(1) / len(values)
        reversed_roots = [roots[0]] + roots[1:][::-1]
        return [Scalar(x) * invlen for x in _fft(nvals, o, reversed_roots)]
    else:
        # Regular FFT
        return [Scalar(x) for x in _fft(nvals, o, roots)]


def ifft(values: list[Scalar]):
    return fft(values, True)


In [4]:
from enum import Enum
from numpy.polynomial import polynomial as P

class Basis(Enum):
    LAGRANGE = 1
    MONOMIAL = 2


class Polynomial:
    values: list[Scalar]
    basis: Basis

    def __init__(self, values: list[Scalar], basis: Basis):
        assert all(isinstance(x, Scalar) for x in values)
        assert isinstance(basis, Basis)
        self.values = values
        self.basis = basis

    def __eq__(self, other):
        return (self.basis == other.basis) and (self.values == other.values)

    def __add__(self, other):
        if isinstance(other, Polynomial):
            assert self.basis == other.basis
            if (self.basis == Basis.LAGRANGE):
                assert len(self.values) == len(other.values)
                return Polynomial(
                    [x + y for x, y in zip(self.values, other.values)],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                res = P.polyadd(self.values, other.values)
                return Polynomial(
                    res,
                    self.basis,
                )
        else:
            assert isinstance(other, Scalar)
            if (self.basis == Basis.LAGRANGE):
                return Polynomial(
                    [x + other for x in self.values],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                res = P.polyadd(self.values, [other])
                return Polynomial(
                    res,
                    self.basis,
                )


    def __sub__(self, other):
        if isinstance(other, Polynomial):
            assert self.basis == other.basis
            if (self.basis == Basis.LAGRANGE):
                assert len(self.values) == len(other.values)
                return Polynomial(
                    [x - y for x, y in zip(self.values, other.values)],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                res = P.polysub(self.values, other.values)
                return Polynomial(
                    res,
                    self.basis,
                )
        else:
            assert isinstance(other, Scalar)
            if (self.basis == Basis.LAGRANGE):
                return Polynomial(
                    [x - other for x in self.values],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                res = P.polysub(self.values, [other])
                return Polynomial(
                    res,
                    self.basis,
                )

    def __mul__(self, other):
        if isinstance(other, Polynomial):
            assert self.basis == other.basis
            if (self.basis == Basis.LAGRANGE):
                assert len(self.values) == len(other.values)
                res = [x * y for x, y in zip(self.values, other.values)]
            if (self.basis == Basis.MONOMIAL):
                c1 = self.values
                c2 = other.values
                res = P.polymul(c1,c2)

            return Polynomial(
                res,
                self.basis,
            )
        else:
            assert isinstance(other, Scalar)
            if (self.basis == Basis.LAGRANGE):
                return Polynomial(
                    [x * other for x in self.values],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                c1 = self.values
                c2 = [other]
                res = P.polymul(c1,c2)
                return Polynomial(
                    res,
                    self.basis,
                )

    def __truediv__(self, other):
        if isinstance(other, Polynomial):
            assert self.basis == other.basis
            if (self.basis == Basis.LAGRANGE):
                assert len(self.values) == len(other.values)
                return Polynomial(
                    [x / y for x, y in zip(self.values, other.values)],
                    self.basis,
                )
            if (self.basis == Basis.MONOMIAL):
                qx, rx = P.polydiv(self.values, other.values)
                # here we only consider the scenario of remainder is 0
                assert rx == [0]

                return Polynomial(
                    qx,
                    self.basis,
                )
        else:
            assert isinstance(other, Scalar)
            if (self.basis == Basis.LAGRANGE):
                return Polynomial(
                    [x / other for x in self.values],
                    self.basis,
                )

            if (self.basis == Basis.MONOMIAL):
                c1 = self.values
                c2 = [other]
                res = P.polydiv(c1,c2)
                return Polynomial(
                    res,
                    self.basis,
                )

    # Evaluate at x directly for polynomial of MONOMIAL
    # This is inefficient, just for study usage
    def coeff_eval(self, x: Scalar):
        assert self.basis == Basis.MONOMIAL
        coeffs = self.values
        result = coeffs[0]
        x_pow = Scalar(1)
        for i in range(1, len(coeffs)):
            x_pow = x_pow * x
            result = result + coeffs[i] * x_pow
        return result


In [5]:
def ec_mul(pt, coeff):
    if hasattr(coeff, "n"):
        coeff = coeff.n
    return b.multiply(pt, coeff % b.curve_order)


def ec_lincomb(pairs):
    o = b.Z1
    for pt, coeff in pairs:
        o = b.add(o, ec_mul(pt, coeff))
    return o

def commit_g1(coeffs: list):
    return ec_lincomb([(s, x) for s, x in zip(powers_of_x, coeffs)])

def commit_g2(coeffs: list):
    return ec_lincomb([(s, x) for s, x in zip(powers_of_x2, coeffs)])


In [6]:
# public table
table = [1, 2, 3, 4]
# values to lookup
witness = [1, 2, 1, 3]
table_len = len(table)

# 0. Setup SRS
tau = 100
powers = 4
beta = Scalar(10)
powers_of_x, powers_of_x2 = generate_srs(powers, tau)
roots_of_unity = [x.n for x in Scalar.roots_of_unity(len(table))]

Start to generate structured reference string
Generated G1 side, X^1 point: (8464813805670834410435113564993955236359239915934467825032129101731355555480, 15805858227829959406383193382434604346463310251314385567227770510519895659279)
Generated G2 side, X^1 point: ((9069199875169756180356420481015645310388174870092142108235557076420063073096, 19765057837394621075835443081152314037672821821523780638388756299265215370750), (19790693237741786827533886710327176959865707426594137996397540167727981081713, 10168740664245831497642710459131680020400674013216613216484698513127357137715))
X^1 points checked consistent
Finished to generate structured reference string


In [7]:
# Prove A(X) is well formed
from py_ecc.fields.field_elements import FQ as Field
from collections import Counter
import numpy as np

# 1. Construct polynomial A(X)
# Count the number of times each element of the table f to be looked up appears in table

duplicates = dict(Counter(witness))
m_values = [Scalar(duplicates.get(val, 0)) for val in table]
m_coeffs = ifft(m_values)
m_poly = Polynomial(m_coeffs, Basis.MONOMIAL)

t_values = [Scalar(val) for val in table]
t_coeffs = ifft(t_values)
T_poly = Polynomial(t_coeffs, Basis.MONOMIAL)

A_values = []
for i, t_i in enumerate(table):
    A_i = m_values[i]/(beta + t_i)
    A_values.append(A_i)
    # sanity check
    assert A_i == m_values[i]/(beta + t_i), "A: not equal"
print("A_values: ", A_values)

# use ifft to get coefficients
A_coeffs = ifft(A_values)
A_poly = Polynomial(A_coeffs, Basis.MONOMIAL)

assert np.array_equal(A_poly.values, A_coeffs)

# vanishing polynomial: X^N - 1, N = table_len - 1
ZV_array = [Scalar(-1)] + [Scalar(0)] * (table_len - 1) + [Scalar(1)]
# vanishing polynomial in coefficient form
ZV_poly = Polynomial(ZV_array, Basis.MONOMIAL)

Q_A_poly = (A_poly * (T_poly + beta) - m_poly) / ZV_poly

# 2. commit
A_comm_1 = commit_g1(A_poly.values)
Q_A_comm_1 = commit_g1(Q_A_poly.values)
Z_V_comm_2 = commit_g2(ZV_poly.values)
T_comm_2 = commit_g2(T_poly.values)
m_comm_1 = commit_g1(m_poly.values)
print("Q_A_comm_1: ", Q_A_comm_1)

# TODO: pairing check 2.11
### Check 1: round 2.11: A encodes the correct values ###
print("=== Started Check 1: round 2.11: A encodes the correct values ===")
comb = ec_lincomb([
    (m_comm_1, 1),
    (A_comm_1, -beta)
])
A_check_lhs1 = b.pairing(T_comm_2, A_comm_1)
A_check_rhs1 = b.pairing(Z_V_comm_2, Q_A_comm_1)
A_check_rhs2 = b.pairing(b.G2, comb)
assert A_check_lhs1 == A_check_rhs1 * A_check_rhs2, "Check 1 failed: A encodes the correct values"
print("=== Finished Check 1: round 2.11: A encodes the correct values ===")


A_values:  [15918722088610381979815567814732563700762446836666206795416875772055133451358, 20064222632519335620392538599819168831169334033714698148390020504361157787649, 20204531881697792512842836072545177004813874831153262471106034633762284765185, 0]
Q_A_comm_1:  (15415052476916874991182160238156721134888647709419998893599013621353989425365, 11501847284461168181652651429020983977194714692018271330508400884636232865298)
=== Started Check 1: round 2.11: A encodes the correct values ===
=== Finished Check 1: round 2.11: A encodes the correct values ===


In [8]:
# Prove B(X) is well formed

# 1. Construct polynomial B(X)
B_values = []
for i, f_i in enumerate(witness):
    B_i = 1 / (beta + f_i)
    B_values.append(B_i)
    # sanity check
    assert B_i == 1 / (beta + f_i), "B: not equal"

# use ifft to get coefficients
B_coeffs = ifft(B_values)
B_poly = Polynomial(B_coeffs, Basis.MONOMIAL)

assert np.array_equal(B_poly.values, B_coeffs)


# 2. commit
B_comm_1 = commit_g1(B_coeffs)
print("B_comm_1: ", B_comm_1)

# random number from verifier
zeta = 20

eval_at_zeta = B_poly.coeff_eval(zeta)
print("eval_at_zeta: ", eval_at_zeta)

# vanishing polynomial z(X): X - zeta
z_coeffs = [Scalar(-zeta), Scalar(1)]
z_poly = Polynomial(z_coeffs, Basis.MONOMIAL)

Q_B_poly = (B_poly - eval_at_zeta) / z_poly
print("Q_B_poly: ", Q_B_poly.values)

Q_B_comm_1 = commit_g1(Q_B_poly.values)
print("Q_B_comm_1: ", Q_B_comm_1)

# 3. verify
# f(x) - f(zeta) = q(x)(x-zeta)
# f(x) + zeta * q(x) - f(zeta) = q(x) * x
# b.pairing(b.G2, [f[tau)]_1 + zeta * [q(tau)]_1 - [f(zeta)]_1) == b.pairing(b.G2 * tau, [q(tau)]_1)

#lin_com = [f[x)]_1 + zeta * [q(x)]_1 - [f(zeta)]_1
lin_com1 = b.add(
    B_comm_1,
    b.multiply(Q_B_comm_1, zeta)
)
lin_com2 = b.multiply(b.G1, Scalar(-eval_at_zeta).n)
lin_com = b.add(lin_com1, lin_com2)
assert powers_of_x2[1] == b.multiply(b.G2, tau)
assert b.pairing(b.G2, lin_com) == b.pairing(b.multiply(b.G2, tau), Q_B_comm_1)
print("Verify success!")


B_comm_1:  (20682201959834088444255392993387177061810557395518561585763326266255626951804, 6743382538339576368676379820725382960077518812833977519085207282117576162334)
eval_at_zeta:  16636212567363854668929515734324434679294912095665297729235889404122551281959
Q_B_poly:  [18734473611896232795942783889096194190021534648867716985918856392418332585320
 17905373503114442070504636603460726395227269381486113760343714326671199383778
 11049353372803480280453862001259207386298386974614043065482367342368527655985]
Q_B_comm_1:  (1723670226925097623145265080148568853861410936024775717113081327024210889029, 1726920566509075301699872238531866022673528949081987622557192100613386838924)
Verify success!


In [9]:
# prove sumcheck, based on The Aurora lemma at section 2.1 of paper
# which is N * A(0) == n * B(0)

# 1. prove A(0) is correct with KZG
A_eval_at_0 = A_poly.coeff_eval(0)
print("A_eval_at_0: ", A_eval_at_0)

print("A_comm_1: ", A_comm_1)

# random number from verifier
zeta = 0

eval_at_zeta = A_eval_at_0
print("eval_at_zeta: ", eval_at_zeta)

# vanishing polynomial z(X): X - 0
z_coeffs = [Scalar(0), Scalar(1)]
z_poly = Polynomial(z_coeffs, Basis.MONOMIAL)

Q_A_poly = (A_poly - eval_at_zeta) / z_poly
print("Q_A_poly: ", Q_A_poly.values)

Q_A_comm_1_zeta = commit_g1(Q_A_poly.values)
print("Q_A_comm_1_zeta: ", Q_A_comm_1_zeta)

# 3. verify
# f(x) - f(zeta) = q(x)(x-zeta)
# f(x) + zeta * q(x) - f(zeta) = q(x) * x
# b.pairing(b.G2, [f[x)]_1 + zeta * [q(x)]_1 - [f(zeta)]_1) == b.pairing(b.G2 * tau, [q(x)]_1)
#lin_com = [f[x)]_1 + zeta * [q(x)]_1 - [f(zeta)]_1
lin_com1 = b.add(
    A_comm_1,
    b.multiply(Q_A_comm_1_zeta, zeta)
)
lin_com2 = b.multiply(b.G1, Scalar(-eval_at_zeta).n)
lin_com = b.add(lin_com1, lin_com2)
assert powers_of_x2[1] == b.multiply(b.G2, tau)
assert b.pairing(b.G2, lin_com) == b.pairing(b.multiply(b.G2, tau), Q_A_comm_1_zeta)
print("Verify success!")

A_eval_at_0:  14046869150706877528262735621774227384186413925383541853728232727544644001048
A_comm_1:  (8917082233065637445142715498718931793991726733924264150036551308445675561640, 10545707768690016138765928802784210849980062836550017679315660257500999395798)
eval_at_zeta:  14046869150706877528262735621774227384186413925383541853728232727544644001048
Q_A_poly:  [3032593090198011471006215699663757499809572034203908049769813093169059989153
 14958879270366847329189669194493280512875929108734209951382324568651969355032
 5768623449177920873603353044058573392438896168760581284234709569265268601742]
Q_A_comm_1_zeta:  (1848346930093139968079432045884787055423870130569256258702400190025345965924, 657865574877848602903431434232863694932556009008917090014403511793306619344)
Verify success!


In [10]:
# 2. prove B(0) is correct with KZG

print("B_comm_1: ", B_comm_1)

B_eval_at_0 = B_poly.coeff_eval(0)
zeta = 0
eval_at_zeta = B_eval_at_0
print("B_eval_at_0: ", B_eval_at_0)

# vanishing polynomial z(X): X - 0
z_coeffs = [Scalar(0), Scalar(1)]
z_poly = Polynomial(z_coeffs, Basis.MONOMIAL)

Q_B_poly = (B_poly - eval_at_zeta) / z_poly
print("Q_B_poly: ", Q_B_poly.values)

Q_B_comm_1 = commit_g1(Q_B_poly.values)
print("Q_B_comm_1: ", Q_B_comm_1)

# 3. 验证
# f(x) - f(zeta) = q(x)(x-zeta)
# f(x) + zeta * q(x) - f(zeta) = q(x) * x
# b.pairing(b.G2, [f[x)]_1 + zeta * [q(x)]_1 - [f(zeta)]_1) == b.pairing(b.G2 * tau, [q(x)]_1)
#lin_com = [f[x)]_1 + zeta * [q(x)]_1 - [f(zeta)]_1
lin_com1 = b.add(
    B_comm_1,
    b.multiply(Q_B_comm_1, zeta)
)
lin_com2 = b.multiply(b.G1, Scalar(-eval_at_zeta).n)
lin_com = b.add(lin_com1, lin_com2)
assert powers_of_x2[1] == b.multiply(b.G2, tau)
assert b.pairing(b.G2, lin_com) == b.pairing(b.multiply(b.G2, tau), Q_B_comm_1)
print("Verify success!")

B_comm_1:  (20682201959834088444255392993387177061810557395518561585763326266255626951804, 6743382538339576368676379820725382960077518812833977519085207282117576162334)
B_eval_at_0:  14046869150706877528262735621774227384186413925383541853728232727544644001048
Q_B_poly:  [10838889499035794941792543743998067702249977425801991278215836844207280839632
 15800734765437588683891454030849329554743173893365595887678409345058731220248
 11049353372803480280453862001259207386298386974614043065482367342368527655985]
Q_B_comm_1:  (948946274175271009940314681516022457447883910578151972872817587825769518999, 16781378638098203819716007662698839353166237777016630729884885378422106862655)
Verify success!


In [11]:
# 3. Verify N * A(0) == n * B(0)
assert A_eval_at_0 * len(table) == eval_at_zeta * len(witness)
print("Verify success!")

Verify success!


下面利用 FK 来计算 $Q_A(X)$ 的承诺

In [12]:
import numpy as np
import py_ecc.bn128 as b

# https://eprint.iacr.org/2023/033
def fk(coeffs, powers_of_x):
    print("\n ***************** Start fk() ****************")
    assert len(coeffs) == len(powers_of_x), "length should be equal"
    n = len(coeffs)
    assert is_power_of_two(n), "length should be power of 2"
    # Get first column of circulant matrix in length of 2 * len(coeffs)
    # For example: coeffs is [1, 2, 3, 4]
    # The first column of circulant matrix should be: [4, 0, 0, 0, 0, 0, 2, 3]
    first_col = coeffs.copy()
    # first coefficient is unused, so set it to Scalar(0)
    first_col[0] = Scalar(0)

    # get first column of circulant matrix in 2n size
    # 1. padding 0
    first_col = np.pad(first_col, (n, 0), 'constant', constant_values=(Scalar(0),))
    # 2. roll by 1 to right
    first_col = np.roll(first_col, 1)

    # inverse srs: delete last one then inverse
    inv_powers_of_x = powers_of_x[:-1][::-1]
    inv_powers_of_x.append(b.Z1)
    # padding n 0s to the end
    ec_neutral_vals = [b.Z1] * n
    padded_x = inv_powers_of_x + ec_neutral_vals

    # We have circulant matrix C, C = F_inv * diag(F * first_col) * F
    # F: DFT matrix, F_inv: inverse DFT matrix
    # We want to get Q_T_comm_poly_coeffs = C * x = F_inv * diag(F * first_col) * F * x
    # 1. right hand side: F * x
    rhs = ec_fft(padded_x)

    # 2. middle hand side: F * first_col
    mhs = fft(first_col)

    # middle * right (element wise) to get diagonal: diag(F * first_col) * F * x
    m_r_hs = [b.multiply(rhs[i], mhs[i].n) for i in range(len(rhs))]

    # 3. ifft
    result = ec_ifft(m_r_hs)

    # 4. return firt n values
    Q_comm_poly_coeffs = result[:n]
    print("\n ***************** End fk() ****************")
    return Q_comm_poly_coeffs


def is_power_of_two(n):
    """
    Check if a given number is a power of two.

    :param n: The number to be checked.
    :return: True if n is a power of two, False otherwise.
    """
    if n <= 0:
        return False
    else:
        return (n & (n - 1)) == 0

def ec_fft(values: list, inv=False):
    def _fft(vals: list, modulus, roots_of_unity):
        if len(vals) == 1:
            return vals
        L = _fft(vals[::2], modulus, roots_of_unity[::2])
        R = _fft(vals[1::2], modulus, roots_of_unity[::2])
        o = [0] * len(vals)
        for i, (x, y) in enumerate(zip(L, R)):
            y_times_root = b.multiply(y, roots_of_unity[i])
            o[i] = b.add(x, y_times_root)
            o[i + len(L)] = b.add(x, b.neg(y_times_root))
        return o

    assert is_power_of_two(
        len(values)), "ec_fft: values length should be powers of 2"
    roots = [x.n for x in Scalar.roots_of_unity(len(values))]
    o, nvals = Scalar.field_modulus, values
    if inv:
        # Inverse FFT
        invlen = (Scalar(1) / len(values)).n
        reversed_roots = [roots[0]] + roots[1:][::-1]
        return [b.multiply(x, invlen) for x in _fft(nvals, o, reversed_roots)]
    else:
        # Regular FFT
        return _fft(nvals, o, roots)


def ec_ifft(values: list):
    return ec_fft(values, True)


In [13]:
def precompute_with_fk(table, powers_of_x):
    t_values = [Scalar(val) for val in table]
    t_poly_coeffs = ifft(t_values)
    # compute h values with fk
    return fk(t_poly_coeffs, powers_of_x)

table_len = len(table)

Q_T_comm_poly_coeffs = precompute_with_fk(table, powers_of_x[:table_len])

roots = Scalar.roots_of_unity(len(A_values))

Q_A_comm_1_FK = b.Z1
for i in range(table_len):
    K_T_Comm = b.Z1
    root = Scalar(1)
    for j in range(table_len):
        K_T_Comm = b.add(K_T_Comm, b.multiply(
            Q_T_comm_poly_coeffs[j], root.n))
        root = root * roots[i]
    A_val = A_values[i].n
    scale = roots[i]/table_len
    # Compute Quotient polynomial commitment of T(X)
    Q_T_Comm = b.multiply(K_T_Comm, scale.n)
    A_times_Q_T_Comm = b.multiply(Q_T_Comm, A_val)
    # Do the accumulation
    Q_A_comm_1_FK = b.add(Q_A_comm_1_FK, A_times_Q_T_Comm)

print("\n Commitment of Q_A(X) with FK:  \n", Q_A_comm_1_FK)
print("\n Q_A_comm_1:  \n", Q_A_comm_1)

assert Q_A_comm_1_FK == Q_A_comm_1

print("verify success!")


 ***************** Start fk() ****************

 ***************** End fk() ****************

 Commitment of Q_A(X) with FK:  
 (15415052476916874991182160238156721134888647709419998893599013621353989425365, 11501847284461168181652651429020983977194714692018271330508400884636232865298)

 Q_A_comm_1:  
 (15415052476916874991182160238156721134888647709419998893599013621353989425365, 11501847284461168181652651429020983977194714692018271330508400884636232865298)
verify success!
