In [1]:
import syft as sy

In [2]:
from syft.core.store.storeable_object import StorableObject

In [3]:
from syft.core.common import UID

In [4]:
import torch

# test functions

In [5]:
from syft.core.common.serde.serialize import _serialize
from syft.core.common.serde.deserialize import _deserialize
def serde(data, via_proto=True):
    to_proto = True if via_proto else False
    to_bytes = not to_proto
    from_proto = to_proto
    from_bytes = to_bytes
    
    data_serde = _deserialize(
#         blob=_serialize(
#             obj=data,
#             to_proto=to_proto,
#             to_bytes=to_bytes
#         ),
        blob = data.serialize(
            to_proto=to_proto,
            to_bytes=to_bytes
        ),
        from_proto=from_proto,
        from_bytes=from_bytes
    )
    
    return data_serde

In [6]:
import math
def assert_eq(obj_1, obj_2):
    if isinstance(obj_1, torch.Tensor):
        assert (obj_1==obj_2).any()
    elif isinstance(obj_1, sy.lib.python.Float):
        assert math.isclose(obj_1, obj_2, rel_tol=1e-6)
    else:
        assert obj_1==obj_2        

In [7]:
def test_serde(data, via_proto=True):
    data_serde = serde(data, via_proto=via_proto)
    assert_eq(data_serde, data)

In [8]:
def test_StorableObject_data(data, via_proto=True):
    storable = StorableObject(id=UID(), data=data)
    storable_serde = serde(storable, via_proto=via_proto)
    assert_eq(storable_serde.data, data)

In [9]:
remote = sy.VirtualMachine().get_root_client()
def test_send_get(data, remote=remote):
    tags = ["tag#"]
    description = "test description"
    ptr = data.send(remote, tags=tags, description=description, searchable=True)
    assert remote.store[-1].tags==ptr.tags==tags
    assert remote.store[-1].description==ptr.description==description
    assert remote.store[-1].id_at_location==ptr.id_at_location
    assert_eq(ptr.get(), data)

In [10]:
def test(data):
    test_serde(data, via_proto=True)
    test_serde(data, via_proto=False)
    
    test_StorableObject_data(data, via_proto=True)
    test_StorableObject_data(data, via_proto=False)
    
    test_send_get(data)

# primitives

## namedtuple

In [11]:
from syft.lib.python.namedtuple import ValuesIndices

def test_torch_valuesindices_serde() -> None:
    x = torch.Tensor([[1, 2], [1, 2]])
    y = x.mode()
    values = y.values
    indices = y.indices
#     import pdb; pdb.set_trace()

    ser = y.serialize()
    # horrible hack, we shouldnt be constructing these right now anyway
    params = [None] * 17
    params[0] = values
    params[1] = indices
    vi = ValuesIndices(*params)
    de = _deserialize(blob=ser)

    assert (de.values == y.values).all()
    assert (de.indices == y.indices).all()
    assert (vi.values == de.values).all()
    assert (vi.indices == de.indices).all()

In [12]:
test_torch_valuesindices_serde()

## Int

In [13]:
data = sy.lib.python.Int(77)
test(data)

## String

In [14]:
data = sy.lib.python.String("Hello world")
test(data)

## Dict

In [15]:
data = sy.lib.python.Dict({
    "a": 1,
    "b": 2
})
test(data)

## Bool

In [16]:
data = sy.lib.python.Bool(True)
test(data)

## Float

In [17]:
data = sy.lib.python.Float(1.1)
test(data)

## List

In [18]:
data = sy.lib.python.List([1,2,3])
test(data)

## OrderedDict

In [19]:
data = sy.lib.python.collections.OrderedDict({
    "a":1,
    "b":2
})
test(data)

# torch

In [20]:
# !pytest ../../../../tests/syft/lib/torch/

## torch.device

In [21]:
data = torch.device('cpu')
test(data)

## torch.Tensor

In [22]:
data = torch.rand((2,3))
test(data)

## torch.nn.parameter.Parameter

In [23]:
data = torch.nn.parameter.Parameter(torch.tensor([1.0, 2, 3]), requires_grad=True)
data.grad = torch.randn_like(data)
test(data)

# PSI

In [24]:
# !pytest ../../../../tests/syft/lib/psi/psi_test.py