In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
from syft.lib.torch.module import ModelExecutor
logger.remove()

In [3]:
alice = VirtualMachine()
alice_client = alice.get_root_client()

# Define Model

In [4]:
class MLP(th.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = th.nn.Linear(8, 4)
        self.a1 = th.nn.ReLU()
        self.l2 = th.nn.Linear(4, 2)
        
    @staticmethod
    def forward(model, x):
        l1_out = model.a1(model.l1(x))
        l2_out = model.l2(l1_out)
        return l2_out

In [5]:
local_model = MLP()

In [6]:
executor = ModelExecutor(local_model)

# Define Plan

In [7]:
dummy_dl = sy.lib.python.List([
    [th.rand([4,8]), th.randint(0,2, (4,))] for _ in range(1)
])

In [8]:
@make_plan
def train(dl=dummy_dl,
          model=local_model
         ):
    remote_torch = ROOT_CLIENT.torch
    optimizer = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    criterion = remote_torch.nn.CrossEntropyLoss()
    
    for xy in dl:
        x = xy[0]
        y = xy[1]
        out = executor(model,x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
    return loss, model

  Variable._execution_engine.run_backward(


# Run

In [9]:
remote_model = local_model.send(alice_client)
dummy_dl_ptr = dummy_dl.send(alice_client)
train_ptr = train.send(alice_client)

In [10]:
res_ptr = train_ptr(dl=dummy_dl_ptr, model=remote_model)

In [11]:
res = res_ptr.get()
loss = res[0]
get_model = res[1]

  grad = getattr(obj, "grad", None)


In [12]:
loss = res[0]
loss

tensor(0.7463, requires_grad=True)

In [13]:
get_model = res[1]

In [14]:
list(local_model.parameters())

[Parameter containing:
 tensor([[-0.3135, -0.2279, -0.3278,  0.0852,  0.1022, -0.1589,  0.0081, -0.0930],
         [ 0.1456, -0.1648, -0.1655, -0.3043, -0.1124, -0.1589,  0.0799,  0.2055],
         [ 0.1958,  0.1776, -0.0701, -0.2342, -0.1607,  0.1567, -0.0070,  0.2601],
         [-0.0713, -0.0567,  0.0856, -0.1195,  0.0285, -0.2217,  0.1879,  0.2331]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0974,  0.3509, -0.1961, -0.1452], requires_grad=True),
 Parameter containing:
 tensor([[-0.2790,  0.1434,  0.0993,  0.3418],
         [-0.4207,  0.0309,  0.0196,  0.0993]], requires_grad=True),
 Parameter containing:
 tensor([-0.4862,  0.1588], requires_grad=True)]

In [15]:
list(get_model.parameters())

[Parameter containing:
 tensor([[-0.3135, -0.2279, -0.3278,  0.0852,  0.1022, -0.1589,  0.0081, -0.0930],
         [ 0.1444, -0.1637, -0.1660, -0.3050, -0.1120, -0.1588,  0.0788,  0.2055],
         [ 0.1960,  0.1787, -0.0697, -0.2341, -0.1602,  0.1570, -0.0069,  0.2607],
         [-0.0713, -0.0567,  0.0856, -0.1195,  0.0285, -0.2217,  0.1879,  0.2331]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0974,  0.3507, -0.1948, -0.1452], requires_grad=True),
 Parameter containing:
 tensor([[-0.2790,  0.1412,  0.0998,  0.3418],
         [-0.4207,  0.0331,  0.0191,  0.0993]], requires_grad=True),
 Parameter containing:
 tensor([-0.4711,  0.1437], requires_grad=True)]