In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
from syft.lib.torch.module import ModelExecutor
logger.remove()

In [3]:
# Dataset
mnist_path = Path.home() / ".pysyft" / "mnist"
mnist_path.mkdir(exist_ok=True, parents=True)

mnist_train = datasets.MNIST(str(mnist_path), train=True, download=True,
               transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

mnist_test = datasets.MNIST((mnist_path), train=False, 
              transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = th.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, pin_memory=True)
test_loader = th.utils.data.DataLoader(mnist_test, batch_size=1024, shuffle=True, pin_memory=True)

# define model


In [4]:
# class NormalModule(th.nn.Module):
    
#     def __init__(self):
#         super().__init__()
#         self.p1 = th.nn.Parameter(th.rand(50,50))
        
#     def fw(self, x):
#         out = x @ self.p1
#         out.sum().backward()
#         print(self.p1.grad)
#         return out



In [5]:
# mod = NormalModule()

In [6]:
# y = mod.fw(th.rand(32,50))

In [7]:
# y.sum().backward()

In [8]:
# mod.p1.grad

In [9]:
class ForwardToPlanConverter(type):
    def __call__(cls, *args, **kwargs):
        obj = type.__call__(cls, *args, **kwargs)
        obj.make_forward_plan()
        return obj

class SyModule(th.nn.Module, metaclass=ForwardToPlanConverter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.building_forward = False
        self._parameter_pointers = dict()

    def make_forward_plan(self):
        if not hasattr(self, "forward"):
            raise ValueError("Missing .forward() method for Module")
        self.building_forward = True
        self.forward = make_plan(self.forward)
        self.building_forward = False
        
    def __getattr__(self, name):
        # this is __getattr__ instead of __getattribute__ because of the structure of torch.nn.Module
        if name in self._parameter_pointers and self.building_forward:
            return self._parameter_pointers[name]
        
        res = super().__getattr__(name)
        if isinstance(res, (th.nn.Module, th.nn.Parameter)) and self.building_forward:
            res_ptr = res.send(ROOT_CLIENT)
            self._parameter_pointers[name] = res_ptr
            return res_ptr
        else:
            return res

In [10]:
# x = th.nn.Parameter(th.Tensor([1,2])).send(ROOT_CLIENT)

In [11]:
# x.id_at_location

In [12]:
# x @ th.tensor([1,2])

In [13]:
# x.grad.get()

In [14]:
# class MySyModuleBlock(SyModule):
#     def __init__(self):
#         super().__init__()
#         self.p1 = th.nn.Parameter(th.rand(50,50))
#         self.relu1 = th.nn.ReLU()
#         self.l2 = th.nn.Linear(50,10)
    
#     def forward(self, x = th.rand(32, 50)):
#         o1 = x @ self.p1
#         relu_out = self.relu1(o1)
#         out = self.l2(relu_out)
#         return out

In [15]:
# class MySyModule(SyModule):
#     def __init__(self):
#         super().__init__()
#         self.layer1 = th.nn.Linear(28*28,50)
#         self.relu1 = th.nn.ReLU()
#         self.layer2 = MySyModuleBlock()
    
#     def forward(self, x = th.rand(32,28*28)):
#         x_reshaped = x.view(-1, 28 * 28)
#         o1 = self.layer1(x_reshaped)
#         a1 = self.relu1(o1)
#         out = self.layer2(x=a1)[0]
#         return out

In [16]:
class MySyModule(SyModule):
    def __init__(self):
        super().__init__()
        self.layer1 = th.nn.Linear(28*28,10)
    
    def forward(self, x = th.rand(32,28*28)):
        x_reshaped = x.view(-1, 28 * 28)
        o1 = self.layer1(x_reshaped)
        return o1

In [17]:
x = MySyModule()

In [18]:
# model = x.send(ROOT_CLIENT)

In [19]:
# for kv in model.state_dict():
#     k, v = kv[0], kv[1]
#     print(kv.get())
#     print(model.__getitem__[kv].get())

In [20]:
# for p in model.state_dict():
#     p.data = p.data - lr * p.grad
#     print(p.data)

In [21]:
# %debug

# Train plan

In [22]:
dummy_dl = sy.lib.python.List([
    [th.rand([32,1,28,28]), th.randint(0,10, (32,))] for _ in range(1)
])

In [23]:
remote_torch = ROOT_CLIENT.torch

In [24]:
def cross_entropy_loss(logits, targets, batch_size):
    norm_logits = logits - logits.max()
    log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
    return -(targets * log_probs).sum() / batch_size


In [25]:
def sgd_step(model, lr=0.1):
    with ROOT_CLIENT.torch.no_grad():
        i=0
        for p in model.parameters():
            p.data = p.data - lr * p.grad
            print(p)
#             y = th.zeros_like(p.grad.get())
#             print(y)
#             exit()
            if i ==0:
                p.grad = th.zeros(10,784)
            else:
                p.grad = th.zeros(10)
            i+=1

In [26]:
# def train(xs = th.rand([64*3, 1, 28, 28]), ys = th.randint(0, 10, [64*3, 10]),
#           model = MySyModule()):

In [27]:
@make_plan
def train(dl = dummy_dl, model = MySyModule()):
#     loss_fn = remote_torch.nn.CrossEntropyLoss()
    optimizer = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)

#     for i in range(1):
#         indices = th.tensor(range(64*i, 64*(i+1)))
#         x, y = xs.index_select(0, indices), ys.index_select(0, indices)
#         out = model.forward(x=x)[0]
#         loss = cross_entropy_loss(out, y, 32)
#         loss.backward()
#         sgd_step(model)

    for xy in dl:
#         optimizer.zero_grad()
        x, y = xy[0], xy[1]
        out = model.forward(x=x)[0]
        loss = remote_torch.nn.functional.nll_loss(out, y)
#         loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
    
    return [model]

In [28]:
# %debug

In [29]:
# %debug

In [30]:
# train.actions

In [31]:
# alice_client = VirtualMachine(name="alice").get_client()
# train_ptr = train.send(alice_client)
# local_model = MySyModule()

In [32]:
# list(local_model.parameters())[0][0][:5]

In [33]:
# local_model, = train_ptr(dl=dummy_dl, model=local_model).get()

In [34]:
# list(local_model.parameters())[0][0][:5]

In [35]:
alice_client = VirtualMachine(name="alice").get_client()
train_ptr = train.send(alice_client)

In [36]:
def test(test_loader, model):
    correct = []
    model.eval()
    for data, target in test_loader:
        output = model(x=data)[0]
        _, pred = th.max(output, 1)
        correct.append(th.sum(np.squeeze(pred.eq(target.data.view_as(pred)))))
    acc = sum(correct) / len(test_loader.dataset)
    return acc

In [37]:
def show_predictions(test_loader, model, n=6):
    xs, ys = next(iter(test_loader))
    preds = model(x=xs)[0].detach()
    
    fig, axs = plt.subplots(1, n, sharex='col', sharey='row', figsize=(16, 8))
    for i in range(n):
        ax = axs[i]
        ax.set_xticks([]),ax.set_yticks([])
        ax.set_xlabel(f"prediction: {np.argmax(preds[i])}, actual: {ys[i]}")
        ax.imshow(xs[i].reshape((28, 28)))

In [38]:
# list(local_model.parameters())[0]

In [39]:
# show_predictions(test_loader, local_model)
# print(f"accuracy: {test(test_loader, local_model):.2F}")

# Train

In [40]:
# print(list(local_model.parameters())[0][0][:5])

In [41]:
local_model = MySyModule()

In [42]:
print(list(local_model.parameters())[0].grad)

None


In [51]:
train.inputs["model"].id_at_location

<UID: 43d1c7d81a754829a05e890085d12c84>

In [48]:
train.actions

[RunClassMethodAction ModulePointer.parameters(, ),
 SaveObjectAction <Storable: 0.10000000149011612>,
 SaveObjectAction <Storable: 0>,
 RunClassMethodAction SGD(ListPointer, lr=FloatPointer,momentum=IntPointer),
 RunClassMethodAction ListPointer.__len__(, ),
 RunClassMethodAction ListPointer.__iter__(, ),
 RunClassMethodAction IteratorPointer.__next__(, ),
 SaveObjectAction <Storable: 0>,
 RunClassMethodAction FloatIntStringTensorParameterUnionPointer.__getitem__(IntPointer, ),
 SaveObjectAction <Storable: 1>,
 RunClassMethodAction FloatIntStringTensorParameterUnionPointer.__getitem__(IntPointer, ),
 RunClassMethodAction ModulePointer.forward(, x=AnyPointer),
 SaveObjectAction <Storable: 0>,
 RunClassMethodAction TensorPointer.__getitem__(IntPointer, ),
 RunClassMethodAction nll_loss(TensorPointer,AnyPointer, ),
 RunClassMethodAction TensorPointer.backward(, ),
 RunClassMethodAction SGDPointer.step(, )]

In [53]:
train.actions[0]._self.id_at_location

<UID: 43d1c7d81a754829a05e890085d12c84>

In [47]:
local_model.forward.actions

[SaveObjectAction <Storable: -1>,
 SaveObjectAction <Storable: 784>,
 RunClassMethodAction AnyPointer.view(IntPointer,IntPointer, ),
 SaveObjectAction <Storable: Linear(in_features=784, out_features=10 ... , bias=True)>,
 RunClassMethodAction LinearPointer.__call__(TensorPointer, )]

In [None]:
res_ptr  = train_ptr(dl=dl, model=local_model)


In [43]:
list(local_model.parameters())[0]

Parameter containing:
tensor([[ 0.0139,  0.0127, -0.0037,  ...,  0.0193,  0.0165, -0.0348],
        [ 0.0262,  0.0309, -0.0004,  ...,  0.0342,  0.0341, -0.0166],
        [ 0.0241, -0.0271,  0.0044,  ...,  0.0099, -0.0231, -0.0232],
        ...,
        [-0.0154,  0.0040, -0.0244,  ..., -0.0335, -0.0019,  0.0127],
        [-0.0140, -0.0159, -0.0108,  ...,  0.0122, -0.0130, -0.0113],
        [-0.0301, -0.0329,  0.0227,  ..., -0.0131,  0.0025, -0.0035]],
       requires_grad=True)

In [44]:
# for i, (x, y) in enumerate(train_loader):
# #     ys = th.nn.functional.one_hot(y)
# #     xs = x
#     dl = [[x,y]]
# #     res_ptr  = train_ptr(dl=dl, model=local_model)
#     res_ptr  = train_ptr(xs=xs, ys=ys, model=local_model)
# #     local_model.zero_grad()
#     local_model, = res_ptr.get()
# #     set_params(local_model, params)

#     if i%10 == 0:
# #         print(list(local_model.parameters())[0].grad[0][0:5])
#         print(list(local_model.parameters())[0].grad[0][240:250])
#         print(list(local_model.parameters())[0][0][:5])
#         acc = test(test_loader, local_model)
#         print(f"Iter: {i} Test accuracy: {acc:.2F}", flush=True)
#     if i>20:
#         break

In [45]:
for i, (x, y) in enumerate(train_loader):
    dl = [[x,y]]
    res_ptr  = train_ptr(dl=dl, model=local_model)
#     local_model.zero_grad()
    local_model, = res_ptr.get()

    if i%10 == 0:
        print(list(local_model.parameters())[0].grad[0][240:250])
        print(list(local_model.parameters())[0][0][:5])
        acc = test(test_loader, local_model)
        print(f"Iter: {i} Test accuracy: {acc:.2F}", flush=True)
    if i>20:
        break

tensor([-0.1740, -0.1726, -0.1693, -0.1877, -0.2428, -0.1887, -0.1052, -0.0235,
         0.0398,  0.0398])
tensor([ 0.0099,  0.0087, -0.0077, -0.0034, -0.0293], grad_fn=<SliceBackward>)


  grad = getattr(obj, "grad", None)


Iter: 0 Test accuracy: 0.34
tensor([-0.1740, -0.1726, -0.1693, -0.1877, -0.2428, -0.1887, -0.1052, -0.0235,
         0.0398,  0.0398])
tensor([-0.0299, -0.0311, -0.0475, -0.0432, -0.0691], grad_fn=<SliceBackward>)
Iter: 10 Test accuracy: 0.34
tensor([-0.1740, -0.1726, -0.1693, -0.1877, -0.2428, -0.1887, -0.1052, -0.0235,
         0.0398,  0.0398])
tensor([-0.0696, -0.0709, -0.0872, -0.0830, -0.1089], grad_fn=<SliceBackward>)
Iter: 20 Test accuracy: 0.34


In [45]:
self.param_groups[0]["params"][0].grad[0][240:250]

tensor([ 0.0564,  0.0503,  0.0422,  0.0390,  0.0390,  0.0390,  0.0603,  0.0601,
         0.0511,  0.0658,  0.0872,  0.0868, -0.0227, -0.1257, -0.0915, -0.0602,
        -0.2632, -0.3236, -0.2641, -0.1850, -0.2224, -0.2277, -0.2891, -0.4306,
        -0.4448, -0.2071, -0.0554,  0.0273,  0.0579,  0.0531])


In [41]:
list(local_model.parameters())[0][0].shape

torch.Size([784])

In [81]:
# %debug

In [83]:
print(list(local_model.parameters())[0][0][:5])

tensor([0.0352, 0.0421, 0.0192, 0.0494, 0.0243], grad_fn=<SliceBackward>)


In [93]:
local_model.eval()
data, target = next(iter(test_loader))

In [94]:
output = local_model(x=data)[0]
_, pred = th.max(output, 1)

In [95]:
pred

tensor([9, 9, 1,  ..., 6, 6, 9])

In [85]:
# %debug

In [26]:
x.forward(x=th.rand(8))

[[tensor([ 0.1378,  0.0570, -0.2406,  0.2745], requires_grad=True)]]

In [71]:
x.forward.__name__

'forward'

In [40]:
class Model(th.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.w = th.nn.Linear(8,4)
        print(self.forward)
    
#     @make_plan
    def forward(self, x=th.rand(32,8)):
        return self.w(x)

In [41]:
Model()

<bound method Model.forward of Model(
  (w): Linear(in_features=8, out_features=4, bias=True)
)>


Model(
  (w): Linear(in_features=8, out_features=4, bias=True)
)

In [33]:
import inspect

In [34]:
params = inspect.signature(Model.forward).parameters.items()

In [36]:
# params

In [55]:
# inspect.signature(Model.forward).parameters["x"]

In [56]:
# %debug

In [4]:
dummy_dl = sy.lib.python.List([
    [th.rand([4,8]), th.randint(0,2, (4,))] for _ in range(1)
])

In [5]:

@make_plan
def train(dl=dummy_dl):
    x = dl[0]
    return [x]

In [10]:
train.actions

[SaveObjectAction <Storable: 0>,
 RunClassMethodAction ListPointer.__getitem__(IntPointer, )]

# appendix

In [18]:
# %debug

In [19]:
# %debug

In [20]:
# x.layer2.forward.actions

In [21]:
# x.forward.actions

In [22]:
# x.forward.actions[4]._self

In [23]:
# x._parameter_pointers["layer1"].id_at_location

In [24]:
# attr_uid2name = {param.id_at_location: name for name, param in x._parameter_pointers.items()}

In [25]:
# attr_uid2name

In [26]:
# from syft.core.node.common.action.save_object_action import SaveObjectAction
# from syft.core.node.common.action.run_class_method_action import RunClassMethodAction

In [27]:
# plan = x.forward

In [28]:
# for uid, name in attr_uid2name.items():
#     for action in plan.actions:
# #         if isinstance(action, SaveObjectAction):
# #             if action.obj.id == uid:
# #                 action.obj.data = getattr(x, name)
# #                 print(name, action.obj.data)
#         if isinstance(action, RunClassMethodAction):
#             if action._self.id_at_location == uid:
                
                
        

In [29]:
# x.layer1

In [30]:
# plan.actions

In [31]:
# x.layer1.weight

In [32]:
# x.forward.actions[3].obj.id

In [33]:
# x.forward.actions[3].obj.data.weight

In [34]:
# x.forward.actions[3].obj.data

In [35]:
# replace actions where 

In [36]:
# PLAN_BUILDER_VM.store.clear()

In [37]:
# x.send(ROOT_CLIENT)

In [38]:
# list(x.data for x in PLAN_BUILDER_VM.store.values())

In [39]:
# @make_plan
# def train(model = x):
#     y = model.forward(x = th.rand(32,28*28))
#     return [model]

In [40]:
# train.inputs["model"].get()

In [41]:
# train.actions

In [42]:
# @make_plan
# def test(x = th.rand(3)):
#     return x

In [43]:
# PLAN_BUILDER_VM.store.clear()

In [44]:
# list(x.data for x in PLAN_BUILDER_VM.store.values())

In [45]:
# test_ptr = test.send(ROOT_CLIENT)

In [46]:
# list(x.data for x in PLAN_BUILDER_VM.store.values())

In [47]:
# y = test_ptr(x=th.rand(3))

In [48]:
# y = x.forward(x=th.rand(32,28*28))

In [49]:
# y[0].sum().backward()

In [50]:
# print(x.parameters())

In [51]:
# loss = out.sum()
# loss.backward()
# print(x.layer2.p2.weight.grad.get())

In [52]:
# x.layer2.p1.grad

In [53]:
# %debug

In [54]:
# x.layer2.p1.requires_grad

In [55]:
# alice_client = VirtualMachine(name="alice").get_client()
# x_ptr = x.send(alice_client)

In [56]:
# y = x_ptr.forward(x=th.rand(32,28*28))

In [57]:
# y[0].sum().backward()

In [58]:
# x_ptr.parameters()[0]

In [59]:
# x.send(alice_client)

In [60]:
# y = x.forward(x=th.rand(32,28*28))

In [61]:
# x.forward(x=th.rand([32,1,28,28]))

In [62]:
# x.forward.

In [63]:
# alice = sy.VirtualMachine()
# alice_client = alice.get_root_client()

In [64]:
# x = MySyModule()
# x_ptr = x.send(alice_client)
# y, = x_ptr(x=th.rand(8)).get()