In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
from syft.lib.torch.module import ModelExecutor
logger.remove()

In [3]:
alice = VirtualMachine()
alice_client = alice.get_root_client()

# Define Model

In [4]:
class MLP(th.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = th.nn.Linear(8, 4)
        self.a1 = th.nn.ReLU()
        self.l2 = th.nn.Linear(4, 2)
        
    @staticmethod
    def forward(model, x):
        l1_out = model.a1(model.l1(x))
        l2_out = model.l2(l1_out)
        return l2_out

In [5]:
local_model = MLP()

In [6]:
executor = ModelExecutor(local_model)

# Define Plan

In [7]:
dummy_dl = sy.lib.python.List([
    [th.rand([4,8]), th.randint(0,2, (4,))] for _ in range(1)
])

In [8]:
@make_plan
def train(dl=dummy_dl,
          model=local_model
         ):
    remote_torch = ROOT_CLIENT.torch
    optimizer = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    criterion = remote_torch.nn.CrossEntropyLoss()
    
    for xy in dl:
        x = xy[0]
        y = xy[1]
        out = executor(model,x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
    return loss, model

  Variable._execution_engine.run_backward(


# Run

In [9]:
remote_model = local_model.send(alice_client)
dummy_dl_ptr = dummy_dl.send(alice_client)
train_ptr = train.send(alice_client)

In [10]:
res_ptr = train_ptr(dl=dummy_dl_ptr, model=remote_model)

In [11]:
res = res_ptr.get()
loss = res[0]
get_model = res[1]

  grad = getattr(obj, "grad", None)


In [12]:
loss = res[0]
loss

tensor(0.7022, requires_grad=True)

In [13]:
get_model = res[1]

In [14]:
list(local_model.parameters())

[Parameter containing:
 tensor([[ 0.2833, -0.3295,  0.2022,  0.1772,  0.2799,  0.1299,  0.0827, -0.3357],
         [-0.0901, -0.2366, -0.3408, -0.2212, -0.0780,  0.0700,  0.2329,  0.0218],
         [-0.0892, -0.1332,  0.1858, -0.3072,  0.2823, -0.1750,  0.1798, -0.1320],
         [-0.3528,  0.3329,  0.2806,  0.0390,  0.0221,  0.3490, -0.3451,  0.2502]],
        requires_grad=True),
 Parameter containing:
 tensor([0.1904, 0.3201, 0.1401, 0.1586], requires_grad=True),
 Parameter containing:
 tensor([[ 0.4244,  0.4566,  0.3440,  0.4543],
         [ 0.2025, -0.2310,  0.1854, -0.0016]], requires_grad=True),
 Parameter containing:
 tensor([-0.2559,  0.3375], requires_grad=True)]

In [15]:
list(get_model.parameters())

[Parameter containing:
 tensor([[ 0.2854, -0.3280,  0.2033,  0.1776,  0.2811,  0.1307,  0.0849, -0.3357],
         [-0.0847, -0.2380, -0.3378, -0.2236, -0.0727,  0.0700,  0.2362,  0.0223],
         [-0.0887, -0.1336,  0.1853, -0.3084,  0.2827, -0.1756,  0.1804, -0.1328],
         [-0.3486,  0.3361,  0.2829,  0.0399,  0.0247,  0.3507, -0.3408,  0.2502]],
        requires_grad=True),
 Parameter containing:
 tensor([0.1915, 0.3215, 0.1387, 0.1610], requires_grad=True),
 Parameter containing:
 tensor([[ 0.4299,  0.4568,  0.3471,  0.4538],
         [ 0.1969, -0.2312,  0.1823, -0.0011]], requires_grad=True),
 Parameter containing:
 tensor([-0.2506,  0.3322], requires_grad=True)]