# Functional Encryption - Classification and information leakage

Our start point is the work on encrypted classification using Function Encryption of the paper [Reading in the Dark: Classifying Encrypted Digits with Functional Encryption](https://eprint.iacr.org/2018/206), and the associated [GitHub repository](https://github.com/edufoursans/reading-in-the-dark).

More specifically, the paper provides a new Functional Encryption scheme for quadratic multi-variate polynomials, which can under some hypothesis be seen as a single hidden layer neural network with a quadratic activation.
In the paper, the output corresponds to element per class, and it is made in clear. We analyse how this output can disclose information about the initial input or about charasteristics of this input.

To this aim, we have just built a dataset which is very similar to MNIST, used in the original paper but which is composed of 26 letter characters of 5 differents fonts. Our goal is two-fold:
 - Evaluate how the output in clear can be leverage with a public NN to make better prediction than a simple `argmax` function in the character recognition task.
 - Analyse to what extent the output in clear of the model trained for character recognition can reveal information about the font used, using an "adversarial" network.
 
**Purpose**

We have studied many aspects of the problem with Functional Encryption, and we are now interested to see how the problem evolves when using Secure Multiparty Computation. As in Part 5, we will be using **fixed precision** tensors. We'll leverage here the PySyft library which provides a way to directly use PyTorch in a fixed precision + securely shared scheme!

The setting is completely different: here the protocol is now interactive.

# 4 Quadratic model to Additive Sharing Tensor


Let's define the precision fractional to 3 as before

In [1]:
PREC_FRAC = 3

Load torch and syft packages

In [2]:
# Allow to load packages from parent
import sys, os
sys.path.insert(1, os.path.realpath(os.path.pardir))

In [3]:
import time
import torch
import syft as sy
sy.create_sandbox(globals(), verbose=False, download_data=False)

Setting up Sandbox...
Done!


In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
from learn import main, train, test, show_results, show_confusion_matrix

Let's load the quadratic model that we saved in Part 4! _Be sure that the path and file name match._

In [5]:
class QuadNet(nn.Module):
    def __init__(self, output_size):
        super(QuadNet, self).__init__()
        self.proj1 = nn.Linear(784, 50)
        self.diag1 = nn.Linear(50, output_size)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.proj1(x)
        x = x * x
        x = self.diag1(x)
        return F.log_softmax(x, dim=1)
    
    def transform(self, x):
        x = x.view(-1, 784)
        x = self.proj1(x)
        x = x * x
        x = self.diag1(x)
        return x

In [6]:
model = QuadNet(26)
path = '../data/models/quad_char.pt'
model.load_state_dict(torch.load(path))
model.eval()

QuadNet(
  (proj1): Linear(in_features=784, out_features=50, bias=True)
  (diag1): Linear(in_features=50, out_features=26, bias=True)
)

We will now convert the model into fixed precision, look how the `diag1.bias` changes for example!

In [7]:
model.diag1.bias

Parameter containing:
tensor([-0.0080,  0.0055,  0.0057, -0.0033,  0.0100, -0.0097, -0.0070, -0.0024,
         0.0154, -0.0010,  0.0280,  0.0104, -0.0033, -0.0095, -0.0225,  0.0200,
         0.0206, -0.0316, -0.0121, -0.0407, -0.0133, -0.0193,  0.0031,  0.0142,
         0.0110,  0.0231], requires_grad=True)

In [8]:
model.fix_precision(precision_fractional=PREC_FRAC).share(alice, bob, crypto_provider=jason)

QuadNet(
  (proj1): Linear(in_features=784, out_features=50, bias=True)
  (diag1): Linear(in_features=50, out_features=26, bias=True)
)

In [9]:
field = model.diag1.bias.child.field
field

4611686018427387903

In [10]:
model.diag1.weight.child.child.child

[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:91293187908 -> alice:80299753931]
	-> (Wrapper)>[PointerTensor | me:66503119196 -> bob:71856611927]
	*crypto provider: jason*

We now define the components which are necessary for performing an evaluation

In [11]:
import learn

In [12]:
class Parser:
    """Parameters for the testing"""
    def __init__(self):
        self.test_batch_size = 1000

And we load the data!

In [13]:
torch.manual_seed(1)
args = Parser()
args.test_batch_size = 10

data = learn.load_data()
train_data, train_target_char, train_target_family, test_data, test_target_char, test_target_family = data
test_target = test_target_char
test_dataset = learn.build_tensor_dataset(test_data, test_target)
test_loader = utils.DataLoader(
    test_dataset,
    batch_size=args.test_batch_size, shuffle=True
)

Training set 60000 items
Testing set  10000 items


Here comes the test phase, which in very close to `learn.test`. However, as you see we convert the data into fixed precision, and instead of a full forward pass, we omit the last log_softmax (by using `.transform()`) as it should not be applied in the encryption part so not be applied on the integers. Hence, we apply it after the output is converted back to float.

In [14]:
def test(model, test_loader, prec_frac):
    test_loss = -1
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            start_time = time.time()
            data.fix_precision_(precision_fractional=prec_frac).share_(alice, bob, crypto_provider=jason) # <-- This is new
            output = model.transform(data) # <-- Not calling forward to avoid the log_softmax
            forward_time = time.time()
            pred = output.argmax(dim=1)
            argmax_time = time.time()
            pred = pred.get().float_precision().long() # <-- This is new
            #pred = output.argmax(1, keepdim=True)  # get the index of the max log-probability
            total_time = time.time()
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += args.test_batch_size
            print('acc (tmp)', round(100.*correct/total, 2), '%')
            total_duration = total_time-start_time
            print(
                round(total_duration/args.test_batch_size, 2), "s, ", 
                round(100*(argmax_time-forward_time)/total_duration, 5), "% argmax"
            )

    test_loss /= len(test_loader.dataset)

    acc = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), acc))
    
    return acc

test(model, test_loader, PREC_FRAC)


acc (tmp) 100.0 %
0.16 s,  97.80646 % argmax
acc (tmp) 85.0 %
0.16 s,  98.29862 % argmax
acc (tmp) 86.67 %
0.17 s,  98.14854 % argmax
acc (tmp) 87.5 %
0.18 s,  98.28187 % argmax
acc (tmp) 90.0 %
0.16 s,  97.76471 % argmax
acc (tmp) 91.67 %
0.18 s,  98.32203 % argmax
acc (tmp) 91.43 %
0.18 s,  98.36532 % argmax
acc (tmp) 91.25 %
0.18 s,  98.19131 % argmax
acc (tmp) 90.0 %
0.18 s,  98.04364 % argmax
acc (tmp) 89.0 %
0.18 s,  98.2559 % argmax
acc (tmp) 90.0 %
0.18 s,  98.18227 % argmax


KeyboardInterrupt: 

## Conclusion

As you observe, the model behave properly as we achieve the same accuracy (at 94.6%). However, the argmax function computed privately is still very slow and makes a complete evaluation of the testing set unpractical. The time per sample is approximately of 1.3 seconds, which is a bit less than FE (Article [Reading in the dark](https://eprint.iacr.org/2018/206.pdf) reports a 4.8s delay in comparison, but for a 10 class output. The security settings can't be compared as the MPC is provided in an honest but curious setting and relies on information theoretic security. Transforming the MPC implementation with higher security standard would probably lead to an important overhead (use of zero knowledge, etc). In return, FE leaks critical information as all this project demonstrate.