In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as data

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import syft

random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
hook = syft.TorchHook(torch)

  return f(*args, **kwds)


In [2]:
number_workers = 24

workers = [syft.VirtualWorker(hook, id="Worker {}".format(i)) for i in range(number_workers)]
print(workers)

[<VirtualWorker id:Worker 0 #tensors:0>, <VirtualWorker id:Worker 1 #tensors:0>, <VirtualWorker id:Worker 2 #tensors:0>, <VirtualWorker id:Worker 3 #tensors:0>, <VirtualWorker id:Worker 4 #tensors:0>, <VirtualWorker id:Worker 5 #tensors:0>, <VirtualWorker id:Worker 6 #tensors:0>, <VirtualWorker id:Worker 7 #tensors:0>, <VirtualWorker id:Worker 8 #tensors:0>, <VirtualWorker id:Worker 9 #tensors:0>, <VirtualWorker id:Worker 10 #tensors:0>, <VirtualWorker id:Worker 11 #tensors:0>, <VirtualWorker id:Worker 12 #tensors:0>, <VirtualWorker id:Worker 13 #tensors:0>, <VirtualWorker id:Worker 14 #tensors:0>, <VirtualWorker id:Worker 15 #tensors:0>, <VirtualWorker id:Worker 16 #tensors:0>, <VirtualWorker id:Worker 17 #tensors:0>, <VirtualWorker id:Worker 18 #tensors:0>, <VirtualWorker id:Worker 19 #tensors:0>, <VirtualWorker id:Worker 20 #tensors:0>, <VirtualWorker id:Worker 21 #tensors:0>, <VirtualWorker id:Worker 22 #tensors:0>, <VirtualWorker id:Worker 23 #tensors:0>]


In [3]:
trainset = datasets.MNIST(root='../data', train=True, download=True, transform= transforms.ToTensor())
testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())

print(len(trainset))
print(len(testset))

60000
10000


In [4]:
federated_trainset = trainset.federate(workers)
print(federated_trainset)

FederatedDataset
    Distributed accross: Worker 0, Worker 1, Worker 2, Worker 3, Worker 4, Worker 5, Worker 6, Worker 7, Worker 8, Worker 9, Worker 10, Worker 11, Worker 12, Worker 13, Worker 14, Worker 15, Worker 16, Worker 17, Worker 18, Worker 19, Worker 20, Worker 21, Worker 22, Worker 23
    Number of datapoints: 60000



In [5]:
class NNET(nn.Module):
    def __init__(self):
        super(NNET, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 14, kernel_size = 5, stride = 1, padding = 2)
        self.conv2 = nn.Conv2d(in_channels = 14, out_channels = 28, kernel_size = 5, stride = 1, padding = 2)
        self.conv3 = nn.Conv2d(in_channels = 28, out_channels = 42, kernel_size = 5, stride = 1, padding = 2)
        self.fc1 = nn.Linear(in_features = 6*6*42, out_features = 500)
        self.fc2 = nn.Linear(in_features = 500, out_features = 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 1)
        x = x.view(-1, 6*6*42)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

In [6]:
model = NNET()
epochs = 30
batch_size = 32

In [7]:
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss(reduction='sum')
for epoch in range(epochs):
    optimizer.zero_grad()
    model = model.send(syft.local_worker)

    batch_idx = torch.tensor(0.).send(syft.local_worker)
    training_loss   = torch.tensor(0.).send(syft.local_worker)
    correct_preds  = torch.tensor(0.).send(syft.local_worker)
    for i, worker in enumerate(workers): 
        model.move(worker)
        batch_idx.move(worker)
        training_loss.move(worker)
        correct_preds.move(worker)
        dataset = federated_trainset.datasets[worker.id]
        dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        img,target = next(iter(dataloader))
        batch_idx.add_(img.shape[0]) 
        output = model(img)
        loss = criterion(output, target)
        loss.backward()
        training_loss.add_(loss.data)
        correct_preds.add_(torch.sum(torch.eq(output.data.argmax(dim=1), target)))

    for param in model.parameters():
        param.grad.div_(batch_idx)

    model = model.get() 

    Loss     = training_loss.div_(batch_idx).get().item()
    Accuracy = correct_preds.div_(batch_idx).get().item()

    
    print('Epoch: {}\t Training Loss: {:.3f}\tTraining Accuracy: {:.3f}'.format(epoch, Loss, Accuracy))

    optimizer.step()

Epoch: 0	 Training Loss: 2.304	Training Accuracy: 0.087
Epoch: 1	 Training Loss: 2.279	Training Accuracy: 0.134
Epoch: 2	 Training Loss: 2.255	Training Accuracy: 0.133
Epoch: 3	 Training Loss: 2.142	Training Accuracy: 0.350
Epoch: 4	 Training Loss: 1.895	Training Accuracy: 0.454
Epoch: 5	 Training Loss: 1.618	Training Accuracy: 0.470
Epoch: 6	 Training Loss: 1.808	Training Accuracy: 0.322
Epoch: 7	 Training Loss: 1.398	Training Accuracy: 0.522
Epoch: 8	 Training Loss: 1.425	Training Accuracy: 0.525
Epoch: 9	 Training Loss: 1.144	Training Accuracy: 0.605
Epoch: 10	 Training Loss: 1.201	Training Accuracy: 0.629
Epoch: 11	 Training Loss: 1.033	Training Accuracy: 0.725
Epoch: 12	 Training Loss: 0.852	Training Accuracy: 0.737
Epoch: 13	 Training Loss: 0.805	Training Accuracy: 0.738
Epoch: 14	 Training Loss: 0.644	Training Accuracy: 0.781
Epoch: 15	 Training Loss: 0.613	Training Accuracy: 0.806
Epoch: 16	 Training Loss: 0.606	Training Accuracy: 0.780
Epoch: 17	 Training Loss: 0.660	Training 

In [9]:
testloader = data.DataLoader(testset, batch_size=1000)
test_loss = 0
batch_idx = 0
correct_preds  = 0
with torch.no_grad():
    for i, (imgs, labels) in enumerate(testloader, 1):
        batch_idx += imgs.size(0)

        preds = model(imgs)

        test_loss += criterion(preds, labels).item()
        correct_preds += (preds.argmax(dim=1) == labels).sum().item()
    
    Loss = test_loss/batch_idx
    Accuracy = correct_preds/batch_idx
print('Test Loss: {:.3f}\tTest Accuracy: {:.3f}'.format(Loss, Accuracy))

Test Loss: 0.251	Test Accuracy: 0.928
