# Encrypted Convolution on MNIST 

In this notebook we perform encrypted eval on MNIST Dataset, and for this we will use a single Neural Network compose of 1 Convulution layer and another 2 linear layers, for simplicity we are using the square fonction as an activation fonction 

## Model Description
The model is the sequence of the below layers:

- **Conv:** Convolution with 4 kernels. Shape of the kernel is 7x7. Strides are 3x3.
- **Activation:** Square activation function.
- **Linear Layer 1:** Input size: 256. Output size: 64.
- **Activation:** Square activation function.
- **Linear Layer 2:** Input size: 64. Output size: 10.

### Convolution 

for the convolution operation we will use the algo that translate the 2D conv into a single matrix multiplication and 

<div align="center">
<img src="assets/im2col_conv2d.png" width="50%"/>
<div><b>Figure1:</b> Image to column convolution</div>
</div>

**The figure is taken from the official TenSEAL Tutorials**

this operation requires arranging the elements of the matrix , and since we can't do that with the ciphertext so we will do a pre-processing before the encryption step.we first need to apply an *im2col* operation to the input matrix and encrypt it into a single ciphertext( we translate it into a single vecor using a vertical scan), then we do a matrix multiplication between the encrypted matrix  and the flattened kernel vector which replicate every element **n** times where **n** is the number of windows .then we do a ciphertext-plaintext multiplication witch a sequence of rotate and sum operations in order to sum the elements of a single window 


<div align="center">
<img src="assets/im2col_conv2d_ckks1.png" width="50%"/>
<div><b>Figure2:</b> Image to column convolution with CKKS - step 1</div>
</div>

<div align="center">
<img src="assets/im2col_conv2d_ckks2.png" width="50%"/>
<div><b>Figure3:</b> Image to column convolution with CKKS - step 2</div>
</div>

if we have multiple kernels so we need to do this operation multiple times and combines the results in a single vector which will be the input of the linear layer


### Linear Layer 
for the linear layer we will multiply the vector by the plain matrix and adding the plain bias, the multiplication is used based on the method explained in the figure below : 
<div align="center">
<img src="assets/vec-matmul.png" width="65%"/>
<div><b>Figure4:</b> Vector-Matrix Multiplication</div>
</div>

### Square fonction
the square fonction is very simple we need just to multiply the vector by itself 

after explaining each operation we conclude that we need 6 multiplications : 2 for the convolutions, 1 for the first square fonction , 1 for the first linear layer , 1 for the second square fonctions , 1 for the second linear layer

## Training 

now that we know how these operations work in theory we will implement a model of HE using the TenSEAL lib, but first we need to implement a Pytorch Model to classify the MNIST dataset

In [1]:
import torch 
from torch.utils.data import DataLoader
from torchvision import datasets 
from torchvision.transforms import transforms
import numpy as np 

train_data = datasets.MNIST('data',train=True,download =True,transform = transforms.ToTensor())
test_data = datasets.MNIST('data',train=False,download=True,transform = transforms.ToTensor())



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [182]:
batch_size = 100
train_dl = DataLoader(train_data,batch_size = batch_size,shuffle = True)
test_dl = DataLoader(test_data,batch_size= 1,shuffle = True)

In [50]:
# the output of the conv2d layer will be 4 vecctors each vector contains 64 slots(because we have 64 windows 1 value for each window)
class ConvMnist(torch.nn.Module):
    def __init__(self, hidden=64, output=10):
        super(ConvMnist, self).__init__()        
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3)
        self.fc1 = torch.nn.Linear(256, hidden)
        self.fc2 = torch.nn.Linear(hidden, output)

    def forward(self, x):
        x = self.conv1(x)
        # the model uses the square activation function
        x = x * x
        # flattening while keeping the batch axis
        x = x.view(-1, 256)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x


def train(model, train_loader, criterion, optimizer, n_epochs=10):
    # model in training mode
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
    
    # model in evaluation mode
    model.eval()
    return model




model = ConvMnist()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = train(model, train_dl, criterion, optimizer, 10)

Epoch: 1 	Training Loss: 0.440560
Epoch: 2 	Training Loss: 0.154996
Epoch: 3 	Training Loss: 0.107849
Epoch: 4 	Training Loss: 0.083934
Epoch: 5 	Training Loss: 0.068531
Epoch: 6 	Training Loss: 0.058931
Epoch: 7 	Training Loss: 0.051740
Epoch: 8 	Training Loss: 0.046910
Epoch: 9 	Training Loss: 0.041639
Epoch: 10 	Training Loss: 0.037754


In [86]:
def test(model,test_dl, criterion): 
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total =list(0. for i in range(10)) 
    for data,target in test_dl: 
        output = model(data)
        loss = criterion(output,target)
        test_loss+=loss.item()
        
        # transform output probas to predicted class using torch.max() fonction which returns 2 results (when dim=1) : 
            # first an array with the max value of each row (the max prob in every sample class)
            # second an array that contains the indexes of the max proba in each row
        _,preds = torch.max(output,dim=1)
        # preds example = [3,5,0,1,4,5,6...]
        #compare the predictions to the true labels 
        correct = np.squeeze(preds.eq(target.data.view_as(preds)))
        # calculate the correct labels for each object 
        for i in range(len(target)):
            # in this loop we are going to count the number of correct prediction for avery class  
            label = target.data[i]
            # adding +1 to the label if the prediction is correct else adding 0 in the list defined first 
            # we add 1 to the class_correct[label] if the predictions is true(check it in the correct array) else add 0
            class_correct[label] += correct[i].item()
            # increment the class_total[lable] (of each label) by 1
            class_total[label] +=1
            
        # calculate the avg loss test 
    test_loss /= len(target)
    print(f"Test loss : {test_loss}")
        
    print(f"Class Correct : {class_correct}")
    print(f"Class total : {class_total}")
    for label in range(10):
        print(f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})')

    print(f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )
    
            
        
test(model,test_dl,criterion)

Test loss : 0.08376144944224506
Class Correct : [966.0, 1124.0, 1010.0, 988.0, 963.0, 851.0, 951.0, 999.0, 957.0, 989.0]
Class total : [980.0, 1135.0, 1032.0, 1010.0, 982.0, 892.0, 958.0, 1028.0, 974.0, 1009.0]
Test Accuracy of 0: 98% (966/980)
Test Accuracy of 1: 99% (1124/1135)
Test Accuracy of 2: 97% (1010/1032)
Test Accuracy of 3: 97% (988/1010)
Test Accuracy of 4: 98% (963/982)
Test Accuracy of 5: 95% (851/892)
Test Accuracy of 6: 99% (951/958)
Test Accuracy of 7: 97% (999/1028)
Test Accuracy of 8: 98% (957/974)
Test Accuracy of 9: 98% (989/1009)

Test Accuracy (Overall): 97% (9798/10000)


In [146]:
model.fc1.weight.data[1]

tensor([ 0.0807, -0.0213,  0.0576,  0.1030,  0.1820,  0.0284,  0.0028,  0.0861,
         0.0578,  0.0746,  0.0697,  0.1778,  0.1499,  0.0184,  0.0172, -0.0491,
         0.0099, -0.0215,  0.1208,  0.2982,  0.1099, -0.1379, -0.0627, -0.0542,
        -0.0041,  0.1026,  0.1210, -0.0228, -0.1597, -0.0300,  0.0080, -0.0396,
        -0.2043, -0.0416, -0.1052, -0.0073,  0.0518,  0.0599, -0.0298,  0.0446,
         0.0085, -0.0113, -0.0093,  0.1461,  0.1221, -0.0651, -0.1133, -0.0393,
         0.0285, -0.0013, -0.0018,  0.0656,  0.0571,  0.0017, -0.0691,  0.0640,
         0.0309,  0.0613,  0.0098,  0.0322,  0.0356,  0.0670, -0.0452,  0.0909,
         0.1450,  0.1128,  0.1454,  0.0751, -0.0691,  0.0815,  0.1116, -0.0008,
         0.0635,  0.1084,  0.0704, -0.1360, -0.1425,  0.0592,  0.1031,  0.1103,
         0.0977, -0.0357, -0.0940, -0.1817, -0.1355,  0.0654,  0.1192,  0.0740,
         0.0528, -0.0117,  0.0386, -0.1803, -0.0601,  0.0439,  0.0805,  0.0361,
         0.0210, -0.0138, -0.1610, -0.15

## Encrypted eval

In [163]:
import tenseal as ts

In [227]:
# now we will evaluate encrypted data on HE Conv NN
class HEConvMNIST: 
    def __init__(self,model): 
        #extracting the weight and bias of the convolution layer
        # we are using the .view() method for the reshaping the weight into 4 groups of 7x7 matrix and for not hardcoding it 
        self.conv_weight = model.conv1.weight.data.view(
        model.conv1.out_channels , model.conv1.kernel_size[0], model.conv1.kernel_size[1]
        )
        self.conv_bias = model.conv1.bias.data.tolist()
        #extracting the first linear layer 
        self.lin1_weight = model.fc1.weight.T.data.tolist()
        self.lin1_bias = model.fc1.bias.data.tolist()
        
        #extracting the second linear layer
        self.lin2_weight = model.fc2.weight.T.data.tolist()
        self.lin2_bias =model.fc2.bias.data.tolist()
        
    def forward(self,enc_x,windows_nb): 
        #windows_nb is the number of the windows in the convolution and we are using it when applying the convolution in the .conv2d_im2col() method
        enc_channels = [] #this list is for saving the result of the conv of each channel 
        for kernel , bias in zip(self.conv_weight,self.conv_bias): 
            #applying the convolution for every kernel
            # we can apply the .conv2d_im2col() method on the enc_x input because we are doing pre-preocessing to the data before giving them to the model
            y = enc_x.conv2d_im2col(kernel,windows_nb) + bias 
            enc_channels.append(y)
            
        # pack all the channels into a single flattened vector 
        enc_x = ts.CKKSVector.pack_vectors(enc_channels)
        
        #applying the square fonction
        enc_x.square_()
        
        #first linear layer 
        enc_x = enc_x.mm(self.lin1_weight) +self.lin1_bias

        # square fonction 2nd time
        enc_x.square_()
        
        #second linear layer 
        enc_x = enc_x.mm(self.lin2_weight) + self.lin2_bias
        
        return enc_x
    def __call__(self,*args,**keys):
        return self.forward(*args,**keys)
    

In [235]:
from time import time

In [236]:
def enc_test(context,he_model,test_dl,criterion, kernel_shape,stride): 
    test_loss = 0.
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    # the class_correct is for counting how much correct prediction for each class (0,1,2,3,...,9)
    # the class_total is for counting how much value we have for each class 
    start_time = time()

    for data , target in test_dl:
        # first we have to pre-process the data before giving it to the he_model 
        # for that we need to encrypt the input matrix and extract the windows_nb using the .im2col_encoding()  method
        # .im2col_encoding() takes 5 argument 
            # context for encrypting the data 
            # the tensor to encrypt
            # the kernel_n_rows  
            # the kernel_n_cols 
            # and the stride 
            # the last 3 args are for doing the convolution in the he model
        
        enc_x , window_nb = ts.im2col_encoding(context,data.view(28,28).tolist(),kernel_shape[0],kernel_shape[1],stride)
        #.view(28,28) is for reshaping the input to a matrix of 28 x 28
        enc_output = he_model(enc_x,window_nb)
        
        output = enc_output.decrypt()
        output = torch.tensor(output).view(1,-1)
        
        #Compute the loss 
        loss = criterion(output,target)
        
        #computing the correct labels 
        test_loss += loss 
        
        #converting the output into predictions 
        _,preds = torch.max(output,dim=1)
        # the preds array will contains the indexes of the values that has max proba in each predction
        correct = np.squeeze(preds.eq(target.data.view_as(preds)))
        
        #calculating the accuracy for each class 
        label = target[0]
        class_correct[label] = correct.item()
        class_total[label] +=1

    # calculate and print avg test loss
    test_loss = test_loss / sum(class_total)
    print(f'Test Loss: {test_loss:.6f}\n')
    end_time = time()
    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )
    print(f"The test operation takes : {round(end_time - start_time,2)} Seconds")

    
            
            

In [190]:
kernel_shape = model.conv1.kernel_size
stride = model.conv1.stride[0]

### Choosing the parameters

Starting with a scale of more than 20 bits, we need to choose the number of bits of all the middle primes equal to that, so we are already over 120 bits. With this lower bound of coefficient modulus and a security level of 128-bits, we will need a polynomial modulus degree of at least 8192. The upper bound for choosing a higher degree is at 218. Trying different values for the precision and adjusting the coefficient modulus, while studying the loss and accuracy, we end up with 26-bits of scale and primes. We also have 5 bits (31 - 26) for the integer part in the last coefficient modulus, which should be enough for our use case, since output values aren't that big.

In [187]:
#Encryption params 

#controle the precision of the fractionnal part 
bit_scale = 26

coeff_mod_bit_sizes = [31,bit_scale,bit_scale,bit_scale,bit_scale,bit_scale,bit_scale,31]
poly_mod_degree= 8192

context = ts.context(ts.SCHEME_TYPE.CKKS,poly_mod_degree,coeff_mod_bit_sizes=coeff_mod_bit_sizes)

context.global_scale = 2**bit_scale

context.generate_galois_keys()


executing the test over 1000 samples of the MNIST dataset will take time, but we can way that it deserved because we complete an end-to-end Encrypted Inference 

In [222]:
16384 // 4096

4

In [None]:
he_model = HEConvMNIST(model)
enc_test(context,he_model,test_dl,criterion,kernel_shape,stride)

In [198]:
for x,y in test_dl : 
    print(x.shape)
    break

torch.Size([1, 1, 28, 28])


In a real-world use case, this would also require sending the encrypted input from the client to the server, and the encrypted result from the server to the client, so the size of these objects really matters. The encrypted input takes about 476KB, while the encrypted result is only about 70KB.