In [1]:
import syft as sy
import torch
import torch.nn as nn
import time
import torch.autograd as autograd

syft = sy 

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
james = sy.VirtualWorker(hook, id="james")
crypto_provider = james

# Idea: decompose exponential to exponents of boolean

Example with exp(5)

In [2]:
e = torch.exp(torch.tensor(1.))

In [3]:
e**5

tensor(148.4131)

In [4]:
 e**(1*2**2 + 0*2**1 + 1*2**0)

tensor(148.4131)

Option 1

In [5]:
(e**1)**(2**2) * (e**0)**(2**1) * (e**1)**(2**0)

tensor(148.4131)

Option 2

In [6]:
(e**(2**2) if 1 else 1)*(e**(2**1) if 0 else 1)*(e**(2**0) if 1 else 1)

tensor(148.4131)

which can be easily written as a citcuit as well

# POC
We use Option 1

here is a function to decompose values to binary:

In [7]:
Q_BITS = 3
def decompose(tensor):
    """decompose a tensor into its binary representation."""
    n_bits = Q_BITS
    powers = torch.arange(n_bits)
    if hasattr(tensor, "child") and isinstance(tensor.child, dict):
        powers = powers.send(*list(tensor.child.keys()), **no_wrap)
    for i in range(len(tensor.shape)):
        powers = powers.unsqueeze(0)
    tensor = tensor.unsqueeze(-1)
    moduli = 2 ** powers
    tensor = torch.fmod((tensor / moduli.type_as(tensor)), 2)
    return tensor

In [8]:
decompose(torch.tensor([5]))

tensor([[1, 0, 1]])

This is how to do OPtion 1 in plain pyTorch:

In [9]:
x = torch.tensor([5])
x_bin = decompose(x)[0]
print(x_bin)
moduli = 2 ** torch.arange(Q_BITS)
print(moduli)
coeffs = (torch.exp(x_bin.float()) ** moduli.float())
print(torch.prod(coeffs))

tensor([1, 0, 1])
tensor([1, 2, 4])
tensor(148.4131)


Converts a bit b in Encrypted(Exp(b))

In [10]:
def exp_bit(x):
    # Takes a bit (0 or 1, here 0 for example)and share it in a *binary* field
    #x = torch.tensor([0])
    x_sh = x.share(alice, bob, crypto_provider=crypto_provider, field=2)
    
    # Access shares
    x0, x1 = x_sh.child.child['alice'], x_sh.child.child['bob']
    x0 = x0.float()
    x1 = x1.float()
    print(alice._objects[x0.id_at_location], bob._objects[x1.id_at_location])
    
    # Compute privately the wrap field bit, which decrypts to 1 iff x0+x1 >= 2  
    x0_sh = x0.fix_precision().share(alice, bob, crypto_provider=charlie).get()
    x1_sh = x1.fix_precision().share(alice, bob, crypto_provider=charlie).get()
    wrap_field = x0_sh * x1_sh 
    
    # Compute exp of shares
    exp_x0, exp_x1 = [torch.exp(x0), torch.exp(x1)]
    alice._objects[exp_x0.id_at_location], bob._objects[exp_x1.id_at_location]

    # Share the exp of shares
    exp_x0_sh = exp_x0.fix_precision().share(alice, bob, crypto_provider=charlie).get()
    exp_x1_sh = exp_x1.fix_precision().share(alice, bob, crypto_provider=charlie).get()

    # Apply exp(x0 + x1) =  exp(x0) *  exp(x1) formula + a wrapping correction if needed
    one = torch.tensor([1.]).fix_precision()
    inv_exp_field_size = torch.exp(-torch.tensor([2.])).fix_precision()
    exp_sh = exp_x0_sh * exp_x1_sh * (wrap_field * (inv_exp_field_size - one) + one)

    # Open and get 1.0
    return exp_sh


# Example
x = torch.tensor([1])
exp_x = exp_bit(x)
print(exp_x.get().float_prec())

tensor([1.]) tensor([0.])
tensor([2.7180])


## Clear text demo

In [11]:
x = torch.tensor([5, 3, 1])
x_bin = decompose(x)

exp_x_bin = torch.exp(x_bin.float())
print(exp_x_bin)
exp_x_bin_list = exp_x_bin.unbind(dim=1)
print(exp_x_bin_list)

moduli = 2 ** torch.arange(Q_BITS)
moduli_list = map(lambda x:x.item(), moduli.unbind(dim=0))

coeffs_list = [
    xp_x_bin_item ** modulo
    for xp_x_bin_item, modulo
    in zip(exp_x_bin_list, moduli_list)
]

coeffs = torch.stack(coeffs_list).t()
print(coeffs)

torch.prod(coeffs, dim=1)

tensor([[2.7183, 1.0000, 2.7183],
        [2.7183, 2.7183, 1.0000],
        [2.7183, 1.0000, 1.0000]])
(tensor([2.7183, 2.7183, 2.7183]), tensor([1.0000, 2.7183, 1.0000]), tensor([2.7183, 1.0000, 1.0000]))
tensor([[ 2.7183,  1.0000, 54.5981],
        [ 2.7183,  7.3891,  1.0000],
        [ 2.7183,  1.0000,  1.0000]])


tensor([148.4131,  20.0855,   2.7183])

## Encrypted Demo

In [12]:
x = torch.tensor([5, 3, 1])
x_bin = decompose(x)
print('bin', x_bin)
exp_x_bin_sh = exp_bit(x_bin)
print('exp sh', exp_x_bin_sh.child.child.virtual_get())
exp_x_bin_list = torch.unbind(exp_x_bin_sh, dim=1)

moduli = 2 ** torch.arange(Q_BITS)
moduli_list = list(map(lambda x:x.fix_precision(precision_fractional=0), moduli.unbind(dim=0)))

coeffs_list = [
    xp_x_bin_item ** modulo
    for xp_x_bin_item, modulo
    in zip(exp_x_bin_list, moduli_list)
]

exp_x_sh = coeffs_list[0]
for coeff in coeffs_list[1:]:
    exp_x_sh = exp_x_sh * coeff

print('exp_x', exp_x_sh.get().float_prec())

bin tensor([[1, 0, 1],
        [1, 1, 0],
        [1, 0, 0]])
tensor([[1., 1., 1.],
        [1., 0., 0.],
        [0., 1., 0.]]) tensor([[0., 1., 0.],
        [0., 1., 0.],
        [1., 1., 0.]])
exp sh tensor([[2718,  997, 2718],
        [2718, 2718, 1000],
        [2718,  997, 1000]])
exp_x tensor([147.4430,  20.0810,   2.7020])
