In [3]:
import time
import numpy as np
import math
import itertools
from Crypto.Util import number
import Crypto.Random as Random
from Crypto.Hash import SHA256

# KRR

In [4]:
def buildParams(epsilon, width):
    d = len(CATEGORIES)
    l, n = decideRatio(epsilon, d, width)
    assert (n-l) % d == 0, "Invalied combination, n, l, d"
    print("n: ", n, "l: ", l, "d:", d)
    D = max([l, (n-l)//d]) + 1
    return d, l, n, D

def decideRatio(eps, d, width):
    ratio = np.exp(eps) / ((d-1) + np.exp(eps))
    print('original p=', ratio)
    integer = int(ratio * width)
    while integer > 0:
        if (width-integer) % d == 0:
            g = math.gcd(integer, width, (width-integer) // d)
            print('approximate p=', integer/width)
            return integer // g, width // g
        integer -= 1
    assert False, "Not found"    

class Prover:
    def __init__(self, data, d, l, n, D):
        if data in CATEGORIES:
            self.data = data
        else:
            assert False, "out of categories"
        self.d, self.l, self.n, self.D = d, l, n, D
        
    def setup(self):
        print('receiver setup')
        t = self.data
        print("secret input: ", t)
        mu_array = []
        mu_array = [t]*self.l
        for category in CATEGORIES:
            mu_array = mu_array + ([category] * ((self.n - self.l) // self.d))
        mu_array = np.random.permutation(mu_array).tolist()
        self.mu_array = mu_array
#         print("mu array: ", mu_array)
        
    def setPubKey(self, pub_key):
        self.pub_key = pub_key
    
    def step2(self, g_a, g_b, g_ab):
        print('step2')
        q, g, h = self.pub_key
        w_array = []
        y_array = []
        v_array = []
        
        for i in range(0, self.n):
            r = number.getRandomRange(1, q-1)
            s = number.getRandomRange(1, q-1)
            w = pow(g, r, q) * pow(g_a, s, q) % q
            v = pow(g_b, r, q) * pow(g_ab * pow(g, i, q) % q, s, q) % q
            y = pow(g, self.D**self.mu_array[i], q) * pow(h, v, q) % q
            w_array.append(w)
            y_array.append(y)
            v_array.append(v)

        self.w_array = w_array
        self.y_array = y_array
        self.v_array = v_array

        return w_array, y_array
    
    def AKEncBool1(self):
        print('AKEncBool1')
        q, g, h = self.pub_key
        b_array = []
        c_array = []
        s_array = []
        random_w_array = []
        for y, mu in zip(self.y_array, self.mu_array):
            random_w = number.getRandomRange(1, q-1)
            c = [0] * self.d
            s = [0] * self.d
            b = [0] * self.d
            for category in CATEGORIES:
                if category != mu:
                    c_i = number.getRandomRange(1, q-1)
                    s_i = number.getRandomRange(1, q-1)
                    c[category] = c_i
                    s[category] = s_i
                    g_i_inv = pow(pow(g, self.D**category, q), -1, q)
                    deno = pow(y * g_i_inv % q, c_i, q)
                    b_i = pow(h, s_i, q) * pow(deno, -1, q) % q
                    b[category] = b_i
                else:
                    b_i = pow(h, random_w, q)
                    b[category] = b_i
            random_w_array.append(random_w)
            b_array.append(b)
            c_array.append(c)
            s_array.append(s)

        self.b_array = b_array
        self.s_array = s_array
        self.c_array = c_array
        self.random_w_array = random_w_array
        
        return b_array
    
    def AKEncBool3(self, ak_enc_bool_x_array):
        print('AKEncBool3')
        q, g, h = self.pub_key
        for x,c,s,mu,v,random_w in zip(ak_enc_bool_x_array, self.c_array, self.s_array, self.mu_array, self.v_array, self.random_w_array):
            for category in CATEGORIES:
                if category == mu:
                    c[mu] = x - sum(c)
                    s[mu] = v * c[mu] + random_w

        return self.c_array, self.s_array
    
    def AKLin1(self):
        print('AKLin1')
        q, g, h = self.pub_key
        b_lin = [0] * self.d
        c_lin = [0] * self.d
        s_lin = [0] * self.d
        random_w = number.getRandomRange(1, q-1)
        common = sum([self.D**category for category in CATEGORIES]) * ((self.n - self.l) // self.d)
        for category in CATEGORIES:
            if category != self.data:
                total = common + self.l * self.D**category
                c_i = number.getRandomRange(1, q-1)
                s_i = number.getRandomRange(1, q-1)
                c_lin[category] = c_i
                s_lin[category] = s_i
                g_i_inv = pow(pow(g, total, q), -1, q)
                commitment = 1
                for y in self.y_array:
                    commitment = commitment * y % q
                deno = pow(commitment * g_i_inv % q, c_i, q)
                b_i = pow(h, s_i, q) * pow(deno, -1, q) % q
                b_lin[category] = b_i
            else:
                b_i = pow(h, random_w, q)
                b_lin[category] = b_i

        self.b_lin = b_lin
        self.s_lin = s_lin
        self.c_lin = c_lin
        self.random_w_lin = random_w
        return b_lin

    def AKLin3(self, ak_enc_bool_x_lin):
        print('AKLin3')
        q, g, h = self.pub_key
        for category in CATEGORIES:
            if category == self.data:
                self.c_lin[self.data] = ak_enc_bool_x_lin - sum(self.c_lin)
                v_sum = 0
                for v in self.v_array:
                    v_sum += v
                self.s_lin[self.data] = v_sum * self.c_lin[self.data] + self.random_w_lin
        return self.c_lin, self.s_lin
    
class Verifier:
    def __init__(self, d, l, n, D):
        self.d, self.l, self.n, self.D = d, l, n, D
    
    def setup(self, security):
        print('interviewr setup')
        q = number.getPrime(2 * security, Random.new().read)        
        g = number.getRandomRange(1, q-1)
        h = number.getRandomRange(1, q-1)
        
        self.q = q
        self.g = g
        self.h = h
        
        self.sigma = np.random.randint(0, self.n)
        print("sigma: ", self.sigma)
        
        self.pub_key = (q, g, h)
        
    def step1(self):
        print('step1')
        a = number.getRandomRange(1, self.q-1)
        b = number.getRandomRange(1, self.q-1)
        self.a = a
        self.b = b
        
        g_a = pow(self.g, a, self.q)
        g_b = pow(self.g, b, self.q)
        g_ab = pow(self.g, a * b - self.sigma + 1, self.q)

        return g_a, g_b, g_ab
        
    def step3(self, w_array, y_array):
        print('step3')
        v_sigma = pow(w_array[self.sigma], self.b, self.q)
        g_mu_sigma = y_array[self.sigma] * pow(pow(self.h, v_sigma, self.q), -1, self.q) % self.q
        secret_output = None
        for category in CATEGORIES:
            if pow(self.g, self.D**category, self.q) == g_mu_sigma:
                secret_output = category
        print("secret output: ", secret_output, "g^{mu_sigma}: ", g_mu_sigma)
        return secret_output
    
    def AKEncBool2(self, num):
        print('AKEncBool2')
        self.ak_enc_bool_x_array = []
        for _ in range(num):
            self.ak_enc_bool_x_array.append(number.getRandomRange(1, self.q-1))
        return self.ak_enc_bool_x_array

    def AKEncBool4(self, c_array, s_array, b_array, y_array):
        print('AKEncBool4')
        print("####### AKEncBool verification #######")
        for s,c,b,y,x in zip(s_array, c_array, b_array, y_array, self.ak_enc_bool_x_array):
            for i in CATEGORIES:
                if pow(self.h, s[i], self.q) != b[i] * pow(y * pow(pow(self.g, self.D**i, self.q), -1, self.q) % self.q, c[i], self.q) % self.q:
                    print("AKEncBool False1.")
                    return False
            if x != sum(c):
                print("AKEncBool False2.")
                return False
        print("AKEncBool OK.")
        return True 
    
    def AKLin2(self):
        print('AKLin2')
        self.ak_enc_bool_x_lin = 0
        self.ak_enc_bool_x_lin = number.getRandomRange(1, self.q-1)
        return self.ak_enc_bool_x_lin
    
    def AKLin4(self, s_lin, c_lin, b_lin, y_array):
        print('AKLin4')
        print("####### AKLin verification #######")
        common = sum([self.D**category for category in CATEGORIES]) * ((self.n - self.l) // self.d)
        commitment = 1
        for y in y_array:
            commitment = commitment * y % self.q
        for category in CATEGORIES:
            total = common + self.l * self.D**category
            if pow(self.h, s_lin[category], self.q) != b_lin[category] * pow(commitment * pow(pow(self.g, total, self.q), -1, self.q) % self.q, c_lin[category], self.q) % self.q:
                print("AKEncBool False1.")
                return False
        if self.ak_enc_bool_x_lin != sum(c_lin):
            print("AKEncBool False2.")
            return False
        print("AKEncBool OK.")
        return True 

In [5]:
CATEGORIES = list(range(0,5))

epsilon = 1.0
secret_input = 2
width = 1000
d, l, n, D = buildParams(epsilon, width)


receiver = Verifier(d, l, n, D)
receiver.setup(security=80)
pub_key = receiver.pub_key

sender = Prover(secret_input, d, l, n, D)
sender.setup()
sender.setPubKey(pub_key)

g_a, g_b, g_ab = receiver.step1()
w_array, y_array = sender.step2(g_a, g_b, g_ab)
x = receiver.step3(w_array, y_array)

b_array = sender.AKEncBool1()
ak_enc_bool_x_array = receiver.AKEncBool2(len(y_array))
c_array, s_array = sender.AKEncBool3(ak_enc_bool_x_array)
receiver.AKEncBool4(c_array, s_array, b_array, y_array)

b_lin = sender.AKLin1()
ak_enc_bool_x_lin = receiver.AKLin2()
c_lin, s_lin = sender.AKLin3(ak_enc_bool_x_lin)
receiver.AKLin4(s_lin, c_lin, b_lin, y_array)

original p= 0.40460967519168967
approximate p= 0.4
n:  25 l:  10 d: 5
interviewr setup
sigma:  19
receiver setup
secret input:  2
step1
step2
step3
secret output:  None g^{mu_sigma}:  390550628014242342745240443052835529196414001927
AKEncBool1
AKEncBool2
AKEncBool3
AKEncBool4
####### AKEncBool verification #######
AKEncBool OK.
AKLin1
AKLin2
AKLin3
AKLin4
####### AKLin verification #######
AKEncBool OK.


True