**IMPORTANT**: Before starting this notebook make sure that the kernel of the previous notebook is shutdown or reset it's state to forget the previous `model_user` Nillion client

In [1]:
## If problems arise with the loading of the shared library, this script can be used to load the shared library before other libraries.
## Remember to also run on your local machine the script below:
# bash replace_lib_version.sh

import platform
import ctypes

if platform.system() == "Linux":
    # Force libgomp and py_nillion_client to be loaded before other libraries consuming dynamic TLS (to avoid running out of STATIC_TLS)
    ctypes.cdll.LoadLibrary("libgomp.so.1")
    ctypes.cdll.LoadLibrary(
        "/home/vscode/.local/lib/python3.12/site-packages/py_nillion_client/py_nillion_client.abi3.so"
    )

In [2]:
from typing import Dict, List

import json
import os
import joblib


from dotenv import load_dotenv
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

import nada_numpy as na
import nada_numpy.client as na_client
import py_nillion_client as nillion
from nillion_python_helpers import (
    create_nillion_client,
    getUserKeyFromFile,
    getNodeKeyFromFile,
)

from config import NUM_FEATS

## Authenticate with Nillion

To connect to the Nillion network, we need to have a user key and a node key. These serve different purposes:

The `user_key` is the user's private key. The user key should never be shared publicly, as it unlocks access and permissions to secrets stored on the network.

The `node_key` is the node's private key which is run locally to connect to the network.

In [3]:
# Load all Nillion network environment variables
assert os.getcwd().endswith(
    "examples/spam_detection"
), "Please run this script from the examples/spam_detection directory otherwise, the rest of the tutorial may not work"
load_dotenv()

True

In [4]:
cluster_id = os.getenv("NILLION_CLUSTER_ID")
model_user_userkey = getUserKeyFromFile(os.getenv("NILLION_USERKEY_PATH_PARTY_2"))
model_user_nodekey = getNodeKeyFromFile(os.getenv("NILLION_NODEKEY_PATH_PARTY_2"))
model_user_client = create_nillion_client(model_user_userkey, model_user_nodekey)
model_user_party_id = model_user_client.party_id
model_user_user_id = model_user_client.user_id

In [5]:
# This information was provided by the model provider
with open("target/tmp.json", "r") as provider_variables_file:
    provider_variables = json.load(provider_variables_file)

program_id = provider_variables["program_id"]
model_store_id = provider_variables["model_store_id"]
model_provider_party_id = provider_variables["model_provider_party_id"]

print("Program ID: ", program_id)
print("Model Store ID: ", model_store_id)
print("Model Provider Party ID: ", model_provider_party_id)

Program ID:  33sxVBj3jenx74bGq5eiX3HzwBJS85aGTjutfnfwPwVyJEPhhWr2h1CcYeryqUvvNXKr4ipGQjNFBVbHUDCWXjWE/spam_detection
Model Store ID:  bf1d62e2-beff-41bd-9c5a-f1acf9e6779d
Model Provider Party ID:  12D3KooWFYjK13Ny2W4hEfcZtD5DUvCGP6CJ4H2YnnUCHpZBDKpj


## Model user flow

### Convert text to features

In [6]:
vectorizer: TfidfVectorizer = joblib.load("model/vectorizer.joblib")

In [7]:
# Let's find out whether it's a billion dollar opportunity or pyramid scheme
INPUT_DATA = "Free entry in 2 a wkly comp to win exclusive prizes! Text WIN to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"

[features] = vectorizer.transform([INPUT_DATA]).toarray().tolist()

In [8]:
features = np.array(features).astype(float)

### Send features to Nillion

In [9]:
async def store_features(
    *,
    client: nillion.NillionClient,
    cluster_id: str,
    program_id: str,
    party_id: str,
    user_id: str,
    features: np.ndarray
) -> Dict[str, str]:
    """Stores text features in Nillion network.

    Args:
        client (nillion.NillionClient): Nillion client that stores features.
        cluster_id (str): Nillion cluster ID.
        program_id (str): Program ID of Nada program.
        party_id (str): Party ID of party that will store text features.
        user_id (str): User ID of user that will get compute permissions.
        features (List[float]): List of text features.
        precision (int): Scaling factor to convert float to ints.

    Returns:
        Dict[str, str]: Resulting `model_user_party_id` and `features_store_id`.
    """

    secrets = nillion.Secrets(na_client.array(features, "my_input", na.SecretRational))

    print(secrets)
    secret_bindings = nillion.ProgramBindings(program_id)
    secret_bindings.add_input_party("User", party_id)

    features_store_id = await client.store_secrets(
        cluster_id, secret_bindings, secrets, None
    )

    return {
        "model_user_user_id": user_id,
        "features_store_id": features_store_id,
    }

In [10]:
result_store_features = await store_features(
    client=model_user_client,
    cluster_id=cluster_id,
    program_id=program_id,
    party_id=model_user_party_id,
    user_id=model_user_user_id,
    features=features,
)

model_user_user_id = result_store_features["model_user_user_id"]
features_store_id = result_store_features["features_store_id"]

print("✅ Text features uploaded successfully!")
print("model_user_user_id:", model_user_user_id)
print("features_store_id:", features_store_id)

<builtins.Secrets object at 0x30ad72dd0>
✅ Text features uploaded successfully!
model_user_user_id: 346jE91YmBjhSdvVSKCSpBxAY1kj5brpnCtWutLWMvXj8sF2iBR7fnmm4fjAGj7uPmHtb7CBQRjV5Q4H1KzepDCL
features_store_id: 10c7d969-f60f-43fe-8795-2335bc9be522


### Run inference & check result

In [11]:
async def run_inference(
    *,
    client: nillion.NillionClient,
    cluster_id: str,
    program_id: str,
    model_user_party_id: str,
    model_provider_party_id: str,
    model_store_id: str,
    features_store_id: str,
) -> Dict[str, str | float]:
    """Runs blind inference on the Nillion network by executing the Nada program on the uploaded data.

    Args:
        client (nillion.NillionClient): Nillion client that runs inference.
        cluster_id (str): Nillion cluster ID.
        program_id (str): Program ID of Nada program.
        model_user_party_id (str): Party ID of party that will run inference.
        model_user_party_id (str): Party ID of party that will provide model params.
        model_store_id (str): Store ID that points to the model params in the Nillion network.
        features_store_id (str): Store ID that points to the text features in the Nillion network.
        precision (int): Scaling factor to convert float to ints.s

    Returns:
        Dict[str, str | float]: Resulting `compute_id` and `logit`.
    """
    compute_bindings = nillion.ProgramBindings(program_id)
    compute_bindings.add_input_party("User", model_user_party_id)
    compute_bindings.add_input_party("Provider", model_provider_party_id)
    compute_bindings.add_output_party("User", model_user_party_id)

    _ = await client.compute(
        cluster_id,
        compute_bindings,
        [features_store_id, model_store_id],
        nillion.Secrets({}),
        nillion.PublicVariables({}),
    )

    while True:
        compute_event = await client.next_compute_event()
        if isinstance(compute_event, nillion.ComputeFinishedEvent):
            inference_result = compute_event.result.value
            break

    sigmoid = lambda x: 1 / (1 + np.exp(-x))

    quantized_logit = inference_result["logit_0"]
    logit = quantized_logit / (2 ** na.get_log_scale())
    output_probability = sigmoid(logit)
    return {
        "compute_id": compute_event.uuid,
        "logit": logit,
        "output_probability": output_probability,
    }

In [12]:
result_inference = await run_inference(
    client=model_user_client,
    cluster_id=cluster_id,
    program_id=program_id,
    model_user_party_id=model_user_party_id,
    model_provider_party_id=model_provider_party_id,
    model_store_id=model_store_id,
    features_store_id=features_store_id,
)

compute_id = result_inference["compute_id"]
logit = result_inference["logit"]
output_probability = result_inference["output_probability"]

print("✅ Inference ran successfully!")
print("compute_id:", compute_id)
print("logit:", logit)

print("Probability of spam in Nillion: {:.6f}%".format(output_probability * 100))

✅ Inference ran successfully!
compute_id: dc230bb3-1a27-41c0-9236-3ee0e11f5c1a
logit: 2.4093170166015625
Probability of spam in Nillion: 91.753502%


### Compare result to what we would have gotten in plain-text inference

In [13]:
vectorizer: TfidfVectorizer = joblib.load("model/vectorizer.joblib")
classifier: LogisticRegression = joblib.load("model/classifier.joblib")

In [14]:
features = vectorizer.transform([INPUT_DATA]).toarray().tolist()

In [15]:
[logit_plain_text] = classifier.decision_function(features)

In [16]:
print("Logit in plain text: {}".format(logit_plain_text))

Logit in plain text: 2.4080795630742755


In [17]:
[result] = classifier.predict_proba(features)
output_probability_plain_text = result[1]

In [18]:
print(
    "Probability of spam in plain text: {:.6f}%".format(
        output_probability_plain_text * 100
    )
)
print("Probability of spam in Nillion: {:.6f}%".format(output_probability * 100))

Probability of spam in plain text: 91.744134%
Probability of spam in Nillion: 91.753502%
