<h1>Federated Learning - MNIST Example</h1>
<h2>Create a PyGrid cluster(1 Network + N nodes)</h2>
<h2>Populate remote PyGrid nodes with labeled tensors </h2>
In this notebook, we will populate our PyGrid nodes with labeled data so that it will be used later by people interested in train models.

**NOTE:** At the time of running this notebook, we will create a cluster using the auto-scale API.  

Components:
 - PyGrid Network 
 - 3 PyGrid Node
 
This notebook was made based on <a href="https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb">Part 10: Federated Learning with Encrypted Gradient Aggregation</a> tutorial

<h2>Import dependencies</h2>

In [None]:
import syft as sy
from syft.grid.clients.dynamic_fl_client import DynamicFLClient
import syft.grid.autoscale.gcloud as gcloud
import syft.grid.autoscale.utils.gcloud_configurations as configs
import torch
import pickle
import time
import torchvision
from torchvision import datasets, transforms
import tqdm

<h2>Setup auto-scale API using GCP</h2>

Pass:
 - Path of credentials.json file
 - Project ID
 - Region of the project

In [None]:
NEW = gcloud.GoogleCloud(
    credentials="/usr/terraform.json", project_id="project", region=configs.Region.us_central1,
)

<h2>Spin-up a PyGrid Cluster</h2>

 - Reserve an IP
 - Spin-up a Cluster

In [None]:
#to create a cluster we first need to reserve an external ip
NEW.reserve_ip("grid")

In [None]:
c1 = NEW.create_cluster(
    name="tutorial",
    machine_type=configs.MachineType.f1_micro,
    zone=configs.Zone.us_central1_a,
    reserve_ip_name="grid",
    target_size=3,
    eviction_policy="delete",
)

<h2>Setup config</h2>

 - Init hook, connect with grid nodes, etc...

 - Open your GCP console

 - Insert the external IP of two nodes in the list nodes below

In [None]:
hook = sy.TorchHook(torch)

# Connect directly to grid nodes
nodes = ["xxxxxxxxxx",
         "xxxxxxxxxx"]

compute_nodes = []
for node in nodes:
    compute_nodes.append( DynamicFLClient(hook, node) )

## 1 - Load Dataset

The code below will load and preprocess an N amount of MNIST data samples.

In [None]:
N_SAMPLES = 10000
MNIST_PATH = './dataset'

transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)

dataiter = iter(trainloader)

images_train_mnist, labels_train_mnist = dataiter.next()

<h2>2 - Split dataset </h2>
We will split our dataset to send to nodes. 

In [None]:
datasets_mnist = torch.split(images_train_mnist, int(len(images_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)
labels_mnist = torch.split(labels_train_mnist, int(len(labels_train_mnist) / len(compute_nodes)), dim=0 )  #tuple of chunks (labels / number of nodes)

<h2>3 - Tagging tensors</h2>
The code below will add a tag (of your choice) to the data that will be sent to grid nodes. This tag is important as the network will need it to retrieve this data later.

In [None]:
tag_img = []
tag_label = []


for i in range(len(compute_nodes)):
    tag_img.append(datasets_mnist[i].tag("#X", "#mnist", "#dataset").describe("The input datapoints to the MNIST dataset."))
    tag_label.append(labels_mnist[i].tag("#Y", "#mnist", "#dataset").describe("The input labels to the MNIST dataset."))

<h2> 4 - Sending our tensors to grid nodes</h2>

In [None]:
shared_x1 = tag_img[0].send(compute_nodes[0]) # First chunk of dataset to Bob
shared_x2 = tag_img[1].send(compute_nodes[1]) # Second chunk of dataset to Alice

shared_y1 = tag_label[0].send(compute_nodes[0]) # First chunk of labels to Bob
shared_y2 = tag_label[1].send(compute_nodes[1]) # Second chunk of labels to Alice

In [None]:
print("X tensor pointers: ", shared_x1, shared_x2)
print("Y tensor pointers: ", shared_y1, shared_y2)

<h2>Disconnect nodes</h2>

In [None]:
for i in range(len(compute_nodes)):
    compute_nodes[i].close()