Now it's time to implement our first encrypted machine learning evaluation using TenSEAL! We will use the Tox21 dataset (released by NIH and EPA), which provides different molecules (each is a 1024-bit vector), along with a label vector of 12 classes, representing different toxic effects.

The goal is to first train a model on plain data using PyTorch, then doing encrypted evaluation on the entire test set using TenSEAL. You can later see that we get approximately the same results doing the evaluation on plain or encrypted data.

In [None]:
import tenseal as ts
import torch as th
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

th.manual_seed(73)

The Tox21 dataset was downloaded using deepchem, and prepared as torch tensors for convenience.

In [None]:
train_X = th.load("data/train_X.pt")
train_y = th.load("data/train_y.pt")
test_X = th.load("data/test_X.pt")
test_y = th.load("data/test_y.pt")

Let's create some data loaders for our data set

In [None]:
# Training dataset
train_dataset = TensorDataset(train_X, train_y)
train_loader = DataLoader(train_dataset, batch_size=64)
# Test dataset
test_dataset = TensorDataset(test_X, test_y)
test_loader = DataLoader(test_dataset, batch_size=1)

We now define a PyTorch model, consisting of 2 linear layers, and a square activation function between them. The choice of the square activation is not random. We can choose to approximate non-linear activation functions such as sigmoid, which may require a higher polynomial (thus bigger homomorphic encryption parameters). For simplicity, we only use the square activation, which requires a single multiplication.

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1024, 128)
        self.fc2 = nn.Linear(128, 12)
        
    def forward(self, x):
        out = self.fc1(x)
        out = out * out
        out = self.fc2(out)
        return out

In [None]:
def train(model, device, train_loader, optimizer, criterion, epochs):
    losses = []
    for epoch in range(1, epochs + 1):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        
        model.eval()
        print('Train Epoch: {:2d}   Avg Loss: {:.6f}'.format(epoch, th.mean(th.tensor(losses))))

    return model

In [None]:
model = Model()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
device = th.device("cuda" if th.cuda.is_available() else "cpu")

model = train(model, device, train_loader, optimizer, criterion, 30)

In [None]:
def compute_labels(out):
    out = th.sigmoid(out)
    return (out >= 0.5).int()


# compute accuracy using hamming loss
def accuracy(output, target):
    # convert to labels
    out = compute_labels(output)
    # flatten and compute hamming loss
    flat_out = out.flatten()
    flat_target = target.flatten()
    incorrect = th.logical_xor(flat_out, flat_target).sum().item()
    hamming_loss = incorrect / len(flat_out)
    return 1 - hamming_loss


print("Accuracy on test set: {:.2f}".format(accuracy(model(test_X), test_y)))

Now we define a PyTorch-like model, but which uses TenSEAL operations. During initialization, we fetch and store weights from PyTorch layers. The forward method will then use the stored weights to perform linear layers.

In [None]:
class HEModel:
    def __init__(self, fc1, fc2):
        self.fc1_weight = fc1.weight.t().tolist()
        self.fc1_bias = fc1.bias.tolist()
        self.fc2_weight = fc2.weight.t().tolist()
        self.fc2_bias = fc2.bias.tolist()
        
    def forward(self, encrypted_vec):
        # first fc layer + square activation function
        # TODO: ~ 2 lines of code
        # second fc layer
        # TODO: ~ 1 lines of code
        return encrypted_vec
    
    def __call__(self, x):
        return self.forward(x)

We have previously discussed some intuitions on how to choose the encryption parameters, and now we have to put them into practice. You have to decide which encryption parameters will work for your evaluation. Don't worry if you don't get them right from the beginning, it's an iterative process, you experiment some parameters, then optimize.

In [None]:
# TODO: Choose parameters
bits_scale = None
coeff_mod_bit_sizes = None
polynomial_modulus_degree = None

# Create context
context = ts.context(ts.SCHEME_TYPE.CKKS, polynomial_modulus_degree, coeff_mod_bit_sizes=coeff_mod_bit_sizes)
# Set global scale
context.global_scale = 2 ** bits_scale
# Generate galois keys required for matmul in ckks_vector
context.generate_galois_keys()

he_model = HEModel(model.fc1, model.fc2)

Encrypted evaluation can now be done on the entire test set, one by one. We start by encrypting the vector, then do the encrypted evaluation, and finally decrypt the result. If this was implemented using two parties, then the encrypted vector will be sent over for remote evaluation, and the encrypted result will be sent back for decryption.

In [None]:
# how many labels in the encrypted evaluation are the same as in the plain evaluation?
match = 0
he_outs = []
for data, _ in test_loader:
    # remove batch axis, we only need a flat vector
    vec = data.flatten()
    # encryption
    # TODO: ~ 1 line of code
    # encrypted evaluation
    encrypted_out = he_model(encrypted_vec)
    # decryption
    he_out = th.tensor(encrypted_out.decrypt())
    he_outs.append(he_out.tolist())
    out = model(data)
    # how many labels match
    he_labels = compute_labels(he_out)
    plain_labels = compute_labels(out)
    match += (he_labels == plain_labels).sum().item()

In [None]:
print("Accuracy on test set (encrypted evaluation): {:.2f}".format(accuracy(th.tensor(he_outs), test_y)))
print("Encrypted evaluation matched {:.1f}% of the labels from the plain evaluation".format(
    match / (12 * len(test_loader)) * 100)
)