<h1>Federated Learning - MNIST Example</h1>
<h2>Populate remote grid nodes with labeled tensors </h2>
In this notebook, we will populate our grid nodes with labeled data so that it will be used later by people interested in train models.

**NOTE:** At the time of running this notebook, we were running the grid components in background mode.  

Components:
 - Grid Gateway(http://localhost:5000)
 - Grid Node Bob (http://localhost:3000)
 - Grid Node Alice (http://localhost:3001)
 
This notebook was made based on <a href="https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb">Part 10: Federated Learning with Encrypted Gradient Aggregation</a> tutorial

<h2>Import dependencies</h2>

In [3]:
import syft as sy
from syft.workers.node_client import NodeClient
import torch
import pickle
import time
import torchvision
from torchvision import datasets, transforms
import tqdm

<h2>Setup config</h2>
Init hook, connect with grid nodes, etc...

In [4]:
hook = sy.TorchHook(torch)

# Connect directly to grid nodes
nodes = ["ws://localhost:3000/",
         "ws://localhost:3001/"]

compute_nodes = []
for node in nodes:
    compute_nodes.append( NodeClient(hook, node) )

## 1 - Load Dataset

The code below will load and preprocess an N amount of MNIST data samples.

In [13]:
N_SAMPLES = 10000
MNIST_PATH = './dataset'

transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)

dataiter = iter(trainloader)

images_train_mnist, labels_train_mnist = dataiter.next()

<h2>2 - Split dataset </h2>
We will split our dataset to send to nodes. 

In [6]:
datasets_mnist = torch.split(images_train_mnist, int(len(images_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)
labels_mnist = torch.split(labels_train_mnist, int(len(labels_train_mnist) / len(compute_nodes)), dim=0 )  #tuple of chunks (labels / number of nodes)

<h2>3 - Tagging tensors</h2>
The code below will add a tag (of your choice) to the data that will be sent to grid nodes. This tag is important as the gateway will need it to retrieve this data later.

In [7]:
tag_img = []
tag_label = []


for i in range(len(compute_nodes)):
    tag_img.append(datasets_mnist[i].tag("#X", "#mnist", "#dataset").describe("The input datapoints to the MNIST dataset."))
    tag_label.append(labels_mnist[i].tag("#Y", "#mnist", "#dataset").describe("The input labels to the MNIST dataset."))

<h2> 4 - Sending our tensors to grid nodes</h2>

In [8]:
shared_x1 = tag_img[0].send(compute_nodes[0]) # First chunk of dataset to Bob
shared_x2 = tag_img[1].send(compute_nodes[1]) # Second chunk of dataset to Alice

shared_y1 = tag_label[0].send(compute_nodes[0]) # First chunk of labels to Bob
shared_y2 = tag_label[1].send(compute_nodes[1]) # Second chunk of labels to Alice

In [9]:
print("X tensor pointers: ", shared_x1, shared_x2)
print("Y tensor pointers: ", shared_y1, shared_y2)

X tensor pointers:  (Wrapper)>[PointerTensor | me:48155914346 -> Bob:84790492900]
	Tags: #mnist #X #dataset 
	Shape: torch.Size([5000, 1, 28, 28])
	Description: The input datapoints to the MNIST dataset.... (Wrapper)>[PointerTensor | me:14916691261 -> Alice:53510693506]
	Tags: #mnist #X #dataset 
	Shape: torch.Size([5000, 1, 28, 28])
	Description: The input datapoints to the MNIST dataset....
Y tensor pointers:  (Wrapper)>[PointerTensor | me:34933432192 -> Bob:96713703699]
	Tags: #mnist #Y #dataset 
	Shape: torch.Size([5000])
	Description: The input labels to the MNIST dataset.... (Wrapper)>[PointerTensor | me:47811294626 -> Alice:47171429524]
	Tags: #mnist #Y #dataset 
	Shape: torch.Size([5000])
	Description: The input labels to the MNIST dataset....


<h2>Disconnect nodes</h2>

In [10]:
for i in range(len(compute_nodes)):
    compute_nodes[i].close()