In [1]:
import syft as sy
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import torch.autograd as autograd
import numpy as np

syft = sy 

hook = sy.TorchHook(th)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
james = sy.VirtualWorker(hook, id="james")
crypto_provider = james
torch = th

In [2]:
import re

In [4]:
s = "hgkg(AAA|BBB|CCC)"
re.search(r'\((.*?)\)',s).group(1)

'AAA|BBB|CCC'

In [26]:
s = 'BBB.CCC'

In [27]:
column_pattern = re.compile(r'(?:\w+\.)?(\w+)\.\w+')
re.search(column_pattern, s).group(1)

'BBB'

In [None]:
def get_table_name(name):
    """
    Extract the table_name from a column specification
    Example:
        Case 1
            OWNER.TABLE.COLUMN -> OWNER.TABLE
        Case 2
            TABLE.COLUMN -> TABLE
    """
    elems = name.split(".")
    if len(elems) == 3:
        return ".".join(name.split(".")[:2])
    elif len(elems) == 2:
        return name.split(".")[0]
    else:
        return None

# DPF

In [6]:
λ = 8 # security paramater
n = 32 # nb of bits of the algebric field

In [7]:
# PRG
def G(seed):
    assert len(seed) == λ
    np.random.seed(seed)
    return np.random.randint(2, size=2*(λ + 1))

def Convert(bits):
    #TODO see figure 3
    return bits.dot(1 << np.arange(bits.size)[::-1])

In [8]:
def List(n):
    return [None]*n

def bit_decomposition(x, nbits = n):
    return list(map(int, np.binary_repr(x, width=nbits)))

def randbit(size):
    return np.random.randint(2, size=size)

def xor(*args):
    """Multi-input xor"""
    if len(args) == 2:
        return np.bitwise_xor(*args)
    else:
        return np.bitwise_xor(args[0], xor(*args[1:]))
    
def concat(*args):
    return np.concatenate(args)

def split(l, idx):
    # Convert idx which are split part sizes to cumulative indices
    if isinstance(idx, (list, tuple)):
        cumsum = 0
        cum_idx = []
        for i in idx:
            cumsum += i
            cum_idx.append(cumsum)
        # Remove last element which equals the total length
        # And create an empty split part if kept
        return np.split(l, cum_idx[:-1])
    else:
        return np.split(l, idx)

In [9]:
def Gen(alpha, beta):
    alpha = bit_decomposition(alpha)
    s, t, CW = List(n+1), List(n+1), List(n+1)
    s[0] = split(randbit(size=2 * λ), [λ, λ])
    t[0] = [0, 1]
    for i in range(1, n + 1):
        sL_0, tL_0, sR_0, tR_0 = split(G(s[i-1][0]), [λ, 1, λ, 1])
        sL_1, tL_1, sR_1, tR_1 = split(G(s[i-1][1]), [λ, 1, λ, 1])
        sL, tL, sR, tR = [sL_0, sL_1], [tL_0, tL_1], [sR_0, sR_1], [tR_0, tR_1]
        
        if alpha[i-1] == 0: # Keep, Lose = L, R
            sKeep, tKeep, sLose, tLose = sL, tL, sR, tR
        else: # Keep, Lose = R, L
            sLose, tLose, sKeep, tKeep = sL, tL, sR, tR
            
        s_CW = xor(sLose[0], sLose[1])
        tL_CW = xor(tL[0], tL[1], alpha[i-1], 1)
        tR_CW = xor(tR[0], tR[1], alpha[i-1])
        tKeep_CW = tR_CW if alpha[i-1] else tL_CW
        
        CW[i-1] = concat(s_CW, tL_CW, tR_CW)
        s[i] = [xor(sKeep[b], t[i-1][b] * s_CW) for b in (0,1)]
        t[i] = [xor(tKeep[b], t[i-1][b] * tKeep_CW) for b in (0,1)]
        
    CW[n] = (-1)**t[n][1]*(beta - Convert(s[n][0]) + Convert(s[n][1]))
    
    k = [concat(s[0][b], *CW) for b in (0, 1)]
    return k
        

In [10]:
def Eval(b, k_b, x):
    x = bit_decomposition(x)
    s, t, τ = List(n+1), List(n+1), List(n+1)
    s[0], *CW = split(k_b, [λ, *[λ+2]*n, 1])
    t[0] = b
    for i in range(1, n+1):
        s_CW, tL_CW, tR_CW = split(CW[i-1], [λ, 1, 1])
        τ[i] = xor(G(s[i-1]), t[i-1]*concat(s_CW, tL_CW, s_CW, tR_CW))
        sL, tL, sR, tR = split(τ[i], [λ, 1, λ, 1])
        if x[i-1] == 0:
            s[i], t[i] = sL, tL
        else:
            s[i], t[i] = sR, tR
    return (-1)**b * (Convert(s[n]) + t[n]*CW[n])
        

In [11]:
def test_DPF():
    alpha = [2, 7, -1]
    beta = [1, -3, 2]
    for a, b in zip(alpha, beta):
        k = Gen(a, b)
        for x in [a, 2*a + 1, -a + 1]:
            y0 = Eval(0, k[0], x)
            y1 = Eval(1, k[1], x)
            
            if x == a:
                assert y0+y1 == b
            else:
                assert y0+y1 == 0
                
test_DPF()

# DIF simplified
We address here the problem $x \le \alpha$

In [97]:
n = 3

In [163]:
def G(seed):
    assert len(seed) == λ
    np.random.seed(seed)
    return np.random.randint(2, size=2*2*(λ + 1))

def TruthTable(s, node):
    (ConditionValueForLeaf, LeafValue), ConditionValueForNextNode = node
    index1, value1 = ConditionValueForLeaf, LeafValue
    leafTable = np.zeros((2, λ+1), dtype="int64")
    leafTable[index1, :] = concat(s, [value1])
    if isinstance(ConditionValueForNextNode, tuple): # Terminal Node
        ConditionValueForLeaf2, LeafValue2 = ConditionValueForNextNode
        index2, value2 = ConditionValueForLeaf2, LeafValue2 
    else:
        index2, value2 = ConditionValueForNextNode, 1
    nextTable = np.zeros((2, λ+1), dtype="int64")
    nextTable[index2, :] = concat(s, [value2])
    
    return np.array([leafTable, nextTable]).flatten()


def find_input(T, m):
    val = 0
    for i in range(m):
        
        val += 2**(m-i-1)*T[i][1]
    return val
        
def Gen(T):
    s, t, cw, CW = List(n+1), List(n+1), List(n+1), List(n+1)
    s[0] = split(randbit(size=2 * λ), [λ, λ])
    next_s = randbit(size=λ)
    cw[0] = TruthTable(next_s, T[0])
    #CW[0] = (-1)*(cw[0] - G(s[0][0]) + G(s[0][1]))
    CW[0] = xor(cw[0], G(s[0][0]), G(s[0][1]))
    k = [concat(s[0][b], CW[0]) for b in (0, 1)]
    for i in range(1, n):
        x_i = find_input(T, i)
        s[i], t[i] = List(2), List(2)
        for b in (0, 1):
            σ, _ = Eval(b, k[b], x_i, n=i)
            s[i][b], t[i][b] = split(σ, [λ, 1])
        next_s = randbit(size=λ)
        cw[i] = TruthTable(next_s, T[i])
        #CW[i] = (-1)**t[i][1]*(cw[i] - G(s[i][0]) + G(s[i][1]))
        CW[i] = xor(cw[i], G(s[i][0]), G(s[i][1]))
        k = [concat(s[0][b], *CW[:i+1]) for b in (0, 1)]
    return k
    
    

In [170]:
def Eval(b, k_b, x, n=3):
    #print("Eval", f"n={n}")
    FnOutput = List(n+1)
    σ = None
    x = bit_decomposition(x, nbits=n)
    s, t, τ = List(n+1), List(n+1), List(n+1)
    s[0], *CW = split(k_b, [λ, *[(λ+1)*2*2]*n])
    t[0] = b
    for i in range(1, n+1):
        #print('i in Eval', i, "x bit", x[i-1])
        τ[i] = xor(G(s[i-1]), t[i-1]*CW[i-1])
        τ[i] = τ[i].reshape(2, 2, λ+1)
        σ_leaf, σ_i = τ[i]
        FnOutput[i-1] = σ_leaf[x[i-1]][-1]
        σ = σ_i[x[i-1]]
        
        s[i], t[i] = split(σ, [λ, 1])

    # Last tour, the other σ is also a leaf:
    FnOutput[n] = σ[-1]
    return σ, FnOutput

In [188]:
# [Node_1, Node_n]
# Node_i = ((ConditionValueForLeaf, LeafValue), ConditionValueForNextNode)
# Node_i = ((ConditionValueForLeaf, LeafValue), (ConditionValueForLeaf, LeafValue))

def test_DIF_simplified():
    trees = {
        2: [((1, 0), 0), ((0, 1), 1), ((1, 0), (0, 1))],
        5: [((0, 1), 1), ((1, 0), 0), ((1, 1), (0, 1))],
        6: [((0, 1), 1), ((0, 1), 1), ((1, 0), (0, 1))],
    }
    for alpha, T in trees.items():
        k = Gen(T)
        print('Test x <= α', alpha)
        for x in range(0, 2**3):
            sigma = sum(xor(Eval(0, k[0], x, n=3)[1], Eval(1, k[1], x, n=3)[1]))
            print(f'x={x}', sigma)
            assert int(x<=alpha) == sigma

test_DIF_simplified()

Test x <= α 2
x=0 1
x=1 1
x=2 1
x=3 0
x=4 0
x=5 0
x=6 0
x=7 0
Test x <= α 5
x=0 1
x=1 1
x=2 1
x=3 1
x=4 1
x=5 1
x=6 0
x=7 0
Test x <= α 6
x=0 1
x=1 1
x=2 1
x=3 1
x=4 1
x=5 1
x=6 1
x=7 0


# DIF simplified v2
We address here the problem $x \le \alpha$

In [190]:
n = 3

In [224]:
def G(seed):
    assert len(seed) == λ
    np.random.seed(seed)
    return np.random.randint(2, size=2*2*(λ + 1))

def TruthTable(s, node):
    (ConditionValueForLeaf, LeafValue), ConditionValueForNextNode = node
    index1, value1 = ConditionValueForLeaf, LeafValue
    leafTable = np.zeros((2, λ+1), dtype="int64")
    leafTable[index1, :] = concat(s, [value1])
    if isinstance(ConditionValueForNextNode, tuple): # Terminal Node
        ConditionValueForLeaf2, LeafValue2 = ConditionValueForNextNode
        index2, value2 = ConditionValueForLeaf2, LeafValue2 
    else:
        index2, value2 = ConditionValueForNextNode, 1
    nextTable = np.zeros((2, λ+1), dtype="int64")
    nextTable[index2, :] = concat(s, [value2])
    
    return np.array([leafTable, nextTable]).flatten()
        
def Gen(T):
    s, t, cw, CW = List(n+1), List(n+1), List(n+1), List(n)
    s[0] = split(randbit(size=2 * λ), [λ, λ])
    t[0] = (0, 1)
    
    for i in range(0, n):
        𝛼_i = T[i][1] if i != (n-1) else T[i][1][0]
        next_s = randbit(size=λ)
        cw[i] = TruthTable(next_s, T[i])
        CW[i] = xor(cw[i], G(s[i][0]), G(s[i][1]))
        s[i+1], t[i+1] = List(2), List(2)
        for b in (0, 1):
            τ = xor(G(s[i][b]), t[i][b] * CW[i])
            τ = τ.reshape(2, 2, λ+1)
            σ_leaf, σ_node = τ
            s[i+1][b], t[i+1][b] = split(σ_node[𝛼_i], [λ, 1])
    
    k = [concat(s[0][b], *CW) for b in (0, 1)]
    return k
    
    

In [225]:
def Eval(b, k_b, x, n=3):
    #print("Eval", f"n={n}")
    FnOutput = List(n+1)
    σ = None
    x = bit_decomposition(x, nbits=n)
    s, t, τ = List(n+1), List(n+1), List(n+1)
    s[0], *CW = split(k_b, [λ, *[(λ+1)*2*2]*n])
    t[0] = b
    for i in range(1, n+1):
        #print('i in Eval', i, "x bit", x[i-1])
        τ[i] = xor(G(s[i-1]), t[i-1]*CW[i-1])
        τ[i] = τ[i].reshape(2, 2, λ+1)
        σ_leaf, σ_i = τ[i]
        FnOutput[i-1] = σ_leaf[x[i-1]][-1]
        σ = σ_i[x[i-1]]
        s[i], t[i] = split(σ, [λ, 1])

    # Last tour, the other σ is also a leaf:
    FnOutput[n] = σ[-1]
    return σ, FnOutput

In [223]:
# [Node_1, Node_n]
# Node_i = ((ConditionValueForLeaf, LeafValue), ConditionValueForNextNode)
# Node_i = ((ConditionValueForLeaf, LeafValue), (ConditionValueForLeaf, LeafValue))

def test_DIF_simplified():
    trees = {
        2: [((1, 0), 0), ((0, 1), 1), ((1, 0), (0, 1))],
        5: [((0, 1), 1), ((1, 0), 0), ((1, 1), (0, 1))],
        6: [((0, 1), 1), ((0, 1), 1), ((1, 0), (0, 1))],
    }
    for alpha, T in trees.items():
        k = Gen(T)
        print('Test x <= α', alpha)
        for x in range(0, 2**3):
            sigma = sum(xor(Eval(0, k[0], x, n=3)[1], Eval(1, k[1], x, n=3)[1]))
            print(f'x={x}', sigma)
            assert int(x<=alpha) == sigma

test_DIF_simplified()

bit 0
bit 1
bit 0
Test x <= α 2
x=0 1
x=1 1
x=2 1
x=3 0
x=4 0
x=5 0
x=6 0
x=7 0
bit 1
bit 0
bit 0
Test x <= α 5
x=0 1
x=1 1
x=2 1
x=3 1
x=4 1
x=5 1
x=6 0
x=7 0
bit 1
bit 1
bit 0
Test x <= α 6
x=0 1
x=1 1
x=2 1
x=3 1
x=4 1
x=5 1
x=6 1
x=7 0


# DIF
Hypothesis: $l_{max} = 2; \forall u, \deg(u) = 2, l_u = 2 $

In [8]:
l_max = 2

In [11]:
# PRG
def G(seed):
    assert len(seed) == λ
    np.random.seed(seed)
    return np.random.randint(2, size=(λ + 1)*l_max**2)

def Convert(bits):
    #print(len(bits), bits)
    return bits.dot(1 << np.arange(bits.size)[::-1])

In [None]:
def Gen(tree):
    CW
    for v in tree.find_nodes_with_all_siblings_leaf():

In [41]:
Decision = {}

class Tree:
    def __init__(self, root=None):
        self.root = root
        self.vertices = [] # List(Node)
        self.edges = [] #List(Edge)
        self.var = {} #dict(Node -> int)
        
    def add_node(self, node):
        self.vertices.append(node)
        if isinstance(node, Internal):
            self.var[node] = node.index
            
    def find_nodes_with_all_siblings_leaf(self):
        nodes = []
        for vertice in self.vertices:
            if self.root != vertice:
                all_leaf = True
                for sibling in vertice.siblings():
                    all_leaf = all_leaf and sibling.is_leaf
                if all_leaf:
                    nodes.append(vertice)
        return nodes
    
    def __str__(self):
        return 'TREE\n' + self.root.__str__()

class Node:
    def deg(self):
        raise NotImplementedError
        
    def siblings(self):
        return list(self.parent.child.values())
    
    def get_input(self):
        if self.parent is None:
            return []
        return self.parent.get_input() + [Decision[self.parent][self]]
        
        
class Internal(Node):
    is_leaf = False
    def __init__(self, index):
        super().__init__()
        self.id = np.random.randint(100, 999)
        self.index = index
        self.child = {}
        self.parent = None
    
    def deg(self):
        return len(self.child.values())
    
    def __str__(self):
        r = f"Internal {self.index}"
        if self.deg() != 2:
            raise ValueError
        else:
            for val, node in self.child.items():
                if isinstance(node, Leaf):
                    r += f" [{val}] --> " 
                    r += str(node)
                    r += "\n\t"
                
            for val, node in self.child.items():
                if isinstance(node, Internal):
                    r += f"[{val}]\n\t |\n" 
                    r += str(node)
        return r
                
class Leaf(Node):
    is_leaf = True
    
    def __init__(self, label):
        self.id = np.random.randint(100, 999)
        self.label = label
        self.parent = None
    
    def deg(self):
        return 0
    
    def __str__(self):
        return f"Leaf{self.id} ({self.label})"
    
    __repr__ = __str__
    
def connect(node1, node2, decision):
    node1.child[decision] = node2
    node2.parent = node1
    if node1 not in Decision:
        Decision[node1] = {}
    Decision[node1][node2] = decision
    

Build a tree (here: OR)

In [42]:
tree = Tree()
last_node = None
for i in range(n):
    node = Internal(i)
    leaf = Leaf(1)
    tree.add_node(node)
    tree.add_node(leaf)
    connect(node, leaf, 1)
    if i == 0:
        tree.root = node
    else:
        connect(last_node, node, 0)
            
    last_node = node
    
leaf = Leaf(0)
tree.add_node(leaf)
connect(last_node, leaf, 0)

In [43]:
print(tree)

TREE
Internal 0 [1] --> Leaf337 (1)
	[0]
	 |
Internal 1 [1] --> Leaf391 (1)
	[0]
	 |
Internal 2 [1] --> Leaf154 (1)
	[0]
	 |
Internal 3 [1] --> Leaf745 (1)
	[0]
	 |
Internal 4 [1] --> Leaf984 (1)
	[0]
	 |
Internal 5 [1] --> Leaf226 (1)
	[0]
	 |
Internal 6 [1] --> Leaf636 (1)
	[0]
	 |
Internal 7 [1] --> Leaf751 (1)
	[0]
	 |
Internal 8 [1] --> Leaf275 (1)
	[0]
	 |
Internal 9 [1] --> Leaf799 (1)
	[0]
	 |
Internal 10 [1] --> Leaf294 (1)
	[0]
	 |
Internal 11 [1] --> Leaf784 (1)
	[0]
	 |
Internal 12 [1] --> Leaf917 (1)
	[0]
	 |
Internal 13 [1] --> Leaf345 (1)
	[0]
	 |
Internal 14 [1] --> Leaf778 (1)
	[0]
	 |
Internal 15 [1] --> Leaf529 (1)
	[0]
	 |
Internal 16 [1] --> Leaf978 (1)
	[0]
	 |
Internal 17 [1] --> Leaf327 (1)
	[0]
	 |
Internal 18 [1] --> Leaf593 (1)
	[0]
	 |
Internal 19 [1] --> Leaf291 (1)
	[0]
	 |
Internal 20 [1] --> Leaf555 (1)
	[0]
	 |
Internal 21 [1] --> Leaf400 (1)
	[0]
	 |
Internal 22 [1] --> Leaf576 (1)
	[0]
	 |
Internal 23 [1] --> Leaf806 (1)
	[0]
	 |
Internal 24 [1] --> L

In [45]:
tree.find_nodes_with_all_siblings_leaf()

[Leaf637 (1), Leaf155 (0)]

Example evaluation in clear

In [266]:
x = randbit(n)

u = tree.root
while u.is_leaf == False:
    i = tree.var[u]
    u = u.child[x[i]]
print(u.label)

1


In [145]:
x = 5
y0 = Eval(0, k[0], x)
y1 = Eval(1, k[1], x)
print(y0, y1, y0+y1)

[155] [-154] [1]


In [91]:
n*(λ+2) + λ

328

In [22]:
list(range(1, 10 + 1))

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [26]:
b = np.array([0, 0, 0, 0, 0, 1, 1, 0, 1])
b.dot(1 << np.arange(b.size)[::-1])

13

In [39]:
split([1, 1, 1, 0, 0, 1, 1], [3, 2, 3])

[array([1, 1, 1]), array([], dtype=int64), array([1]), array([0, 0, 1, 1])]

In [17]:
np.concatenate(([1, 1, 1, 0, 0, 1, 1, 0], [1, 1, 1, 0, 0, 1, 1, 0]))

array([1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0])