<h1>Federated Learning - MNIST Example</h1>
<h2>Populate remote grid nodes with labeled tensors </h2>
In this notebook, we will populate our grid nodes with labeled data so that it will be used later by people interested in train models.

**NOTE:** At the time of running this notebook, we were running the grid components in background mode.  

Components:
 - Grid Gateway(http://localhost:5000)
 - Grid Node Bob (http://localhost:3000)
 - Grid Node Alice (http://localhost:3001)
 
This notebook was made based on <a href="https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb">Part 10: Federated Learning with Encrypted Gradient Aggregation</a> tutorial

<h2>Import dependencies</h2>

In [None]:
import grid as gr
import syft as sy
import torch
import pickle
import time
import torchvision
from torchvision import datasets, transforms
import tqdm

<h2>Setup config</h2>
Init hook, connect with grid nodes, etc...

In [None]:
hook = sy.TorchHook(torch)

# Connect directly to grid nodes
nodes = ["ws://localhost:3000/",
         "ws://localhost:3001/"]

compute_nodes = []
for node in nodes:
    compute_nodes.append( gr.WebsocketGridClient(hook, node) )

## 1) Load Dataset

The code below will load and preprocess an N amount of MNIST data samples.

### Here you will load all data at once and then divide it into equal parts between grid nodes. If you would like to send pieces of data at a time, check out the "Load Dataset (Split Method)" below

In [None]:
N_SAMPLES = 10000
MNIST_PATH = './dataset'

transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)

dataiter = iter(trainloader)

images_train_mnist, labels_train_mnist = dataiter.next()

<h2>2) Split dataset </h2>
We will split our dataset to send to nodes. 

In [None]:
datasets_mnist = torch.split(images_train_mnist, int(len(images_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)
labels_mnist = torch.split(labels_train_mnist, int(len(labels_train_mnist) / len(compute_nodes)), dim=0 )  #tuple of chunks (labels / number of nodes)

<h2>3) Tagging tensors</h2>
The code below will add a tag (of your choice) to the data that will be sent to grid nodes. This tag is important as the gateway will need it to retrieve this data later.

In [None]:
tag_img = []
tag_label = []


for i in range(len(compute_nodes)):
    tag_img.append(datasets_mnist[i].tag("#X", "#mnist", "#dataset").describe("The input datapoints to the MNIST dataset."))
    tag_label.append(labels_mnist[i].tag("#Y", "#mnist", "#dataset").describe("The input labels to the MNIST dataset."))

<h2> 4) Sending our tensors to grid nodes</h2>

In [None]:
# NOTE: For some reason, there is strange behavior when trying to send within a loop.
# Ex : tag_x[i].send(compute_nodes[i])
# When resolved, this should be updated.

shared_x1 = tag_img[0].send(compute_nodes[0], garbage_collect_data=False) # First chunk of dataset to Bob
shared_x2 = tag_img[1].send(compute_nodes[1], garbage_collect_data=False) # Second chunk of dataset to Alice

shared_y1 = tag_label[0].send(compute_nodes[0], garbage_collect_data=False) # First chunk of labels to Bob
shared_y2 = tag_label[1].send(compute_nodes[1], garbage_collect_data=False) # Second chunk of labels to Alice

In [None]:
print("X tensor pointers: ", shared_x1, shared_x2)
print("Y tensor pointers: ", shared_y1, shared_y2)

# Load Dataset (Split method)

The code below should only be used if you have not followed steps 1,2,3,4.

### Here we will send parts of the MNIST training dataset to each worker at a time. If you choose to load data with this method, you should not perform steps 1,2,3 and 4.

In [None]:
N_SAMPLES = 500
MNIST_PATH = './data'

transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)

dataiter = iter(trainloader)
n_workers = len(compute_nodes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for i, data in enumerate(dataiter):
    images_train_mnist, labels_train_mnist = data[0].to(device), data[1].to(device)
    images_train_mnist.tag("#X", "#mnist", "#dataset").describe("The input datapoints to the MNIST dataset.")
    labels_train_mnist.tag("#Y", "#mnist", "#dataset").describe("The input labels to the MNIST dataset.")
    images_train_mnist.send(compute_nodes[i % n_workers], garbage_collect_data=False)
    labels_train_mnist.send(compute_nodes[i % n_workers], garbage_collect_data=False)
    print("Sending data to:", compute_nodes[i % n_workers])

<h2>Disconnect nodes</h2>

In [None]:
for i in range(len(compute_nodes)):
    compute_nodes[i].close()