In [4]:
import numpy as np
import parameters as p

In [19]:
class Prover:
    def __init__(self):
        pass

    # g: The original polynomial which prover wants to prove.
    # r: Random value choosed by verifier. (random check point at each iteration.)
    # A_F: Storing value for the rest unextended bits in each iteration.
    # g_t: Storing sumation(s(0), s(1), s(2)) at x=0, x=1, x=2 at each iteration.
    def sumcheck(self, A_F, r):
        sum = A_F.sum()%p.prime
        g_t = np.zeros((p.bitwise,3), dtype='int32')
        for i in range(p.bitwise):
            for b in range(2**(p.bitwise-i-1)):
                for t in range(3):
                    g_t[i,t] = ( g_t[i,t] + A_F[b]*(1 - t) + A_F[b+2**(p.bitwise-i-1)]*t ) % p.prime
                A_F[b] = ( A_F[b]*(1 - r[i]) + A_F[b+2**(p.bitwise-i-1)]*r[i] ) % p.prime
        return sum, g_t
    
    # Sumcheck for NTT transformation: s = c*F
    def sumcheck_ntt(self, A_c, A_F, r):
        sum = (A_c*A_F).sum()%p.prime
        g_t = np.zeros((p.bitwise,3), dtype='int32')
        A_s = np.zeros(p.N, dtype='int32')
        for i in range(p.bitwise):
            for b in range(2**(p.bitwise-i-1)):
                for t in range(3):
                    F_t = A_F[b]*(1 - t) + A_F[b+2**(p.bitwise-i-1)]*t 
                    c_t = A_c[b]*(1 - t) + A_c[b+2**(p.bitwise-i-1)]*t 
                    g_t[i,t] = ( g_t[i,t] + F_t*c_t ) % p.prime
                A_F[b] = ( A_F[b]*(1 - r[i]) + A_F[b+2**(p.bitwise-i-1)]*r[i] ) % p.prime
                A_c[b] = ( A_c[b]*(1 - r[i]) + A_c[b+2**(p.bitwise-i-1)]*r[i] ) % p.prime
                A_s[b] = ( A_F[b]*A_c[b] ) % p.prime
        return sum, g_t, A_s

In [15]:
class Verifier:
    def __init__(self):
        pass

    # Solve linear equation y = k+ bx.
    #   e.g.    1 1   k   s(1)
    #           1 2   b   s(2)
    def solve_linear(self, a, r):
        coefficient = np.array([[1,1],[1,2]])
        dependcy = np.array([a[1],a[2]])
        x = np.linalg.solve(coefficient, dependcy)
        s = x[0] + x[1]*r
        return x%p.prime, s%p.prime

    # Verifier independly calculate s(r) with 2 points from prover (e.g. s(1), s(2)).
    # Note that s(x) is a linear equation.
    # For the first iteration, check H ?= s_1 (0) + s_1 (1)
    # For next 2 ~ (l-1) iterations, check s_(i) (r_(i)) ?= s_(i+1) (0) + s_(i+1) (1)
    def sum_verify(self, g_t, r):
        s = np.zeros(p.bitwise, dtype='int32')
        for i in range(p.bitwise):
            _, s[i] = self.solve_linear(g_t[i], r[i])
        return s

    # For the last(l) iteration, verifier needs to calculate the sumation of s'_l (r) by his own,
    # then check s'_l (r) ?= s_l (r_l)
    def multi_ext(self, g, r):
        gv = np.copy(g)
        for i in range(p.N):
            v = np.array([int(char) for char in np.binary_repr(i, width=p.bitwise)])
            #print(f"i={i}, v={v}")
            for idx,k in enumerate(v):
                gv[i] = (gv[i] * ((1-k)*(1-r[idx]) + k*r[idx])) % p.prime
                #print(g[i], k)
        return gv.sum()%p.prime
    

In [17]:
def main():
    r = np.random.randint(30, size=(p.bitwise))
    g = np.array([int(x) for x in range(p.N)])

    P = Prover()
    g_p = g.copy()
    sum, g_t = P.sumcheck(g_p, r)
    print(f"prover sends initial sumation: \n {sum}")
    print(f"prover sends g_t: \n {g_t}")
    #print(A_F)
    #print(g)

    V = Verifier()
    s = V.sum_verify(g_t, r)
    last = V.multi_ext(g, r)
    print(s)
    print(last)

    try:
        if not sum == (g_t[0][0] + g_t[0][1]) % p.prime:
            raise ValueError('first round error !!!')
        for i in range(p.bitwise-1):
            if not s[i] == (g_t[i+1][0] + g_t[i+1][1]) % p.prime:
                raise ValueError(f'{i+1} round error !!!')
        if not last == s[p.bitwise-1]:
            raise ValueError('last round error !!!')
    except ValueError as e:
        print(str(e))


if __name__ == "__main__":
    main()

prover sends initial sumation: 
 342
prover sends g_t: 
 [[438 673 139]
 [286 537  19]
 [101 356 611]
 [ 38 294 550]
 [242 306 370]
 [145 161 177]
 [647 651 655]
 [335 336 337]]
[ 54 457 332 548 306 529 671 351]
351
