In [1]:
import time
import math
import sympy as sp
import numpy as np
from numpy.polynomial import polynomial as p
import matplotlib.pyplot as plt
from sage.all import *

In [2]:
from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
from sage.stats.distributions.discrete_gaussian_polynomial import DiscreteGaussianDistributionPolynomialSampler 

In [3]:
#Rejection sampling algorithm 
#If parameters chosen as BBC+18, output is 1 with probability 1/rho

def Rej(Z, B, sigma, rho):
    u = np.random.random()
    if u > 1/rho * np.exp((-2*Z.flatten().dot(B.flatten()) + np.linalg.norm(B.flatten())**2)/(2*sigma**2)):
        return 0
    else:
        return 1

In [4]:
#Balance algorithm
#Given a matrix x with coeffs in [0, q], returns the same matrix with coeffs in [-(q-1)/2, (q-1)/2]
def balance(x, q):
    return np.array([n if n <= (q-1)/2 else n-q for n in x.flatten()]).reshape(np.shape(x))

In [5]:
class Gen:
    
    A = None
    S = None
    T = None
    
    def __init__(self, A, S, T):
        self.A = A
        self.S = S
        self.T = T
    

In [6]:
class Prover:
    
    A = None
    S = None
    T = None
    sigma = None
    Y = None
     
    def __init__(self, A, S, T, sigma):
        self.A = A
        self.S = S
        self.T = T
        
        self.Y = []
        self.sigma = sigma
    
    def calculateW(self, n):
        v = self.A.shape[1]
        D = DiscreteGaussianDistributionIntegerSampler(self.sigma)
        self.Y = np.array([D() for _ in range(v*n)]).reshape(v,n)
        W = np.matmul(self.A, self.Y)
        return W
    
    def calculateZ(self, C):
        Z = np.matmul(self.S, C) + self.Y
        return Z

In [13]:
class Verifier:
    
    A = None
    T = None
    W = None
    C = None
    
    def __init__(self, A, T):
        self.A = A
        self.T = T
        self.W = []
        self.C = []
        
    def calculateC(self, W):
        self.W = W
        n = W.shape[1]
        l = self.T.shape[1]
        self.C = np.random.randint(2, size=[l,n])
        return self.C
    
    def verify(self, Z, B, q):
        AZ = np.matmul(self.A, Z)
        TC = np.matmul(self.T, self.C)
        return np.array_equal(AZ%q, (TC + self.W)%q) and np.all(np.linalg.norm(Z, np.inf, axis=0) <= B)

In [14]:
class ZKP_Protocol:
    
    lamb = None
    
    def __init__(self, lamb):
        
        self.lamb = lamb #Security parameter lambda
        q = sp.nextprime(2**31) #Prime for base field Z_q
        l = 10 #Number of equations
        
        
        ############# Invent matrices A, S, T ###################
        
        r = 1000 #poly(lambda)
        v = 1000 #poly(lambda)
        n = self.lamb + 2
        
        A = np.random.randint(0, q, size = [r, v])
        S = np.random.randint(0, 2**12, size = [v, l])
        #S = np.vstack((np.ones((l,), dtype = int), np.zeros((v-1, l), dtype = int)))
        T = np.matmul(A, S)%q

        #########################################################
        
        s = np.linalg.norm(S, 2)%q
        rho = 3
        sigma = 12/np.log(rho)*s*np.sqrt(l*n)+2**-100
        B = math.sqrt(2*v)*sigma

        print("s = " + str(s))
        print("sigma = " + str(sigma))
        print("B = " + str(B))
        
        
        gen = Gen(A, S, T)
        prover = Prover(gen.A, gen.S, gen.T, sigma)
        verifier = Verifier(gen.A, gen.T)
        
        print(A)
        print(S)
        print(T)
        
        abort = True
        num_aborts = 0
        
        #Protocol starts
        
        start_time = time.time()
        
        W = prover.calculateW(n)
        
        while abort:
        
            C = verifier.calculateC(W)
        
            Z = prover.calculateZ(C)
        
            #Rejection sampling
        
            abort = not Rej(Z, np.matmul(S,C), sigma, rho)
            num_aborts += abort
            
            #Verification    
            
        bit = verifier.verify(Z, B, q)
        
        end_time = time.time()
        
        print(bit)
        print("Times aborted: " + str(num_aborts))
        print("Total execution time: " + str(end_time - start_time) + " seconds")
        
lamb = 128
zkp = ZKP_Protocol(lamb)

s = 209783.1763932437
sigma = 82618846.37746485
B = 3694827134.452473
[[ 590472370  405436150  581090750 ... 2124535677  490778377 1965814493]
 [ 996173083  686992701  191922186 ...  512213660 1422861685 1577948324]
 [ 277450348 1804373565  612587937 ... 1753962074  124938276 1856953445]
 ...
 [1105759460 1689009910  239959177 ... 1631689074  976340712 1963737039]
 [1463037424 1342851492  761065116 ...  208727993  285435440  904554265]
 [1356788574 1958822043  591373292 ...  903655417  272081985 1600892207]]
[[2422 1511  197 ... 2881 4067 3495]
 [1590 2463 3689 ... 3001 1089 1953]
 [ 574 3728 2882 ... 3982  177 1078]
 ...
 [2013 2411   92 ...  742 3647  400]
 [3710  939 2088 ...  235 2448  640]
 [3726 2513 3481 ... 2749 1172 3110]]
[[1513619872 1633896248 1036709624 ... 1046583796 1434289752 1490054068]
 [1154806012  494143351  407345585 ...  464998146 2024601915 1203965586]
 [ 687275192  173876041  859296676 ...   82925368  513122574 2133749186]
 ...
 [ 913635188 1559420380  777422513

In [45]:
def rejSamplingTest():
    res = 0
    r = 10
    n = 10
    rho = 3
    times = 10000
    print("Should abort " + str(100/rho) + "% of the time")
    for _ in range (times):
        B = np.random.randint(0, 4099, size = [r,n])
        sigma = 12/math.log(rho) * np.linalg.norm(B,2)
        D = DiscreteGaussianDistributionIntegerSampler(sigma)
        Y = np.array([D() for _ in range(r*n)]).reshape(r,n)
        res += Rej(Y+B, B, sigma, rho)
    print("Aborted " + str(100*res/times) + "% of the time")
    
rejSamplingTest()

Should abort 33.333333333333336% of the time
Aborted 33.46% of the time
