In [1]:
import time
import math
import sympy as sp
import numpy as np
from numpy.polynomial import polynomial as p
import matplotlib.pyplot as plt
from sage.all import *

In [2]:
from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
from sage.stats.distributions.discrete_gaussian_polynomial import DiscreteGaussianDistributionPolynomialSampler 

In [3]:
#Rejection sampling algorithm 
#If parameters chosen as BBC+18, output is 1 with probability 1/rho

def Rej(Z, B, sigma, rho):
    u = np.random.random()
    if u > 1/rho * np.exp((-2*Z.flatten().dot(B.flatten()) + np.linalg.norm(B.flatten())**2)/(2*sigma**2)):
        return 0
    else:
        return 1

In [4]:
#Balance algorithm
#Given a matrix x with coeffs in [0, q], returns the same matrix with coeffs in [-(q-1)/2, (q-1)/2]
def balance(x, q):
    return np.array([n if n <= (q-1)/2 else n-q for n in x.flatten()]).reshape(np.shape(x))

In [5]:
class Gen:
    
    A = None
    S = None
    T = None
    
    def __init__(self, A, S, T):
        self.A = A
        self.S = S
        self.T = T
    

In [6]:
class Prover:
    
    A = None
    S = None
    T = None
    sigma = None
    Y = None
     
    def __init__(self, A, S, T, sigma):
        self.A = A
        self.S = S
        self.T = T
        
        self.Y = []
        self.sigma = sigma
    
    def calculateW(self, n):
        v = self.A.shape[1]
        D = DiscreteGaussianDistributionIntegerSampler(self.sigma)
        self.Y = np.array([D() for _ in range(v*n)]).reshape(v,n)
        W = np.matmul(self.A, self.Y)
        return W
    
    def calculateZ(self, C):
        Z = np.matmul(self.S, C) + self.Y
        return Z

In [7]:
class Verifier:
    
    A = None
    T = None
    W = None
    C = None
    
    def __init__(self, A, T):
        self.A = A
        self.T = T
        self.W = []
        self.C = []
        
    def calculateC(self, W):
        self.W = W
        n = W.shape[1]
        l = self.T.shape[1]
        self.C = np.random.randint(2, size=[l,n])
        return self.C
    
    def verify(self, Z, B, q):
        AZ = np.matmul(self.A, Z)
        TC = np.matmul(self.T, self.C)
        return np.array_equal(AZ%q, (TC + self.W)%q) and np.all(np.linalg.norm(Z, np.inf, axis=0) <= B)
        

In [8]:
class ZKP_Protocol:
    
    lamb = None
    
    def __init__(self, lamb):
        
        self.lamb = lamb #Security parameter lambda
        q = sp.nextprime(2**31) #Prime for base field Z_q
        l = 3 #Number of equations
        
        
        ############# Invent matrices A, S, T ###################
        
        r = 1000 #poly(lambda)
        v = 100 #poly(lambda)
        n = self.lamb + 2
        
        A = np.random.randint(0, q, size = [r, v])
        S = np.random.randint(0, 2**12, size = [v, l])
        T = np.matmul(A, S)%q

        #########################################################
        
        s = np.linalg.norm(S, 2)
        rho = 3
        sigma = 12/np.log(rho)*s*np.sqrt(l*n)
        B = math.sqrt(2*v)*sigma

        print("s = " + str(s))
        print("sigma = " + str(sigma))
        print("B = " + str(B))
                
        gen = Gen(A, S, T)
        prover = Prover(gen.A, gen.S, gen.T, sigma)
        verifier = Verifier(gen.A, gen.T)
        
        print(A)
        print(S)
        print(T)
        
        abort = True
        num_aborts = 0
        
        #Protocol starts
        
        start_time = time.time()
        
        W = prover.calculateW(n)
        
        while abort:
        
            C = verifier.calculateC(W)
        
            Z = prover.calculateZ(C)
        
            #Rejection sampling
        
            abort = not Rej(Z, np.matmul(S,C), sigma, rho)
            num_aborts += abort
            
            #Verification    
            
        bit = verifier.verify(Z, B, q)
        
        end_time = time.time()
        
        print(bit)
        print("Times aborted: " + str(num_aborts))
        print("Total execution time: " + str(end_time - start_time) + " seconds")
        
lamb = 256
zkp = ZKP_Protocol(lamb)

s = 37714.456300279475
sigma = 11460805.046193901
B = 162080259.32041422
[[1244423605  188067493  468544086 ...  551336420 1289619515 1064383101]
 [ 157622746 1566355544 1826044681 ...  273144123 1816623634 1910932690]
 [1827303795  779512107  195213894 ... 1356450909  368597433  552090858]
 ...
 [ 486929374  125420028 2015393616 ... 1785406657 1799609002  192518602]
 [ 108813690  483456708  568249490 ... 1251171451  573100899 2033308718]
 [ 807276156 1878105673 1812059719 ...  562933669 1388112307 1564082033]]
[[1112 1761 2034]
 [ 577 2951 3059]
 [1890 2798 1376]
 [3471 4081  782]
 [ 317 1719 3466]
 [3291 3213 2416]
 [  57 2016 2384]
 [3346 3434 1598]
 [1517 2860 3108]
 [1805 4038 1390]
 [2810 3720 3405]
 [3777  557 3651]
 [1967 3985 3446]
 [1964  161 3019]
 [3983 2126 3936]
 [3727 3072  334]
 [3958 2233 1995]
 [2953 2965  892]
 [ 141  706 3671]
 [2590 2452 1857]
 [3640 3303 1315]
 [2089 2017 3665]
 [4002  694 3267]
 [1214 2813  868]
 [  84 1087  584]
 [ 505 1501  610]
 [ 744 3764 407

In [11]:
def rejSamplingTest():
    res = 0
    r = 10
    n = 10
    rho = 3
    times = 20000
    print("Should abort " + str(100/rho) + "% of the time")
    for _ in range (times):
        B = np.random.randint(0, 2**32, size = [r,n])
        sigma = 12/math.log(rho) * np.linalg.norm(B,2)
        D = DiscreteGaussianDistributionIntegerSampler(sigma)
        Y = np.array([D() for _ in range(r*n)]).reshape(r,n)
        res += Rej(Y+B, B, sigma, rho)
    print("Aborted " + str(100*res/times) + "% of the time")
    
rejSamplingTest()

Should abort 33.333333333333336% of the time


  


Aborted 33.455% of the time
