# Cinnamon: A Framework For Scale Out Encrypted AI
## Notebook 3: Encrypted MNIST Inference, Parallelization Features of the Cinnamon Framework and the Cinnamon Accelerator Simulator

In this tutorial, we will run an encrypted MNIST inference and use the Cinnamon compiler to parallelize the program. We will then use the Cinnamon accelerator simulator to evaulate Cinnamon's parallelization strategies.

MNIST Model Credits: (Github: youben11)[https://github.com/youben11/encrypted-evaluation]

Author:
- Siddharth Jayashankar (sidjay@cmu.edu)

### Exercise 4 
In this exercise, we will write an encrypted CNN inference model for the MNIST dataset in the Cinnamon DSL. The MNIST dataset contains 28x28 images of handwritten digits 0 thorugh 9. The model classifies these digits into the corresponding category. 

#### The MNIST CNN Model
The input to the model is a 28x28 image from the MNIST dataset.

Our model architecure is as follows:
- A 2D convolution layer with 4 output channels, a kernel size of (7x7) and a stride of (3x3)
- A square activation function
- A 256x64 fully connected layer
- A square activation function
- A 64x10 fully connected layer

The predicted digit is the argmax of the final layer.

Notice the use a square activation function. This is because, the CKKS scheme can only express linear and polynomial functions.

In [1]:
import torch
class MNIST_CNN(torch.nn.Module):
    """CNN for classifying MNIST data.
    Input should be an encoded 28x28 matrix representing the image.
    The input should also be normalized with a mean=0.1307 and an std=0.3081.
    """
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        self.conv2d = torch.nn.Conv2d(in_channels=1,out_channels=4,kernel_size=(7,7),stride=(3,3),bias=True)
        self.fc1 = torch.nn.Linear(in_features=256,out_features=64,bias=True)
        self.fc2 = torch.nn.Linear(in_features=64,out_features=10,bias=True)

    def forward(self, x):
        conv = self.conv2d(x)
        conv_sq = conv * conv
        conv_sq = conv_sq.reshape(1,-1)
        o2 = self.fc1(conv_sq)
        o2_sq = o2 * o2
        o3 = self.fc2(o2_sq)
        return o3

Let's load the pretrained model and check it's accuracy.

In [None]:
from PIL import Image
from torchvision import transforms
# Load the image specified by sample_num and normalize it
def load_input(sample_num):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    img = Image.open(f"samples/img_{sample_num}.jpg")
    img = transform(img).view(1,28, 28).to(torch.float32)
    return img

# Get the labels of the images
def get_labels():
    with open('samples/answers.txt', 'r') as file:
        labels = file.readlines()
    labels = [int(l.rstrip('\n')) for l in labels]
    return labels

# Compute the accuracy of the model over samples [1,num_samples] from the test set
def accuracy(model, num_samples):
    y = get_labels()
    model = model.eval()
    out = []
    for i in range(1,num_samples+1):
        input = load_input(i)
        o = model(input)
        o = torch.argmax(o)
        out.append(o)
    correct = torch.tensor([y[i] == out[i] for i in range(num_samples)],dtype=int)
    return correct.float().mean()

# Load the model
model = MNIST_CNN()
model.load_state_dict(torch.load("mnist.pth"))
print("Model loaded successfully")

# The test set has a total of 350 images. Let's check the accuracy of the model over all these images
NUM_SAMPLES = 350
model_accuracy = accuracy(model, 250)
print(f"Model accuracy on plain test_set: {model_accuracy}")

We see that our model achieves a good accuracy of 98%.

Now let's implement an encrypted inference using the MNSIT model shown above in the Cinnamon DSL. We assume the model inputs (images) and activation values are encrypted while the model weights are plaintext values.

In [3]:
from cinnamon.dsl import *

RNS_BIT_SIZE = 28
TOP_LEVEL = 20

## Helper function to create plaintext matrix inputs for the baby step giant step matrix multiplication 
def get_bsgs_plaintexts(name_base,babysteps,giantsteps,scale,level):
    ret = []
    for (g,gs) in enumerate(giantsteps):
        for (b,bs) in enumerate(babysteps):
            ret.append(PlaintextInput(f"{name_base}_{bs}_{gs}",scale,level))
    return ret

## Implement the Convolution operation. Refer to 
def conv_2d(image):
    ## Implementing convolution for one output channel
    def do_convolution(out_channel_id,image,result):
        babysteps = [i * 8 for i in range(16)]
        giantsteps = [i * 8192 for i in range(4)]
        plaintexts = get_bsgs_plaintexts(f"conv_weight_{out_channel_id}",babysteps,giantsteps,scale=56,level=image.level())
        product = bsgs(image,plaintexts,babysteps,giantsteps)
        product = product.rescale()
        bias = PlaintextInput(f"conv_bias_{out_channel_id}",product.scale(),product.level())
        result[out_channel_id] = product + bias
 
    output_channels = 4
    outputs = [None for _ in range(output_channels)]
    for o in range(output_channels):
        do_convolution(o,image,outputs)

    ## Stack all the output channels in a single ciphertext
    for o in range(1,output_channels):
        outputs[0] += outputs[o] >> (64*128*o)

    return outputs[0].rescale().rescale()
    

def square(x):
    return (x * x).relinearize()


# Matrix multiplication of a 256x64 matrix with a 256x1 vector
def matmul_256x64(v):
    babysteps = [i * 128 for i in range(8)]
    giantsteps = [i * 1024 for i in range(8)]
    plaintexts = get_bsgs_plaintexts(f"fc1_w",babysteps,giantsteps,scale=56,level=v.level())
    product = bsgs(v,plaintexts,babysteps,giantsteps)
    product = product.rescale()
    product += product << (1024*8)
    product = product.rescale()
    product += product >> (1024*16)
    bias = PlaintextInput(f"fc1_b",product.scale(),product.level())
    result = product + bias
    return result

# Matrix multiplication of a 64x10 matrix with a 64x1 vector
def matmul_64x10(v):
    babysteps = [i * 128 for i in range(4)]
    giantsteps = [i * 512 for i in range(4)]
    plaintexts = get_bsgs_plaintexts(f"fc2_w",babysteps,giantsteps,scale=56,level=v.level())
    product = bsgs(v,plaintexts,babysteps,giantsteps)
    product = product.rescale()
    product += product << (1024*2)
    product = product.rescale()
    product += product << (1024*4)
    bias = PlaintextInput(f"fc2_b",product.scale(),product.level())
    result = product + bias
    return result

## Encrypted MNIST inference model
def mnist(numChips=1):
    mnistProgram = CinnamonProgram('Mnist',RNS_BIT_SIZE,num_chips=numChips)
    with mnistProgram:
        scale = 28*3
        image = CiphertextInput('image',scale,TOP_LEVEL)
        conv = conv_2d(image)
        conv_sq = square(conv)
        o2 = matmul_256x64(conv_sq.rescale())
        o2_square = square(o2)
        o3 = matmul_64x10(o2_square.rescale().rescale())
        Output('pred',o3)

    return mnistProgram

While I've implemented most of the model, one important component has been left as an exercise. 

### Exercise 4.1 Implement Plaintext Matrix Ciphertext Multiplication using Baby Step Giant Step

The Baby Step Giant Step algorithm is a common algorithm to implement plaintext matrix times ciphertext matrix multiplication. The convolution and fully connected layers make use of this function. 

In [4]:
## Impelement the BABY step giantestep algorithm
def bsgs(input,M,babysteps,giantsteps):
    ## TODO: Fill in the rotate_babysteps
    rotate_babysteps = []
    for (g,gs) in enumerate(giantsteps):
        for (b,bs) in enumerate(babysteps):
            i = g * len(babysteps) + b
            if b == 0:
                # TODO: Fill in the multiplication 
            else:
                # TODO: Fill in the multiplication 
            else:
        if g == 0:
            # TODO: Fill in the giantsteps
        else:
            # TODO: Fill in the giantsteps
    return prod

Now Let's Compile The Program

In [None]:
# Import the compiler module
from cinnamon.compiler import *
from cinnamon.passes import *

# Set The Number of chips to compile for
numChips = 1
# Set the directory where the Cinnamon compiler outputs should be created
output_dir = "outputs/"
!mkdir -p {output_dir}
program = mnist(numChips)

# Compile the program
keyswitch_pass(program)
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir)

And create the inputs for our model. I've provided an internal method to take care of this in mnist_io.py

In [6]:
from mnist_io import *

# Returns the Program Inputs and Output Scales for the MNIST Program
def mnist_io(sample_num):
    input_image = load_input(sample_num)
    input_image = input_image.detach().numpy()[0]
    return get_mnist_program_io(input_image, TOP_LEVEL)

Let's test our implementation. We will run only 1 sample here due to time constraints.

In [8]:
import random
random.seed(10)
def generate_secret_key(Slots,HammingWeight=32):
    secretKey = [0]*(2*Slots)
    count = 0
    while count < HammingWeight:
        pos = random.randint(0,2*Slots-1)
        if secretKey[pos] != 0:
            continue
        val = random.randint(0,1)
        if val == 0:
            secretKey[pos] = -1
        elif val == 1:
            secretKey[pos] = 1
        else:
            raise Exception("")
        count += 1
    return secretKey

secretKey = generate_secret_key(SLOTS,HammingWeight=SLOTS)

In [None]:
import cinnamon_emulator

context = cinnamon_emulator.Context(SLOTS,Primes)

encryptor = cinnamon_emulator.CKKSEncryptor(context,secretKey)
emulator = cinnamon_emulator.Emulator(context)

emulator.generate_and_serialize_evalkeys(f"{output_dir}/evalkeys",f"{output_dir}/program_inputs",encryptor)

# This function runs a single encrypted inference over the image samples/image_{sample_id}.jpg 
def run_one_sample(sample_id):
    print(f"Running sample {sample_id}")
    Inputs, OutScale = mnist_io(sample_id)
    emulator.generate_inputs(f"{output_dir}/program_inputs",f"{output_dir}/evalkeys",Inputs,encryptor)
    emulator.run_program(f"{output_dir}/instructions",numChips,1024)
    outputs = emulator.get_decrypted_outputs(encryptor,OutScale)
    prediction = np.real(outputs["pred"][0::128][0:10])
    return np.argmax(prediction)

encrypted_predictions = []
encrypted_predictions.append(run_one_sample(1))

As you have just experienced, running FHE on a CPU is very slow. And running all the samples in the dataset might take a while. Uncomment the code block below if you have the time to run more samples.

In [10]:
# NUM_ENCRYPTED_SAMPLES = 20
# for i in range(2,NUM_ENCRYPTED_SAMPLES+1):
#     encrypted_predictions.append(run_one_sample(i))

Now let's check the accuracy of our model on the few encrypted test samples we ran and compare it to the accuracy of the plaintext model on the set of test samples.

In [None]:
num_samples = len(encrypted_predictions)
y = get_labels()[:num_samples]
correct = torch.tensor([int(y[i] == encrypted_predictions[i]) for i in range(num_samples)])
encrypted_accuracy = correct.float().mean()

plain_accuracy = accuracy(model,num_samples)
print(f"Encrypted Accuracy on {num_samples} samples: {encrypted_accuracy}")
print(f"Plain Accuracy on {num_samples} samples:     {plain_accuracy}")

We see that the accuracy of the encrypted model matches the accuracy of the plaintext model on the few samples we ran.

### The Cinnamon Simulation Infrastructure
Sadly, as you have just seen, FHE is very slow on CPUs. For FHE to be practical, hardware acceleration is essential. Cinnamon proposes a scale out hardware accelerator design for FHE. Let's see how we can simulate the program we just wrote on the Cinnamon accelerator.

#### Simulation overview
The Cinnamon simulator is a custom element built within the [SST](http://sst-simulator.org) simulation framework. This component is a cycle accurate simulator for the Cinnamon accelerator. The simulator evaluates the performance of programs compiled with the Cinnamon compiler on the Cinnamon accelerator. Let's simulate the encrypted MNIST inference we just compiled on the Cinnamon accelerator.

To run the simulation in SST, we need to first create a set up an sst setup file. Take a look at the file [cinnamon-setup.py](cinnamon-setup.py). This file contains the default parameters for the Cinnamon architecture. To perform architectural studies, we can edit the cinnamon-setup.py file.

Now, let's simulate the running of our encrypted MNIST inference programming on the Cinnamon accelerator.

In [None]:
# Create a directory to store the logs of the simulation
sst_log_dir = "sst_log_dir"
!mkdir -p {sst_log_dir}
# Run the simulation on SST. Point the instructions_dir argument to the director where the compiler output was generated.
!sst cinnamon-setup.py -- --instructions_dir="{output_dir}/" > {sst_log_dir}/simulation_1.log
# Print out the last 5 lines of the simulation log.
!tail -n5 {sst_log_dir}/simulation_1.log

That was quick wasn't it. Just took a few milli seconds! This shows the potential for hardware accelerators in making FHE practical. You can take a look at the [simulation log](sst_log_dir/simulation_1.log) to see the detailed output of the simulator. It contains detailed statistics on each component in the chip.

### Exercise 5.1 Limb Level Parallelism using the Cinnamon Compiler

CKKS ciphertexts in the RNS representation can be thought of as a matrix of modular integers. Each column is called a limb. In fact, the number of limbs in a ciphertext is its by it's level. The limbs of a ciphertext are largely data independent. There is good potential for parallelizing FHE computation by parallelizing the limbs. However, truly realizing this potential requires addressing the communication overheads introduced by cross limb dependencies. The figure below depicts limb level parallelism when the level is 4.

![image](images/LimbLevelParallelism.jpg)

Cinnamon developed novel algorithms and compiler techniques to realize the potential of limb level parallelism. The details are described in detail in the paper on the [Cinnamon framework](https://dl.acm.org/doi/pdf/10.1145/3669940.3707260). The Cinnamon compiler automatically implements limb level paralleism using Cinnamon's algorithms and techniques. To parallelize a program across Cinnamon chips, just change the `numChips` argument. That's it. It's that simple. The Cinnamon compiler will take care of the rest for you.


Let's compile the MNIST inference program for Cinnamon-4 by setting the `numChips` argument to 4. The Cinnamon compiler will generate a sequence of instructions for each of the 4 chips and insert the appropriate synchronization and communication instructions.

In [None]:
numChips = 4
output_dir = f"outputs_{numChips}ch/"
!mkdir -p {output_dir}
program = mnist(numChips)
keyswitch_pass(program)
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir)

We can use the Cinnamon emulator to run the program we just compiled. The Cinnamon emulator runs each chip's instructions as a multi threaded program with `numChips` threads.

In [None]:
import cinnamon_emulator
context = cinnamon_emulator.Context(SLOTS,Primes)

encryptor = cinnamon_emulator.CKKSEncryptor(context,secretKey)
emulator = cinnamon_emulator.Emulator(context)



emulator.generate_and_serialize_evalkeys(f"{output_dir}/evalkeys",f"{output_dir}/program_inputs",encryptor)
encrypted_predictions = []
encrypted_predictions.append(run_one_sample(1))
print(f"Encrypted Limb Level Parallel Predictions: {encrypted_predictions}")

You might have noticed that the emulated progran too ran faster. Now, let's use the Cinnamon simulator to see how limb level parallelism using 4 chips speeds up our program. We pass the `--chips` command line argument to the setup file to specify that this program is to be run on 4 chips.

In [None]:
!sst cinnamon-setup.py -- --instructions_dir="{output_dir}/" --chips={numChips} > {sst_log_dir}/simulation_{numChips}.log
!tail -n5 {sst_log_dir}/simulation_{numChips}.log

This was much faster than the 1 chip example. You can take a look at the [simulation logfile](sst_log_dir/simulation_4.log) to see a detailed report of the simulation. 

### Exercise 5.2 Program Level Parallelism using the Cinnamon Compiler.

The example above distributes the limbs of all ciphertexts modulo the 4 chips. However, there is another dimension along which our program can be parallelized - program level parallelism. Program level parallelism parallelizes computation at the ciphertext level. In our encrypted MNIST program, we can exploit this kind of parallelism in the convolution layer. Looking into the convolution operation, we see that we perform the operation 4 times, to get a four channel output. Now, instead of parallelizing each of the four iterations at the limb level, we could parallelize the four iterations of the convolution across the 4 chips. This kind of parallelism is called program parallelism.

In the Cinnamon DSL, program parallelism is implemented using `CinnamonStreams`. Each stream can be thought of as a concurrent thread that executes on the number of chips specified by the `StreamSize` argument. The number of streams to be created is specified by the `NumStreams` argument, with each stream receiving a streamID in 0 through `NumStreams-1`. Within a single stream, the values are parallelized at the limb level. The `streamFn` argument specifies the function that each stream implements. The first argument passed to the stream function is the streamId. The rest of the arguments are function specific.  

In this example, we set `StreamSize=1` and `NumStreams=4`. This creates four concurrent streams with each stream running on 1 chip and implementing the `do_convolution` function. The figure below illustrates  program paralleism and limb parallelism.

![image](images/LimbAndProgramParalleism.jpg)

By default, the program runs in a single stream with `StreamSize=numChips`. Thus, the rest of the program will remain limb level parallelized across the four chips, just as in the previous example.


In [15]:
## A Program Level Parallelism implementation of convolution
def conv_2d(image):
    def do_convolution(out_channel_id,image,result):
        babysteps = [i * 8 for i in range(16)]
        giantsteps = [i * 8192 for i in range(4)]
        plaintexts = get_bsgs_plaintexts(f"conv_weight_{out_channel_id}",babysteps,giantsteps,scale=56,level=image.level())
        product = bsgs(image,plaintexts,babysteps,giantsteps)
        product = product.rescale()
        bias = PlaintextInput(f"conv_bias_{out_channel_id}",product.scale(),product.level())
        result[out_channel_id] = product + bias
 
    output_channels = 4
    outputs = [None for _ in range(output_channels)]

    # for o in range(output_channels):
    #     do_convolution(o,image,outputs)

    print("Compiling conv2D using Cinnamon Streams")
    # TODO: Use program level parallelism to parallelize the four iterations of the convolution.
    CinnamonStream(StreamSize=1,NumStreams=4,StreamFn=do_convolution,image=image,result=outputs)


    for o in range(1,output_channels):
        outputs[0] += outputs[o] >> (64*128*o)

    return outputs[0].rescale().rescale()
    

Let's compile this program, emulate an inference and evaluate it's performance on the Cinnamon accelerator.

In [None]:
# To compile for a 4 chip Cinnamon accelerator, set numChips to 4 
numChips = 4
output_dir = f"outputs_{numChips}ch_programParallel/"
!mkdir -p {output_dir}
program = mnist(numChips)
keyswitch_pass(program)
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir)

In [None]:
emulator.generate_and_serialize_evalkeys(f"{output_dir}/evalkeys",f"{output_dir}/program_inputs",encryptor)
encrypted_predictions = []
encrypted_predictions.append(run_one_sample(1))
print(f"Encrypted Program Parallel Predictions: {encrypted_predictions}")

In [None]:
!sst cinnamon-setup.py -- --instructions_dir="outputs_{numChips}ch_programParallel/" --chips={numChips} > {sst_log_dir}/simulation_{numChips}_program_parallel.log
!tail -n5 {sst_log_dir}/simulation_{numChips}.log

### Cinnamon Keyswitch Pass and Cinnamon's New Parallel Keyswitching Algorithms

Cinnamon introduces two new parallel keyswitching algorithms: Input Broadcast Keyswitching and Output Aggregation Keyswitching. These two parallel keyswitching algorithms are designed with the objective of minimizing the inter chip communication. These algorithms work by looking for specific program patterns and reordering operations and trading off communication for compute. The `keyswitch_pass` in the cinnamon compiler implements looks for program patterns that the  . The two new parallel keyswitching algorithms and the `keyswitch` pass are intendend to work together to minimize communcation overheads and deliver performance gains. Let's run an experiment using the Cinnamon framework to evalutate these techniques.

In this experiment, we will compile three different versions of the mnist program:
- **unoptimized**: With keyswitch_pass and cinnamon's parallel keyswitching algorithms disabled.
- **pass**: With just keyswitch_pass enabled and cinnamon's parallel keyswitching algorithms disabled.
- **pass_cinnamon_ks**: With both keyswitch_pass and cinnamon's new parallel keyswitching algorithms enabled.

We will evaluate these 3 configurations using 4 Cinnamon chips. The default value of the inter chip link bandwidth in Cinnamon is 256GB/s. However in this experiment, we will compare how these three configurations perform across a range of interchip bandwidths from 128GB/s to 1TB/s.


First, let's reset the convolution back to the version without program parallelism.

In [19]:
# Reset Convolution to use Limb Level Parallelism
def conv_2d(image):
    def do_convolution(out_channel_id,image,result):
        babysteps = [i * 8 for i in range(16)]
        giantsteps = [i * 8192 for i in range(4)]
        plaintexts = get_bsgs_plaintexts(f"conv_weight_{out_channel_id}",babysteps,giantsteps,scale=56,level=image.level())
        product = bsgs(image,plaintexts,babysteps,giantsteps)
        product = product.rescale()
        bias = PlaintextInput(f"conv_bias_{out_channel_id}",product.scale(),product.level())
        result[out_channel_id] = product + bias
 
    output_channels = 4
    outputs = [None for _ in range(output_channels)]

    print("Compiling conv2D using Limb Level Parallelism")
    for o in range(output_channels):
        do_convolution(o,image,outputs)

    for o in range(1,output_channels):
        outputs[0] += outputs[o] >> (64*128*o)

    return outputs[0].rescale().rescale()
    

Now, let's compile the three program in the three configurations we listed above.

In [None]:
numChips = 4
program = mnist(numChips)

## unoptimized
output_dir = f"outputs_{numChips}ch_unoptimized/"
!mkdir -p {output_dir}
# Keyswitch Pass is not enabled
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir, use_cinnamon_keyswitching=False)

## pass
output_dir = f"outputs_{numChips}ch_pass/"
!mkdir -p {output_dir}
keyswitch_pass(program)
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir, use_cinnamon_keyswitching=False)

## pass_cinnamon_ks
output_dir = f"outputs_{numChips}ch_pass_cinnamon_ks/"
!mkdir -p {output_dir}
keyswitch_pass(program)
cinnamon_compile(program, TOP_LEVEL, numChips, 256, output_dir, use_cinnamon_keyswitching=True)

We use the `--linkBW` argument in [cinnamon-setup.py](cinnamon-setup.py) to specify the inter chip link bandwidth to use.

In [21]:
expt_dir = "bandwidth_expt_dir"
configs = ['unoptimized','pass','pass_cinnamon_ks']

bandwidths = ["0.125","0.25","0.5","1"] # in TB/s
for c in configs:
    !mkdir -p {expt_dir}/{c}
for bw in bandwidths:
    for c in configs:
        !sst cinnamon-setup.py -- --instructions_dir="outputs_{numChips}ch_{c}/" --chips={numChips} --linkBW={bw} > {expt_dir}/{c}/{bw}bw.log
print("Experiments Completed")

[rollup.py](rollup.py) implements utilities for reading the logfiles and plotting the relative performance of the three configs across the bandwidths.

In [None]:
import rollup
rollup.plot_speedups(expt_dir,configs,bandwidths)

The plots depict how Cinnamon's new keyswitching alogorithms and compiler passes can deliver speedup for scale out encrypted AI applications. This marks the end of the tutorial. Thank you taking part!.