**IMPORTANT**: Before starting this notebook make sure that the kernel of the previous notebook is shutdown or reset it's state to forget the previous `model_user` Nillion client

In [None]:
## If problems arise with the loading of the shared library, this script can be used to load the shared library before other libraries.
## Remember to also run on your local machine the script below:
# bash replace_lib_version.sh

import platform
import ctypes

if platform.system() == "Linux":
    # Force libgomp and py_nillion_client to be loaded before other libraries consuming dynamic TLS (to avoid running out of STATIC_TLS)
    ctypes.cdll.LoadLibrary("libgomp.so.1")
    ctypes.cdll.LoadLibrary(
        "/home/vscode/.local/lib/python3.12/site-packages/py_nillion_client/py_nillion_client.abi3.so"
    )

In [None]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import json
import os
from common.utils import compute, store_secret_array
from nillion_python_helpers import (create_nillion_client,
                                    create_payments_config)
from py_nillion_client import NodeKey, UserKey
import py_nillion_client as nillion
from torchvision import transforms
import nada_numpy as na
from PIL import Image
from dotenv import load_dotenv
import numpy as np
import torch

from cosmpy.aerial.client import LedgerClient
from cosmpy.aerial.wallet import LocalWallet
from cosmpy.crypto.keypairs import PrivateKey

home = os.getenv("HOME")
load_dotenv(f"{home}/.config/nillion/nillion-devnet.env")

## Authenticate with Nillion

To connect to the Nillion network, we need to have a user key and a node key. These serve different purposes:

The `user_key` is the user's private key. The user key should never be shared publicly, as it unlocks access and permissions to secrets stored on the network.

The `node_key` is the node's private key which is run locally to connect to the network.

In [None]:
# Load all Nillion network environment variables
assert os.getcwd().endswith(
    "examples/multi_layer_perceptron"
), "Please run this script from the examples/multi_layer_perceptron directory otherwise, the rest of the tutorial may not work"
load_dotenv()

In [None]:
cluster_id = os.getenv("NILLION_CLUSTER_ID")
grpc_endpoint = os.getenv("NILLION_NILCHAIN_GRPC")
chain_id = os.getenv("NILLION_NILCHAIN_CHAIN_ID")
seed = "my_seed"
model_user_userkey = UserKey.from_seed((seed))
model_user_nodekey = NodeKey.from_seed((seed))
model_user_client = create_nillion_client(model_user_userkey, model_user_nodekey)
model_user_party_id = model_user_client.party_id
model_user_user_id = model_user_client.user_id

In [None]:
payments_config = create_payments_config(chain_id, grpc_endpoint)
payments_client = LedgerClient(payments_config)
payments_wallet = LocalWallet(
    PrivateKey(bytes.fromhex(os.getenv("NILLION_NILCHAIN_PRIVATE_KEY_0"))),
    prefix="nillion",
)

In [None]:
# This information was provided by the model provider
with open("data/tmp.json", "r") as provider_variables_file:
    provider_variables = json.load(provider_variables_file)

program_id = provider_variables["program_id"]
model_store_id = provider_variables["model_store_id"]
model_provider_party_id = provider_variables["model_provider_party_id"]

print("Program ID: ", program_id)
print("Model Store ID: ", model_store_id)
print("Model Provider Party ID: ", model_provider_party_id)

## Model user flow

### Read image

In [None]:
test_image = transforms.Compose(
    [
        transforms.Grayscale(),
        transforms.Resize((16, 16)),
        transforms.ToTensor(),
    ]
)(Image.open("data/COVID-19_Lung_CT_Scans/COVID-19/COVID-19_0001.png"))

In [None]:
test_image_batch = np.array(test_image.unsqueeze(0))
test_image_batch.shape  # (B, channels, H, W)

### Send features to Nillion

In [None]:
permissions = nillion.Permissions.default_for_user(model_user_client.user_id)
permissions.add_compute_permissions({model_user_client.user_id: {program_id}})

images_store_id = await store_secret_array(
    model_user_client,
    payments_wallet,
    payments_client,
    cluster_id,
    test_image_batch,
    "my_input",
    na.SecretRational,
    1,
    permissions,
)

### Run inference & check result

In [None]:
compute_bindings = nillion.ProgramBindings(program_id)

compute_bindings.add_input_party("Provider", model_provider_party_id)
compute_bindings.add_input_party("User", model_user_party_id)
compute_bindings.add_output_party("User", model_user_party_id)

In [None]:
result = await compute(
    model_user_client,
    payments_wallet,
    payments_client,
    program_id,
    cluster_id,
    compute_bindings,
    [model_store_id, images_store_id],
    nillion.NadaValues({}),
    verbose=True,
)
result

In [None]:
result_inference = {key: na.float_from_rational(value) for key, value in result.items()}
result_inference

### Compare result to what we would have gotten in plain-text inference

In [None]:
# Create custom torch Module
class MyNN(torch.nn.Module):
    """My simple neural net"""

    def __init__(self) -> None:
        """Model is a two layers and an activations"""
        super().__init__()
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=2, kernel_size=3, stride=4, padding=1
        )
        self.pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.fc1 = torch.nn.Linear(in_features=8, out_features=2)

        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()

    def forward(self, x: np.ndarray) -> np.ndarray:
        """My forward pass logic"""
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [None]:
my_model = MyNN()
my_model.load_state_dict(torch.load("./data/my_model.pt"))

In [None]:
torch.softmax(my_model(test_image.unsqueeze(0))[0], dim=0)

In [None]:
torch.softmax(
    torch.Tensor([result_inference["output_0"], result_inference["output_1"]]), dim=0
)