# Cinnamon: A Framework For Scale Out Encrypted AI
## Notebook 2: Encrypted AI

In this tutorial, we will run an encrypted logistic regression inference.

Logistic Regression Model Credits: [TenSEAL](https://github.com/OpenMined/TenSEAL/blob/main/tutorials/Tutorial%201%20-%20Training%20and%20Evaluation%20of%20Logistic%20Regression%20on%20Encrypted%20Data.ipynb)

Author:
- Siddharth Jayashankar (sidjay@cmu.edu)

### Exercise 3 
In this exercise, we will write an encrypted logistic regression model in the Cinnamon DSL. The logistic regression model predicts the 10 year risk of a patient developing Coronary Heart Disease (CHD). The model we will use has been pre-trained on the [Framingham](https://www.kaggle.com/code/helddata/logistic-regression-by-framingham-heart-study?scriptVersionId=86142061) dataset. 

#### The Logistic Regression Model
The logistic regression model consists of a linear layer followed by a sigmoid function.

The sigmoid function is a not expressible as a composition of multiplications, rotations and additions. Thus, we need to find a suitable polynomial approximation for it.

We use the approximation `sigmoid_poly(x) = 0.5 + 0.197 * x - 0.004 * x**3` from https://eprint.iacr.org/2018/462.pdf to approximate the sigmoid function. This approximiation is good in the range [-5,5] and most of the values in the logistic regression inference lie in this range.


In [1]:
import torch
import random

class LR(torch.nn.Module):

    def __init__(self, use_sigmoid_poly=True):
        super(LR, self).__init__()
        self.lr = torch.nn.Linear(in_features=8, out_features=1, bias=True)
        self.use_sigmoid_poly=use_sigmoid_poly

    @staticmethod
    def sigmoid_poly(x):
        # We use the polynomial approximation of degree 3
        # sigmoid(x) = 0.5 + 0.197 * x - 0.004 * x^3
        # from https://eprint.iacr.org/2018/462.pdf
        # which fits the function pretty well in the range [-5,5]
        return 0.5 + 0.197 * x - 0.004 * (x**3)
        
    def forward(self, x):
        linear = self.lr(x)
        if self.use_sigmoid_poly:
            out = LR.sigmoid_poly(linear)
        else:
            out = torch.sigmoid(linear)
        return out


Let's load the test data for the model. This dataset contains health related data of several patients.

In [None]:
file_path = "data.pth"
data = torch.load(file_path)
x_test,y_test = data["x_test"], data["y_test"]

First, let's ensure that our polynomial approximation of sigmoid actually works. We can do this by comparing the inference accuracies we obtain using the original sigmoid function and our polynomial approximation of the sigmoid function. If the difference between the two is acceptable, we can be confident that our choice of polynomial approximation is a good one.

In [None]:
def accuracy(model, x, y):
    model = model.eval()
    out = model(x)
    correct = torch.abs(y - out) < 0.5
    return correct.float().mean()

# Load the model
model = LR(use_sigmoid_poly=False)
model.load_state_dict(torch.load("model.pth"))
model_poly = LR(use_sigmoid_poly=True)
model_poly.load_state_dict(torch.load("model.pth"))

model_accuracy = accuracy(model, x_test, y_test)
model_poly_accuracy = accuracy(model, x_test, y_test)
print(f"Model(sigmoid)      accuracy on plain test_set: {model_accuracy}")
print(f"Model(sigmoid poly) accuracy on plain test_set: {model_poly_accuracy}")

We see that the difference in accuracies is negligible. This gives us confidence that our choice of approximation for the sigmoid function works.

### Exercise 3.1: Coming up with a data layout
FHE operations: addition and multiplication are pointwise operations on groups of `SLOTS` values. However, what our logistic regression implementation requires is a matrix vector product between the weights and the samples. Thus, we need a way to layout these matrices in the slots. This is also called packing values.

Looking at the implementation of the linear layer, we observe that each prediction `Pred[i]` is computed as a dot product between the sample and the weight matrix and then a bias is added. Finally, the sigmoid of the value is computed. Thus,

```
Pred[i] = 0
for j in range(8):
    Pred[i] += Samples[i][j] * Weights[j] 
Pred[i] += Bias
Pred[i] = sigmoid(Pred[i])

```

A simple way to layout this computation is by first laying the sample matrix in the row major order. This will occupy (n_samples x n_features) = (334 x 8) = 2672 slots. The remaining slots can be set to zero. To align the weight matrix with each slot, we pack the weight matrix in a row major order and repeat it for each of the samples.

Thus, \
`samples_packed[8*i + j] = Samples[i][j]` \
`weights_packed[8*i + j] = Weights[j]`

The multiplication operation between the samples and the weights can be implemented as a pointwise multiplication. i.e. \
`product[8*i + j] = samples_packed[8*i + j] * weights_packed[8*i + j] = Samples[i][j] * Weights[j]`

However, we still have to work out the dot product. This is where we make use of rotations. The figure shows how roatation and summation can be used to compute the sum of values in a ciphertext. 

![image](images/RotateAndSum.jpg)

In general, there can be several possible layouts for a program. The layout selected will influence the structure of the FHE program. Layouts can differ in several ways like the number of ciphertexts required and the number of operations needed to perform the compuatation. Exploring these tradeoffs and automatically finding an efficient layout is an exciting research direction.


The sigmoid function is applied pointwise to the output of the linear layer so that doesn't require much consideration in terms of layout. However, implementing a polynomial in FHE can be quite tedious and time consuming. 


Now that we have the layout representation decided, let's write the program in the Cinnamon DSL. I have already provided an implementation of the sigmoid function as efficiently and correctly implementing a polynomial can be tedious and time consuming. Note that since Cinnamon is an embedded DSL, we can easily make use of python features like functions.

In [4]:
# Pack x_test in row-major order
SLOTS = 32*1024
import numpy as np
samples = x_test.numpy()
weights = model.lr.weight.detach().numpy()[0]
bias = model.lr.bias.detach().numpy()

samples_packed = np.zeros(SLOTS, dtype=np.float32)
weights_packed = np.zeros(SLOTS, dtype=np.float32)
bias_packed = np.zeros(SLOTS, dtype=np.float32)


for i in range(x_test.shape[0]):
    for j in range(8):
        samples_packed[i * 8 + j] = samples[i,j]

# Pack model weights in row-major order
for i in range(x_test.shape[0]):
    for j in range(8):
        weights_packed[i * 8 + j] = weights[j]

# Pack model biases in row-major order
for i in range(x_test.shape[0]):
    bias_packed[i * 8:(i + 1) * 8] = bias[0]

In [None]:
from cinnamon.dsl import *
from cinnamon.compiler import cinnamon_compile
TOP_LEVEL = 51
NUM_CHIPS=1
RNS_BIT_SIZE =28



lr_program = CinnamonProgram('LogisticRegression',RNS_BIT_SIZE,NUM_CHIPS)
with lr_program:

    SCALE = 56
    def sigmoid(x):
        # 0.5 + 0.197 * x - 0.004 * (x**3)
        x2 = (x * x).relinearize().rescale().rescale()
        xc3 = x * PlaintextInput("c3",3*SCALE - x2.scale() -x.scale() ,x.level(),scalar=True)
        x3 = x2 * xc3.rescale().rescale()
        x3 = x3.relinearize()
        x1 = x * PlaintextInput("c1",2*SCALE - x.scale(),x.level(),scalar=True)
        s = x3 + x1.modswitch().modswitch()
        s = s.rescale().rescale()
        c0 = PlaintextInput("c0",s.scale(),s.level(),scalar=True)
        s = s + c0
        return s

    def dot_product(A,B):
        ## TODO: Implement Dot Product Using Rotate and Sum
        return

    level = TOP_LEVEL
    ## TODO: Create Inputs for X and W
    X = CiphertextInput('x')
    w = PlaintextInput('w')
    # Compute Dot Product
    dp = dot_product(X,w)
    dp = dp.rescale().rescale()
    ## TODO: Create Input for Bias
    b = PlaintextInput('b')
    dp = dp + b
    pred = sigmoid(dp)
    Output('pred', pred)

lr_program_dir = "lr_program_outputs"
!mkdir -p "{lr_program_dir}"
cinnamon_compile(lr_program,TOP_LEVEL,1,1024,f"{lr_program_dir}/")

In [6]:
RNS_PRIMES = [204865537, 205651969, 206307329, 207880193, 209059841, 210370561, 211025921, 211812353, 214171649, 215482369, 215744513, 216137729, 216924161, 217317377, 218628097, 219676673, 220594177, 221249537, 222035969, 222167041, 222953473, 223215617, 224002049, 224133121, 225574913, 228065281, 228458497, 228720641, 230424577, 230686721, 230817793, 231473153, 232390657, 232652801, 234356737, 235798529, 236584961, 236716033, 239337473, 239861761, 240648193, 241827841, 244842497, 244973569, 245235713, 245760001, 246415361, 249561089, 253100033, 253493249, 254279681, 256376833, 256770049, 257949697, 258605057, 260571137, 260702209, 261488641, 261881857, 263323649, 263454721, 264634369, 265420801, 268042241]

class ValueMeta:
    def __init__(self, scale, level):
        self.scale = scale
        self.level = level
        if self.level <= 0:
            raise ValueError("Level must be positive")
        
    def rescale(self):
        return ValueMeta(self.scale/RNS_PRIMES[self.level -1],self.level-1)

    def __mul__(self,other):
        if not isinstance(other,ValueMeta):
            raise ValueError("Must multiply by ValueMeta")
        if self.level != other.level:
            raise ValueError("Levels must be the same for Multiplication")
        return ValueMeta(self.scale*other.scale,self.level)

    def __add__(self,other):
        if not isinstance(other,ValueMeta):
            raise ValueError("Must multiply by ValueMeta")
        if self.level != other.level:
            raise ValueError("Levels must be the same for Addition")
        return ValueMeta(self.scale,self.level)

    def __lsfhift__(self,_):
        return ValueMeta(self.scale,self.level)

    def __rsfhift__(self,_):
        return ValueMeta(self.scale,self.level)

    def modswitch(self):
        return ValueMeta(self.scale,self.level-1)

def sigmoid_inputs(xM,sigmod_out_scale):
    Inputs = {}
    OutScale = {}
    coeffs = [0.5,0.197,0,-0.004]
    x2M = (xM * xM).rescale().rescale()
    c3scale = sigmod_out_scale
    for i in range(4):
        c3scale *= RNS_PRIMES[xM.level-1-i]
    c3scale = c3scale/(xM.scale*x2M.scale)
    c3M = ValueMeta(c3scale,xM.level)
    Inputs["c3"] = (coeffs[3],c3M.scale)
    c1scale = sigmod_out_scale
    for i in range(2,4):
        c1scale *= RNS_PRIMES[xM.level-1-i]
    c1scale = c1scale/xM.scale
    c1M = ValueMeta(c1scale,xM.level)
    Inputs["c1"] = (coeffs[1],c1M.scale)
    x3M = xM * c3M
    x3M = x3M.rescale().rescale() * x2M
    x1M = xM * c1M
    x1M = x1M.modswitch().modswitch()
    sM = x3M + x1M  
    sM = sM.rescale().rescale()
    c0M = ValueMeta(sM.scale,sM.level)
    Inputs["c0"] = (coeffs[0],c0M.scale)
    return (Inputs,OutScale,sM)

In [7]:
lr_inputs = {}
lr_output_scales = {}

level = TOP_LEVEL
scale = 1 << 56
lr_inputs["x"] = (samples_packed,scale)
lr_inputs["w"] = (weights_packed,scale)
wM = ValueMeta(scale,level)
yM = ValueMeta(scale,level)
lM = (wM * yM).rescale().rescale()
lr_inputs["b"] = (bias_packed,lM.scale)
lr_output_scales["l"] = lM.scale
(sigmoidInputs,sigmoidOutScale,sM) = sigmoid_inputs(lM,scale)
lr_inputs.update(sigmoidInputs)
lr_output_scales.update(sigmoidOutScale)
lr_output_scales["pred"] = sM.scale

In this case, we see that our choice of polynomial approximation for the sigmoid function results in no drop in accuracy. This is great news for us.

The Cinnamon Compiler is a python embedded DSL. Let's import the Cinnamon modules

Now let's write a simple program to add two numbers. We first create a `CinnamonProgram` object. The first argument to the constructor is the name of the program. If you don't understand the other arguments, don't worry, these will be made clear later on. 

Now, we need to create two ciphertexts in the program and add them together. But before that, let's introduce two concepts of CKKS-FHE ciphertexts: scale and level.

Cinnamon uses the information about the scale and levels provided here to type check the program. If you add values with incompatible scales and levels, the compiler will produce an error.

### Exercise 1.3: Emulating The Program
Now, that we have compiled the program, we want to make sure that our program and the compiler output is actually what we want it to be. To test this, we use the Cinnamon emulator. The Cinnamon emulator reads in Cinnamon assembly and emulates the instructions on a CPU. But before, we do that, we need to actually assign values and scale information to our program inputs and outputs.

First, let's import the cinnamon emulator

In [8]:
import cinnamon_emulator

Let's create a secret key to encrypt and decrypt our values

In [9]:
import random
random.seed(10)
def generate_secret_key(Slots,HammingWeight=32):
    secretKey = [0]*(2*Slots)
    count = 0
    while count < HammingWeight:
        pos = random.randint(0,2*Slots-1)
        if secretKey[pos] != 0:
            continue
        val = random.randint(0,1)
        if val == 0:
            secretKey[pos] = -1
        elif val == 1:
            secretKey[pos] = 1
        else:
            raise Exception("")
        count += 1
    return secretKey

secretKey = generate_secret_key(SLOTS,HammingWeight=SLOTS)

Now let's create the ciphertexts and emulate the encrypted program

In [None]:
context = cinnamon_emulator.Context(SLOTS,RNS_PRIMES)
encryptor = cinnamon_emulator.CKKSEncryptor(context,secretKey)
emulator = cinnamon_emulator.Emulator(context)

base_dir = lr_program_dir
emulator.generate_and_serialize_evalkeys(f"{base_dir}/evalkeys",f"{base_dir}/program_inputs",encryptor)
emulator.generate_inputs(f"{base_dir}/program_inputs",f"{base_dir}/evalkeys",lr_inputs,encryptor)
emulator.run_program(f"{base_dir}/instructions",NUM_CHIPS,1024)
emulator_outputs = emulator.get_decrypted_outputs(encryptor,lr_output_scales)

Now let's collect the outputs of the encrypted program

In [None]:
encrypted_predictions = emulator_outputs["pred"][0::8][:y_test.shape[0]]
y_test_np = y_test.flatten().numpy()
encrypted_accuracy = (np.abs(encrypted_predictions-y_test_np) < 0.5).mean()
print(f"Model(sigmoid)      accuracy on plain test_set:     {model_accuracy}")
print(f"Model(sigmoid poly) accuracy on plain test_set:     {model_poly_accuracy}")
print(f"Model(sigmoid poly) accuracy on encrypted test_set: {encrypted_accuracy}")

The encrypted model inference provides the same accuracy as the plaintext model inference! And the server learnt nothing about the data of our patients. This example illustrates the power of encrytped AI: you can avail the services of an AI model without compromising your sensitive data.

Congratulations! You have just run your very first encrypted AI inference using Cinnamon. This marks the end of notebook 2. In [notebook3](../notebook3/notebook3.ipynb), we will see how the Cinnamon compiler can parallelize code and use the Cinnamon architectural simulator to see Cinnamon's scale out features in action.