In [1]:
import syft as sy
import torch as th
from syft.lib.python.collections import OrderedDict
import collections

In [2]:
alice = sy.VirtualMachine()
alice_client = alice.get_root_client()

[2021-04-15T14:42:38.295348+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.adjust_sharpness not supported in 0.8.1
[2021-04-15T14:42:38.296090+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.autocontrast not supported in 0.8.1
[2021-04-15T14:42:38.297014+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.equalize not supported in 0.8.1
[2021-04-15T14:42:38.297641+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.invert not supported in 0.8.1
[2021-04-15T14:42:38.298456+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.posterize not supported in 0.8.1
[2021-04-15T14:42:38.299100+0800][CRITICAL][logger]][1001] Skipping torchvision.torchvision.transforms.functional.solarize not supported in 0.8.1


In [3]:
# Linear
fc = th.nn.Linear(4,2)

# send
fc_ptr = fc.send(alice_client)
print(f"----fc_ptr----\n{fc_ptr}\n")

# remote call
res_ptr = fc_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# remote update state dict
sd2 = OrderedDict(th.nn.Linear(4,2).state_dict())
sd2_ptr = sd2.send(alice_client)
fc_ptr.load_state_dict(sd2_ptr)

# get
print(f"----fc_ptr.get().state_dict()----\n{fc_ptr.get().state_dict()}\n")
print(f"----sd2----\n{sd2}\n")

----fc_ptr----
<syft.proxy.torch.nn.LinearPointer object at 0x7f7e619d2a30>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f7e6194e7f0>

----res_ptr.get()----
tensor([[0.3768, 0.2239]], requires_grad=True)

----fc_ptr.get().state_dict()----
OrderedDict([('weight', tensor([[-0.1867, -0.1446,  0.3949, -0.4263],
        [-0.3738, -0.0338,  0.4329, -0.0855]])), ('bias', tensor([-0.3970, -0.4864]))])

----sd2----
OrderedDict([('weight', tensor([[-0.1867, -0.1446,  0.3949, -0.4263],
        [-0.3738, -0.0338,  0.4329, -0.0855]])), ('bias', tensor([-0.3970, -0.4864]))])



  grad = getattr(obj, "grad", None)


In [4]:
# ReLU
relu = th.nn.ReLU(inplace=True)

# send
relu_ptr = relu.send(alice_client)
print(f"----relu_ptr----\n{relu_ptr}\n")

# remote call
res_ptr = relu_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# get
print(f"----relu_ptr.get()----\n{relu_ptr.get()}\n")

----relu_ptr----
<syft.proxy.torch.nn.ReLUPointer object at 0x7f7e619d2250>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f7e618a14f0>

----res_ptr.get()----
tensor([[0.1065, 0.2725, 0.9136, 0.3703]])

----relu_ptr.get()----
ReLU(inplace=True)



In [5]:
# Sequential
seq = th.nn.Sequential()
seq.add_module("fc1", th.nn.Linear(4,2))
seq.add_module("fc2", th.nn.Linear(2,1))

# send
seq_ptr = seq.send(alice_client)
print(f"----seq_ptr----\n{seq_ptr}\n")

# remote call
res_ptr = seq_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# remote update state dict
sd2 = OrderedDict(
    th.nn.Sequential(
        collections.OrderedDict([
            ("fc1", th.nn.Linear(4,2)),
            ("fc2", th.nn.Linear(2,1))
        ])
    ).state_dict()
)
sd2_ptr = sd2.send(alice_client)
seq_ptr.load_state_dict(sd2_ptr)

# get
print(f"----seq_ptr.get().state_dict()----\n{seq_ptr.get().state_dict()}\n")
print(f"----sd2----\n{sd2}\n")

----seq_ptr----
<syft.proxy.torch.nn.SequentialPointer object at 0x7f7f014f2370>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f7e6194e7f0>

----res_ptr.get()----
tensor([[0.2109]], requires_grad=True)

----seq_ptr.get().state_dict()----
OrderedDict([('fc1.weight', tensor([[ 0.0035,  0.4069, -0.4324, -0.4882],
        [ 0.3855,  0.1846,  0.3832,  0.2659]])), ('fc1.bias', tensor([-0.4026,  0.4498])), ('fc2.weight', tensor([[-0.1285,  0.3856]])), ('fc2.bias', tensor([0.2325]))])

----sd2----
OrderedDict([('fc1.weight', tensor([[ 0.0035,  0.4069, -0.4324, -0.4882],
        [ 0.3855,  0.1846,  0.3832,  0.2659]])), ('fc1.bias', tensor([-0.4026,  0.4498])), ('fc2.weight', tensor([[-0.1285,  0.3856]])), ('fc2.bias', tensor([0.2325]))])



In [6]:
# user defined model
class M(th.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc1 = th.nn.Linear(4,2)
        self.fc2 = th.nn.Linear(2,1)

    def forward(model, x=th.rand(4), th=th):
        x = model.fc1(x)
        x = model.fc2(x)
        return x
        
m = M()

# local call
x = th.rand(1,4)
print(f"----m(m)----\n{m(x)}\n")

# send
m_ptr = m.send(alice_client)
print(f"----m_ptr----\n{m_ptr}\n")

# remote call
x_ptr = x.send(alice_client)
print(f"----m_ptr(x=x_ptr)).get()----\n{m_ptr(x=x_ptr).get()}\n")

# remote update state dict
sd2 = OrderedDict(M().state_dict())
sd2_ptr = sd2.send(alice_client)
m_ptr.load_state_dict(sd2_ptr)

# get
m_get = m_ptr.get()
print(f"----m_get.state_dict()----\n{m_get.state_dict()}\n")
print(f"----sd2----\n{sd2}\n")
print(f"----type(m_get)----\n{type(m_get)}")

[2021-04-15T14:42:38.459331+0800][CRITICAL][logger]][1001] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 09d2b078e1fa4788a2211e3d29f2bf6b>.


----m(m)----
tensor([[-0.1710]], grad_fn=<AddmmBackward>)

----m_ptr----
<syft.proxy.torch.nn.ModulePointer object at 0x7f7e61a5ce20>

----m_ptr(x=x_ptr)).get()----
tensor([[-0.1710]], requires_grad=True)

----m_get.state_dict()----
OrderedDict([('fc1.weight', tensor([[ 0.3132, -0.3550, -0.4005, -0.0776],
        [ 0.3409,  0.0481, -0.2663, -0.3465]])), ('fc1.bias', tensor([-0.3148,  0.3046])), ('fc2.weight', tensor([[-0.6934, -0.4875]])), ('fc2.bias', tensor([-0.2427]))])

----sd2----
OrderedDict([('fc1.weight', tensor([[ 0.3132, -0.3550, -0.4005, -0.0776],
        [ 0.3409,  0.0481, -0.2663, -0.3465]])), ('fc1.bias', tensor([-0.3148,  0.3046])), ('fc2.weight', tensor([[-0.6934, -0.4875]])), ('fc2.bias', tensor([-0.2427]))])

----type(m_get)----
<class '__main__.M'>
