In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as D

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import syft

random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
hook = syft.TorchHook(torch)

  return f(*args, **kwds)


In [2]:
number_workers = 15

workers = [syft.VirtualWorker(hook, id="Worker_{}".format(i)) for i in range(number_workers)]
print(workers)

[<VirtualWorker id:Worker_0 #tensors:0>, <VirtualWorker id:Worker_1 #tensors:0>, <VirtualWorker id:Worker_2 #tensors:0>, <VirtualWorker id:Worker_3 #tensors:0>, <VirtualWorker id:Worker_4 #tensors:0>, <VirtualWorker id:Worker_5 #tensors:0>, <VirtualWorker id:Worker_6 #tensors:0>, <VirtualWorker id:Worker_7 #tensors:0>, <VirtualWorker id:Worker_8 #tensors:0>, <VirtualWorker id:Worker_9 #tensors:0>, <VirtualWorker id:Worker_10 #tensors:0>, <VirtualWorker id:Worker_11 #tensors:0>, <VirtualWorker id:Worker_12 #tensors:0>, <VirtualWorker id:Worker_13 #tensors:0>, <VirtualWorker id:Worker_14 #tensors:0>]


In [3]:
trainset = datasets.MNIST(root='../data', train=True, download=True, transform= transforms.ToTensor())
testset  = datasets.MNIST(root='../data', train=False, download=True, transform=transforms.ToTensor())

print(len(trainset))
print(len(testset))

60000
10000


In [4]:
testloader = D.DataLoader(testset, batch_size=1024, shuffle=True)

In [5]:
#create the secure worker
secure_worker = syft.VirtualWorker(hook, id="secure_worker")

In [6]:
#create the federated datasets
federated_trainset = trainset.federate(workers)
print(federated_trainset)

FederatedDataset
    Distributed accross: Worker_0, Worker_1, Worker_2, Worker_3, Worker_4, Worker_5, Worker_6, Worker_7, Worker_8, Worker_9, Worker_10, Worker_11, Worker_12, Worker_13, Worker_14
    Number of datapoints: 60000



In [7]:
class NNET(nn.Module):
    def __init__(self):
        super(NNET, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 14, kernel_size = 5, stride = 1, padding = 2)
        self.conv2 = nn.Conv2d(in_channels = 14, out_channels = 28, kernel_size = 5, stride = 1, padding = 2)
        self.conv3 = nn.Conv2d(in_channels = 28, out_channels = 42, kernel_size = 5, stride = 1, padding = 2)
        self.fc1 = nn.Linear(in_features = 6*6*42, out_features = 500)
        self.fc2 = nn.Linear(in_features = 500, out_features = 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 1)
        x = x.view(-1, 6*6*42)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

In [8]:
#create the model
model = NNET()
batch_size = 32
epochs = 40
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.NLLLoss(reduction='sum')

In [9]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

conv1.weight tensor([[[[-0.0993,  0.0754, -0.1704,  0.1466, -0.1454],
          [-0.1590, -0.1264,  0.0906, -0.0739,  0.0748],
          [-0.1697, -0.1213, -0.0734, -0.0393, -0.1526],
          [ 0.1310, -0.0472,  0.0642,  0.1414,  0.0373],
          [ 0.0547,  0.1931, -0.0902,  0.0634, -0.0890]]],


        [[[ 0.1429,  0.1597, -0.1844,  0.1707,  0.0955],
          [ 0.0872,  0.0823,  0.1663, -0.0264, -0.1691],
          [-0.0574, -0.1409,  0.0132, -0.0373, -0.1073],
          [-0.0182,  0.1895, -0.0158,  0.0064, -0.0312],
          [ 0.0314,  0.1782,  0.1223,  0.0710,  0.0435]]],


        [[[ 0.0472,  0.0773, -0.0258, -0.1859, -0.1237],
          [ 0.1707,  0.0120, -0.1620,  0.0315,  0.1653],
          [-0.1890, -0.1346, -0.0797,  0.0080, -0.0466],
          [-0.0220, -0.1950,  0.0937,  0.1755,  0.1222],
          [-0.1416, -0.1612,  0.0831,  0.0045,  0.0820]]],


        [[[-0.1954, -0.0119,  0.1410,  0.0928,  0.0073],
          [ 0.0393, -0.0189, -0.1100, -0.0756, -0.1218],
      

# Training and test model

In [10]:
workers_models = []
for worker in workers:
    worker_model = NNET().copy().send(worker)
    workers_models.append(worker_model)
    
for param in model.parameters():
    param.grad = torch.zeros_like(param.data)

In [12]:
for epoch in range(epochs):
    optimizer.zero_grad()

    workers_gradients = [list() for l in model.parameters()]
    workers_batch_idx = []
    workers_loss = []
    workers_correct_preds = []
    
    for idx, worker in enumerate(workers):
        dataset = federated_trainset.datasets[workers[idx].id]
        dataloader = D.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        img, label = next(iter(dataloader))
        output = workers_models[idx](img)
        loss = criterion(output, label)
        loss.backward()
        
        #Make secret sharing and fixed precision on the gradients
        for gradient, param in zip(workers_gradients, workers_models[idx].parameters()):
            gradient.append(param.grad.clone().fix_precision().share(*workers, crypto_provider=worker).get())
            param.grad.zero_()
        #For each metric do secret sharing and fixed precision as well   
        workers_batch_idx.append(torch.tensor(0, dtype=torch.long).send(worker)\
                                 .add_(img.shape[0]).share(*workers, crypto_provider=worker).get())
        workers_loss.append(loss.data.clone().fix_precision()\
                                .share(*workers, crypto_provider=worker).get())
        workers_correct_preds.append((output.data.argmax(dim=1) == label)\
                                     .sum().share(*workers, crypto_provider=worker).get())

    #sum the instances of the workers data
    batch_idx = sum(workers_batch_idx).get().item()
    
    
    for param, gradients in zip(model.parameters(), workers_gradients):
        param.grad.copy_(sum(gradients).get().float_precision() / batch_idx)
        
    Loss = sum(workers_loss).get().float_precision().item() / batch_idx
    Accuracy = sum(workers_correct_preds).get().item() / batch_idx

    print('Epoch: {}\t Training Loss: {}\tTraining Accuracy: {}'.format(epoch, Loss, Accuracy))

    optimizer.step()
    
    #clean-up
    for model_params, worker_params in zip(model.parameters(), 
                           zip(*[worker_model.parameters() for worker_model in workers_models])):
         for worker, worker_param in zip(workers, worker_params):
            worker_param.data.copy_(model_params.data.clone().send(worker))

Epoch: 0	 Training Loss: 2.3044102986653647	Training Accuracy: 0.09583333333333334
Epoch: 1	 Training Loss: 2.3152875264485675	Training Accuracy: 0.11041666666666666
Epoch: 2	 Training Loss: 2.3010709126790365	Training Accuracy: 0.13958333333333334
Epoch: 3	 Training Loss: 2.299097951253255	Training Accuracy: 0.075
Epoch: 4	 Training Loss: 2.2824000040690104	Training Accuracy: 0.18333333333333332
Epoch: 5	 Training Loss: 2.247668711344401	Training Accuracy: 0.20833333333333334
Epoch: 6	 Training Loss: 2.1898790995279946	Training Accuracy: 0.20625
Epoch: 7	 Training Loss: 2.109412511189779	Training Accuracy: 0.2625
Epoch: 8	 Training Loss: 1.8952479044596353	Training Accuracy: 0.4270833333333333
Epoch: 9	 Training Loss: 1.6825749715169271	Training Accuracy: 0.41458333333333336
Epoch: 10	 Training Loss: 1.3339645385742187	Training Accuracy: 0.5791666666666667
Epoch: 11	 Training Loss: 1.1573478698730468	Training Accuracy: 0.6208333333333333
Epoch: 12	 Training Loss: 1.0275041580200195	Tr

In [13]:
test_loss = 0
batch_idx = 0
correct_preds  = 0
with torch.no_grad():
    for i, (imgs, targets) in enumerate(testloader, 1):
        batch_idx += imgs.size(0)

        output = model(imgs)

        test_loss += criterion(output, targets).item()
        correct_preds += (output.argmax(dim=1) == targets).sum().item()

print()
print("Test Loss:", test_loss / batch_idx)
print("Test Accuracy:", correct_preds / batch_idx)



Test Loss: 0.2203809341430664
Test Accuracy: 0.9342
