# Encrypted Inference on ResNet-18

_Encrypted Machine Learning as a Service_ allows owners of sensitive data to use external AI services to get insights over their data. Let's consider a practical scenario where a data owner holds private images and would like to use a service to have those images labeled, without disclosing the images or the labels, and without having to get access the model, which is often considered to be a business asset by such services and is therefore not accessible.

To get a realistic example, we will consider [the task of distinguishing between bees and ants](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html), which uses a ResNet-18 model to achieve around 95% accuracy. We won't consider training such model, as we assume the AI service provider has already done this training using some data. Instead, we will showcase how we can use PySyft to encrypt both the model and some image samples and to label those images in a fully private way.

Author:
- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)

## 1. Did you just say _encrypted_?

First, let's try to understand what mechanisms we use to make the data and the model private. If you want to jump straight to the code, you can skip this section! 

## 2. Show me the code

Enough explications, let's open the code!
We will first load the data and the model and store them on the `data_owner` and the `model_owner`.

In [1]:
import torch
torch.set_num_threads(1)
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time

### Load the data

We download the data and load it on a dataLoader with small batches of size 2, to reduce the inference time and the memory pressure on the RAM.

In [2]:
#!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
#!unzip hymenoptera_data.zip

In [3]:
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_dir = 'hymenoptera_data'
image_dataset = datasets.ImageFolder('hymenoptera_data/val', data_transform)
dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=1, shuffle=True, num_workers=4)

dataset_size = len(image_dataset)
class_names = image_dataset.classes

Wan't to have a look at your data? Check the samples on [this tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html).

### Load the model

Now let's download the trained model

In [4]:
#!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-1_M81rMYoB1A8_nKXr0BBOwSIKXPp2v' -O lol.pt

_You can also download the file_ [here](https://drive.google.com/file/d/1-1_M81rMYoB1A8_nKXr0BBOwSIKXPp2v/view?usp=sharing) _if the command above is not working._

In [5]:
model = models.resnet18(pretrained=True)
# Here the size of each output sample is set to 2.
model.fc = nn.Linear(model.fc.in_features, 2)
state = torch.load("./resnet18_ants_bees.pt", map_location='cpu')
model.load_state_dict(state)
model.eval()
# This is a small trick because these two consecuting operations can be switch without
# changing the result but it reduces the number of comparisons we have to compute
model.maxpool, model.relu = model.relu, model.maxpool

Great, now we're ready to start!

### Virtual Setup

First let's create a virtual setup with 2 workers names `data_owner` and `model_owner`.

In [6]:
#import syft as sy

#hook = sy.TorchHook(torch) 
#data_owner = sy.VirtualWorker(hook, id="data_owner")
#model_owner = sy.VirtualWorker(hook, id="model_owner")
#crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

import syft as sy
from syft.grid.clients.data_centric_fl_client import DataCentricFLClient

hook = sy.TorchHook(torch)
data_owner = DataCentricFLClient(hook, "ws://localhost:7600")
model_owner = DataCentricFLClient(hook, "ws://localhost:7601")
crypto_provider = DataCentricFLClient(hook, "ws://localhost:7602")

my_grid = sy.PrivateGridNetwork(data_owner, model_owner, crypto_provider)

In [7]:
# Remove compression to have faster communication
from syft.serde.compression import NO_COMPRESSION
sy.serde.compression.default_compress_scheme = NO_COMPRESSION

Let's put some data on the `data_owner` and the model on the `model_owner`

In [8]:
data, true_labels = next(iter(dataloader))
data_ptr = data.send(data_owner)

# We store the true output of the model for comparison purpose
true_prediction = model(data)
model_ptr = model.send(model_owner)

As usual, when calling `.send()`, we only have access to pointers to the data

In [9]:
print(data_ptr)

(Wrapper)>[PointerTensor | me:68450300422 -> data_owner:72958304771]


### Encryption time!

We will now encrypt both the model and the data. To do this, we encrypt them remotely using the pointers and get back the encrypted objects. 

In [10]:
encryption_kwargs = dict(
    workers=(data_owner, model_owner),
    crypto_provider=crypto_provider,
    protocol="fss",
    precision_fractional=4,
)

In [11]:
encrypted_data = data_ptr.encrypt(**encryption_kwargs).get()
encrypted_model = model_ptr.encrypt(**encryption_kwargs).get()

TypeError: can not serialize 'DataCentricFLClient' object

### Secure inference
We are now able to run our secure inference, so let's do it and let's compare it to the `true_labels` 

In [12]:
encrypted_data.child.child.protocol

'fss'

In [13]:
start_time = time.time()

encrypted_prediction = encrypted_model(encrypted_data)
encrypted_labels = encrypted_prediction.argmax(dim=1)

print(time.time() - start_time, "seconds")

labels = encrypted_labels.decrypt()

print("Predicted labels:", labels)
print("     True labels:", true_labels)



205.43717098236084 seconds
Predicted labels: tensor([0.])
     True labels: tensor([0])


Hooray! This works!! Well at least with a probability of 95%...

But is the computation _exactly_ the same than the plaintext model? Well not exactly, because we sometime use approximations, but let's open the model output logits to verify how close we are from plaintext execution.

In [14]:
print(encrypted_prediction.decrypt())
print(true_prediction)

tensor([[ 1.0772, -0.8127]])
tensor([[ 1.0569, -0.8008]], grad_fn=<AddmmBackward>)


As you can observe, this is quite close and in practice the accuracy of the model is preserved.

Regarding **runtime**, we manage to predict a batch of 2 images in ~400 seconds, which isn't super fast but is already reasonable for our usecase!

## Extension

> Ok that's good, but in real life I won't use virtual workers!

That's right, actually you can run exactly the same experiment using PyGrid and workers which live in a PrivateGridNetwork.

To do so, first clone [PyGrid](https://github.com/OpenMined/PyGrid) and then start new nodes in your terminal (one by tab) as such:
```
cd PyGrid/apps/node
./run.sh --id data_owner      --port 7600 --host localhost --start_local_db
./run.sh --id model_owner     --port 7601 --host localhost --start_local_db
./run.sh --id crypto_provider --port 7602 --host localhost --start_local_db
```

And you replace the `syft` imports as such:
```
import syft as sy
from syft.grid.clients.data_centric_fl_client import DataCentricFLClient

hook = sy.TorchHook(th)
data_owner = DataCentricFLClient(hook, "ws://localhost:7600")
model_owner = DataCentricFLClient(hook, "ws://localhost:7601")
crypto_provider = DataCentricFLClient(hook, "ws://localhost:7602")

my_grid = sy.PrivateGridNetwork(data_owner, model_owner, crypto_provider)
```

The computation will be exactly the same, and the runtime will roughtly double. You can run the experiment to verify this, and it's a nice intro to PyGrid! 

## What's  next?

Next is improving this first proof of concept! How can this be done?

- First, we can optimize our implementation, for example by switching for Python to Rust.
- Second, we can try to adapt the model structure or model layers to have a faster execution given our constraints without compromising accuracy. Think of the swap we made between maxpool and relu in the ResNet-18 architecture at thhe beginning.
- Last, we can investigate new Function Secret Sharing crypto protocols, this is a new and promising field, we expect new breakthroughs to help us improving the inference time!

### Join us!

If you want to help, come and [apply to join one of our cryptography teams](https://forms.gle/BWmYQJrCwqe1m3ex5)!

### Star PySyft on GitHub

You can also help our community by starring the repositories! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! 

- [Join slack.openmined.org](http://slack.openmined.org)