In [1]:
import syft as sy
from syft.spdz import spdz
from syft.mpc.securenn import decompose, private_compare, msb, relu_deriv
from syft.core.frameworks.torch.tensor import _GeneralizedPointerTensor, _SPDZTensor, _SNNTensor
from syft.core.frameworks.torch import utils as torch_utils

import unittest
import numpy as np
import torch
import importlib

hook = sy.TorchHook(verbose=True)

me = hook.local_worker
me.is_client_worker = False

bob = sy.VirtualWorker(id="bob", hook=hook, is_client_worker=False)
alice = sy.VirtualWorker(id="alice", hook=hook, is_client_worker=False)

me.add_workers([bob, alice])
bob.add_workers([me, alice])
alice.add_workers([me, bob])



In [2]:
x = (torch.FloatTensor([[0.1,0.2,0.4,0.3],[0.9,0,0,0.1]])).fix_precision().share(alice,bob)
# x.get().decode()

In [4]:
out = x.argmax()

In [6]:
out.get().decode()


 0  0  1  0
 1  0  0  0
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x4]

In [245]:
def very_slow_argmax(self):
    # there are a TON of things about this that are stupidly slow
    # but unfortunately there are bugs elsewhere that I don't have
    # time to fix. TODO: optimize the crap out of this
    
    my_shape = list(self.get_shape())
    assert len(my_shape) == 2
    
    max_vals = self[:,0:1]
    for i in range(1,my_shape[1]):
        new_vals = self[:,i:i+1]
        gate = (max_vals > new_vals)
        left = (gate * max_vals)

        gate = (max_vals < new_vals)
        right = gate * new_vals
        max_vals = left + right 

    max_vals = max_vals.expand(my_shape)
    out = (max_vals >= self) * (max_vals <= self)
    return out

In [241]:
out.get().decode()


 0  0  1  0
 1  0  0  0
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x4]

In [None]:
out = max_vals.expand(x.get_shape()) == x

In [151]:
x.get().decode()


 0.1000  0.2000  0.4000  0.3000
 0.9000  0.0000  0.0000  0.1000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x4]

In [152]:
max_vals.expand(x.get_shape()).get().decode()

TypeError: expand received an invalid combination of arguments - got (dict), but expected one of:
 * (int ... size)
      didn't match because some of the arguments have invalid types: ([31;1mdict[0m)
 * (torch.Size size)
      didn't match because some of the arguments have invalid types: ([31;1mdict[0m)


In [51]:
type(gate)

syft.core.frameworks.torch.tensor.FloatTensor

In [52]:
type(max_vals)

syft.core.frameworks.torch.tensor.LongTensor

In [36]:
x == max_vals.expand(x.get_shape())

False

In [33]:
type(x)

syft.core.frameworks.torch.tensor.FloatTensor

In [34]:
type(max_vals)

syft.core.frameworks.torch.tensor.LongTensor

In [12]:
max_vals.decode()


 0.2000
 0.9000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x1]

In [89]:
max_vals.decode()


 0.0000
 0.9000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x1]

In [15]:
x.get()

[Fixed precision]

 0.1000  0.2000  0.4000  0.3000
 0.9000  0.0000  0.0000  0.1000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x4]

In [11]:
z = x.view(2,2,2)

In [12]:
z.get().decode()


(0 ,.,.) = 
  0.1000  0.2000
  0.4000  0.3000

(1 ,.,.) = 
  0.9000  0.0000
  0.0000  0.1000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x2x2]

In [47]:
type(z.child.child.child)

syft.core.frameworks.torch.tensor._SNNTensor

In [40]:
z.get()

[Fixed precision]

(0 ,.,.) = 
  0.1000  0.2000
  0.4000  0.3000

(1 ,.,.) = 
  0.9000  0.0000
  0.0000  0.1000
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x2x2]

In [7]:
z = x < y

In [8]:
a = z * y

In [9]:
a.get()

[Fixed precision]

 0  4  0  8
 0  0  0  0
[syft.core.frameworks.torch.tensor.FloatTensor of size 2x4]

In [3]:
(x > y).get()


 1  0  1  0
 1  0  0  0
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [4]:
(x < y).get()


 0  1  0  1
 0  0  0  0
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [5]:
(x >= y).get()


 1  0  1  0
 1  1  1  1
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [6]:
(x <= y).get()


 0  1  0  1
 0  1  1  1
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [7]:
(x == y).get()


 0  0  0  0
 0  1  1  1
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [5]:
v = sy.FloatTensor([1.1])
v2 = sy.FloatTensor([1.2])
v.__gt__(v2)


 0
[syft.core.frameworks.torch.tensor.ByteTensor of size 1]

In [3]:
# def newgt(self, *args, **kwargs):
#     try:
#         out = self.child > args[0].child
#         return out
#     except:
#         return self.native___gt__(*args, **kwargs)

In [4]:
# # torch.LongTensor.native___gt__ = torch.LongTensor.__gt__
# torch.LongTensor.__gt__ = newgt

In [5]:
x > y

RuntimeError: Command "gt" is not a supported Torch operation.

In [5]:
x.child > y.child

executing
running relu
running relu


[Head of chain]
[syft.core.frameworks.torch.tensor.LongTensor with no dimension]

In [6]:
(x.child > y.child).get()

executing
running relu
running relu



 1  0  1  0
 1  0  0  0
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]

In [8]:
out = x.relu()

running relu
running relu


In [6]:
out.get()


 0  3  0  7
 0  0  1  2
[syft.core.frameworks.torch.tensor.LongTensor of size 2x4]