<a href="https://colab.research.google.com/github/ArianaMarta/small_pysyft/blob/main/With_RPIs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install syft==0.2.9

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

In [None]:
# import the syft library
import syft as sy

In [None]:
## THE OTHER NOTEBOOK ##

# hook torch with syft to add extra funcionalities 
# to support Federated Learning and other private AI tools
hook = sy.TorchHook(torch)

# create two virtual workers, in this case two schools for example
# they will hold the data while training the model locally
aalto_school = sy.VirtualWorker(hook, id="aalto")
tampere_school = sy.VirtualWorker(hook, id="tampere")

In [None]:
# Hook PyTorch
hook = sy.TorchHook(torch)  
# When using raspberry pi 
hook = sy.TorchHook(torch) 

kwargs_websocket_aalto = {"host": "ip_aalto", "hook": hook}
aalto_school = WebsocketClientWorker(id="aalto", port=8777, **kwargs_websocket_aalto)

kwargs_websocket_tampere = {"host": "ip_tampere", "hook": hook}
tampere_school = WebsocketClientWorker(id="tampere", port=8778, ** kwargs_websocket_tampere)

virtual_workers = [aalto_school,tampere_school]

In [None]:
# Now lets load the data 

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )), # Normalize a tensor image with mean and standard deviation
])

train_set = datasets.MNIST(
    "../data", train=True, download=True, transform=transform)
test_set = datasets.MNIST(
    "../data", train=False, download=True, transform=transform)

# transform the data into a federated dataset using .federate()
# the federate() method splits the data within the workers
federated_train_loader = sy.FederatedDataLoader(                  
    train_set.federate((aalto_school, tampere_school)), batch_size=64, shuffle=True) 

# test data remains with us locally
# so we use torch.utils.data.DataLoader as we normally did 
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=64, shuffle=True)

In [None]:
# Now lets create a simple Net

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 784-dim tensor of pixel values for each image (28*28 sized images)
        self.fc1 = nn.Linear(784, 500)
        # producing a tensor of length 10 which indicates the class scores for an input image (0-9)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
model_ori = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)

n_epoch = 10
model_virtual = {}

for worker in virtual_workers:
  model_copied = model_ori.copy()
  model_ptr = model_copied.send(worker)
  model_virtual[worker.id] = model_ptr

for epoch in range(n_epoch):
  #model.train()
  
  for batch_index, (data, target) in enumerate(federated_train_loader):
    model_ptr = model_virtual[data.location.id]
    optimizer.zero_grad()     # training the model
    output = model_ptr(data)

    # this loss is a pointer to the tensor loss at the remote location
    loss = F.nll_loss(output, target)
    # call backward() on the loss pointer, that will send the command to call
    # backward on the actual loss tensor present on the remote machine
    loss.backward()
    optimizer.step()
    # get back the updated model/ improved model using .get() 
    model_got = model_ptr.get() 
    #Perform model weights' updates    
    for param in model_got.parameters():
        param.data.add_(param.grad.data)
    #sent back the model to the RPIs
    model_sent = model_got.send(data.location)
    model_virtual[data.location.id] = model_sent

    if batch_index % 100 == 0: 
        # the the variable loss was also created at the remote worker, 
        # so we need to explicitly get it back
        loss = loss.get() # get back the loss
        print('Training Epoch: {:2d} [{:5d}/{:5d} ({:3.0f}%)]\tLoss: {:.6f}'.format(
            epoch+1, batch_index * 64,
            len(federated_train_loader) * 64,
            100. * batch_index / len(federated_train_loader), loss.item()))    



In [None]:
# now for testing
# we will receive the model weights that will be aggregated to form a combined model.

#model.eval()

test_losses_dic = {}

with torch.no_grad():
  # test the model in each virtual worker
  for worker in virtual_workers:
    test_losses = 0
    correct = 0
    # for each RPIs train the same testset 
    for data, target in test_loader:
      model_ptr = model_virtual[worker.id]
      output = model_ptr(data)
      # add losses together
      test_losses += F.nll_loss(
              output, target, reduction='sum').item()
      # to get the index of the max log-probability class
      pred = output.argmax(1, keepdim=True) 
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_losses_dic[worker.id] = test_losses / len(test_loader.dataset)

for worker in virtual_workers:
  print('\nTest set for worker: %s Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      worker,
      test_losses,correct,
      len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))