In [1]:
import syft as sy
from syft.spdz import spdz
from syft.mpc.securenn import decompose, select_shares, private_compare, generate_zero_shares_communication
from syft.core.frameworks.torch.tensor import _GeneralizedPointerTensor, _SPDZTensor

import unittest
import numpy as np
import torch
import importlib


In [2]:
hook = sy.TorchHook(verbose=True)

me = hook.local_worker
me.is_client_worker = False

bob = sy.VirtualWorker(id="bob", hook=hook, is_client_worker=False)
alice = sy.VirtualWorker(id="alice", hook=hook, is_client_worker=False)

me.add_workers([bob, alice])
bob.add_workers([me, alice])
alice.add_workers([me, bob])



In [3]:
def decompose(tensor):
    """
    decompose a tensor into its binary representation
    """
    powers = torch.arange(0,Q_BITS)
    for i in range(len(tensor.shape)):
        powers = powers.unsqueeze(0)
    tensor = tensor.unsqueeze(-1)
    moduli = 2 ** powers
    tensor = ((tensor+2**(Q_BITS)) / moduli.type_as(tensor)) % 2
    return tensor

def _pc_beta0(x, r):
    # note x and r are both binary tensors,
    # and dim -1 contains their bits
    # x should be shared, r should be public
    w = xor(x, r)
    z = r  - (x - 1)
    w_sum = torch.zeros(w.shape).type_as(w)
    for i in range(Q_BITS - 2, -1, -1):
        w_sum[..., i] = w[..., (i + 1):].sum(dim=-1, keepdim=True)
    c = z + w_sum
    return c


def _pc_beta1(x, t):
    w = xor(x, t)
    z = (x + 1) - t
    w_sum = torch.zeros(w.shape).type_as(w)
    for i in range(Q_BITS - 2, -1, -1):
        w_sum[..., i] = w[..., (i + 1):].sum(dim=-1, keepdim=True)
    c = z + w_sum
    return c


def generate_one_shares_communication(alice, bob, sizes):
    return torch.ones(sizes).long().share(alice, bob)

def xor(x, y):
    return x + y - 2 * x * y

In [89]:
x = torch.LongTensor([5])#.share(bob, alice)

r = torch.LongTensor([4])

Q_BITS = spdz.Q_BITS

beta = (torch.rand(1)>0).long()

t = (r + 1) % (2 ** Q_BITS)

x_bits = decompose(x)
r_bits = decompose(r)
t_bits = decompose(t)

zeros = (beta == 0).long()
ones = (beta == 1).long()
others = (r == (2 ** Q_BITS - 1)).long()
ones = ones & (others - 1).abs()

c_zeros = _pc_beta0(x_bits, r_bits)
c_ones = _pc_beta1(x_bits, t_bits)
c_other = _pc_else([alice, bob], x_bits.shape)

# TODO: recombine c properly here
c = torch.zeros(*x_bits.shape).long()
c[zeros] = c_zeros
c[ones] = c_ones
c[others] = c_other

RuntimeError: invalid argument 3: out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensor.c:459

In [75]:
c_ones



Columns 0 to 12 
    1     1     1     1     1     1     1     1     1     1     1     1     1

Columns 13 to 25 
    1     1     1     1     1     1     1     1     1     1     1     1     1

Columns 26 to 30 
    1     1     1     1     1
[syft.core.frameworks.torch.tensor.LongTensor of size 1x31]

In [17]:
def _pc_else(workers, *sizes):
    u = generate_zero_shares_communication(*workers, *sizes)
    (w0, u0), (w1, u1) = u.child.shares.child.pointer_tensor_dict.items()
    u0 = u0.wrap()
    u1 = u1.wrap()
    
    c0 = u0 * 0
    c1 = u1 * 0
    for i in range(Q_BITS - 2, -1, -1):
        if i == 0:
            c0[..., i] = -u0[..., i]
            c1[..., i] = u1[..., i]

        c0[..., i] = u0[..., i] + 1
        c1[..., i] = -u1[..., i]
    ptr_dict = {w0:c0.child, w1:c1.child}
    c_gp = _GeneralizedPointerTensor(ptr_dict, torch_type='syft.LongTensor').wrap(True)
    c = _SPDZTensor(c_gp, torch_type='syft.LongTensor').wrap(True)
    return c

In [21]:
c_other

[Head of chain]
[syft.core.frameworks.torch.tensor.LongTensor with no dimension]

In [66]:
def private_compare(x, r, beta, workers):
    """
    computes beta XOR (x > r)

    x is private input
    r is public input for comparison
    beta is public random bit tensor

    all of type _GeneralizedPointerTensor
    """
    t = (r + 1) % (2 ** Q_BITS)

    x_bits = decompose(x)
    r_bits = decompose(r)
    t_bits = decompose(t)

    zeros = beta == 0
    ones = beta == 1
    others = r == (2 ** Q_BITS - 1)
    ones = ones & (others - 1).abs()

    c_zeros = _pc_beta0(x_bits, r_bits)
    c_ones = _pc_beta1(x_bits, t_bits)
    c_other = _pc_else()

    # TODO: recombine c properly here
    # torch.zeros()
    c = torch.cat([c_zeros, c_ones, c_other], -1)

    s = random_as(c, mod=p)
    permute = torch.randperm(c.size(-1))
    d = s * c[..., permute]
    d.get()
    return (d == 0).max()