In [191]:
import syft as sy
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import torch.autograd as autograd
import numpy as np

syft = sy 

hook = sy.TorchHook(th)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
james = sy.VirtualWorker(hook, id="james")
crypto_provider = james
torch = th



# DPF v2
https://docs.google.com/presentation/d/108V1QMm7ACiD_mr8ITniYfQUwURSafZep6UJaQASZ0k/edit?folder=0AIi9WbrMGj2CUk9PVA#slide=id.g7643f50265_0_0

In [192]:
λ = 16 # security paramater
n = 32 # nb of bits of the algebric field

In [193]:
# PRG
def G(seed):
    assert len(seed) == λ
    r = np.random.RandomState(seed)
    return r.randint(2, size=2*(λ + 1))

def Convert(bits):
    #TODO see figure 3 of paper
    return bits.dot(1 << np.arange(bits.size)[::-1])

In [202]:
def List(n):
    return [None]*n #np.empty(n, dtype=np.int32)

def bit_decomposition(x, nbits = n):
    return list(map(int, np.binary_repr(x, width=nbits)))

def randbit(size):
    return np.random.randint(2, size=size)

def xor(*args):
    """Multi-input xor"""
    if len(args) == 2:
        return np.bitwise_xor(*args)
    else:
        return np.bitwise_xor(args[0], xor(*args[1:]))
    
def concat(*args):
    return np.concatenate(args)

def split(l, idx):
    # Convert idx which are split part sizes to cumulative indices
    if isinstance(idx, (list, tuple)):
        cumsum = 0
        cum_idx = []
        for i in idx:
            cumsum += i
            cum_idx.append(cumsum)
        # Remove last element which equals the total length
        # And create an empty split part if kept
        return np.split(l, cum_idx[:-1])
    else:
        return np.split(l, idx)
    
def TruthTableDPF(s, α_i):
    Table = np.zeros((2, λ+1), dtype="int64")
    Table[α_i] = concat(s, [1])
    return Table.flatten()

In [203]:
def Gen(alpha, beta):
    α = bit_decomposition(alpha)
    s, t, CW = List(n+1), List(n+1), List(n+1)
    s[0] = split(randbit(size=2 * λ), [λ, λ])
    t[0] = [0, 1]
    for i in range(0, n):
        # Re-use useless randomness
        sL_0, _, sR_0, _ = split(G(s[i][0]), [λ, 1, λ, 1])
        sL_1, _, sR_1, _ = split(G(s[i][1]), [λ, 1, λ, 1])
        s_rand = xor(sL_0, sL_1) if α[i] else xor(sR_0, sR_1) 
        #ERROR sL, _, sR, _ = split(G(s[i][(1 - α[i])]), [λ, 1, λ, 1])
        #ERROR s_rand = xor(sL, sR)
        cw_i = TruthTableDPF(s_rand, α[i])
        CW[i] = xor(cw_i, G(s[i][0]), G(s[i][1]))
        
        s[i+1], t[i+1] = List(2), List(2)
        for b in (0, 1):
            τ = xor(G(s[i][b]), t[i][b] * CW[i])
            τ = τ.reshape(2, λ+1)
            s[i+1][b], t[i+1][b] = split(τ[𝛼[i]], [λ, 1])
        
    CW[n] = (-1)**t[n][1]*(beta - Convert(s[n][0]) + Convert(s[n][1]))
    
    k = [concat(s[0][b], *CW) for b in (0, 1)]
    return k
        

In [204]:
def Eval(b, k_b, x):
    x = bit_decomposition(x)
    s, t = List(n+1), List(n+1)
    s[0], *CW = split(k_b, [λ, *[2*(λ+1)]*n, 1])
    t[0] = b
    for i in range(0, n):
        τ = xor(G(s[i]), t[i]*CW[i])
        τ = τ.reshape(2, λ+1)
        s[i+1], t[i+1] = split(τ[x[i]], [λ, 1])
    return (-1)**b * (Convert(s[n]) + t[n]*CW[n])
        

In [205]:
def test_DPF():
    alpha = [2, 7, -1]
    beta = [1, -3, 2]
    for a, b in zip(alpha, beta):
        k = Gen(a, b)
        for x in [a, 2*a + 1, -a + 1]:
            y0 = Eval(0, k[0], x)
            y1 = Eval(1, k[1], x)
            
            if x == a:
                assert y0+y1 == b
            else:
                assert y0+y1 == 0
                
test_DPF()

ValueError: setting an array element with a sequence.

# DIF v3 simplified
We address here the problem $x \le \alpha$

https://docs.google.com/presentation/d/108V1QMm7ACiD_mr8ITniYfQUwURSafZep6UJaQASZ0k/edit?folder=0AIi9WbrMGj2CUk9PVA#slide=id.g6d24e8b0e8_0_1223

In [114]:
n = 32

In [115]:
def G(seed):
    assert len(seed) == λ
    np.random.seed(seed)
    return np.random.randint(2, size=2 + 2*(λ + 1))
    
    
def TruthTableDIF(s, α_i):
    leafTable = np.zeros((2, 1), dtype="int64")
    # if α_i is 0, then ending on the leaf branch means your bit is 1 to you're > α so you should get 0
    # if α_i is 1, then ending on the leaf branch means your bit is 0 to you're < α so you should get 1
    leaf_value = α_i
    leafTable[1-α_i] = leaf_value
    
    nextTable = np.zeros((2, λ+1), dtype="int64")
    nextTable[α_i] = concat(s, [1])
    
    return concat(leafTable.flatten(), nextTable.flatten())

In [116]:
def Gen(alpha):
    α = bit_decomposition(alpha, nbits=n)
    s, t, CW = List(n+1), List(n+1), List(n)
    s[0] = split(randbit(size=2 * λ), [λ, λ])
    t[0] = [0, 1]
    for i in range(0, n):
        # Re-use useless randomness
        _, _, sL_0, _, sR_0, _ = split(G(s[i][0]), [1, 1, λ, 1, λ, 1])
        _, _, sL_1, _, sR_1, _ = split(G(s[i][1]), [1, 1, λ, 1, λ, 1])
        s_rand = xor(sL_0, sL_1) if α[i]==1 else xor(sR_0, sR_1)
        cw_i = TruthTableDIF(s_rand, α[i])
        CW[i] = xor(cw_i, G(s[i][0]), G(s[i][1]))
        
        s[i+1], t[i+1] = List(2), List(2)
        for b in (0, 1):
            τ = xor(G(s[i][b]), t[i][b] * CW[i])
            σ_leaf, σ_node = split(τ, [2, 2*(λ+1)])
            σ_node = σ_node.reshape(2, λ+1)
            s[i+1][b], t[i+1][b] = split(σ_node[𝛼[i]], [λ, 1])
        
    k = [concat(s[0][b], *CW) for b in (0, 1)]
    return k
        

In [119]:
def Eval(b, k_b, x, n):
    #print("Eval", f"n={n}")
    FnOutput = List(n+1)
    x = bit_decomposition(x, nbits=n)
    s, t = List(n+1), List(n+1)
    s[0], *CW = split(k_b, [λ, *[2 + 2*(λ + 1)]*n])
    t[0] = b
    for i in range(0, n):
        #print('i in Eval', i, "x bit", x[i-1])
        τ = xor(G(s[i]), t[i]*CW[i])
        #print(len(τ), 2*(λ+1) + 2)
        σ_leaf, σ_i = split(τ, [2, 2*(λ+1)])
        #print(σ_leaf.shape, σ_i.shape)
        σ_leaf, σ_i = σ_leaf.reshape(2, 1), σ_i.reshape(2, λ+1)
        FnOutput[i] = σ_leaf[x[i]][-1]
        σ = σ_i[x[i]]
        s[i+1], t[i+1] = split(σ, [λ, 1])

    # Last tour, the other σ is also a leaf:
    FnOutput[n] = σ[-1]
    return σ, FnOutput

In [135]:
def test_DIF_simplified():
    t = time.time()
    n_alpha, n_x = 10, 10
    for alpha in np.random.randint(0,2**n, n_alpha):
        k = Gen(alpha)
        print('Test x <= α', alpha)
        t2 = time.time()
        for x in np.random.randint(0,2**n, n_x):
            sigma = sum(xor(Eval(0, k[0], x, n=n)[1], Eval(1, k[1], x, n=n)[1]))
            #print(f'x={x}', sigma)
            assert int(x<=alpha) == sigma
        print(round(1000*(time.time() - t2)/(n_x * 2), 1), 'ms / Eval')
    print(round(1000*(time.time() - t)/(n_alpha*n_x), 1), 'ms')

test_DIF_simplified()

Test x <= α 868901555
1.4 ms / Eval
Test x <= α 2094728343
1.2 ms / Eval
Test x <= α 1331758589
1.2 ms / Eval
Test x <= α 2209524418
1.2 ms / Eval
Test x <= α 1067235160
1.1 ms / Eval
Test x <= α 937450
1.3 ms / Eval
Test x <= α 2481108934
1.2 ms / Eval
Test x <= α 2722300044
1.3 ms / Eval
Test x <= α 380752735
1.2 ms / Eval
Test x <= α 1329112794
1.2 ms / Eval
3.1 ms


In [130]:
np.random.randint(0,2, 20)

array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1])