In [7]:
import numpy as np
import time

# DPF v2
https://docs.google.com/presentation/d/108V1QMm7ACiD_mr8ITniYfQUwURSafZep6UJaQASZ0k/edit?folder=0AIi9WbrMGj2CUk9PVA#slide=id.g7643f50265_0_0

In [8]:
λ = 63 # security paramater
n = 32 # nb of bits of the algebric field

In [9]:
# PRG
def G(seed):
    assert len(seed) == λ
    r = np.random.RandomState(seed)
    return r.randint(2, size=2*(λ + 1))

def Convert(bits):
    #TODO see figure 3 of paper
    return bits.dot(1 << np.arange(bits.size)[::-1])

In [10]:
def List(n):
    return [None]*n

def Array(*shape):
    return np.empty(shape, dtype=np.int32)

def bit_decomposition(x, nbits = n):
    return list(map(int, np.binary_repr(x, width=nbits)))

def randbit(size):
    return np.random.randint(2, size=size)

def xor(*args):
    """Multi-input xor"""
    if len(args) == 2:
        return np.bitwise_xor(*args)
    else:
        return np.bitwise_xor(args[0], xor(*args[1:]))
    
def concat(*args, **kwargs):
    return np.concatenate(args, **kwargs)

def split(l, idx):
    # Convert idx which are split part sizes to cumulative indices
    if isinstance(idx, (list, tuple)):
        cumsum = 0
        cum_idx = []
        for i in idx:
            cumsum += i
            cum_idx.append(cumsum)
        # Remove last element which equals the total length
        # And create an empty split part if kept
        return np.split(l, cum_idx[:-1])
    else:
        return np.split(l, idx)
    
def TruthTableDPF(s, α_i):
    Table = np.zeros((2, λ+1), dtype=np.int32)
    Table[α_i] = concat(s, [1])
    return Table.flatten()

In [14]:
def GenDPF(alpha, beta):
    α = bit_decomposition(alpha)
    s, t, CW = Array(n+1, 2, λ), Array(n+1, 2), Array(n, 2*(λ+1))
    s[0] = randbit(size=(2, λ))
    t[0] = [0, 1]
    for i in range(0, n):
        # Re-use useless randomness
        sL_0, _, sR_0, _ = split(G(s[i][0]), [λ, 1, λ, 1])
        sL_1, _, sR_1, _ = split(G(s[i][1]), [λ, 1, λ, 1])
        s_rand = xor(sL_0, sL_1) if α[i] else xor(sR_0, sR_1) 
        
        cw_i = TruthTableDPF(s_rand, α[i])
        CW[i] = xor(cw_i, G(s[i][0]), G(s[i][1]))
        
        for b in (0, 1):
            τ = xor(G(s[i][b]), t[i][b] * CW[i])
            τ = τ.reshape(2, λ+1)
            *s[i+1][b], t[i+1][b] = τ[𝛼[i]]
        
    CW_n = [(-1)**t[n][1]*(beta - Convert(s[n][0]) + Convert(s[n][1]))]
    
    k = [(s[0][b], *CW, CW_n) for b in (0, 1)]
    return k
        

In [15]:
def EvalDPF(b, k_b, x):
    x = bit_decomposition(x)
    s, t = Array(n+1, λ), Array(n+1, 1)
    s[0], *CW = k_b
    t[0] = b
    for i in range(0, n):
        τ = xor(G(s[i]), t[i]*CW[i])
        τ = τ.reshape(2, λ+1)
        *s[i+1], t[i+1] = τ[x[i]]
    return (-1)**b * (Convert(s[n]) + t[n]*CW[n])
        

In [16]:
def test_DPF():
    t = time.time()
    n_alpha, n_x = 10, 10
    beta = 2
    for alpha in np.random.randint(0,2**n, n_alpha):
        k = GenDPF(alpha, beta)
        t2 = time.time()
        for x in concat(np.random.randint(0,2**n, n_x), [alpha]):
            y0 = EvalDPF(0, k[0], x)
            y1 = EvalDPF(1, k[1], x)
            
            if x == alpha:
                assert y0+y1 == beta
            else:
                assert y0+y1 == 0
        print(round(1000*(time.time() - t2)/((n_x+1) * 2), 1), 'ms / Eval')
    print(round(1000*(time.time() - t)/(n_alpha*n_x), 1), 'ms')
                
test_DPF()

1.5 ms / Eval
1.2 ms / Eval
1.1 ms / Eval
1.1 ms / Eval
1.1 ms / Eval
1.1 ms / Eval
1.2 ms / Eval
1.1 ms / Eval
1.2 ms / Eval
1.1 ms / Eval
3.3 ms


# DIF v3 simplified
We address here the problem $x \le \alpha$

https://docs.google.com/presentation/d/108V1QMm7ACiD_mr8ITniYfQUwURSafZep6UJaQASZ0k/edit?folder=0AIi9WbrMGj2CUk9PVA#slide=id.g6d24e8b0e8_0_1223

In [8]:
def G(seed):
    assert len(seed) == λ
    r = np.random.RandomState(seed)
    return r.randint(2, size=2 + 2*(λ + 1))
    
    
def TruthTableDIF(s, α_i):
    leafTable = np.zeros((2, 1), dtype=np.int32)
    # if α_i is 0, then ending on the leaf branch means your bit is 1 to you're > α so you should get 0
    # if α_i is 1, then ending on the leaf branch means your bit is 0 to you're < α so you should get 1
    leaf_value = α_i
    leafTable[1-α_i] = leaf_value
    
    nextTable = np.zeros((2, λ+1), dtype=np.int32)
    nextTable[α_i] = concat(s, [1])
    
    #return concat(leafTable.flatten(), nextTable.flatten())
    return concat(leafTable, nextTable, axis=1).flatten()

In [9]:
def GenDIF(alpha):
    α = bit_decomposition(alpha, nbits=n)
    s, t, CW = Array(n+1, 2, λ), Array(n+1, 2), Array(n, 2 + 2*(λ+1))
    s[0] = randbit(size=(2, λ))
    t[0] = [0, 1]
    for i in range(0, n):
        # Re-use useless randomness
        _, _, sL_0, _, sR_0, _ = split(G(s[i][0]), [1, 1, λ, 1, λ, 1])
        _, _, sL_1, _, sR_1, _ = split(G(s[i][1]), [1, 1, λ, 1, λ, 1])
        s_rand = xor(sL_0, sL_1) if α[i] else xor(sR_0, sR_1)
        cw_i = TruthTableDIF(s_rand, α[i])
        CW[i] = xor(cw_i, G(s[i][0]), G(s[i][1]))
        
        for b in (0, 1):
            τ = xor(G(s[i][b]), t[i][b] * CW[i])
            τ = τ.reshape(2, λ+2)
            σ_leaf, *s[i+1][b], t[i+1][b] = τ[𝛼[i]]
        
    k = [(s[0][b], CW) for b in (0, 1)]
    return k

In [10]:
def EvalDIF(b, k_b, x, n):
    FnOutput = Array(n+1, 1)
    x = bit_decomposition(x, nbits=n)
    s, t = Array(n+1, λ), Array(n+1, 1)
    s[0], CW = k_b
    t[0] = b
    for i in range(0, n):
        τ = xor(G(s[i]), t[i]*CW[i])
        τ = τ.reshape(2, λ+2)
        σ_leaf, *s[i+1], t[i+1] = τ[x[i]]
        FnOutput[i] = σ_leaf

    # Last tour, the other σ is also a leaf:
    FnOutput[n] = t[n]
    return FnOutput.sum() % 2

In [11]:
def test_DIF_simplified():
    t = time.time()
    n_alpha, n_x = 10, 10
    for alpha in np.random.randint(0,2**n, n_alpha):
        k = GenDIF(alpha)
        print('Test x <= α', alpha)
        t2 = time.time()
        for x in np.random.randint(0,2**n, n_x):
            sigma = xor(EvalDIF(0, k[0], x, n=n), EvalDIF(1, k[1], x, n=n))
            #print(f'x={x}', sigma)
            assert int(x<=alpha) == sigma
        print(round(1000*(time.time() - t2)/(n_x * 2), 1), 'ms / Eval')
    print(round(1000*(time.time() - t)/(n_alpha*n_x), 1), 'ms')

test_DIF_simplified()

Test x <= α 3696131652
1.2 ms / Eval
Test x <= α 1914295028
1.3 ms / Eval
Test x <= α 405651634
1.2 ms / Eval
Test x <= α 3477612570
1.3 ms / Eval
Test x <= α 2784886718
1.3 ms / Eval
Test x <= α 3881626632
1.2 ms / Eval
Test x <= α 131850509
1.2 ms / Eval
Test x <= α 3345301304
1.2 ms / Eval
Test x <= α 1516860411
1.3 ms / Eval
Test x <= α 3518191607
1.3 ms / Eval
3.3 ms


# PyTorch

In [37]:
import syft as sy
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import torch.autograd as autograd
import numpy as np

syft = sy 

hook = sy.TorchHook(th)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
james = sy.VirtualWorker(hook, id="james")
crypto_provider = james
torch = th



## Equal zero

In [65]:
def G(seed):
    assert len(seed) == λ
    r = np.random.RandomState(seed)
    return r.randint(2, size=2*(λ + 1))

In [66]:
field = 2**8
field_bounds = -2**7, 2**7-1
print('range', field_bounds)

def mod(x):
    if field_bounds[0] > x: 
        x += field
    elif x > field_bounds[1]:
        x -= field
    return x

def share(x):
    s = torch.randint(*field_bounds, x.shape)
    return mod(x-s).long(), s

range (-128, 127)


In [67]:
x = th.tensor([0.])
s = torch.randint(*field_bounds, (1,))
x_priv = share(x.fix_prec(base=2, precision_fractional=3, field=field).child.child)
x_priv, sum(x_priv)

((tensor([-13]), tensor([13])), tensor([0]))

In [68]:
def keygen_equal_zero():
    r_in = torch.randint(*field_bounds, (1,)).numpy()[0]
    print('r_in', r_in)
    k = GenDPF(r_in, 1)
    return share(r_in), k

In [71]:
def equal_zero(x_priv):
    r, k = keygen_equal_zero()
    share_0 = x_priv[0] + r[0]
    share_1 = x_priv[1] + r[1]
    x_pub = mod(share_0 + share_1).numpy()[0]
    res_0 = EvalDPF(0, k[0], x_pub)
    res_1 = EvalDPF(1, k[1], x_pub)
    return res_0, res_1

In [72]:
# Disclose value
res_0, res_1 = equal_zero(x_priv)
print(res_0 + res_1)

r_in -115
[1]


## Comparison

In [36]:
def G(seed):
    assert len(seed) == λ
    r = np.random.RandomState(seed)
    return r.randint(2, size=2 + 2*(λ + 1))