# Sum-check Protocol

Adapted from [thor314](https://github.com/thor314/pazk/blob/main/4-sumcheck/)

## Utils

定义一些小工具, 这里需要对 `deg_j` 函数做如下解释:

- 限制了次数为 10 以内, 根本原因是限制数字在 int 表示范围, 这种限制对于 c/cpp, rust 等强类型语言十分重要, 而 python 的整型可以表示大整数, 因此可以挣脱限制. 
- 这个函数的中心思想是, 将多项式中目标变量的最高次数项对函数值的影响最大化, 因而就可以找出那个次数. 因此当次数大于 1000, 变量取值为 100 和 1000 的算法就不成立了, 但是可以找到更大的取值使算法继续成立. 同时, 变量数目不宜过多, 比如变量有 $10^{1000}$ 个, 那么他们对函数值的影响占比大于目标变量, 函数也不成立, 同样的, **一定能取到两个足够大的数 用这个算法求解**.

In [20]:
from inspect import signature
from typing import Callable


def arity(g: Callable) -> int:
    """Return the arguments taken by g"""
    return len(signature(g).parameters)


def to_bits(n: int, pad_to_len: int) -> list[int]:
    """Return the binary vector representaion of n, front-paddad to pad_to_len"""
    temp = [ 1 if i == '1' else 0 for i in bin(n).lstrip('0b')]
    
    diff = pad_to_len - len(temp) 
    if diff > 0:
        return [0 for _ in range(diff)] + temp
    return temp

# 这里使用了一个技巧, 一个函数中某个变量 x 的最高次数为 a, 那么赋值 x=100 与 x=1000 其他值均为 1, 分别输出 out1, out2, 则有: |out1/100^a - out2/1000^a| < 1
def deg_j(g: Callable, j: int) -> int:
    """Return the degree of j-th variable in g (0 <= j < arity(g)). Assume a non-negative integer power less than 10"""
    arity_g = arity(g)
    assert arity_g > j

    exp = 1
    while True:
        args = [1 for _ in range(j)] + [100] + [1 for _ in range(arity_g - 1 - j)]
        out1 = g(*args)
        args = [1 for _ in range(j)] + [1000] + [1 for _ in range(arity_g - 1 - j)]
        out2 = g(*args)

        if out1 == out2:
            return 0

        if abs(out1 // 100**exp - out2 // 1000**exp) < 1:
            return exp
        elif exp > 10:
            raise ValueError("exp grew larger than 10")
        else:
            exp += 1


def test_degj():
    def f(a, b, c): return a*b*b*c+b+c*c*c
    def ff(a, b, c, d): return a*b+b+d**3
    assert deg_j(f, 0) == 1
    assert deg_j(f, 1) == 2
    assert deg_j(f, 2) == 3
    assert deg_j(ff, 0) == 1
    assert deg_j(ff, 1) == 1
    assert deg_j(ff, 2) == 0
    assert deg_j(ff, 3) == 3
    def fff(a, b, c, d): return a*b*c+b+c+c*d
    assert deg_j(fff, 0) == 1
    assert deg_j(fff, 1) == 1
    assert deg_j(fff, 2) == 1
    assert deg_j(fff, 3) == 1

    # 去除次数小于 10 的限制, 算法依然正确
    # def g(a, b ,c): return a**20 + a * b ** 10 + a * b * c**5
    # assert deg_j(g, 0) == 20
    # assert deg_j(g, 1) == 10
    # assert deg_j(g, 2) == 5
    

test_degj()

## Protocol

Sum-check 是一个两方交互协议, 目标是求多变量多项式 $g(\cdot)$ 所有变量取布尔值时的和, 即:

$$
sum=\sum_{(x_1, x_2, \dots, x_n)\in \{0, 1\}^{v}}g(x_1, x_2, \dots, x_n)
$$

其中 Prover 负责:

- 在协议初始时针对用户的输入 input , 计算 g(input) 结果并发送给 Verifier. 随后协议启动
- 协议总共交互 `arity(g)` 轮. 在协议的每一轮中, Prover **依次**计算仅关于某一个变量的多项式 $g_i(\cdot)$, 并发送给 Verifier

Verifier 负责:

- 在协议每一轮中, 计算一个随机值 $v_i$, 发送给 Prover
- 在最后一轮中, 无需发送 $v_n$, 自己计算 $g_n(v_1, ..., v_n)$, 并验证 $g_n(v_0, ..., v_n) == g(v_0, ..., v_n)$

In [93]:
from random import randrange
class Verifier:
    """Initialization"""
    def __init__(self, g, arity_g, H) -> None:
        self.g = g
        self.arity_g = arity_g
        self.H = H
        self.random_challenges = []
        self.round = 1
        self.polynomials = []


    def receive_polynomial(self, polynomial: Callable):
        """Receive the latest polynomial from the Prover"""
        self.polynomials.append(polynomial)

    def check_latest_polynomial(self):
        """Validate that self.j_th_polynomial is a univariate poly of at most deg_j(g)
        and that g_{j-1}(r_{j-1}) = g_j(0) + g_j(1)"""    
        poly = self.polynomials[-1]
        
        # 检查是否是一元多项式
        arity_j = arity(poly)
        assert arity_j == 1

        deg_latest = deg_j(poly, 0)
        deg_of_j = deg_j(self.g, self.round - 1)
        if deg_latest != deg_of_j:
            raise ValueError("Prover sent a polynomial of degree {} greater than expected: {}".format(deg_latest, deg_of_j))
        
        print(poly(0), poly(1))
        new_sum = poly(0) + poly(1)
        if self.round == 1:
            check = self.H
        else:
            print(self.random_challenges[-1])
            check = self.polynomials[-2](self.random_challenges[-1])
        print(self.random_challenges)
        if check != new_sum:
            raise ValueError("Prover sent incorrect polynomial: {}, expected: {}".format(new_sum, check))


    def get_and_send_ramdom_challenge(self, p):
        """Get a new random value, append it to random_challengs, send it to the Prover, update the round"""
        self.random_challenges.append(randrange(2))
        p.receive_challenge(self.random_challenges[-1])
        self.round += 1

    # last round
    def evaluate_and_check_g_v(self):
        assert len(self.random_challenges) == self.arity_g-1
        self.random_challenges.append(randrange(2))
        print(self.random_challenges)
        g_final = self.g(*self.random_challenges)
        check = self.polynomials[-1](self.random_challenges[-1])
        print(check, g_final)

        if g_final != check:
            raise ValueError(
                "Prover sent incorrect final polynomial: {},expected: {}".format(g_final, check))
        else:
            print("VERIFIER ACCEPTS")
            return True


class Prover:
    """This prover uses a function  currying cache to imporve its runtime"""
    
    def __init__(self, g, g_arity) -> None:
        """Initial prover, compute the witness H"""
        


class InefficientProver:
    def __init__(self, g, arity_g) -> None:
        self.g = g
        self.arity_g = arity_g
        self.random_challenges = []
        self.polynomials = []
        self.round = 1
        
        # 函数取值从全 0 到全 1 , 共 2^v 种取值, 全部存放在 args_vector 中
        argsv = [to_bits(i, self.arity_g) for i in range(2 ** self.arity_g)]
        # witnesses
        self.H = sum([self.g(*args) for args in argsv])

    def compute_and_send_polynomial(self, v: Verifier):
        """compute the next polynomial, append to self.polynomials, send it, update the round"""
        # make this constant, or we'll have a moving round called from within g_j
        round = self.round

        def g_j(X_j: int) -> int:
            args_init = self.random_challenges[:round-1] + [X_j]
            pad_len = self.arity_g - len(args_init)
            argsv = [args_init + to_bits(i, pad_len) for i in range(2**pad_len)]
            return sum([self.g(*args) for args in argsv])
        self.polynomials.append(g_j)
        v.receive_polynomial(g_j)
        self.round += 1

    def receive_challenge(self, challenge: int):
        self.random_challenges.append(challenge)
        print("Received chellenge {}, initiating round {}".format(challenge, self.round))



class Sumcheck:
    def __init__(self) -> None:
        pass

    def process(self, g: Callable):
        arity_g = arity(g)
        # print(arity_g)
        p = InefficientProver(g, arity_g)
        v = Verifier(g, arity_g, p.H)

        round = 1
        while round < arity_g:
            p.compute_and_send_polynomial(v)
            v.check_latest_polynomial()
            v.get_and_send_ramdom_challenge(p)
            round += 1
        return v.evaluate_and_check_g_v()

# if __name__ == '__main__':
#     def g(a, b ,c): return a ** 5 + b ** 4 - c

#     protocol = Sumcheck()
#     ans = protocol.process(g)
#     if ans == True:
#         print("SUCCESS!")
#     else:
#         print("FAILED!")
class SumcheckProtocol:
    """The sumcheck protocol, as defined in Proofs Arguments and Zero Knowledge ch 4.1, 
    defined over polynomials of arbitrary arity over {0,1}"""

    def __init__(self, g: Callable) -> None:
        g_arity = arity(g)
        if g_arity <= 1:
            raise ValueError(
                "function arity must be greater than or equal to 1")

        self.g_arity = g_arity
        # for simplicity, have prover compute H on initialization
        self.p = InefficientProver(g, self.g_arity)
        self.v = Verifier(g, self.g_arity, self.p.H)
        self.round = 1
        self.done = False

    def __repr__(self) -> str:
        return f'Protocol(round: "{self.round}", H: "{self.p.H}", challenges: "{self.p.random_challenges}")'

    def advance_round(self):
        if not self.done:
            # Prover: compute next polynomial and send it to verifier
            self.p.compute_and_send_polynomial(self.v)
            self.v.check_latest_polynomial()
            if self.round == self.g_arity:
                # final round
                self.done = self.v.evaluate_and_check_g_v()
            else:
                self.v.get_and_send_ramdom_challenge(self.p)
                self.round += 1
        else:
            raise RuntimeError("Sumcheck protocol has finished")

    def advance_to_end(self, verbose: bool = False):
        while not self.done:
            if verbose:
                print("ADVANCE OUTPUT:", self)
            self.advance_round()

def test_sumcheck():
    def g(a, b, c): return a + b + a*b + c
    protocol = SumcheckProtocol(g)
    protocol.advance_to_end(True)

    def f(a, b, c): return a*b*c+b+c
    protocol = SumcheckProtocol(f)
    protocol.advance_to_end(True)

    def ff(a, b, c, d): return a*b*c+b+c+c*d
    protocol = SumcheckProtocol(ff)
    protocol.advance_to_end(True)

    def gg(a, b ,c): return a ** 5 + b ** 4 - c
    protocol = SumcheckProtocol(gg)
    protocol.advance_to_end(True)
        
test_sumcheck()



ADVANCE OUTPUT: Protocol(round: "1", H: "14", challenges: "[]")
4 10
[]
Received chellenge 1, initiating round 2
ADVANCE OUTPUT: Protocol(round: "2", H: "14", challenges: "[1]")
3 7
1
[1]
Received chellenge 0, initiating round 3
ADVANCE OUTPUT: Protocol(round: "3", H: "14", challenges: "[1, 0]")
1 2
0
[1, 0]
[1, 0, 0]
1 1
VERIFIER ACCEPTS
ADVANCE OUTPUT: Protocol(round: "1", H: "9", challenges: "[]")
4 5
[]
Received chellenge 1, initiating round 2
ADVANCE OUTPUT: Protocol(round: "2", H: "9", challenges: "[1]")
1 4
1
[1]
Received chellenge 1, initiating round 3
ADVANCE OUTPUT: Protocol(round: "3", H: "9", challenges: "[1, 1]")
1 3
1
[1, 1]
[1, 1, 0]
1 1
VERIFIER ACCEPTS
ADVANCE OUTPUT: Protocol(round: "1", H: "22", challenges: "[]")
10 12
[]
Received chellenge 0, initiating round 2
ADVANCE OUTPUT: Protocol(round: "2", H: "22", challenges: "[0]")
3 7
0
[0]
Received chellenge 1, initiating round 3
ADVANCE OUTPUT: Protocol(round: "3", H: "22", challenges: "[0, 1]")
2 5
1
[0, 1]
Received ch