# Install Crypten
Currently pip version of this framework is unstable due to some version dependency. It needs to be install from the source. [Issue link](https://github.com/facebookresearch/CrypTen/issues/391). 

Ignore this part if you have crypten already installed. This is exclusively for Google colab.

In [None]:
!git clone https://github.com/facebookresearch/CrypTen.git
%cd CrypTen
# after this commit some version dependency is broken
!git checkout efe8edad571be1c586d0d9cefc562d562d4e9aa1
# !python setup.py install --user
%pip install -e .

## Check installed version

In [None]:
!pip show crypten

Name: crypten
Version: 0.4.0
Summary: CrypTen: secure machine learning in PyTorch.
Home-page: https://github.com/facebookresearch/CrypTen
Author: Facebook AI Research
Author-email: None
License: MIT licensed, as found in the LICENSE file
Location: /root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg
Requires: torch, torchvision, omegaconf, onnx, pandas, pyyaml, tensorboard, future, scipy, sklearn
Required-by: 


## Fix existing bug
[Issue link](https://github.com/facebookresearch/CrypTen/issues/438). Due to "/config" string in the setup.py of this framework, the crypten configs are not copied properly. You can either change "/config" to "config" manually or do the following.

Update: Seems `pip install -e .` inside Cryten solves it.

In [None]:
# current setup file doesn't copy the default.yaml correctly in the configs folder
# !cp configs/default.yaml /root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/configs/

# Restart the runtime
You would need to restart the kernel runtime to load the newly installed crypten module. If you have restarted no need to run the prior cells. You can just start from here. 

# Tutorial 2: Inside CrypTensors
This notebook is adapted from the original source tutorial [Tutorial_2_Inside_CrypTensors.ipynb](https://github.com/facebookresearch/CrypTen/blob/main/tutorials/Tutorial_2_Inside_CrypTensors.ipynb).

Note: This tutorial is optional, and can be skipped without any loss of continuity to the following tutorials.


In this tutorial, we will take a brief look at the internals of ```CrypTensors```. 

Using the `mpc` backend, a `CrypTensor` is a tensor encrypted using secure MPC protocols, called an `MPCTensor`. In order to support the mathematical operations required by the `MPCTensor`, CrypTen implements two kinds of secret-sharing protocols: arithmetic secret-sharing and binary secret-sharing. Arithmetic secret sharing forms the basis for most of the mathematical operations implemented by `MPCTensor`. Similarly, binary secret-sharing allows for the evaluation of logical expressions.

In this tutorial, we'll first introduce the concept of a `CrypTensor` <i>ptype</i> (i.e. <i>private-type</i>), and show how to use it to obtain `MPCTensors` that use arithmetic and binary secret shares. We will also describe how each of these <i>ptypes</i> is used, and how they can be combined to implement desired functionality.

In [None]:
#import the libraries
import crypten
import torch

# doesn't work in windows
#initialize crypten
crypten.init()
#Disables OpenMP threads -- needed by @mpc.run_multiprocess which uses fork
torch.set_num_threads(1)

## <i>ptype</i> in CrypTen
CrypTen defines the `ptype` (for <i>private-type</i>) attribute of an `MPCTensor` to denote the kind of secret-sharing protocol used in the `CrypTensor`. The `ptype` is, in many ways, analogous to the `dtype` of PyTorch. The `ptype` may have two values: 

- `crypten.mpc.arithmetic` for `ArithmeticSharedTensors`</li>
- `crypten.mpc.binary` for  `BinarySharedTensors`</li>

We can use the `ptype` attribute to create a `CrypTensor` with the appropriate secret-sharing protocol. For example: 

In [None]:
#Constructing CrypTensors with ptype attribute

#arithmetic secret-shared tensors
x_enc = crypten.cryptensor([1.0, 2.0, 3.0], ptype=crypten.mpc.arithmetic)
print("x_enc internal type:", x_enc.ptype)

#binary secret-shared tensors
y = torch.tensor([1, 2, 1], dtype=torch.int32)
y_enc = crypten.cryptensor(y, ptype=crypten.mpc.binary)
print("y_enc internal type:", y_enc.ptype)

x_enc internal type: ptype.arithmetic
y_enc internal type: ptype.binary


### Arithmetic secret-sharing
Let's look more closely at the `crypten.mpc.arithmetic` <i>ptype</i>. Most of the mathematical operations implemented by `CrypTensors` are implemented using arithmetic secret sharing. As such, `crypten.mpc.arithmetic` is the default <i>ptype</i> for newly generated `CrypTensors`. 

Let's begin by creating a new `CrypTensor` using `ptype=crypten.mpc.arithmetic` to enforce that the encryption is done via arithmetic secret sharing. We can print values of each share to confirm that values are being encrypted properly. 

To do so, we will need to create multiple parties to hold each share. We do this here using the `@mpc.run_multiprocess` function decorator, which we developed to execute crypten code from a single script (as we have in a Jupyter notebook). CrypTen follows the standard MPI programming model: it runs a separate process for each party, but each process runs an identical (complete) program. Each process has a `rank` variable to identify itself.

Note that the sum of the two `_tensor` attributes below is equal to a scaled representation of the input. (Because MPC requires values to be integers, we scale input floats to a fixed-point encoding before encryption.)

In [None]:
import crypten.mpc as mpc
import crypten.communicator as comm 

@mpc.run_multiprocess(world_size=2)
def examine_arithmetic_shares():
    x_enc = crypten.cryptensor([1, 2, 3], ptype=crypten.mpc.arithmetic)
    
    rank = comm.get().get_rank()
    crypten.print(f"\nRank {rank}:\n {x_enc}\n", in_order=True)
        
x = examine_arithmetic_shares()


Rank 0:
 MPCTensor(
	_tensor=tensor([ 5497951077848605038,  8173160353998178085, -3478346136588492155])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)


Rank 1:
 MPCTensor(
	_tensor=tensor([-5497951077848539502, -8173160353998047013,  3478346136588688763])
	plain_text=HIDDEN
	ptype=ptype.arithmetic
)



### Binary secret-sharing
The second type of secret-sharing implemented in CrypTen is binary or XOR secret-sharing. This type of secret-sharing allows greater efficiency in evaluating logical expressions. 

Let's look more closely at the `crypten.mpc.binary` <i>ptype</i>. Most of the logical operations implemented by `CrypTensors` are implemented using arithmetic secret sharing. We typically use this type of secret-sharing when we want to evaluate binary operators (i.e. `^ & | >> <<`, etc.) or logical operations (like comparitors).

Let's begin by creating a new `CrypTensor` using `ptype=crypten.mpc.binary` to enforce that the encryption is done via binary secret sharing. We can print values of each share to confirm that values are being encrypted properly, as we did for arithmetic secret-shares.

(Note that an xor of the two `_tensor` attributes below is equal to an unscaled version of input.)

In [None]:
@mpc.run_multiprocess(world_size=2)
def examine_binary_shares():
    x_enc = crypten.cryptensor([2, 3], ptype=crypten.mpc.binary)
    
    rank = comm.get().get_rank()
    crypten.print(f"\nRank {rank}:\n {x_enc}\n", in_order=True)
        
x = examine_binary_shares()


Rank 0:
 MPCTensor(
	_tensor=tensor([ 3865989190094285849, -2456286262483940891])
	plain_text=HIDDEN
	ptype=ptype.binary
)


Rank 1:
 MPCTensor(
	_tensor=tensor([ 3865989190094285851, -2456286262483940890])
	plain_text=HIDDEN
	ptype=ptype.binary
)



### Using Both Secret-sharing Protocols
Quite often a mathematical function may need to use both additive and XOR secret sharing for efficient evaluation.  Functions that require conversions between sharing types include comparators (`>, >=, <, <=, ==, !=`) as well as functions derived from them (`abs, sign, relu`, etc.). For a full list of supported functions, please see the CrypTen documentation.

CrypTen provides functionality that allows for the conversion of between <i>ptypes</i>. Conversion between <i>ptypes</i> can be done using the `.to()` function with a `crypten.ptype` input, or by calling the `.arithmetic()` and `.binary()` conversion functions.

In [None]:
from crypten.mpc import MPCTensor

@mpc.run_multiprocess(world_size=2)
def examine_conversion():
    x = torch.tensor([1, 2, 3])
    rank = comm.get().get_rank()

    # create an MPCTensor with arithmetic secret sharing
    x_enc_arithmetic = MPCTensor(x, ptype=crypten.mpc.arithmetic)
    
    # To binary
    x_enc_binary = x_enc_arithmetic.to(crypten.mpc.binary)
    x_from_binary = x_enc_binary.get_plain_text()
    
    # print only once
    crypten.print("to(crypten.binary):")
    crypten.print(f"  ptype: {x_enc_binary.ptype}\n  plaintext: {x_from_binary}\n")

        
    # To arithmetic
    x_enc_arithmetic = x_enc_arithmetic.to(crypten.mpc.arithmetic)
    x_from_arithmetic = x_enc_arithmetic.get_plain_text()
    
    # print only once
    crypten.print("to(crypten.arithmetic):")
    crypten.print(f"  ptype: {x_enc_arithmetic.ptype}\n  plaintext: {x_from_arithmetic}\n")

        
z = examine_conversion()

to(crypten.binary):
  ptype: ptype.binary
  plaintext: tensor([1., 2., 3.])

to(crypten.arithmetic):
  ptype: ptype.arithmetic
  plaintext: tensor([1., 2., 3.])



## Data Sources
CrypTen follows the standard MPI programming model: it runs a separate process for each party, but each process runs an identical (complete) program. Each process has a `rank` variable to identify itself.

If the process with rank `i` is the source of data `x`, then `x` gets encrypted with `i` as its source value (denoted as `src`). However, MPI protocols require that both processes to provide a tensor with the same size as their input. CrypTen ignores all data provided from non-source processes when encrypting.

In the next example, we'll show how to use the `rank` and `src` values to encrypt tensors. Here, we will have each of 3 parties generate a value `x` which is equal to its own `rank` value. Within the loop, 3 encrypted tensors are created, each with a different source. When these tensors are decrypted, we can verify that the tensors are generated using the tensor provided by the source process.

(Note that `crypten.cryptensor` uses rank 0 as the default source if none is provided.)

In [None]:
@mpc.run_multiprocess(world_size=3)
def examine_sources():
    # Create a different tensor on each rank
    rank = comm.get().get_rank()
    x = torch.tensor(rank)
    crypten.print(f"Rank {rank}: {x}", in_order=True)
    
    # 
    world_size = comm.get().get_world_size()
    for i in range(world_size):
        x_enc = crypten.cryptensor(x, src=i)
        z = x_enc.get_plain_text()
        
        # Only print from one process to avoid duplicates
        crypten.print(f"Source {i}: {z}")
        
x = examine_sources()

Rank 0: 0
Rank 1: 1
Rank 2: 2
Source 0: 0.0
Source 1: 1.0
Source 2: 2.0


# Tutorial 4: Classification with Encrypted Neural Networks
This tutorial is adapted from [Tutorial_4_Classification_with_Encrypted_Neural_Networks.ipynb](https://github.com/facebookresearch/CrypTen/blob/main/tutorials/Tutorial_4_Classification_with_Encrypted_Neural_Networks.ipynb).

In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, suppose say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We will see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.

To simulate this scenario, we will begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

## Setup

We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. 

In [None]:
%cd CrypTen/tutorials/
%run ./mnist_utils.py --option train_v_test

/content/CrypTen/tutorials
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /tmp/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /tmp/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /tmp/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/MNIST/raw



Next, we will define the structure of Alice's network as a class. Even though Alice has a pre-trained model, the CrypTen will require this structure as input.

In [None]:
# Define Alice's network
import torch.nn as nn
import torch.nn.functional as F

class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
 
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out
    
crypten.common.serial.register_safe_class(AliceNet)

We will also define a helper routine `compute_accuracy` to make it easy to compute the accuracy of the output we get.

In [None]:
def compute_accuracy(output, labels):
    pred = output.argmax(1)
    correct = pred.eq(labels)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    return accuracy

## Encrypting a Pre-trained Model

Assume that Alice has a pre-trained network ready to classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. We'll use the pre-trained model in `models/tutorial4_alice_model.pth` in this tutorial. As in Tutorial 3, we will assume Alice is using the rank 0 process, while Bob is using the rank 1 process. 

In [None]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

In CrypTen, encrypting PyTorch network is straightforward: we load a PyTorch model from file to the appropriate source, convert it to a CrypTen model and then encrypt it. Let us understand each of these steps.

As we did with CrypTensors in Tutorial 3, we will use CrypTen's load functionality (i.e., `crypten.load`) to read a model from file to a particular source. The source is indicated by the keyword argument `src`. As in Tutorial 3, this src argument tells us the rank of the party we want to load the model to (and later, encrypt the model from). In addition, here we also need to provide a dummy model to tell CrypTen the model's structure. The dummy model is indicated by the keyword argument `dummy_model`. Note that unlike loading a tensor, the result from `crypten.load` is not encrypted. Instead, only the `src` party's model is populated from the file.

Once the model is loaded, we call the function `from_pytorch`: this function sets up a CrypTen network from the PyTorch network. It takes the plaintext network as input as well as dummy input. The dummy input must be a `torch` tensor of the same shape as a potential input to the network, however the values inside the tensor do not matter.  

Finally, we call `encrypt` on the CrypTen network to encrypt its parameters. Once we call the `encrypt` function, the models `encrypted` property will verify that the model parameters have been encrypted. (Encrypted CrypTen networks can also be decrypted using the `decrypt` function).

In [None]:
# Load pre-trained model to Alice
dummy_model = AliceNet()
plaintext_model = torch.load('models/tutorial4_alice_model.pth')

print(plaintext_model)

# Encrypt the model from Alice:    

# 1. Create a dummy input with the same shape as the model input
dummy_input = torch.empty((1, 784))

# 2. Construct a CrypTen network with the trained model and dummy_input
private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)

# 3. Encrypt the CrypTen network with src=ALICE
private_model.encrypt(src=ALICE)

#Check that model is encrypted:
print("Model successfully encrypted:", private_model.encrypted)

AliceNet(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)
Model successfully encrypted: True


  param = torch.from_numpy(numpy_helper.to_array(node))


## Classifying Encrypted Data with Encrypted Model

We can now use Alice's encrypted network to classify Bob's data. For this, we need to encrypt Bob's data as well, as we did in Tutorial 3 (recall that Bob has the rank 1 process). Once Alice's network and Bob's data are both encrypted, CrypTen inference is performed with essentially identical steps as in PyTorch. 

In [None]:
import crypten.mpc as mpc
import crypten.communicator as comm

labels = torch.load('/tmp/bob_test_labels.pth').long()
count = 100 # For illustration purposes, we'll use only 100 samples for classification

@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    # Load pre-trained model to Alice
    model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)
    
    # Encrypt model from Alice 
    dummy_input = torch.empty((1, 784))
    private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model.encrypt(src=ALICE)
    
    # Load data to Bob
    data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=BOB)
    data_enc2 = data_enc[:count]
    data_flatten = data_enc2.flatten(start_dim=1)

    # Classify the encrypted data
    private_model.eval()
    output_enc = private_model(data_flatten)
    
    # Compute the accuracy
    output = output_enc.get_plain_text()
    accuracy = compute_accuracy(output, labels[:count])
    crypten.print("\tAccuracy: {0:.4f}".format(accuracy.item()))
    
encrypt_model_and_data()

## Validating Encrypted Classification

Finally, we will verify that CrypTen classification results in encrypted output, and that this output can be decrypted into meaningful labels. 

To see this, in this tutorial, we will just check whether the result is an encrypted tensor; in the next tutorial, we will look into the values of tensor and confirm the encryption. We will also decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted output of the model, and can both use this to obtain the labels. 

In [None]:
@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    # Load pre-trained model to Alice
    plaintext_model = crypten.load_from_party('models/tutorial4_alice_model.pth', src=ALICE)
    
    # Encrypt model from Alice 
    dummy_input = torch.empty((1, 784))
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    private_model.encrypt(src=ALICE)
    
    # Load data to Bob
    data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=BOB)
    data_enc2 = data_enc[:count]
    data_flatten = data_enc2.flatten(start_dim=1)

    # Classify the encrypted data
    private_model.eval()
    output_enc = private_model(data_flatten)
    
    # Verify the results are encrypted: 
    crypten.print("Output tensor encrypted:", crypten.is_encrypted_tensor(output_enc)) 

    # Decrypting the result
    output = output_enc.get_plain_text()

    # Obtaining the labels
    pred = output.argmax(dim=1)
    crypten.print("Decrypted labels:\n", pred)
    
encrypt_model_and_data()

Process Process-16:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/crypten/mpc/context.py", line 30, in _launch
    return_value = func(*func_args, **func_kwargs)
  File "<ipython-input-18-247e6e674f50>", line 8, in encrypt_model_and_data
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
  File "/root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/crypten/nn/onnx_converter.py", line 47, in from_pytorch
    crypten_model = from_onnx(f)
  File "/root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/crypten/nn/onnx_converter.py", line 36, in from_onnx
    return _to_crypten(onnx_model)
  File "/root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/crypten/nn/onnx_

# Tutorial 7: Training an Encrypted Neural Network
Source adapted from [Tutorial_7_Training_an_Encrypted_Neural_Network.ipynb](https://github.com/facebookresearch/CrypTen/blob/main/tutorials/Tutorial_7_Training_an_Encrypted_Neural_Network.ipynb)

In this tutorial, we will walk through an example of how we can train a neural network with CrypTen. This is particularly relevant for the <i>Feature Aggregation</i>, <i>Data Labeling</i> and <i>Data Augmentation</i> use cases. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.

For concreteness, this tutorial will step through the <i>Feature Aggregation</i> use cases: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private. 

## Setup
As usual, we'll begin by importing and initializing the `crypten` and `torch` libraries.  

We will use the MNIST dataset to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`. Let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll again use our helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.

For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. For speed of execution in the notebook, we will only create a dataset of a 100 examples.

In [None]:
%run ./mnist_utils.py --option features --reduced 100 --binary

Next, we'll define the network architecture below, and then describe how to train it on encrypted data in the next section. 

In [None]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out
    
crypten.common.serial.register_safe_class(ExampleNet)

## Encrypted Training

After all the material we've covered in earlier tutorials, we only need to know a few additional items for encrypted training. We'll first discuss how the training loop in CrypTen differs from PyTorch. Then, we'll go through a complete example to illustrate training on encrypted data from end-to-end.



### How does CrypTen training differ from PyTorch training?

There are two main ways implementing a CrypTen training loop differs from a PyTorch training loop. We'll describe these items first, and then illustrate them with small examples below.

<i>(1) Use one-hot encoding</i>: CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.

<i>(2) Directly update parameters</i>: CrypTen does not use the PyTorch optimizers. Instead, CrypTen implements encrypted SGD by implementing its own `backward` function, followed by directly updating the parameters. As we will see below, using SGD in CrypTen is very similar to using the PyTorch optimizers.

We now show some small examples to illustrate these differences. As before, we will assume Alice has the rank 0 process and Bob has the rank 1 process.

In [None]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

# Load Alice's data 
data_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)

In [None]:
# We'll now set up the data for our small example below
# For illustration purposes, we will create toy data
# and encrypt all of it from source ALICE
x_small = torch.rand(100, 1, 28, 28)
y_small = torch.randint(1, (100,))

# Transform labels into one-hot encoding
label_eye = torch.eye(2)
y_one_hot = label_eye[y_small]

# Transform all data to CrypTensors
x_train = crypten.cryptensor(x_small, src=ALICE)
y_train = crypten.cryptensor(y_one_hot)

# Instantiate and encrypt a CrypTen model
model_plaintext = ExampleNet()
dummy_input = torch.empty(1, 1, 28, 28)
model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
model.encrypt()

Graph encrypted module

In [None]:
# Example: Stochastic Gradient Descent in CrypTen

model.train() # Change to training mode
loss = crypten.nn.MSELoss() # Choose loss functions

# Set parameters: learning rate, num_epochs
learning_rate = 0.001
num_epochs = 2

# Train the model: SGD on encrypted data
for i in range(num_epochs):

    # forward pass
    output = model(x_train)
    loss_value = loss(output, y_train)
    
    # set gradients to zero
    model.zero_grad()

    # perform backward pass
    loss_value.backward()

    # update parameters
    model.update_parameters(learning_rate) 
    
    # examine the loss after each epoch
    print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))

Epoch: 0 Loss: 0.5165
Epoch: 1 Loss: 0.4839
