## CrypTen - Training an Encrypted Neural Network across Workers using Plans


We will train an encrypted neural network across different PySyft workers (deployed as [GridNodes](https://github.com/OpenMined/PyGridNode)) where each worker has a subset of the features we need for the train. For this we will be using Plans and we will be using CrypTen as a backend for SMPC. 


Authors:
 - George Muraru - Twitter: [@gmuraru](https://twitter.com/georgemuraru)
 - Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)

## Training Overivew
* In this tutorial we will use the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
* The features we need for training the network are split accross two workers (we will name them *alice* and *bob*)




## Setup

### Download/install needed repos
* Install PySyft from the [`crypten` branch](https://github.com/OpenMined/PySyft/tree/crypten).
* Clone the [GridNode repository](https://github.com/OpenMined/GridNode)
  * we need this because *alice* and *bob* are two different GridNodes

### Run the grid nodes
* In two separate terminals run:
    * ```python -m gridnode --id alice --port 3000```
    * ```python -m gridnode --id bob --port 30001```
    
This will start two workers, *alice* and *bob* and we will connect to them using the port 3000 and 30001.


### Dataset preparation
* Run the cell bellow to download a script from the CrypTen repository
  * It will be used to split the features between the workers
  * Each party will get only a subset of features.
  * We will use only 100 entries from the dataset
  * We will use binary classification (0 vs [1-9] digits)

In [1]:
!wget "https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py" -O "mnist_utils.py"
!python "mnist_utils.py" --option features --reduced 100 --binary

--2020-07-06 01:57:21--  https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.112.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.112.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7401 (7.2K) [text/plain]
Saving to: ‘mnist_utils.py’


2020-07-06 01:57:21 (2.59 MB/s) - ‘mnist_utils.py’ saved [7401/7401]



## Prepare the ground

In [2]:
import pytest
import crypten

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time

import syft as sy

from syft.workers.node_client import NodeClient
from syft.frameworks.crypten.context import run_multiworkers

hook = sy.TorchHook(torch)

## Neural network to train

In [3]:
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2)
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

## Setup the workers and send them the data

We need to have the two GridNodes workers running.

In [4]:
# Syft workers
print("[%] Connecting to workers ...")
ALICE = NodeClient(hook, "ws://localhost:3000")
BOB = NodeClient(hook, "ws://localhost:3001")
print("[+] Connected to workers")

print("[%] Sending labels and training data ...")

# Prepare the labels
label_eye = torch.eye(2)
labels = torch.load("/tmp/train_labels.pth")
labels = labels.long()
labels_one_hot = label_eye[labels]

# Prepare and send training data
alice_train = torch.load("/tmp/alice_train.pth").tag("alice_train")
alice_ptr = alice_train.send(ALICE)
bob_train = torch.load("/tmp/bob_train.pth").tag("bob_train")
bob_ptr = bob_train.send(BOB)

print("[+] Data ready")

[%] Connecting to workers ...
[+] Connected to workers
[%] Sending labels and training data ...
[+] Data ready


## Check the data shape

One entry from the MNIST dataset contains 28x28 features. Those are splitted accross our workers.

We can check it out by running the next cell!

In [5]:
print(f"Alice data shape {alice_train.shape}")
print(f"Bob data shape {bob_train.shape}")

Alice data shape torch.Size([100, 28, 20])
Bob data shape torch.Size([100, 28, 8])


## Initialize a dummy model

Instanciate a model and create a dummy input that could be forwarded through it. This is needed to build the CrypTen model.

In [6]:
dummy_input = torch.empty(1, 1, 28, 28)
pytorch_model = ExampleNet()

### Define the CrypTen computation

We need to specify for the ```run_multiworkers``` decorator:
* the workers that will take part in the computation
* the master address, this will be used for their synchronization
* the instantiated model that will be sent
* a dummy input for the model

We will use the ```func2plan``` decorator to:
* trace the operations from our function
* sending the plan operations to *alice* and *bob* - the plans operations will act as the function
* run the plans operations on both workers

In [7]:
@run_multiworkers(
    [ALICE, BOB], master_addr="127.0.0.1", model=pytorch_model, dummy_input=dummy_input
)
@sy.func2plan()
def run_encrypted_training(
    model=None,
    learning_rate=0.001,
    num_epochs=2,
    batch_size=10,
    num_batches=bob_ptr.shape[0]//10,
    labels_one_hot=labels_one_hot,
    crypten=crypten,
    torch=torch,
):
    x_alice_enc = crypten.load("alice_train", 0)
    x_bob_enc = crypten.load("bob_train", 1)

    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    x_combined_enc = x_combined_enc.unsqueeze(1)

    model.encrypt()
    model.train()
    loss = crypten.nn.MSELoss()

    l_values = []

    for i in range(num_epochs):
        for batch in range(num_batches):
            start, end = batch * batch_size, (batch + 1) * batch_size

            x_train = x_combined_enc[start:end]
            y_batch = labels_one_hot[start:end]
            y_train = crypten.cryptensor(y_batch, requires_grad=True)

            # perform forward pass:
            output = model(x_train)
            loss_value = loss(output, y_train)

            # set gradients to "zero"
            model.zero_grad()

            # perform backward pass:
            loss_value.backward()

            # update parameters
            model.update_parameters(learning_rate)

            # Print progress every batch:
            batch_loss = loss_value.get_plain_text()
            l_values.append(batch_loss)

    model.decrypt()
    return (l_values, model)

## Run the CrypTen computation

Now let's run the computation defined above

In [8]:
# Get the returned values
# key 0 - return values for alice
# key 1 - return values for bob
print("[%] Starting computation")
func_ts = time()
*losses, model = run_encrypted_training()[0]
func_te = time()
print(f"[+] run_encrypted_training() took {int(func_te - func_ts)}s")

losses_per_epoch = len(losses) // 2

for i in range(2):
    print(f"Epoch {i}:")
    for batch, loss in enumerate(losses[i * losses_per_epoch:(i+1) * losses_per_epoch]):
        print(f"\tBatch {(batch+1)} of 10 Loss: {loss:.4f}")

[%] Starting computation




[+] run_encrypted_training() took 47s
Epoch 0:
	Batch 1 of 10 Loss: 0.4391
	Batch 2 of 10 Loss: 0.3641
	Batch 3 of 10 Loss: 0.3411
	Batch 4 of 10 Loss: 0.2911
	Batch 5 of 10 Loss: 0.2423
	Batch 6 of 10 Loss: 0.2896
	Batch 7 of 10 Loss: 0.3237
	Batch 8 of 10 Loss: 0.2433
	Batch 9 of 10 Loss: 0.2236
	Batch 10 of 10 Loss: 0.1933
Epoch 1:
	Batch 1 of 10 Loss: 0.1921
	Batch 2 of 10 Loss: 0.1288
	Batch 3 of 10 Loss: 0.1574
	Batch 4 of 10 Loss: 0.1878
	Batch 5 of 10 Loss: 0.0872
	Batch 6 of 10 Loss: 0.1861
	Batch 7 of 10 Loss: 0.2654
	Batch 8 of 10 Loss: 0.1442
	Batch 9 of 10 Loss: 0.1672
	Batch 10 of 10 Loss: 0.1152


The model returned is a CrypTen model, but we can always run the usual PySyft methods to share the parameters and so on, as far as the model in not encrypted.

In [10]:
cp = sy.VirtualWorker(hook=hook, id="cp")
model.fix_prec()
model.share(ALICE, BOB, crypto_provider=cp)
print(model)
print(list(model.parameters())[0])



Graph unencrypted module
(Wrapper)>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:61929700610 -> alice:73613841655]
	-> [PointerTensor | me:43675353309 -> bob:86244224009]
	*crypto provider: cp*


## CleanUp

In [None]:
# CleanUp portion taken from the CrypTen project

import os

filenames = ['/tmp/alice_train.pth', 
             '/tmp/bob_train.pth', 
             '/tmp/alice_test.pth',
             '/tmp/bob_test.pth', 
             '/tmp/train_labels.pth',
             '/tmp/test_labels.pth',
             'mnist_utils.py']

for fn in filenames:
    if os.path.exists(fn): os.remove(fn)

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+first+issue+%3Amortar_board%3A%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)