In [1]:
import syft as sy
import torch
import time

syft = sy 

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
james = sy.VirtualWorker(hook, id="james")
crypto_provider = james

bit_decompose = syft.frameworks.torch.crypto.securenn.decompose

### Descriptive

0. Take an input

In [2]:
x_sh = torch.tensor([2.])\
    .fix_prec(precision_fractional=0)\
    .share(alice, bob, crypto_provider=james)

_Shares values are_

In [3]:
shares_values = [p.copy().get() for p in x_sh.child.child.child.values()]
shares_sum = shares_values[0] + shares_values[1]
print("Shares\t\t", *shares_values)
print("Real sum\t", shares_sum)
print("Sum modulo\t", shares_sum % 2**62)

Shares		 tensor([3682865790395216025]) tensor([928820228032171881])
Real sum	 tensor([4611686018427387906])
Sum modulo	 tensor([2])


_What should happen? Look at the last bit for each_

In [4]:
print(*map(lambda x: bit_decompose(x)[0][-1], shares_values))

tensor(1) tensor(0)


In [5]:
print(bit_decompose(shares_sum)[0][-1])

tensor(0)


As you see we need carry bits and wan't just sum the MSB

1. Decompose bitwise the shares

In [6]:
share_pointers = list(x_sh.child.child.child.values())
x_mpt = sy.MultiPointerTensor(children=share_pointers)

In [7]:
x_mpt

MultiPointerTensor>{'alice': [PointerTensor | me:70461010071 -> alice:11891612859], 'bob': [PointerTensor | me:63010515958 -> bob:10586763251]}

In [8]:
x_shares_bin = bit_decompose(x_mpt)

In [9]:
x_shares_bin.child['alice'].copy().get()

tensor([[1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0,
         0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1]])

2. Secret share these tensors

In [10]:
x_shares_bin

MultiPointerTensor>{'alice': [PointerTensor | me:71111319528 -> alice:58865810870], 'bob': [PointerTensor | me:63859771289 -> bob:59567078335]}

In [11]:
x_sh_1_bin, x_sh_2_bin = [
        share.share(alice, bob, crypto_provider=james, field=2).get().child
        for w, share in x_shares_bin.child.items()
    ]

In [12]:
x_sh_1_bin

[AdditiveSharingTensor]
	-> [PointerTensor | me:39695145315 -> alice:43764684627]
	-> [PointerTensor | me:68332488837 -> bob:22801229688]
	*crypto provider: james*

3. Compute carry bit

In [13]:
c_sh_bin = x_sh_1_bin * x_sh_2_bin

In [14]:
c_sh_bin.virtual_get()

tensor([[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

4. Shift carry bit

In [15]:
c_sh_bin_shifted = torch.roll(c_sh_bin, shifts=1)
c_sh_bin_shifted[:, 0] = 0

In [16]:
c_sh_bin_shifted.virtual_get()

tensor([[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

5. Sum binary shares

In [17]:
s_sh_bin = x_sh_1_bin + x_sh_2_bin

In [18]:
s_sh_bin.virtual_get()

tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

6. Go back to 3.

In [19]:
x_sh_1_bin, x_sh_2_bin = s_sh_bin, c_sh_bin_shifted

### Condensed

In [20]:
x_sh = torch.tensor([-9841.])\
    .fix_prec(precision_fractional=0)\
    .share(alice, bob, crypto_provider=james)

share_pointers = list(x_sh.child.child.child.values())
x_mpt = sy.MultiPointerTensor(children=share_pointers)

x_shares_bin = bit_decompose(x_mpt)

x_sh_1_bin, x_sh_2_bin = [
    share.share(alice, bob, crypto_provider=james, field=2).get().child
    for w, share in x_shares_bin.child.items()
]

r = x_sh_1_bin[0][-1] + x_sh_2_bin[0][-1]
print('Provisory result', r.virtual_get())

for i in range(62):
    
    c_sh_bin = x_sh_1_bin * x_sh_2_bin
    
    c_sh_bin_shifted = torch.roll(c_sh_bin, shifts=1)
    c_sh_bin_shifted[:, 0] = 0
    
    s_sh_bin = x_sh_1_bin + x_sh_2_bin
    
    print('Step', i)
    #print(s_sh_bin.virtual_get())
    #print(c_sh_bin_shifted.virtual_get())
    
    x_sh_1_bin, x_sh_2_bin = s_sh_bin, c_sh_bin_shifted
    
    r = r + c_sh_bin_shifted[0][-1]
    print('Provisory result', r.virtual_get())

Provisory result tensor(1)
Step 0
Provisory result tensor(1)
Step 1
Provisory result tensor(1)
Step 2
Provisory result tensor(1)
Step 3
Provisory result tensor(1)
Step 4
Provisory result tensor(1)
Step 5
Provisory result tensor(1)
Step 6
Provisory result tensor(1)
Step 7
Provisory result tensor(1)
Step 8
Provisory result tensor(1)
Step 9
Provisory result tensor(1)
Step 10
Provisory result tensor(1)
Step 11
Provisory result tensor(1)
Step 12
Provisory result tensor(1)
Step 13
Provisory result tensor(1)
Step 14
Provisory result tensor(1)
Step 15
Provisory result tensor(1)
Step 16
Provisory result tensor(1)
Step 17
Provisory result tensor(1)
Step 18
Provisory result tensor(1)
Step 19
Provisory result tensor(1)
Step 20
Provisory result tensor(1)
Step 21
Provisory result tensor(1)
Step 22
Provisory result tensor(1)
Step 23
Provisory result tensor(1)
Step 24
Provisory result tensor(1)
Step 25
Provisory result tensor(1)
Step 26
Provisory result tensor(1)
Step 27
Provisory result tensor(1)
Ste

### Benchmarked

In [21]:
Q_BITS = 32
shape = (5, )

In [22]:
def bit_decompose(tensor):
    """decompose a tensor into its binary representation."""
    n_bits = Q_BITS
    powers = torch.arange(n_bits)
    if hasattr(tensor, "child") and isinstance(tensor.child, dict):
        powers = powers.send(*list(tensor.child.keys()), **{"no_wrap": True})
    for i in range(len(tensor.shape)):
        powers = powers.unsqueeze(0)
    tensor = tensor.unsqueeze(-1)
    moduli = 2 ** powers
    tensor = torch.fmod((tensor / moduli.type_as(tensor)), 2)
    return tensor


In [30]:
def private_compare(x_sh):
    share_pointers = list(x_sh.child.child.child.values())
    x_mpt = sy.MultiPointerTensor(children=share_pointers)

    x_shares_bin = bit_decompose(x_mpt)

    x_sh_1_bin, x_sh_2_bin = [
        share.share(alice, bob, crypto_provider=james, field=2).get().child
        for w, share in x_shares_bin.child.items()
    ]

    r = x_sh_1_bin[:,-1] + x_sh_2_bin[:,-1]

    for i in range(Q_BITS):

        c_sh_bin = x_sh_1_bin * x_sh_2_bin

        c_sh_bin_shifted = c_sh_bin.roll(shifts=1)
        c_sh_bin_shifted[:, 0] = 0

        s_sh_bin = x_sh_1_bin + x_sh_2_bin

        x_sh_1_bin, x_sh_2_bin = s_sh_bin, c_sh_bin_shifted

        r = r + c_sh_bin_shifted[:,-1]
        
    return r

In [31]:
import cProfile
n_bit = 10
def foo():
    t = torch.zeros(*shape).uniform_(-2**(n_bit-1), 2**(n_bit-1))
    expected = (t<0).long()
    print(t)
    x_sh = t\
        .fix_prec(precision_fractional=0, field=2**Q_BITS)\
        .share(alice, bob, crypto_provider=james, field=2**Q_BITS)

    r = private_compare(x_sh)

cProfile.run('foo()')

tensor([  25.3318,  374.4645, -507.4072, -410.9270, -164.9139])
         463735 function calls (428012 primitive calls) in 0.395 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.003    0.003 <ipython-input-22-70c5daf652cb>:1(bit_decompose)
        1    0.000    0.000    0.392    0.392 <ipython-input-30-a1f18dfa6eac>:1(private_compare)
        1    0.000    0.000    0.002    0.002 <ipython-input-30-a1f18dfa6eac>:8(<listcomp>)
        1    0.000    0.000    0.395    0.395 <ipython-input-31-65ea90bdbf4d>:3(foo)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
     9478    0.002    0.000    0.002    0.000 __init__.py:123(is_storage)
     8108    0.011    0.000    0.018    0.000 __init__.py:40(packb)
        1    0.000    0.000    0.000    0.000 _tensor_str.py:132(width)
        5    0.000    0.000    0.000    0.000 _tensor_str.py:135(format)
        1    0.000    0.000    0.000 

In [32]:
n_bit = 10
for i in range(10):
    t = torch.zeros(*shape).uniform_(-2**(n_bit-1), 2**(n_bit-1))
    expected = (t<0).long()
    print(t)
    x_sh = t\
        .fix_prec(precision_fractional=0, field=2**Q_BITS)\
        .share(alice, bob, crypto_provider=james, field=2**Q_BITS)

    start_time = time.time()
    r = private_compare(x_sh)
    print(time.time()- start_time)
    
    assert((expected == r.virtual_get()).all())
    #print('Provisory result', r.virtual_get())

tensor([-173.3583,  505.4964,   74.8992,  490.5034,  -62.6939])
0.2968318462371826
tensor([ 65.2811,  29.9857, 116.0848, 355.8784,  49.4533])
0.2948291301727295
tensor([-101.7766, -177.4821,  439.6923, -492.1521,  472.7072])
0.31409573554992676
tensor([-215.1983,   55.0895,  261.0670, -483.7092, -388.7967])
0.30328798294067383
tensor([-447.0695,   26.9320, -121.6979,  188.3701,  295.1523])
0.2832529544830322
tensor([ 440.2696,  161.3704, -213.3281,   48.0425,  448.4467])
0.3066062927246094
tensor([ 167.0793,  101.8032, -215.4391,  337.0345,   30.5123])
0.3302731513977051
tensor([ 510.8940,   91.6204,  499.7314, -419.6425, -298.2939])
0.29282093048095703
tensor([ -17.3899, -408.5215, -101.3419,  128.4747, -469.0257])
0.33150506019592285
tensor([-108.1748, -359.3439, -424.9121,  -51.2743,  -73.5583])
0.3268759250640869


### Comparison

In [26]:
n_bit = 10
for i in range(10):
    t = torch.zeros(*shape).uniform_(-2**(n_bit-1), 2**(n_bit-1))
    expected = (t<0).long()
    print(t)
    x_sh = t\
        .fix_prec(precision_fractional=0, field=2**Q_BITS)\
        .share(alice, bob, crypto_provider=james, field=2**Q_BITS)

    start_time = time.time()
    r = (x_sh < 0)
    print(time.time()- start_time)
    
    assert((expected == r.get().float_prec().long()).all())

tensor([-469.4106,  231.9222, -102.0841,   94.2505,  238.5734])
0.09760904312133789
tensor([   5.4621, -371.8509,  121.0632, -387.8480,  -41.7303])
0.08513689041137695
tensor([-279.3990, -345.9605, -101.5350, -428.3494,  448.5637])
0.09500002861022949
tensor([ 369.0379, -129.9125,  295.7458,  228.8887, -424.2844])
0.09510016441345215
tensor([ 185.3099, -482.8797, -179.1334,  -74.4338,   90.0060])
0.09580326080322266
tensor([ 421.9994, -425.0513, -427.8267,   -8.6971,  -69.3450])
0.08884692192077637
tensor([-84.9066, 437.8497, 341.8250, 491.7605,  72.7029])
0.09327578544616699
tensor([-259.6836,  243.7709, -324.7463,  -41.8763, -388.9908])
0.08305621147155762
tensor([ 243.3921, -156.6104,  288.5427,  102.8445,  478.0731])
0.08210396766662598
tensor([-505.5271,   27.2717, -160.4247, -172.6028, -126.6665])
0.0785226821899414


# Optimized version

In [38]:
import math

In [47]:
torch.concat

AttributeError: module 'torch' has no attribute 'concat'

In [59]:
blocks = torch.stack(torch.split(torch.ones(12), 3))
carries = torch.zeros(4, 1)
print(carries.shape)
print(blocks.shape)
torch.cat((carries, blocks), dim=1)

torch.Size([4, 1])
torch.Size([4, 3])


tensor([[0., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 1., 1., 1.]])

In [36]:
def private_compare_optim(x_sh):
    share_pointers = list(x_sh.child.child.child.values())
    x_mpt = sy.MultiPointerTensor(children=share_pointers)

    x_shares_bin = bit_decompose(x_mpt)

    x_sh_1_bin, x_sh_2_bin = [
        share.share(alice, bob, crypto_provider=james, field=2).get().child
        for w, share in x_shares_bin.child.items()
    ]
    
    # l := nb of bits
    l = x_shares_bin.shape[-1]
    # k := size of blocks
    k = math.ceil(math.sqrt(l / 2))
    
    assert l == 32 and k == 4
    
    x_sh_1_bin_blocks = torch.split(x_sh_1_bin, k)
    x_sh_2_bin_blocks = torch.split(x_sh_2_bin, k)
    
    x_sh_1_bin_blocks = torch.stack(x_sh_1_bin_blocks)
    x_sh_2_bin_blocks = torch.stack(x_sh_2_bin_blocks)
    

    r = x_sh_1_bin[:,-1] + x_sh_2_bin[:,-1]

    for i in range(Q_BITS):

        c_sh_bin = x_sh_1_bin * x_sh_2_bin

        c_sh_bin_shifted = c_sh_bin.roll(shifts=1)
        c_sh_bin_shifted[:, 0] = 0

        s_sh_bin = x_sh_1_bin + x_sh_2_bin

        x_sh_1_bin, x_sh_2_bin = s_sh_bin, c_sh_bin_shifted

        r = r + c_sh_bin_shifted[:,-1]
        
    return r

In [37]:
t = torch.zeros(*shape).uniform_(-2**(n_bit-1), 2**(n_bit-1))
x_sh = t\
    .fix_prec(precision_fractional=0, field=2**Q_BITS)\
    .share(alice, bob, crypto_provider=james, field=2**Q_BITS)

r = private_compare_optim(x_sh)

32
