# AIVM Tutorial: Encrypted Neural Network Inference

This tutorial demonstrates how to use the AIVM (AI Virtual Machine) client to perform encrypted inference on neural networks. We'll explore two use cases:
1. Digit recognition using LeNet5 on MNIST dataset
2. SMS spam detection using BERT Tiny

## Setup and Imports

In [None]:
# Cell 1: Import necessary libraries
import time
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import aivm_client as aic # Import the Nillion-AIVM client

This cell imports the required Python libraries for handling neural networks, data processing, and visualization. The `aivm_client` is the main interface for interacting with the AIVM system.

## MNIST Dataset Loading

In [21]:
def load_mnist():
    trans = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((28, 28)),
            transforms.Normalize((0.5,), (1.0,)),
        ]
    )
    train_set = dset.MNIST(
        root="/tmp/mnist", train=True, transform=trans, download=True
    )
    return train_set

dataset = load_mnist()

inputs, labels =  dataset[20]
inputs = inputs.reshape(1, 1, 28, 28)

This cell defines a function to load the MNIST dataset of handwritten digits. The data is:
- Transformed to tensors
- Resized to $28 \times 28$ pixels
- Normalized to the range $[-0.5, 0.5]$
We then load a single example (index 20) and reshape it for model input.

## Visualization



In [None]:
# Plot the inputs as a grayscale image
plt.imshow(inputs.detach().numpy().squeeze(), cmap='gray')
plt.title(f'Grayscale Image of a {labels}')
plt.show()


This cell visualizes the selected MNIST digit as a grayscale image, helping us verify the input data.

## Encrypted Inference with LeNet5


In [4]:
encrypted_inputs = aic.LeNet5Cryptensor(inputs)



This cell encrypts the input tensor using AIVM's encryption scheme specifically designed for `LeNet5` architecture.



In [None]:
%%time
result = aic.get_prediction(encrypted_inputs, "LeNet5MNIST")
results = torch.argmax(result, dim=1)
print("Predicted Label:", results.item())


This cell performs encrypted inference on the `LeNet5` model:
1. Sends encrypted data to the AIVM
2. Receives encrypted predictions
3. Decrypts and processes the results to get the predicted digit

## Alternative Inference Method


In [None]:
%%time
result = encrypted_inputs.forward("LeNet5MNIST")
results = torch.argmax(result, dim=1)
print("Predicted Label:", results.item())


This cell demonstrates an alternative method for inference using the `forward` method directly on the encrypted tensor.

## Batch Processing



In [None]:
%%time
for i in range(100):
    inputs, labels =  dataset[i]
    inputs = inputs.reshape(1, 1, 28, 28)
    encrypted_inputs = aic.LeNet5Cryptensor(inputs)
    result = aic.get_prediction(encrypted_inputs, "LeNet5MNIST")
    results = torch.argmax(result, dim=1)
    print("Predicted Label:", results.item(), "True Label:", labels, "Correct:", results.item() == labels)


This cell demonstrates batch processing by:
1. Loading 100 different MNIST images
2. Performing encrypted inference on each
3. Comparing predictions with true labels to assess accuracy

## SMS Spam Detection


In [None]:
tokenized_inputs = aic.tokenize("Hello World!")
tokenized_inputs = aic.tokenize("Your free ringtone is waiting to be collected. Simply text the password 'MIX' to 85069 to verify. Get Usher and Britney. FML, PO Box 5249, MK17 92H. 450Ppw 16")
print(tokenized_inputs)

encrypted_inputs = aic.BertTinyCryptensor(*tokenized_inputs)


This cell shows how to:
1. Tokenize text input for BERT processing
2. Encrypt the tokenized input for spam detection



In [None]:
%%time
result = aic.get_prediction(encrypted_inputs, "BertTinySMS")
result


This cell performs encrypted inference using a BERT Tiny model for spam detection.


In [None]:
"SPAM" if torch.argmax(result) else "HAM"



This final cell interprets the model's prediction as either "HAM" (legitimate message) or "SPAM".

## Key Concepts
- The notebook demonstrates privacy-preserving inference using AIVM
- Input data is encrypted before being processed by the models
- Two different use cases showcase the versatility of the system:
  - Computer vision (MNIST digit recognition)
  - Natural language processing (SMS spam detection)
- Performance timing is measured to evaluate the overhead of encrypted computation