# Training an Encrypted Neural Network

In this tutorial, we will walk through an example of how we can train a neural network with CrypTen. This is particularly relevant for the <i>Feature Aggregation</i>, <i>Data Labeling</i> and <i>Data Augmentation</i> use cases. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.

For concreteness, this tutorial will step through the <i>Feature Aggregation</i> use cases: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private. 

## Setup
As usual, we'll begin by importing and initializing the `crypten` and `torch` libraries.  

We will use the MNIST dataset to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`. Let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll again use our helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.

For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. For speed of execution in the notebook, we will only create a dataset of a 100 examples.

In [1]:
import crypten
import torch
import time
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

crypten.init()
torch.set_num_threads(1)

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [3]:
%run ./mnist_utils.py --option features --reduced 200

tensor([5, 0, 4,  ..., 5, 6, 8])
features /tmp/alice_train.pth


Next, we'll define the network architecture below, and then describe how to train it on encrypted data in the next section. 

In [4]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

crypten.common.serial.register_safe_class(ExampleNet)

## Encrypted Training

After all the material we've covered in earlier tutorials, we only need to know a few additional items for encrypted training. We'll first discuss how the training loop in CrypTen differs from PyTorch. Then, we'll go through a complete example to illustrate training on encrypted data from end-to-end.

### How does CrypTen training differ from PyTorch training?

There are two main ways implementing a CrypTen training loop differs from a PyTorch training loop. We'll describe these items first, and then illustrate them with small examples below.

<i>(1) Use one-hot encoding</i>: CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.

<i>(2) Directly update parameters</i>: CrypTen does not use the PyTorch optimizers. Instead, CrypTen implements encrypted SGD by implementing its own `backward` function, followed by directly updating the parameters. As we will see below, using SGD in CrypTen is very similar to using the PyTorch optimizers.

We now show some small examples to illustrate these differences. As before, we will assume Alice has the rank 0 process and Bob has the rank 1 process.

In [5]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

In [6]:
# Load Alice's data 
data_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)

### A Complete Example

We now put these pieces together for a complete example of training a network in a multi-party setting. 

As in Tutorial 3, we'll assume Alice has the rank 0 process, and Bob has the rank 1 process; so we'll load and encrypt Alice's data with `src=0`, and load and encrypt Bob's data with `src=1`. We'll then initialize a plaintext model and convert it to an encrypted model, just as we did in Tutorial 4. We'll finally define our loss function, training parameters, and run SGD on the encrypted data. For the purposes of this tutorial we train on 100 samples; training should complete in ~3 minutes per epoch.

In [7]:
loaded_model = crypten.load("model.pth")
print(loaded_model)

Graph encrypted module


In [8]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    y_eye = torch.eye(10)
    
    batch_num = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            X = crypten.cryptensor(X)
            pred = model(X)
            y_one_hot = crypten.cryptensor(y_eye[y])

            test_loss += loss_fn(pred, y_one_hot)
            print(test_loss.get_plain_text())
            
            plaintext_pred = pred.get_plain_text()
            correct += (plaintext_pred.argmax(1) == y).type(torch.float).sum().item()
            batch_num += 1
            crypten.print(f"batch number: {batch_num}/{num_batches}")
    test_loss = test_loss.get_plain_text() / num_batches
    correct /= size
    crypten.print(f"Test Error: \n Accuracy: {correct}, Avg loss: {test_loss} \n")

In [9]:
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [10]:
batch_size = 64

# Create data loaders.
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [11]:
import crypten.mpc as mpc
import crypten.communicator as comm

# Convert labels to one-hot encoding
# Since labels are public in this use case, we will simply use them from loaded torch tensors
labels = torch.load('/tmp/train_labels.pth')
labels = labels.long()

dummy_input = torch.empty(1, 1, 28, 28)

label_eye = torch.eye(10)
labels_one_hot = label_eye[labels]

@mpc.run_multiprocess(world_size=2)
def run_encrypted_training():
    start_time = time.perf_counter()
    
    # Load data:
    x_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)
    x_bob_enc = crypten.load_from_party('/tmp/bob_train.pth', src=BOB)
    
    crypten.print(x_alice_enc.size())
    crypten.print(x_bob_enc.size())
    
    # Combine the feature sets: identical to Tutorial 3
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    
    # Reshape to match the network architecture
    x_combined_enc = x_combined_enc.unsqueeze(1)
    
    
    # Commenting out due to intermittent failure in PyTorch codebase
    
    # Initialize a plaintext model and convert to CrypTen model
    pytorch_model = ExampleNet()
    model = crypten.nn.from_pytorch(pytorch_model, dummy_input)
    crypten.print("model", type(model))
    model.encrypt()
    crypten.print("encrypted model", model)
    # Set train mode
    model.train()
    crypten.print("train model", model)
  
    # Define a loss function
    loss = crypten.nn.CrossEntropyLoss()

    # Define training parameters
    learning_rate = 0.001
    num_epochs = 3
    batch_size = 64
    total_data_size = x_combined_enc.size(0)
    num_batches = total_data_size // batch_size
    
    rank = comm.get().get_rank()
    for i in range(num_epochs): 
        crypten.print(f"Epoch {i} in progress:")       
        for batch in range(num_batches):
            t_prev = time.perf_counter()
            # define the start and end of the training mini-batch
            start, end = batch * batch_size, (batch + 1) * batch_size
                                    
            # construct CrypTensors out of training examples / labels
            x_train = x_combined_enc[start:end]
            y_batch = labels_one_hot[start:end]
            y_train = crypten.cryptensor(y_batch, requires_grad=True)
            
            # perform forward pass:
            output = model(x_train)
            loss_value = loss(output, y_train)
            
            # set gradients to "zero" 
            model.zero_grad()

            # perform backward pass: 
            loss_value.backward()

            # update parameters
            model.update_parameters(learning_rate)
            
            # Print progress every batch:
            batch_loss = loss_value.get_plain_text()
            crypten.print(f"\tBatch {(batch + 1)} of {num_batches} Loss {batch_loss.item():.4f} [{end}/{total_data_size}]")
            t_now = time.perf_counter()
            crypten.print(f"batch time taken: {t_now - t_prev}, total time so far: {t_now - start_time}")
        test(test_dataloader, model, crypten.nn.CrossEntropyLoss())
        model.train()

    end_time = time.perf_counter()
    
    print("total time:", end_time-start_time)

    # # Convert encrypted tensors to PyTorch tensors
    # torch_model = ExampleNet()  # Define equivalent PyTorch model architecture

    # # Copy the weights and biases from Crypten to PyTorch
    # with torch.no_grad():
    #     torch_model.weight.copy_(model.linear.weight.get_plain_text())
    #     torch_model.bias.copy_(model.linear.bias.get_plain_text())

    crypten.print("train model", model)
    
    crypten.save(model, "model.pth")

run_encrypted_training()

torch.Size([200, 28, 20])
torch.Size([200, 28, 8])


  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


model <class 'crypten.nn.module.Graph'>
encrypted model Graph encrypted module
train model Graph encrypted module
Epoch 0 in progress:
	Batch 1 of 3 Loss 2.3146 [64/200]
batch time taken: 2.6067699720151722, total time so far: 2.8189547969959676
	Batch 2 of 3 Loss 2.3190 [128/200]
batch time taken: 2.659256358863786, total time so far: 5.48044631886296


KeyboardInterrupt: 

	Batch 3 of 3 Loss 2.3203 [192/200]
batch time taken: 2.5861893019173294, total time so far: 8.06855287286453
tensor(2.3141)tensor(2.3141)

batch number: 1/157
tensor(4.6277)tensor(4.6277)

batch number: 2/157
tensor(6.9527)tensor(6.9527)

batch number: 3/157
tensor(9.2738)tensor(9.2738)

batch number: 4/157
tensor(11.5864)tensor(11.5864)

batch number: 5/157
tensor(13.9028)tensor(13.9028)

batch number: 6/157
tensor(16.2282)tensor(16.2282)

batch number: 7/157
tensor(18.5490)tensor(18.5490)

batch number: 8/157
tensor(20.8605)tensor(20.8605)

batch number: 9/157
tensor(23.1850)tensor(23.1850)

batch number: 10/157
tensor(25.5125)tensor(25.5125)

batch number: 11/157
tensor(27.8234)tensor(27.8234)

batch number: 12/157
tensor(30.1405)tensor(30.1405)

batch number: 13/157
tensor(32.4464)tensor(32.4464)

batch number: 14/157
tensor(34.7699)tensor(34.7699)

batch number: 15/157
tensor(37.0873)tensor(37.0873)

batch number: 16/157
tensor(39.4036)tensor(39.4036)

batch number: 17/157
tensor

We see that the average batch loss decreases across the epochs, as we expect during training.

This completes our tutorial. Before exiting this tutorial, please clean up the files generated using the following code.

In [None]:
import os

filenames = ['/tmp/alice_train.pth', 
             '/tmp/bob_train.pth', 
             '/tmp/alice_test.pth',
             '/tmp/bob_test.pth', 
             '/tmp/train_labels.pth',
             '/tmp/test_labels.pth']

for fn in filenames:
    if os.path.exists(fn): os.remove(fn)