In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
logger.remove()

In [3]:
alice = VirtualMachine()
alice_client = alice.get_root_client()

# Define Model

In [4]:
class MLP(th.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = th.nn.Linear(4, 4)
        self.l2 = th.nn.Linear(4, 2)
        
    def forward(self, x=th.rand(4), th=th):
        x = self.l1(x)
        x = th.relu(x)
        x = self.l2(x)
        return x

In [5]:
local_model = MLP()

# Define Plan

In [6]:
dummy_dl = sy.lib.python.List([
    [th.rand([4,4]), th.randint(0,2, (4,))] for _ in range(1)
])

In [7]:
@make_plan
def train(dl=dummy_dl,
          model=local_model
         ):
    remote_torch = ROOT_CLIENT.torch
    optimizer = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)
    criterion = remote_torch.nn.CrossEntropyLoss()
    
    for xy in dl:
        x = xy[0]
        y = xy[1]
        out = model(x=x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
    return loss, model

  Variable._execution_engine.run_backward(


# Run

In [8]:
remote_model = local_model.send(alice_client)
dummy_dl_ptr = dummy_dl.send(alice_client)
train_ptr = train.send(alice_client)

In [9]:
res_ptr = train_ptr(dl=dummy_dl_ptr, model=remote_model)

In [10]:
res = res_ptr.get()
loss = res[0]
get_model = res[1]

  grad = getattr(obj, "grad", None)


In [11]:
loss = res[0]
loss

tensor(0.8033, requires_grad=True)

In [12]:
get_model = res[1]
type(get_model)

__main__.MLP

In [13]:
list(local_model.parameters())

[Parameter containing:
 tensor([[ 0.1649, -0.4974,  0.3292, -0.3353],
         [-0.4181,  0.3504, -0.1989, -0.4969],
         [-0.4332,  0.2956,  0.1563,  0.1294],
         [-0.4273,  0.2099, -0.3497,  0.1874]], requires_grad=True),
 Parameter containing:
 tensor([ 0.1494,  0.4016, -0.4857, -0.2483], requires_grad=True),
 Parameter containing:
 tensor([[ 3.5335e-01,  1.7262e-04, -3.9988e-01, -4.9935e-01],
         [-2.6603e-01, -2.1481e-02, -1.9673e-01, -2.2271e-01]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0615, -0.2713], requires_grad=True)]

In [14]:
list(get_model.parameters())

[Parameter containing:
 tensor([[ 0.1680, -0.4955,  0.3247, -0.3386],
         [-0.4180,  0.3501, -0.1990, -0.4971],
         [-0.4332,  0.2956,  0.1563,  0.1294],
         [-0.4273,  0.2099, -0.3497,  0.1874]], requires_grad=True),
 Parameter containing:
 tensor([ 0.1465,  0.4011, -0.4857, -0.2483], requires_grad=True),
 Parameter containing:
 tensor([[ 0.3514, -0.0055, -0.3999, -0.4993],
         [-0.2641, -0.0158, -0.1967, -0.2227]], requires_grad=True),
 Parameter containing:
 tensor([ 0.0277, -0.2374], requires_grad=True)]