In [1]:
import syft as sy
import torch as th
from syft.lib.python.collections import OrderedDict
import collections
from syft.lib.torch.module import ModelExecutor

In [2]:
alice = sy.VirtualMachine()
alice_client = alice.get_root_client()

[2021-03-26T13:24:40.864048+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.adjust_sharpness not supported in 0.8.1
[2021-03-26T13:24:40.864417+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.autocontrast not supported in 0.8.1
[2021-03-26T13:24:40.864679+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.equalize not supported in 0.8.1
[2021-03-26T13:24:40.864942+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.invert not supported in 0.8.1
[2021-03-26T13:24:40.865178+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.posterize not supported in 0.8.1
[2021-03-26T13:24:40.865437+0800][CRITICAL][logger]][6370] Skipping torchvision.torchvision.transforms.functional.solarize not supported in 0.8.1


In [3]:
# Linear
fc = th.nn.Linear(4,2)

# send
fc_ptr = fc.send(alice_client)
print(f"----fc_ptr----\n{fc_ptr}\n")

# remote call
res_ptr = fc_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# remote update state dict
sd2 = OrderedDict(th.nn.Linear(4,2).state_dict())
sd2_ptr = sd2.send(alice_client)
fc_ptr.load_state_dict(sd2_ptr)

# get
print(f"----fc_ptr.get().state_dict()----\n{fc_ptr.get().state_dict()}\n")
print(f"----sd2----\n{sd2}\n")

----fc_ptr----
<syft.proxy.torch.nn.LinearPointer object at 0x7f20c3d6c0a0>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f21680f2b50>

----res_ptr.get()----
tensor([[-0.0977, -1.2203]], requires_grad=True)

----fc_ptr.get().state_dict()----
OrderedDict([('weight', tensor([[-0.0598, -0.3301,  0.2567,  0.4625],
        [-0.1558,  0.3708,  0.2388, -0.0500]])), ('bias', tensor([ 0.1444, -0.4406]))])

----sd2----
OrderedDict([('weight', tensor([[-0.0598, -0.3301,  0.2567,  0.4625],
        [-0.1558,  0.3708,  0.2388, -0.0500]])), ('bias', tensor([ 0.1444, -0.4406]))])



  grad = getattr(obj, "grad", None)


In [4]:
# ReLU
relu = th.nn.ReLU(inplace=True)

# send
relu_ptr = relu.send(alice_client)
print(f"----relu_ptr----\n{relu_ptr}\n")

# remote call
res_ptr = relu_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# get
print(f"----relu_ptr.get()----\n{relu_ptr.get()}\n")

----relu_ptr----
<syft.proxy.torch.nn.ReLUPointer object at 0x7f20c3e78610>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f21680ed460>

----res_ptr.get()----
tensor([[0.1456, 0.7297, 0.4326, 0.7943]])

----relu_ptr.get()----
ReLU(inplace=True)



In [5]:
# Sequential
seq = th.nn.Sequential()
seq.add_module("fc1", th.nn.Linear(4,2))
seq.add_module("fc2", th.nn.Linear(2,1))

# send
seq_ptr = seq.send(alice_client)
print(f"----seq_ptr----\n{seq_ptr}\n")

# remote call
res_ptr = seq_ptr(th.rand([1,4]))
print(f"----res_ptr----\n{res_ptr}\n")
print(f"----res_ptr.get()----\n{res_ptr.get()}\n")

# remote update state dict
sd2 = OrderedDict(
    th.nn.Sequential(
        collections.OrderedDict([
            ("fc1", th.nn.Linear(4,2)),
            ("fc2", th.nn.Linear(2,1))
        ])
    ).state_dict()
)
sd2_ptr = sd2.send(alice_client)
seq_ptr.load_state_dict(sd2_ptr)

# get
print(f"----seq_ptr.get().state_dict()----\n{seq_ptr.get().state_dict()}\n")
print(f"----sd2----\n{sd2}\n")

----seq_ptr----
<syft.proxy.torch.nn.SequentialPointer object at 0x7f20c3e7f7c0>

----res_ptr----
<syft.proxy.torch.TensorPointer object at 0x7f20c3e7f910>

----res_ptr.get()----
tensor([[-0.5462]], requires_grad=True)

----seq_ptr.get().state_dict()----
OrderedDict([('fc1.weight', tensor([[ 0.4491,  0.2435, -0.2609, -0.1847],
        [ 0.2243,  0.3964, -0.3445,  0.3832]])), ('fc1.bias', tensor([-0.2109, -0.1567])), ('fc2.weight', tensor([[-0.2227, -0.6645]])), ('fc2.bias', tensor([-0.2591]))])

----sd2----
OrderedDict([('fc1.weight', tensor([[ 0.4491,  0.2435, -0.2609, -0.1847],
        [ 0.2243,  0.3964, -0.3445,  0.3832]])), ('fc1.bias', tensor([-0.2109, -0.1567])), ('fc2.weight', tensor([[-0.2227, -0.6645]])), ('fc2.bias', tensor([-0.2591]))])



In [6]:
# user defined model
class M(th.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc1 = th.nn.Linear(4,2)
        self.fc2 = th.nn.Linear(2,1)
        
    @staticmethod
    def forward(model, x):
        x = model.fc1(x)
        x = model.fc2(x)
        return x
        
m = M()

# local call
executor = ModelExecutor(m)
x = th.rand(1,4)
print(f"----executor(m, x)----\n{executor(m, x)}\n")

# send
m_ptr = m.send(alice_client)
print(f"----m_ptr----\n{m_ptr}\n")

# remote call
x_ptr = x.send(alice_client)
print(f"----executor(m_ptr, x_ptr)).get()----\n{executor(m_ptr, x_ptr).get()}\n")

# remote update state dict
sd2 = OrderedDict(M().state_dict())
sd2_ptr = sd2.send(alice_client)
m_ptr.load_state_dict(sd2_ptr)

# get
print(f"----m_ptr.get().state_dict()----\n{m_ptr.get().state_dict()}\n")
print(f"----sd2----\n{sd2}\n")

----executor(m, x)----
tensor([[0.5898]], grad_fn=<AddmmBackward>)

----m_ptr----
<syft.proxy.torch.nn.ModulePointer object at 0x7f20c3e0dac0>

----executor(m_ptr, x_ptr)).get()----
tensor([[0.5898]], requires_grad=True)

----m_ptr.get().state_dict()----
OrderedDict([('fc1.weight', tensor([[-0.3409, -0.2359,  0.0482, -0.4663],
        [-0.2913, -0.1367,  0.3507,  0.1707]])), ('fc1.bias', tensor([-0.1072,  0.0722])), ('fc2.weight', tensor([[ 0.0962, -0.4331]])), ('fc2.bias', tensor([0.5019]))])

----sd2----
OrderedDict([('fc1.weight', tensor([[-0.3409, -0.2359,  0.0482, -0.4663],
        [-0.2913, -0.1367,  0.3507,  0.1707]])), ('fc1.bias', tensor([-0.1072,  0.0722])), ('fc2.weight', tensor([[ 0.0962, -0.4331]])), ('fc2.bias', tensor([0.5019]))])

