## CrypTen - Private Model Inference using Plans


We will do a simple inference on an encrypted neural network that is not know by the local worker.
The workers that known the model structure are deployed as [GridNodes](https://github.com/OpenMined/PyGridNode). For this we will be using Plans and we will be using CrypTen as a backend for SMPC. 


Authors:
 - George Muraru - Twitter: [@gmuraru](https://twitter.com/georgemuraru)
 - Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)

## Inference Overivew
* In this tutorial we will a subset of the the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
* The pre-trained model will be hosted on another worker.

## Setup

### Download/install needed repos
* Install PySyft from the [crypten branch](https://github.com/OpenMined/PySyft/tree/crypten).
* Clone the [GridNode repository](https://github.com/OpenMined/GridNode)
  * we need this because *alice* and *bob* are two different GridNodes

### Run the grid nodes
* In two separate terminals run:
    * ```python -m gridnode --id alice --port 3000```
    * ```python -m gridnode --id bob --port 3001```
    
This will start two workers, *alice* and *bob* and we will connect to them using the port 3000 and 3001.

In [20]:
!wget "https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py" -O "mnist_utils.py"
!wget "https://github.com/facebookresearch/CrypTen/blob/master/tutorials/models/tutorial4_alice_model.pth?raw=true" -O "alice_pretrained_model.pth"
!python ./mnist_utils.py --option train_v_test

--2020-07-08 00:13:23--  https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.112.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.112.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7401 (7.2K) [text/plain]
Saving to: ‘mnist_utils.py’


2020-07-08 00:13:23 (3.97 MB/s) - ‘mnist_utils.py’ saved [7401/7401]

--2020-07-08 00:13:23--  https://github.com/facebookresearch/CrypTen/blob/master/tutorials/models/tutorial4_alice_model.pth?raw=true
Resolving github.com (github.com)... 140.82.118.3
Connecting to github.com (github.com)|140.82.118.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/facebookresearch/CrypTen/raw/master/tutorials/models/tutorial4_alice_model.pth [following]
--2020-07-08 00:13:23--  https://github.com/facebookresearch/

## Prepare the ground

In [21]:
import pytest
import crypten

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time

import syft as sy

from syft.frameworks.crypten.model import OnnxModel
from syft.workers.node_client import NodeClient
from syft.frameworks.crypten.context import run_multiworkers

hook = sy.TorchHook(torch)



## Neural network that will be known only to the workers

*Alice* has the pre-trained version of the model.

In [22]:
class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
 
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out

## Setup the workers and send them the neural network

']We need to have the two GridNodes workers running.

In our current scenario we are sending the serialized model to the workers that are taking part in the computation, but in a real life situation this sending part should not exist - we only need to know that we have the same model on all the workers.

### Scenario
* The local worker wants to run inference on the data that is hosted on *bob* machine.
* The model structure is known only by *alice* and *bob*
* *Alice* has the pre-trained network

In [23]:
# Syft workers
print("[%] Connecting to workers...")
ALICE = NodeClient(hook, "ws://localhost:3000")
BOB = NodeClient(hook, "ws://localhost:3001")
print("[+] Connected to workers")

print("[%] Create the serialized model...")
dummy_input = torch.empty((1, 784))
pytorch_model = AliceNet()

# Alice has the model with the real weights
print("[%] Sending the serialized pre-trained model...")
model_pretrained = torch.load('alice_pretrained_model.pth')
model_alice = OnnxModel.fromModel(model_pretrained, dummy_input).tag("crypten_model")
alice_model_ptr = model_alice.send(ALICE)
print("[+] Serialized model sent to Alice")
    
print("[%] Sending the serialized model...")
model = OnnxModel.fromModel(pytorch_model, dummy_input).tag("crypten_model")
bob_model_ptr = model.send(BOB)
print("[+] Serialized model sent to Bob")
    
print("[%] Send test data to bob...")
data = torch.load('/tmp/bob_test.pth')
data_ptr_bob = data.tag("crypten_data").send(BOB)
print("[+] Data sent to bob")

print("[%] Load labels...")
labels = torch.load('/tmp/bob_test_labels.pth').long()
print("[+] Labels loaded")


# Function used to compute the accuracy for the model
# Taken from CrypTen repository
def compute_accuracy(output, labels):
    pred = output.argmax(1)
    correct = pred.eq(labels)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    return accuracy

[%] Connecting to workers...
[+] Connected to workers
[%] Create the serialized model...
[%] Sending the serialized pre-trained model...


Exception ignored in: <function ObjectPointer.__del__ at 0x7f09295903b0>
Traceback (most recent call last):
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/generic/pointers/object_pointer.py", line 346, in __del__
    self.owner.send_msg(ForceObjectDeleteMessage(self.id_at_location), self.location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/base.py", line 313, in send_msg
    bin_response = self._send_msg(bin_message, location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/virtual.py", line 16, in _send_msg
    return location._recv_msg(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/websocket_client.py", line 103, in _recv_msg
    response = self._forward_to_websocket_server_worker(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/node_client.py", line 181, in _forward_to_websocket_server_worker
  

[+] Serialized model sent to Alice
[%] Sending the serialized model...
[+] Serialized model sent to Bob
[%] Send test data to bob...


Exception ignored in: <function ObjectPointer.__del__ at 0x7f09295903b0>
Traceback (most recent call last):
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/generic/pointers/object_pointer.py", line 346, in __del__
    self.owner.send_msg(ForceObjectDeleteMessage(self.id_at_location), self.location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/base.py", line 313, in send_msg
    bin_response = self._send_msg(bin_message, location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/virtual.py", line 16, in _send_msg
    return location._recv_msg(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/websocket_client.py", line 103, in _recv_msg
    response = self._forward_to_websocket_server_worker(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/node_client.py", line 181, in _forward_to_websocket_server_worker
  

[+] Data sent to bob
[%] Load labels...
[+] Labels loaded


Exception ignored in: <function ObjectPointer.__del__ at 0x7f09295903b0>
Traceback (most recent call last):
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/generic/pointers/object_pointer.py", line 346, in __del__
    self.owner.send_msg(ForceObjectDeleteMessage(self.id_at_location), self.location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/base.py", line 313, in send_msg
    bin_response = self._send_msg(bin_message, location)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/virtual.py", line 16, in _send_msg
    return location._recv_msg(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/websocket_client.py", line 103, in _recv_msg
    response = self._forward_to_websocket_server_worker(message)
  File "/home/george/contrib/coverage/www/upstream/PySyft-v2-notebook/syft/workers/node_client.py", line 180, in _forward_to_websocket_server_worker
  

### Define the CrypTen computation

We need to specify for the ```run_multiworkers``` decorater:
* the workers that will take part in the computation
* the master address, this will be used for communication

We will use the ```func2plan``` decorator to:
* trace the operations from our function
* sending the plan operations to *alice* and *bob* - the plans operations will act as the function
* run the plans operations on both workers

In [24]:
@run_multiworkers([ALICE, BOB], master_addr="127.0.0.1")
@sy.func2plan()
def run_encrypted_inference(crypten=crypten):
    data_enc = crypten.load("crypten_data", 1)
    
    data_enc2 = data_enc[:100]
    data_flatten = data_enc2.flatten(start_dim=1)
    
    # This should load the crypten model that is found at all parties
    model = crypten.load_model("crypten_model")

    model.encrypt(src=0)
    model.eval()
    
    result_enc = model(data_flatten)
    result = result_enc.get_plain_text()
    
    return result

## Run the CrypTen computation

Now let's run the inference.

In [28]:
# Get the returned values
# key 0 - return values for alice
# key 1 - return values for bob
print("[%] Starting computation")
result = run_encrypted_inference()[1]
print("[+] Computation finished")

accuracy = compute_accuracy(result, labels[:100])
print(f"The accuracy is {accuracy.item()}")

[%] Starting computation
[+] Computation finished
The accuracy is 99.0


## CleanUp

In [26]:
# CleanUp portion taken from the CrypTen project

import os


filenames = ['/tmp/alice_train.pth', 
             '/tmp/alice_train_labels.pth', 
             '/tmp/bob_test.pth', 
             '/tmp/bob_test_labels.pth',
             'alice_pretrained_model.pth']
             'mnist_utils.py']

for fn in filenames:
    if os.path.exists(fn): os.remove(fn)

IndentationError: unexpected indent (<ipython-input-26-24425b3f39d0>, line 11)

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+first+issue+%3Amortar_board%3A%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)