In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
from syft.lib.torch.module import ModelExecutor
logger.remove()

In [3]:
# Dataset
mnist_path = Path.home() / ".pysyft" / "mnist"
mnist_path.mkdir(exist_ok=True, parents=True)

mnist_train = datasets.MNIST(str(mnist_path), train=True, download=True,
               transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

mnist_test = datasets.MNIST((mnist_path), train=False, 
              transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = th.utils.data.DataLoader(mnist_train, batch_size=32, shuffle=True, pin_memory=True)
test_loader = th.utils.data.DataLoader(mnist_test, batch_size=1024, shuffle=False, pin_memory=True)

# define model


In [4]:
class ForwardToPlanConverter(type):
    def __call__(cls, *args, **kwargs):
        obj = type.__call__(cls, *args, **kwargs)
        obj.make_forward_plan()
        return obj

class SyModule(th.nn.Module, metaclass=ForwardToPlanConverter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.building_forward = False
        self._parameter_pointers = dict()

    def make_forward_plan(self):
        if not hasattr(self, "forward"):
            raise ValueError("Missing .forward() method for Module")
        self.building_forward = True
        plan = make_plan(self.forward)
        self.forward = plan
        self.__call__ = plan
        self.building_forward = False
        
    def __getattr__(self, name):
        # this is __getattr__ instead of __getattribute__ because of the structure of torch.nn.Module
        if name in self._parameter_pointers and self.building_forward:
            return self._parameter_pointers[name]
        
        res = super().__getattr__(name)
        if isinstance(res, (th.nn.Module, th.nn.Parameter)) and self.building_forward:
            res_ptr = res.send(ROOT_CLIENT)
            self._parameter_pointers[name] = res_ptr
            return res_ptr
        else:
            return res

In [5]:
class MySyModuleBlock(SyModule):
    def __init__(self):
        super().__init__()
        self.p1 = th.nn.Parameter(th.rand(100,10) * 0.01)
    
    def forward(self, x = th.rand(32, 100)):
        o1 = x @ self.p1
        return o1

In [6]:
class MySyModule(SyModule):
    def __init__(self):
        super().__init__()
        self.layer1 = th.nn.Linear(28*28,100)
        self.relu1 = th.nn.ReLU()
        self.layer2 = MySyModuleBlock()
    
    def forward(self, x = th.rand(32,28*28)):
        x_reshaped = x.view(-1, 28 * 28)        
        o1 = self.layer1(x_reshaped)
        a1 = self.relu1(o1)
        out = self.layer2(x=a1)[0]
        return out

# Train plan

In [7]:
model = MySyModule()

In [8]:
dummy_dl = sy.lib.python.List([next(iter(train_loader))])

In [9]:
remote_torch = ROOT_CLIENT.torch

In [10]:
@make_plan
def train(dl = dummy_dl, model = model):
    optimizer = remote_torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0)

    for xy in dl:
        optimizer.zero_grad()
        x, y = xy[0], xy[1]
        out = model(x=x)[0]
        loss =  remote_torch.nn.functional.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    
    return [model]

## Util

In [11]:
def test(test_loader, model):
    correct = []
    model.eval()

    for data, target in test_loader:        
        output = model(x=data)[0]
        _, pred = th.max(output, 1)
        correct.append(th.sum(np.squeeze(pred.eq(target.data))))
    acc = sum(correct) / len(test_loader.dataset)
    return acc

In [12]:
def show_predictions(test_loader, model, n=6):
    xs, ys = next(iter(test_loader))
    preds = model(x=xs)[0].detach()
    
    fig, axs = plt.subplots(1, n, sharex='col', sharey='row', figsize=(16, 8))
    for i in range(n):
        ax = axs[i]
        ax.set_xticks([]),ax.set_yticks([])
        ax.set_xlabel(f"prediction: {np.argmax(preds[i])}, actual: {ys[i]}")
        ax.imshow(xs[i].reshape((28, 28)))

# Train

In [13]:
alice_client = VirtualMachine(name="alice").get_client()
train_ptr = train.send(alice_client)

In [14]:
for i, (x, y) in enumerate(train_loader):
    dl = [[x,y]]
    res_ptr  = train_ptr(dl=dl, model=model)
    model, = res_ptr.get()
    
    if i%10 == 0 and i!=0:
        acc = test(test_loader, model)
        print(f"Iter: {i} Test accuracy: {acc:.2F}", flush=True)
    if i>50:
        break

  grad = getattr(obj, "grad", None)


Iter: 10 Test accuracy: 0.47
Iter: 20 Test accuracy: 0.68
Iter: 30 Test accuracy: 0.68
Iter: 40 Test accuracy: 0.80
Iter: 50 Test accuracy: 0.76
