# SMPC Protocols Explanation 
This notebook will give you an overview and a very quick explanation of/over the different SMPC protocols that are currently implemented in PySyft. It'll also elaborate on **what kind of Machine Learning you can conduct in an encrypted fashion using them**, along with a comparison of their resulting performance to each other and to the non-encrypted scenario. The protocol explanations should mainly serve to give you a high-level understanding of what the main **crypto-slang** stands for, how the terms relate to each other and what resources serve as good starting points to dig deeper! 

Authors:
- Nicolas Remerscheid - GitHub: [@NiWaRe](https://github.com/NiWaRe)


## Quick recap - What is SMPC encryption?
As a quick recap, "SMPC" stands for **Secure Multi-Party Computation** and constitutes a form of encryption that can be used for Machine Learning (i.e. it is possible to do calculations on encrypted data) leveraging a network of min. 2 different servers. These systems are typically resistant to some level of **collusion.** This means that usually it is considered that there exists a *honest majority* of servers which are trusted of not diverting from the given protocol - e.g. not collaborating with each other. As a whole the servers work as a **trusted execution environment** on which sensitive calculations such as model inference, training, etc. can be done without the model or the data being disclosed to any party besides the respective owner. 

## General concepts 
* Important concepts that all protocols are based upon - *in brief:* 
  * **A Public value,** is considered data (e.g. input from the data-sources) which is known by all parties. 
  * **A Private value,** is considered data which is secured through additive secret sharing, only the owner knows the true value. 
  * **Masking:** To share a private value publicly (e.g. as a necessary part in a protocol) it has to be masked first. This typically simply involves adding a random number to the value and projecting it onto a fixed set of numbers, a so-called ring. <br>  *masked_value = (private_value - random_numb) % upper absolute value of set of numbers.* For more info on that check out the Udacity Beginner Tutorials by Andrew Trask. 

# Use-Case: Cifar10 Classification with CNN
* To tackle a problem similar to a real-world-problem, yet still using a well explored example (the privacy tools are of main interest in this tutorial) we choose to *train* an Image-Classifier on the **[Cifar-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html)** and *do inference* on it. <br>
* Containing 60000 32x32 pixels coloured images of 10 different classes (airplanes, birds, etc.) this should showcase a reasonably similair task to other real-world examples such as the training of a classifier for skin-cancer-classification, which heavily relies on sensitive private data. See [Stanford's Skin Cancer Classification with Deep Learning](https://cs.stanford.edu/people/esteva/nature/) for more information on this specific example.
* *Specific characteristics:*
  * We consider a system of **two** servers as a secure computation unit on which all the computations are conducted in a safe manner. In this case we consider the scenario where two data owners train a model (possibly from a third party) in an encrypted fashion on their respective devices. 

In [1]:
# Setup of example use-case problem 

import torch 
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim

from torchvision import datasets, transforms, models
from torch.utils.data.sampler import SubsetRandomSampler 

import numpy as np 
import syft as sy
import time
import tqdm as tqdm 

# Extend torch functionality with PySyft
hook = sy.TorchHook(torch)

In [2]:
## Get Data ##

_ = torch.manual_seed(1234)
batch_size = 32

# download Cifar-10 Dataset and distribute onto two servers 

# normalize data and convert to torch.FloatTensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
# get Cifar10 Dataset from torchvision.datasets
cifar10_data = datasets.CIFAR10('data', train=True,
                              download=True, transform=transform)

cifar10_data_test = datasets.CIFAR10('data', train=False,
                              download=True, transform=transform)

# create DataLoaders 
cifar10_train_loader = torch.utils.data.DataLoader(cifar10_data, batch_size=batch_size)
cifar10_test_loader = torch.utils.data.DataLoader(cifar10_data_test, batch_size=batch_size)

# image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [3]:
## setup distributed scenario ##

# create VirtualWorkers - could be any participating node 
sam = sy.VirtualWorker(hook, id="sam")
kelly = sy.VirtualWorker(hook, id="kelly")
workers = [sam, kelly]

crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

In [4]:
# Remove compression to have faster communication, because compression time 
# is non-negligible: we send to workers crypto material which is very heavy
# and pseudo-random, so compressing it takes a long time and isn't useful:
# randomness can't be compressed, otherwise it wouldn't be random!
from syft.serde.compression import NO_COMPRESSION
sy.serde.compression.default_compress_scheme = NO_COMPRESSION

## Encrypted Inference - Quick Demo
We can do the encryption for data which is stored on our own device (node) or we can remotely encrypt data which is stored on other devices (for example for an Data-Centric Federated Learning with Secure Aggregation use-case with real workers, see [Part 10 - Federated Learning with Secure Aggregation](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb), see this blog post for more about [different kinds of Federated Learning](https://blog.openmined.org/federated-learning-types/)). As an example we'll take Resnet18 as a very popular and powerful model where encrypted inference (forward pass) is already supported! (in `.eval()` mode!) 
<br><br>
Feel free to change the protocol from *Functional Secret Sharing* `fss` to *Secure NN* `snn`. More on these two protocols in the following. 
<br><br>
*Note:* All computations were conducted on Mac OS Catalina, 2 GHz Quad-Core Intel Core i7, 16 GB RAM (MBP late 2013).

In [5]:
# get model - resnet18 
# model weights for a resnet trained on Cifar10: https://github.com/huyvnphan/PyTorch_CIFAR10

model = models.resnet18(pretrained=True).eval()

# Normally you would now train this last classification layer
model.fc = nn.Linear(512, 10)

In [6]:
# encryption parameters - feel free to change these  

encryption_kwargs = dict(
    workers=workers, 
    crypto_provider=crypto_provider, 
    protocol="fss", 
    requires_grad=True,
    precision_fractional= 4
)

In [7]:
# remotely encrypt one batch and the model
first_batch, first_target = next(iter(cifar10_train_loader))

# we assume sam has some data and the model
# (of course you can also skip this and assume the data is stored locally)
ptr_first_batch = first_batch.send(sam)
ptr_model = model.send(sam)

# .get() because a encrypt() returns a pointer from me->sam where AdditiveSharingTensor is stored
# not necessary if you encrypt data which is on your local device 
encrypted_first_batch = ptr_first_batch.encrypt(**encryption_kwargs).get()
encrypted_model = ptr_model.encrypt(**encryption_kwargs).get()

After running the following two lines, you'll see the internal structure of the encrypted data. The pointer on the shared data is structured into three different tensor-types which also reflect the specific actions we just did to get to an encrypted tensor. To see how they're implemented click on the respective tensor. *More on the specifics in the following!* 
* [`AutogradTensor`](https://github.com/OpenMined/PySyft/blob/d811ef1e91e5e2c84fbbf1edf61e6983380b4d16/syft/frameworks/torch/tensors/interpreters/autograd.py#L29) - first, we want to have a tensor on which all computations are tracked for later backprop. 
* [`FixedPrecisionTensor`](https://github.com/OpenMined/PySyft/blob/d811ef1e91e5e2c84fbbf1edf61e6983380b4d16/syft/frameworks/torch/tensors/interpreters/precision.py#L19) - second, we need to convert all numbers from floats to fixed-point integers for the encryption in the next step.
* [`AdditiveSharingTensor`](https://github.com/OpenMined/PySyft/blob/d811ef1e91e5e2c84fbbf1edf61e6983380b4d16/syft/frameworks/torch/tensors/interpreters/additive_shared.py#L63) - lastly, we encrypted the data by parting it into different chunks and distributing it among the workers. 

In [8]:
encrypted_first_batch

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:92566034500 -> sam:59900641324]
	-> [PointerTensor | me:55213798748 -> kelly:4500639760]
	*crypto provider: crypto_provider*

In [9]:
# encrypted inference - feel free to do the same with the second batch & model
start_time = time.time()

encrypted_result = encrypted_model(encrypted_first_batch)
print(f"Inference done in {time.time() - start_time} sec. Privacy preserving fetching of prediction!")
pred = encrypted_result.argmax(dim=1)

print(f"Total Duration: {time.time() - start_time}")
# decrypt() = get().float_precision() - the specific result isn't sensible as the ResNet wasn't actually trained 
print(f"Result: {pred.decrypt()}")



Inference done in 181.70145916938782 sec. Privacy preserving fetching of prediction!
Total Duration: 182.05092406272888
Result: tensor([2., 8., 8., 6., 6., 2., 2., 5., 6., 6., 8., 2., 7., 6., 5., 6., 2., 6.,
        8., 6., 6., 3., 8., 8., 2., 8., 6., 8., 8., 1., 1., 6.])


In [10]:
# SecureNN encryption took 992.60 sec. (~ 2 sec. only for the argmax operation) - 32 samples, 10 classes

## Encrypted Training - Cifar10
Now we know how to do encrypted inference, let's take a look at encrypted training now. To be able to understand the possibilities and limits of encrypted training with different SMPC-protocols we're going to take a look at each of the protocols and then compare their performance for the training of an Image-Classifier on Cifar10. <br>
We don't use the ResNet here, because not all computations are yet supported for encrypted backpropagation (Oct. 2020). Instead we use a simple custom CNN (at the start of each section), which should also give you the liberty of trying out different modules and computations and their performances on your own. Also, feel free to alter and tune the training process along with the loss function, optimizer, etc. to get a better feeling about the performance of different computations. <br>
*Also, if you find that a critical module, loss fct., optimizer isn't supported yet feel free to create an Issue on the OpenMined Github.*

In [5]:
# serves as base for both protocols (for better comparison)

class SimpleCNN_Base(nn.Module):
    def __init__(self): 
        super(SimpleCNN_Base, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
        self.conv3 = nn.Conv2d(16, 32, 3, padding=1)
        
        self.pool1 = nn.AvgPool2d(2, 2)
        self.pool2 = nn.AvgPool2d(2, 2)
        self.pool3 = nn.AvgPool2d(2, 2)
        
        # after conv and pooling layer dimension: 4x4x32 (original: 32x32x3)
        self.lin1 = nn.Linear(4*4*32, 128)
        self.lin2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x): 
        out = self.pool1(F.relu(self.conv1(x)))
        out = self.pool2(F.relu(self.conv2(out)))
        out = self.pool3(F.relu(self.conv3(out)))
        
        batch_size = x.shape[0]
        out = out.view(batch_size, -1)
        out = self.dropout(F.relu(self.lin1(out)))
        out = F.relu(self.lin2(out))
        
        return out

In [6]:
## Utility functions ##

# One-Hot Encoding (Copied from @laRiffle Part 12)
def one_hot_of(index_tensor):
    """
    Transform to one hot tensor
        
    Example:
        [0, 3, 9]
        =>
        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]
            
    """
    onehot_tensor = torch.zeros(*index_tensor.shape, 10) # 10 classes for Cifar10
    onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)
    return onehot_tensor

def CrossEntropyLoss(output, target):
    
    # LogSoftmax #

    # Vectorized and in log-space with substraction instead of division is KEY! 
    # => 17 sec. per division and if using not the vectorized version ~0.5 sec. per value 
    # Only little rounding errors in scale of predefined precision occur when using vectorized version
    log_probs = output - torch.log(torch.exp(output).sum(dim=1).unsqueeze(dim=1))

    # CELoss #
    batch_loss = torch.mean( 
                     torch.sum(  
                        -target * log_probs, dim=1
                    )
                )
    
    return batch_loss

# 1. SPDZ Protocol 
* **Encryption on n parties possible (n >= 2)**
* Basis for advanced protocols secureNN and FSS. 
* *In depth material:*
  * [Bristol Cryptography Blog Series](https://bristolcrypto.blogspot.com/2016/10/what-is-spdz-part-1-mpc-circuit.html)
  * [Morten Dahl's Blog](https://mortendahl.github.io/2017/09/03/the-spdz-protocol-part1/)
  * [PySyft Code](https://github.com/OpenMined/PySyft/blob/master/syft/frameworks/torch/mpc/spdz.py)

### General concept 
There are two types of SMPC: secret-sharing-based SMPC and circuit-garbling-based SMPC. 
The two protocols (secureNN and FSS) that are implemented in PySyft to date are both based on the **SPDZ protocol**, 
which is based on *additive secret-sharing* (the first category). For more information on what *additive secret-sharing* is, check out the tutorials by Andrew Trask as part of the Udacity Private and Secure AI Course.<br>
SPDZ is a very widely used protocol for computing linear functions in an encrypted fashion, meaning it can be used to compute sums and multiplication of encrypted variables. The sum of encrypted variables simply consists of each server summing locally its shares of the private variables, which then leads to each server having a share of the sum of the private variables, thus together secret-sharing the sum of the private variables. The multiplication of two encrypted variables however is done using so-called **"beaver multiplication triples"** - three randomly generated numbers - which make the encrypted multiplication very efficient compared to other approaches (e.g. garbled circuits).
<br>
*Why more efficient?* See as a brief introduction into this topic the 'Extra: Beaver Multiplication Triples' section or to dig deeper see the resources mentioned just above. 
<br>
<br>
The high-level procedure of SPDZ-based protocols is as follow: (based on the definition from the [Bristol Cryptography Blog](https://bristolcrypto.blogspot.com/2016/10/what-is-spdz-part-1-mpc-circuit.html))
   1. Parties secret-share their inputs at the beginning (one crypto-provider exists to generate beaver triples, etc.).
   2. Parties compute mul. and sum. locally (using only their share of the private variables). *By design they don't have to communicate with each other while doing the sums locally. Only, in the end, they share their end-result. For multiplication there is one intermediate communication step to exchange hidden shares that are needed for the product to be computed (the exchange of the masked shares, as explained in the next section about Beaver Triples).*
   3. In the end, parties reveal the result of their calculation by sending their share to one server which adds all shares up to unveil the final result. This could be any participating server, or the shares could also be mutually shared with everybody to give any participating server the possibility to decrypt the final result. (Not implemented in PySyft)

### Extra: Beaver Multiplication Triples
Beaver Triples are simply put randomly generated numbers which are also shared among all different workers and are used to mask (as described above) the input variables to the multiplication so that they can be publicly shared. Then they can be only used to compute the product (the end-result) by simply computing a simple equation out of all masked input variables, which cancels out the random masks and reveals the product. So effectively no further communication between the workers is necessary during the computation itself. Now concerning the overall efficiency of the protocol, you might think that although the communication-complexity is low we still need to generate new Beaver Triples for each multiplication? (because of security issues they can only be used for one multiplication) <br>
For this reason, however, the SPDZ protocol is deliberately parted into an 'offline' and an 'online' phase. The 'offline' phase consists of randomly generating all necessary "crypto-primitives" - e.g. the Beaver Triples - and can essentially be executed independently of specific inputs of the multiplication, i.e. before the actual multiplications. Thus, given we have generated enough Beaver Triples beforehand, during the 'online' phase we can compute as many multiplications as we have pre-generated Beaver Triples. It has to be noted that we could also generate necessary Beaver Triples during the 'online' phase but that would slow down the computation time for the user waiting for his multiplication to be conducted. The offline phase allows us to shorten this waiting time for the user by generating the triples when no requests are being made. *This makes the SPDZ protocol very efficient for linear computations!*

### Extra: Crypto-Store
  * As mentioned above a key feature of the SPDZ-protocol is its splitting of the encrypted computation process into an *online* and an *offline* phase to allow for a significantly decreased execution time (given enough time for the offline phase when no encrypted computations are being conducted)
  * As you probably know in PySyft there exist *worker* objects that have certain default attributes specified and set in `class BaseWorker(AbstractWorker)` ([code](https://github.com/OpenMined/PySyft/blob/c83e615a85bb8944245668d90582fb97c45e6e18/syft/workers/base.py#L48)). One of them is a so-called `worker.crypto_store`. The *crypto_store* object is of type `class PrimitiveStorage`, which specifies a set of given functions that help the respective workers to manage crypto-primitives they need to participate in the respective crypto-protocol. (e.g. Beaver Triples for multiplication in the SPDZ-protocol)
  * Specifically there are two important methods of the crypto_store object. For the *crypto-provider* - the party that serves as a trusted, neutral participant of the protocol - the `crypto_provider.crypto_store.provide_primitives(...)` ([code](https://github.com/OpenMined/PySyft/blob/c83e615a85bb8944245668d90582fb97c45e6e18/syft/frameworks/torch/mpc/primitives.py#L161)) method generates and sends crypto-primitives such as Beaver Triples to participating workers. <br> For "normal" participating workers the `worker.crypto_store.get_keys(...)` ([code](https://github.com/OpenMined/PySyft/blob/c83e615a85bb8944245668d90582fb97c45e6e18/syft/frameworks/torch/mpc/primitives.py#L52)) method takes care of receiving and storing the crypto primitives for later usage during the protocol.

### Extra: Garbled Circuits 
Maybe you've read this term a couple of times on the OpenMined workspace or in SMPC-encryption discussions - *on a high-level what is circuit-garbling-based SMPC, or more specifically what are'Garbled Circuits?* <br>
Garbled Circuits is a protocol which works as **2-Party-Computation** (i.e. only two servers are involved). The name comes from the method which is used to encrypt a function: a function is represented as a **circuit** consisting of different logical gates (e.g. AND, XOR, etc.). This circuit can be described with a truth-table which indicates at what inputs, what outputs follow. As part of the encryption process, the rows of the truth-table are re-ordered arbitrarily which leads to the name "garbled" circuits. <br> 
The details about the protocol (see [Wikipedia](https://en.wikipedia.org/wiki/Garbled_circuit) or [this](https://wiki.mpcalliance.org/garbled_circuit.html) comprehensive article from the MPC Wiki) are rather straight forward, the important thing to note is that this technique is **very flexible** as essentially all functions can be encrypted using this method, but also **very inefficient**. That's why often the Garbled Circuits protocol is mainly used on-top of more optimized protocols (such as e.g. SPDZ) to extend the variety of functions compatible with these protocols while not slowing down the computation too much for most computations. <br>
It has to be noted that nevertheless, researchers try to find alternative, more efficient protocols that work without Garbled Circuits at the cost of supporting only a lower variety of functions. *Examples for protocols especially for computations that are popular in Machine Learning Models are secureNN and FSS which PySyft supports and over which I'll go over in the following section!* 

# 2. SecureNN Protocol
* Introduced in [paper](https://eprint.iacr.org/2018/442.pdf): *SecureNN: 3-Party Secure Computation for Neural Network Training, by Sameer Wagh, Divya Gupta, and Nishanth Chandran, in Proceedings on Privacy Enhancing Technologies,  2019* 
* **Protocol is made for computation on 2 parties with 1 crypto-worker.** (There can still be multiple data-owners, but the computations is done on 2 servers - 2-party-additive-sharing)

## High-Level concept 
SecureNN uses the SPDZ protocol for linear layers (beaver triples, etc.) and contains multiple efficient protocols for common non-linearities, as further described below. Compared to earlier work SecureNN implements the computation of the non-linearities without the need of Garbled Circuits (as is the case for SecureML, which was considered the state-of-the-art ML protocol before SecureNN). This also means "interconversation protocols", to bridge between encoding needed for SPDZ and encoding needed for Garbled Circuits, aren't necessary, decreasing computation time further. <br>
*In general SecureNN is therefore a lot faster than SecureML and other garbled-circuit-based (for non-linearities) protocols.* 

## In detail 
### Possible functions - High Level (Non exhaustive list)
Resulting from the possible low-level computations (see the "curious section" below) the following standard models, optimizers and loss-functions can be used. *This should provide a useful summary, but doesn't claim to be an exhaustive list (feel free to add important items that can be composed out of the above mentioned possible low-level computations)

* **Model Architecture:** Includes the possibility for encrypted computation of the derivative (needed for backprop)
    * **Linear Layers** 
      * Matrix Multiplication and Convolutions (in CNNs e.g.)
      * Average Pooling 
      * Batch-Norm/Normalization (Division of two private variables in general is possible)
      * Dropout (*with help of Select Share, certain computation-results can be ignored. Or simply set some inputs to some neurons to zero.*)
    * **Non-Linear Layers:** 
      * Max Pooling 
      * ReLU, Leaky ReLU, Piece-wise linear activation functions 
      * Argmax
        * **Beware** of using `argmax()` on too many classes, see FSS section for more details. 
     
* **Optimizers:** 
  * SGD (with Momentum)
  * ADAM (Momentum + RMSProp) - *as devision is also possible (and elementwise-multiplication as well)*
    * **Not** yet implemented

* **Loss-functions:** *No native torch losses are compatible, they have to be **manually** implemented!* 
  * L1-Loss - *as max() is possible*
  * MSE - *linear computation and power can be computed with SPDZ* 
  * Hinge-Loss (Linear Classification with Soft-Margin-SVM) - *as max(0, t) is possible*
  * Cross-Entropy-Loss

### Security Guarantees 
* Full Security includes **privacy and correctness**
* The following guarantees hold for all settings where there is a majority of honest participating servers (**Not in dishonest majority setting!**)
1. **Full Security for semi-honest corruption of a server** 
  * Privacy and Correctness of the data is secured if a server is being *honest-but-curious.* Meaning that follows the protocol but tries to infer as much information about the data it sees as possible. 
2. **Privacy against malicious server** 
  * Even a server which doesn't follow the given protocol can't learn anything about the inputs and outputs of the other (honest) servers, *given that the majority of the participants are honest!* This is a common assumption for the malicious case because in a real setting deviating from the given protocol could be prevented with additional measures, as seen below. 
3. Potential Add-on: **Security with Abort**
  * Protection against malicious servers can be guaranteed by adding [MAC authentication](https://en.wikipedia.org/wiki/Message_authentication_code) to the protocol. This would allow aborting the protocol as soon as one of the servers doesn't follow the protocol anymore. 


### Performance Evaluation - Important Metrics 
* Division is possible but very slow! 
* Important Metrics: 
  * Round Complexity - *How many steps does the protocol involve. A step described as exchanging one message.* 
  * Communication Complexity - *How much bits are being exchanged during the protocol* 
* **DETAILS:** see "curious section" below

In [7]:
class SimpleCNN_SNN(SimpleCNN_Base):
    def __init__(self): 
        super(SimpleCNN_SNN, self).__init__()
        
        # Conv-Layers, Lin-Layers, Dropout from base class 
        # as the functionality stays the same for FSS, SNN 
        # both based on SPDZ 
 
        # feel free to comment these out if you want to use AvgPool from the parent class
        #self.pool1 = nn.MaxPool2d(2, 2)
        #self.pool2 = nn.MaxPool2d(2, 2)
        #self.pool3 = nn.MaxPool2d(2, 2)
    
        # Forward function doesn't change 
        
simple_model_snn = SimpleCNN_SNN()

In [8]:
# create encrypted data_loader and encrypted model 
batch_size = 32
# To make it quick we want to train only for 2 batches
nr_batches = 2
nr_samples = batch_size * nr_batches

encryption_kwargs = dict(
    workers=workers, 
    crypto_provider=crypto_provider, 
    protocol="snn", 
    requires_grad=True,
    precision_fractional= 4
)

start_time = time.time()

# Note: Here we one-hot-encode the targets for multi-class one-vs-all classification training 
encrypted_train_loader_snn = [
        (data.encrypt(**encryption_kwargs), one_hot_of(target).encrypt(**encryption_kwargs))
        for i, (data, target) in enumerate(cifar10_train_loader)
        if i < nr_batches
    ]

# Note: Here we don't one-hot-encode targets, because we use targets only to calculate accuracy 
encrypted_test_loader_snn = [
        (data.encrypt(**encryption_kwargs), target.encrypt(**encryption_kwargs))
        for i, (data, target) in enumerate(cifar10_test_loader)
        if i < nr_batches
    ]

encrypted_model_snn = simple_model_snn.encrypt(**encryption_kwargs)

print(f"Duration for encryption: {time.time() - start_time :.4} sec.")

Duration for encryption: 14.94 sec.


In [9]:
## Training ##

nr_epochs = 2

# Don't forget to also convert optim params into integers (no rational numb.) so that weights also stay integers
optimizer_snn = optim.SGD(
    encrypted_model_snn.parameters(), lr=0.01, momentum=0.9).fix_precision(precision_fractional=4) 

epoch_loss = 0
epoch_time = 0
encrypted_model_snn.train()

for epoch in enumerate(range(nr_epochs)): 
    # Start timer 
    epoch_time = time.time()
    for i, (input, target) in enumerate(encrypted_train_loader_snn): 
        print(f"Batch Nr. {i}")
        optimizer_snn.zero_grad()
        encrypted_result_snn = encrypted_model_snn(input)
        # Manual MSELoss
        #batch_loss = ((encrypted_result-target)**2).sum()
        # Manual CELoss
        batch_loss_snn = CrossEntropyLoss(encrypted_result_snn, target)
        batch_loss_snn.backward()
        optimizer_snn.step()
        epoch_loss += batch_loss_snn.get()
        
    print(f"Average Sample Loss: {epoch_loss/nr_samples}, Epoch Time: {time.time() - epoch_time}")
    epoch_loss = 0


Batch Nr. 0




Batch Nr. 1
Average Sample Loss: 714, Epoch Time: 277.5244369506836
Batch Nr. 0
Batch Nr. 1
Average Sample Loss: 714, Epoch Time: 271.10189628601074


In [10]:
## Testing ##

nr_correct = 0

# Start timer 
start_time = time.time()
encrypted_model_snn.eval()

for i, (input, target) in enumerate(encrypted_test_loader_snn): 
    
    print(f"Batch Nr. {i}")
    
    encrypted_result_snn = encrypted_model_snn(input)
    
    # calculate number of correct predictions (for accuracy) 
    encrypted_preds_snn = encrypted_result_snn.argmax(dim=1)
    nr_correct += (encrypted_preds_snn==target).sum()

batch_times = time.time() - start_time

print("*************")
print(f"Test Accuracy: {nr_correct.decrypt().numpy()/nr_samples},\
      Avg. Time Per Sample: {batch_times/nr_samples :.4} sec.")

Batch Nr. 0
Batch Nr. 1
*************
Test Accuracy: [0.140625],      Avg. Time Per Sample: 4.272 sec.


# 3. Functional Secret Sharing Protocol
* Base [paper](https://link.springer.com/content/pdf/10.1007/978-3-662-46803-6_12.pdf) (first introduction): *Function secret sharing. E. Boyle, N. Gilboa, and Y. Ishai. In EUROCRYPT 2015, pages 337–367, 2015.* 
* **Encryption with n parties possible (mainly n == 2)**

## High-Level Concept 
As the SecureNN protocol, the FSS protocol is also based on the SPDZ protocol for encrypted computation of *linear layers* and provides additional protocols for *non-linearities*. The fundamental difference in the FSS protocol is that instead of evaluating a public function (e.g. a ReLU activation function) at a private value (the secret shared data from the data-sources) it rather evaluates a private function at a public value. This is possible by first masking the private data (x - r mod Q) and then making it publicly available to obtain a public value and to then secret-share the function. <br> 
*How exactly do you secret-share a function? Check out [Théo's Tutorial](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2011%20bis%20-%20Encrypted%20inference%20on%20ResNet-18.ipynb) for a nice intro on that!*
 

### Base Implementation 
The concept of "Function-Secret-Sharing" (FSS) was first introduced in the [paper](https://eprint.iacr.org/2018/707): *Function Secret Sharing - Improvements and Extensions, by Elette Boyle and Niv Gilboa and Yuval Ishai, 2018* (This is the updated version)

### AriaNN Implementation 
The PySyft Implementation is based on the [paper](https://arxiv.org/abs/2006.04593): *ARIANN: Low-Interaction Privacy-Preserving Deep Learning via Function Secret Sharing, by Théo Ryffel, David Pointcheval, Francis Bach, 2020* 

## In Detail

### Possible functions - High Level (Non exhaustive list)
Resulting from the possible low-level computations (see "curious section" for more) the following standard models, optimizers and loss-functions can be used. *This should provide a useful summary, but doesn't claim to be an exhaustive list (feel free to add important items that can be composed out of the above mentioned possible low-level computations)

* **Model Architecture:** Includes the possibility for encrypted computation of the derivative (needed for backprop)
    * **Linear Layers** with SPDZ using beaver triples 
      * Matrix Multiplication and Convolutions (in CNNs e.g.)
      * Average Pooling 
      * Batch-Norm/Normalization (approx. with Newton Method) 
      * Dropout 
    * **Non-Linear Layers:** mainly based on direct comparison capabilities
      * Max Pooling 
      * ReLU, Leaky ReLU, Piece-wise linear activation functions 
      * Argmax 
        * **Beware** of using `.argmax()` on too many classes, 250 classes (~128 sec. to compute) was the maximum before my computer crashed because of too much memory consumption when using FSS encryption! (the encrypted comparison operation isn't memory efficient enough for now for datasets like ImageNet - 1000 classes - for now.)
     
* **Optimizers:**
  * SGD (with Momentum)
  * ADAM
    * Not implemented yet (Oct. 2020)

* **Loss-functions:** *No native torch losses are compatible, they have to be **manually** implemented!*  
  * L1-Loss 
  * MSE 
  * Hinge Loss
  * Cross-Entropy-Loss, Logistic Loss

### Security Guarantees 
* Full Security includes **privacy and correctness**
* The following guarantees hold for all settings where there is a majority of honest participating servers (**Not in dishonest majority setting!**)
1. **Full Security for semi-honest corruption of a server** 
  * Privacy and Correctness of the data is secured if a server is being *honest-but-curious.* Meaning that follows the protocol but tries to infer as much information about the data it sees as possible. 
2. Potential Add-on: **Security with Abort**
  * Protection against malicious servers can be guaranteed by adding [MAC authentication](https://en.wikipedia.org/wiki/Message_authentication_code) to the protocol. This would allow aborting the protocol as soon as one of the servers doesn't follow the protocol anymore. 

### Performance Evaluation
* FSS doesn't require lots of communication rounds. But is computationally more intensive than SecureNN (because it includes calling a Pseudorandom Generator (PRG) many times) 
* Important Metrics: 
  * Round Complexity - *How many steps does the protocol involve. A step described as exchanging one message.* 
  * Communication Complexity - *How much bits are being exchanged during the protocol* 
* **DETAILS:** see curious section

In [11]:
class SimpleCNN_FSS(SimpleCNN_Base):
    def __init__(self): 
        super(SimpleCNN_FSS, self).__init__()
        
        # Conv-Layers, Lin-Layers, Dropout from base class 
        # as the functionality stays the same for FSS, SNN 
        # both based on SPDZ 
 
        # feel free to comment these out if you want to use AvgPool from the parent class
        self.pool1 = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.pool3 = nn.MaxPool2d(2, 2)
    
        # Forward function doesn't change 
        
simple_model_fss = SimpleCNN_FSS()

In [13]:
# create encrypted data_loader and encrypted model 
batch_size = 32
# to make it quick we want to train only for 2 batches
nr_batches = 2
nr_samples = batch_size * nr_batches

encryption_kwargs = dict(
    workers=workers, 
    crypto_provider=crypto_provider, 
    protocol="fss", 
    requires_grad=True,
    precision_fractional= 4
)

start_time = time.time()

# Note: Here we one-hot-encode the targets for multi-class one-vs-all classification training 
encrypted_train_loader = [
        (data.encrypt(**encryption_kwargs), one_hot_of(target).encrypt(**encryption_kwargs))
        for i, (data, target) in enumerate(cifar10_train_loader)
        if i < nr_batches
    ]

# Note: Here we don't one-hot-encode targets, because we use targets only to calculate accuracy 
encrypted_test_loader = [
        (data.encrypt(**encryption_kwargs), target.encrypt(**encryption_kwargs))
        for i, (data, target) in enumerate(cifar10_test_loader)
        if i < nr_batches
    ]

encrypted_model = simple_model_fss.encrypt(**encryption_kwargs)

print(f"Duration for encryption: {time.time() - start_time}")

Duration for encryption: 14.699193954467773


In [14]:
## Training ##

nr_epochs = 2

# Don't forget to also convert optim params into integers (no rational numb.) so that weights also stay integers
optimizer = optim.SGD(encrypted_model.parameters(), lr=0.01, momentum=0.9).fix_precision(precision_fractional=4) 

epoch_loss = 0
epoch_time = 0
encrypted_model.train()

for epoch in enumerate(range(nr_epochs)): 
    # Start timer 
    epoch_time = time.time()
    for i, (input, target) in enumerate(encrypted_train_loader): 
        print(f"Batch Nr. {i}")
        optimizer.zero_grad()
        encrypted_result = encrypted_model(input)
        # Manual MSELoss
        #batch_loss = ((encrypted_result-target)**2).sum()
        # Manual CELoss
        batch_loss = CrossEntropyLoss(encrypted_result, target)
        batch_loss.backward()
        optimizer.step()
        epoch_loss += batch_loss.get()
        
    print(f"Average Sample Loss: {epoch_loss/nr_samples}, Epoch Time: {time.time() - epoch_time}")
    epoch_loss = 0


Batch Nr. 0
Batch Nr. 1
Average Sample Loss: 718, Epoch Time: 76.00571823120117
Batch Nr. 0
Batch Nr. 1
Average Sample Loss: 717, Epoch Time: 76.34843015670776


In [15]:
## Testing ##

nr_correct = 0

# Start timer 
start_time = time.time()
encrypted_model.eval()

for i, (input, target) in enumerate(encrypted_test_loader): 
    
    print(f"Batch Nr. {i}")
    
    encrypted_result = encrypted_model(input)
    
    # calculate number of correct predictions (for accuracy) 
    encrypted_preds = encrypted_result.argmax(dim=1)
    nr_correct += (encrypted_preds==target).sum()

batch_times = time.time() - start_time

print("*************")
print(f"Test Accuracy: {nr_correct.decrypt().numpy()/nr_samples},\
      Avg. Time Per Sample: {batch_times/nr_samples :.4} sec.")

Batch Nr. 0
Batch Nr. 1
*************
Test Accuracy: [0.109375],      Avg. Time Per Sample: 1.168 sec.


In [16]:
encrypted_result

(Wrapper)>AutogradTensor>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:56204438389 -> sam:47442428026]
	-> [PointerTensor | me:60568323051 -> kelly:24972194743]
	*crypto provider: crypto_provider*

# 4. Notable Results - Comparison 
Some example differences between the two protocols. Of course these results highly depend on how the different parameters are tuned and there are only very few training steps being considered. *This is a good place for you to start experiment for yourself!* 
<br>
* Experiments: 64 samples with batch_size of 32, 2 training epochs 
  * **FSS:** 
    * **SGD + Momentum:** lr=0.01, momentum=0.9, **MSE-Loss**
      * Avg.-Pool: Test Accuracy: 0.09380, Avg. time per sample: 0.833 sec. 
      * Max.-Pool: Test Accuracy: 0.09375, Avg. Time Per Sample: 1.433 sec.
    * **SGD + Momentum:** lr=0.01, momentum=0.9, **Cross-Entropy-Loss**
      * Avg.-Pool: Test Accuracy: **0.140625**, Avg. Time Per Sample: **0.733 sec.**
      * Max.-Pool: Test Accuracy: 0.140625, Avg. Time Per Sample: 1.427 sec.
    
  * **SNN:** 
    * **SGD + Momentum:** lr=0.01, momentum=0.9, **MSE-Loss**
      * Avg.-Pool: Test Accuracy: 0.09375, Avg. Time Per Sample: 4.624 sec.
      * Max.-Pool: Test Accuracy: 0.09375, Avg. Time Per Sample: 7.83 sec.
    * **SGD + Momentum:** lr=0.01, momentum=0.9, **Cross-Entropy-Loss**
      * Avg.-Pool: Test Accuracy: **0.140625**, Avg. Time Per Sample: **4.242 sec.**
      * Max.-Pool: Test Accuracy: 0.140625, Avg. Time Per Sample: 7.02 sec.

*Note: to reproduce this experiments you might need to restart the run-time between the training of different models or using different encrpytions.*

# 5. Curious Section
This section contains some extra details for those who want to understand the implementations of the protocols or who want to get deeper into the theoretical basics and respective papers. 

## SecureNN
### Possible computations - Low Level
This isn't a general-purpose protocol, which can compute all possible kinds of computations that are used in training NNs, but by giving up this flexibility (e.g. no usage of garbled circuits) we gain efficiency which is a vital criterion for the real-world-applicability of an encryption protocol.

* Matrix Multiplication (SPDZ - Necessary Beaver Triples can also be generated for matrix multiplication)
* Select Share
  * Select one variable out of multiple private variables to be freshly masked and shared for a new computation. Used for e.g. for maxpool (i.e. select max element out of kernel-elements)
* Private Compare
  * Compare public variable with private variable. Used to compute the ReLU function (ReLU = max(x,0)) Used for e.g. for computation of MSB (see below)
* Share Convert
  * Convert private variables from one number space (a "ring") to another. Remember during computation the numbers are plain integers encoded by a bit-sequence of length L. After "share convert" they are encoded as a bit-sequence of length L-1. Used for e.g. for computation of derivative of ReLU.
* Compute MSB (Most-Significant-Bit)
  * Efficient reading of the sign-bit (is input integer positive or negative) mainly to compute the derivative of the ReLU function.
* Non-linear functions:
  * torch.log(), torch.exp()
  
### Security Guarantees 
* Full Security includes **privacy and correctness**
* The following guarantees hold for all settings where there is a majority of honest participating servers (**Not in dishonest majority setting!**)
    1. **Full Security for semi-honest corruption of a server** 
        * Privacy and Correctness of the data is secured if a server is being *honest-but-curious.* Meaning that follows the protocol but tries to infer as much information about the data it sees as possible. 
    2. **Privacy against malicious server** 
        * Even a server which doesn't follow the given protocol can't learn anything about the inputs and outputs of the other (honest) servers, *given that the majority of the participants are honest!* This is a common assumption for the malicious case because in a real setting deviating from the given protocol could be prevented with additional measures, as seen below.
    3. Potential Add-on: **Security with Abort**
        * Protection against malicious servers can be guaranteed by adding [MAC authentication](https://en.wikipedia.org/wiki/Message_authentication_code) to the protocol. This would allow aborting the protocol as soon as one of the servers doesn't follow the protocol anymore.

### Performance Evaluation - Important Metrics 
* Division is possible but very slow! 
* Important Metrics: 
    * Round Complexity - *How many steps does the protocol involve. A step described as exchanging one message.* 
    * Communication Complexity - *How much bits are being exchanged during the protocol*

See pictures from SecureNN paper.
> "The function
Linear m,n,v denotes a matrix multiplication of dimen-
sion m × n with n × v. Conv2dm, i,f,o denotes a convo-
lutional layer with input m × m, i input channels, a
filter of size f × f, and o output channels. lD denotes,
precision of bits. Maxpooln and DMPn denotes Maxpool
and its derivative over n elements. All communication is measured for l−bit inputs and p ,
denotes the field size (which is 67 in our case)"

<table>
    <tr>
        <td><img src="../material/SNN_complexity.png" alt="SNN complexities" width="90%" height="auto" /></td>
        <td><img src="../material/SNN_dependencies.png" alt="SNN dependencies" width="90%" height="auto" /></td>
    </tr>
</table>

## FSS 
### Possible computations - Low Level 
Similar to SecureNN the current FSS protocol implemented in PySyft is focused on efficiency specifically for common computations in Machine Learning, thus leading to decreased flexibility but also allowing decreased computation time for relevant computations. <br>
The following computations are supported: 

* Matrix Multiplication (SPDZ)
* Equality Test 
  * It is checked whether a public value equals a private (i.e. shared) value 
* Comparison 
  * Inequality between a public value/expression and private value/expression
* Non-linear functions
  * torch.log(), torch.exp()


### Performance Evaluation - Important Metrics 
* FSS doesn't require lots of communication rounds. But is computationally more intensive than SecureNN (because it includes calling a Pseudorandom Generator (PRG) many times) 
* Important Metrics: 
  * Round Complexity - *How many steps does the protocol involve. A step described as exchanging one message.* 
  * Communication Complexity - *How much bits are being exchanged during the protocol* 
  
See the table from the above mentioned AriaNN paper as a good overview.
  
<img src="../material/FSS_complexity_comparison_to_FALCON.png" width="60%" height="auto" alt="FSS complexity" />

[[48]](https://arxiv.org/abs/2004.02229) *Sameer Wagh, Shruti Tople, Fabrice Benhamouda, Eyal Kushilevitz, Prateek Mittal, and Tal Rabin. **Falcon:** Honest-majority maliciously secure framework for private deep learning. arXiv preprint arXiv:2004.02229, 2020.*

# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aissue+is%3Aopen+label%3A%22Good+first+issue+%3Amortar_board%3A%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)