In [None]:
!pip install pytorch-ignite
!pip install syft==0.2.9

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine
from ignite.metrics import Accuracy
from ignite.engine import Events

import syft as sy 

In [3]:
hook = sy.TorchHook(torch)  # attach the pytorch hook
joe = sy.VirtualWorker(hook, id="joe")  #  remote worker joe
jane = sy.VirtualWorker(hook, id="jane")  #  remote worker  jane

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

train_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=True, download=True, transform=transform)
test_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=False, download=True, transform=transform)

federated_train_loader = sy.FederatedDataLoader(
    train_set.federate((joe, jane)), batch_size=256, shuffle=True) # the federate() method splits the data within the workers

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=256, shuffle=True)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
        
model = Net()
print(model)
optimizer = optim.SGD(model.parameters(), lr=0.01)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [6]:
def train_step(engine, batch):
    data, targets = batch
    model.send(data.location) # send the model to the client device where the data is present
    model.train()
    optimizer.zero_grad()
    outputs = model(data)
    loss = F.nll_loss(outputs, targets)
    loss.backward()
    optimizer.step()
    model.get() # get back the improved model
    return loss.get()

trainer = Engine(train_step)
ProgressBar().attach(trainer)

In [7]:
def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch
        y_pred = model(x)

    return y_pred, y
    
evaluator = Engine(validation_step)
Accuracy().attach(evaluator, "accuracy")

In [8]:
validate_every = 1
log_every=1

@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def run_validation():
    evaluator.run(test_loader)

@trainer.on(Events.EPOCH_COMPLETED(every=log_every))
def log_validation():
    metrics = evaluator.state.metrics
    # print(metrics)
    print(f"Epoch: {trainer.state.epoch},  Accuracy: {metrics['accuracy']}")

In [9]:
trainer.run(federated_train_loader, max_epochs=2)

[1/235]   0%|           [00:00<?]

Epoch: 1,  Accuracy: 0.844


[1/235]   0%|           [00:00<?]

Epoch: 2,  Accuracy: 0.8752


State:
	iteration: 470
	epoch: 2
	epoch_length: 235
	max_epochs: 2
	output: <class 'torch.Tensor'>
	batch: <class 'tuple'>
	metrics: <class 'dict'>
	dataloader: <class 'syft.frameworks.torch.fl.dataloader.FederatedDataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>