**IMPORTANT**: Before starting this notebook make sure that the kernel of the previous notebook is shutdown or reset it's state to forget the previous `model_user` Nillion client

In [1]:
## If problems arise with the loading of the shared library, this script can be used to load the shared library before other libraries.
## Remember to also run on your local machine the script below:
# bash replace_lib_version.sh

import platform
import ctypes

if platform.system() == "Linux":
    # Force libgomp and py_nillion_client to be loaded before other libraries consuming dynamic TLS (to avoid running out of STATIC_TLS)
    ctypes.cdll.LoadLibrary("libgomp.so.1")
    ctypes.cdll.LoadLibrary(
        "/home/vscode/.local/lib/python3.12/site-packages/py_nillion_client/py_nillion_client.abi3.so"
    )

In [2]:
from typing import Dict
import torch
import json
import os
import py_nillion_client as nillion
from torchvision import transforms
from PIL import Image
from dotenv import load_dotenv
import numpy as np

import nada_algebra as na
import nada_algebra.client as na_client
import py_nillion_client as nillion
from nillion_python_helpers import (
    create_nillion_client,
    getUserKeyFromFile,
    getNodeKeyFromFile,
)

## Authenticate with Nillion

To connect to the Nillion network, we need to have a user key and a node key. These serve different purposes:

The `user_key` is the user's private key. The user key should never be shared publicly, as it unlocks access and permissions to secrets stored on the network.

The `node_key` is the node's private key which is run locally to connect to the network.

In [3]:
# Load all Nillion network environment variables
assert os.getcwd().endswith(
    "examples/multi_layer_perceptron"
), "Please run this script from the examples/multi_layer_perceptron directory otherwise, the rest of the tutorial may not work"
load_dotenv()

True

In [4]:
cluster_id = os.getenv("NILLION_CLUSTER_ID")
model_user_userkey = getUserKeyFromFile(os.getenv("NILLION_USERKEY_PATH_PARTY_2"))
model_user_nodekey = getNodeKeyFromFile(os.getenv("NILLION_NODEKEY_PATH_PARTY_2"))
model_user_client = create_nillion_client(model_user_userkey, model_user_nodekey)
model_user_party_id = model_user_client.party_id
model_user_user_id = model_user_client.user_id

In [5]:
# This information was provided by the model provider
with open("data/tmp.json", "r") as provider_variables_file:
    provider_variables = json.load(provider_variables_file)

program_id = provider_variables["program_id"]
model_store_id = provider_variables["model_store_id"]
model_provider_party_id = provider_variables["model_provider_party_id"]

print("Program ID: ", program_id)
print("Model Store ID: ", model_store_id)
print("Model Provider Party ID: ", model_provider_party_id)

Program ID:  5GnDBJ4brQtS7U3YsUdweoKUfbN7t9CNDaZ7AnVLarwWcRNfQUwr1Ct2aTQsMhhvjb1xeYPfiKAahMz7i5D85ZUd/main
Model Store ID:  d6002305-15f9-4e76-b5aa-510560f89e75
Model Provider Party ID:  12D3KooWQLbKxRFoa3rcA8do1R2o96yhgxGnJS5RbTr3ZJwQiQay


## Model user flow

### Read image

In [6]:
test_image = transforms.Compose(
    [
        transforms.Grayscale(),
        transforms.Resize((16, 16)),
        transforms.ToTensor(),
    ]
)(Image.open("data/COVID-19_Lung_CT_Scans/COVID-19/COVID-19_0001.png"))

In [7]:
test_image_batch = np.array(test_image.unsqueeze(0))
test_image_batch.shape  # (B, channels, H, W)

(1, 1, 16, 16)

### Send features to Nillion

In [8]:
async def store_images(
    *,
    client: nillion.NillionClient,
    cluster_id: str,
    program_id: str,
    party_id: str,
    user_id: str,
    images: torch.Tensor,
) -> Dict[str, str]:
    """Stores text features in Nillion network.

    Args:
        client (nillion.NillionClient): Nillion client that stores features.
        cluster_id (str): Nillion cluster ID.
        program_id (str): Program ID of Nada program.
        party_id (str): Party ID of party that will store text features.
        user_id (str): User ID of user that will get compute permissions.
        images (torch.Tensor): Image batch.
        precision (int): Scaling factor to convert float to ints.

    Returns:
        Dict[str, str]: Resulting `model_user_party_id` and `images_store_id`.
    """
    secrets = nillion.Secrets(
        na_client.array(images, "my_input", nada_type=na.SecretRational)
    )

    secret_bindings = nillion.ProgramBindings(program_id)
    secret_bindings.add_input_party("Party1", party_id)

    images_store_id = await client.store_secrets(
        cluster_id, secret_bindings, secrets, None
    )

    return {
        "model_user_user_id": user_id,
        "images_store_id": images_store_id,
    }

In [9]:
result_store_features = await store_images(
    client=model_user_client,
    cluster_id=cluster_id,
    program_id=program_id,
    party_id=model_user_party_id,
    user_id=model_user_user_id,
    images=test_image_batch,
)

model_user_user_id = result_store_features["model_user_user_id"]
images_store_id = result_store_features["images_store_id"]

print("✅ Images uploaded successfully!")
print("model_user_user_id:", model_user_user_id)
print("images_store_id:", images_store_id)

✅ Images uploaded successfully!
model_user_user_id: unDxCapdG2Dp7w2FwajbBWEF6wrinZp1ArKPvqeMxGt32WbkoXcZGQcSJHwDUMvKr4AG6zQnW4GGDaBcCFqtsu3
images_store_id: 3465b7fe-062e-4528-9fc3-e20bdff86008


### Run inference & check result

In [10]:
async def run_inference(
    *,
    client: nillion.NillionClient,
    cluster_id: str,
    program_id: str,
    model_user_party_id: str,
    model_provider_party_id: str,
    model_store_id: str,
    images_store_id: str,
) -> Dict[str, str | float]:
    """Runs blind inference on the Nillion network by executing the Nada program on the uploaded data.

    Args:
        client (nillion.NillionClient): Nillion client that runs inference.
        cluster_id (str): Nillion cluster ID.
        program_id (str): Program ID of Nada program.
        model_user_party_id (str): Party ID of party that will run inference.
        model_user_party_id (str): Party ID of party that will provide model params.
        model_store_id (str): Store ID that points to the model params in the Nillion network.
        images_store_id (str): Store ID that points to the images in the Nillion network.
        precision (int): Scaling factor to convert float to ints.s

    Returns:
        Dict[str, str | float]: Resulting `compute_id`, `output_0` and `output_1`.
    """
    compute_bindings = nillion.ProgramBindings(program_id)
    compute_bindings.add_input_party("Party0", model_user_party_id)
    compute_bindings.add_input_party("Party1", model_provider_party_id)
    compute_bindings.add_output_party("Party1", model_user_party_id)

    _ = await client.compute(
        cluster_id,
        compute_bindings,
        [images_store_id, model_store_id],
        nillion.Secrets({}),
        nillion.PublicVariables({}),
    )

    while True:
        compute_event = await client.next_compute_event()
        if isinstance(compute_event, nillion.ComputeFinishedEvent):
            inference_result = compute_event.result.value
            break

    return {
        "compute_id": compute_event.uuid,
        "output_0": inference_result["my_output_0_0"] / (2 ** na.get_log_scale()),
        "output_1": inference_result["my_output_0_1"] / (2 ** na.get_log_scale()),
    }

In [11]:
result_inference = await run_inference(
    client=model_user_client,
    cluster_id=cluster_id,
    program_id=program_id,
    model_user_party_id=model_user_party_id,
    model_provider_party_id=model_provider_party_id,
    model_store_id=model_store_id,
    images_store_id=images_store_id,
)
result_inference

{'compute_id': '42b1bcc4-fa08-4260-aa2e-213fb9702f10',
 'output_0': -1.40350341796875,
 'output_1': 0.935302734375}

### Compare result to what we would have gotten in plain-text inference

In [12]:
# Create custom torch Module
class MyNN(torch.nn.Module):
    """My simple neural net"""

    def __init__(self) -> None:
        """Model is a two layers and an activations"""
        super(MyNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=2, kernel_size=3, stride=4, padding=1
        )
        self.pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.fc1 = torch.nn.Linear(in_features=8, out_features=2)

        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()

    def forward(self, x: np.ndarray) -> np.ndarray:
        """My forward pass logic"""
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [13]:
my_model = MyNN()
my_model.load_state_dict(torch.load("./data/my_model.pt"))

<All keys matched successfully>

In [14]:
torch.softmax(my_model(test_image.unsqueeze(0))[0], dim=0)

tensor([0.0880, 0.9120], grad_fn=<SoftmaxBackward0>)

In [15]:
torch.softmax(
    torch.Tensor([result_inference["output_0"], result_inference["output_1"]]), dim=0
)

tensor([0.0880, 0.9120])