# Homomorphic Encrypted LeNet-1
This notebook will show a very practical example of creating a HE-compliant model, which will be used in the notebook `HE-ML` and `HE-ML_CKKS`.

## LeNet-1
The LeNet-1 is a small CNN developed by LeCun et al. It is composed of 5 layers: a convolutional layer with 4 kernels of size 5x5 and tanh activation, an average pooling layer with kernel of size 2, another convolutional layer with 16 kernels of size 5x5 and tanh activation, another average pooling layer with kernel of size 2, and a fully connected layers with size 192x10. 

The highest value in the output tensor corresponds to the label LeNet-1 associated to the input image. 

For this tutorial we will use the MNIST dataset.

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
transform = transforms.ToTensor()

train_set = torchvision.datasets.MNIST(
    root = './data',
    train=True,
    download=True,
    transform=transform
)

test_set = torchvision.datasets.MNIST(
    root = './data',
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=50,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=50,
    shuffle=True
)

In [4]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

def train_net(network, epochs, device):
    optimizer = optim.Adam(network.parameters(), lr=0.001)
    for epoch in range(epochs):

        total_loss = 0
        total_correct = 0

        for batch in train_loader: # Get Batch
            images, labels = batch 
            images, labels = images.to(device), labels.to(device)

            preds = network(images) # Pass Batch
            loss = F.cross_entropy(preds, labels) # Calculate Loss

            optimizer.zero_grad()
            loss.backward() # Calculate Gradients
            optimizer.step() # Update Weights

            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)

        
def test_net(network, device):
    network.eval()
    total_loss = 0
    total_correct = 0
    
    with torch.no_grad():
        for batch in test_loader: # Get Batch
            images, labels = batch 
            images, labels = images.to(device), labels.to(device)

            preds = network(images) # Pass Batch
            loss = F.cross_entropy(preds, labels) # Calculate Loss

            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)

        accuracy = round(100. * (total_correct / len(test_loader.dataset)), 4)

    return total_correct / len(test_loader.dataset)

In [5]:
train = True # If set to false, it will load models previously trained and saved.

In [6]:
experiments = 1

In [7]:
if train:
    accuracies = []
    for i in range(0, experiments):
        LeNet1 = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=5),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),

            nn.Conv2d(4, 12, kernel_size=5),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),

            nn.Flatten(),

            nn.Linear(192, 10),
        )
        
        LeNet1.to(device)
        train_net(LeNet1, 15, device)
        acc = test_net(LeNet1, device)
        accuracies.append(acc)
        
    torch.save(LeNet1, "LeNet1.pt")
else:
    LeNet1 = torch.load("LeNet1.pt")
    LeNet1.eval()
    LeNet1.to(device)

In [8]:
m = np.array(accuracies)
print(f"Mean accuracy on test set: {np.mean(m)}")
print(f"Var: {np.var(m)}")

Mean accuracy on test set: 0.9876
Var: 0.0


## Approximating
As we know, there are some operations that cannot be performed homomorphically on encrypted values. Most notably, these operations are division and comparison. It is possible to perform only linear functions.

Consequently, in the LeNet-1 scheme we used, we can not use `tanh()`. This is because we cannot apply its non-linearities.


One of the most common approach is to replace it with a simple polynomial function, for example a square layer (which simply performs $x \rightarrow x^2$).

We define the model with all the non-linearities removed **approximated**. This model has to be re-trained, and it will be ready to be used on encrypted values.

In [9]:
class Square(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, t):
        return torch.pow(t, 2)

LeNet1_Approx = nn.Sequential(
    nn.Conv2d(1, 4, kernel_size=5),
    Square(),
    nn.AvgPool2d(kernel_size=2),
            
    nn.Conv2d(4, 12, kernel_size=5),
    Square(),
    nn.AvgPool2d(kernel_size=2),
    
    nn.Flatten(),
    
    nn.Linear(192, 10),
)

In [10]:
if train:
    approx_accuracies = []
    for i in range(0, experiments):
        LeNet1_Approx = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=5),
            Square(),
            nn.AvgPool2d(kernel_size=2),

            nn.Conv2d(4, 12, kernel_size=5),
            Square(),
            nn.AvgPool2d(kernel_size=2),

            nn.Flatten(),

            nn.Linear(192, 10),
        )
        
        LeNet1_Approx.to(device)
        train_net(LeNet1_Approx, 15, device)
        acc = test_net(LeNet1_Approx, device)
        approx_accuracies.append(acc)
        
    torch.save(LeNet1, "LeNet1_Approx.pt")

else:
    LeNet1_Approx = torch.load("LeNet1_Approx.pt")
    LeNet1_Approx.eval()
    LeNet1_Approx.to(device)

In [11]:
m = np.array(approx_accuracies)
print(f"Mean: {np.mean(m)}")
print(f"Var: {np.var(m)}")

Mean: 0.9902
Var: 0.0


We can see that replacing `tanh()` with `square()` did not impact the accuracy of the model dramatically. Usually this is not the case, and approximating DL models may worsen the performance badly. This is one of the challenges that HE-ML will have to consider: the creation of DL models keeping in mind the HE constraints from the beginning.

In any case, now the network is HE-compatible.

Nonetheless, having two `square` activation can be quite heavy on your machine. We can also design and save a CNN with only a `square` activation function.

In [13]:
LeNet1_Approx_singlesquare = nn.Sequential(
    nn.Conv2d(1, 4, kernel_size=5),
    Square(),
    nn.AvgPool2d(kernel_size=2),

    nn.Conv2d(4, 12, kernel_size=5),
#     Square(),
    nn.AvgPool2d(kernel_size=2),

    nn.Flatten(),

    nn.Linear(192, 10),
)

LeNet1_Approx_singlesquare.to(device)
train_net(LeNet1_Approx_singlesquare, 15, device)
acc = test_net(LeNet1_Approx_singlesquare, device)
print(f"Accuracy on test set (single square layer): {acc}")
torch.save(LeNet1_Approx_singlesquare, "LeNet1_Approx_single_square.pt")

Accuracy on test set (single square layer): 0.9824
