 # Tutorial 5: Neural Network Inference
 Our library can support inference of neural networks based on secret sharing. Here we present a tutorial of neural network inference using secure two-party computation. Similar to Tutorial_2, we simulate multiple parties using multi-threads and trusted third parties which provide auxiliary parameters using local files. Models are shared before the prediction, and data is shared during the prediction process. You can refer to `./debug/application/neural_network/2pc/neural_network_server.py` and `./debug/application/neural_network/2pc/neural_network_client.py` for examples of actual usage of the neural network.
 In this tutorial, we use AlexNet as an example. First, train the model using `data/neural_network/AlexNet/Alexnet_MNIST_train.py`. 

In [1]:
import torch

# training AlexNet
exec(open('../data/neural_network/AlexNet/Alexnet_MNIST_train.py').read())

And then, import the following packages:

In [2]:
from data.neural_network.AlexNet.Alexnet import AlexNet
from application.neural_network.model.model_converter import load_secure_model_from_file
from config.base_configs import *
from application.neural_network.party.neural_network_party import NeuralNetworkCS

With the server as the model provider and the client as the data provider, we need to generate triples for matrix multiplication in advance and distribute them to both parties. Similar to Tutorial_2, we simulate this process on the server side.
The model provider also needs to import the following packages to share the model and data owner needs to import the following packages to share the data.

In [3]:
from application.neural_network.model.model_converter import share_and_save_model
from application.neural_network.model.model_converter import share_data

Now, we can create our two parties.

In [4]:
import threading
from config.network_configs import *


server_server_address = (SERVER_IP, SERVER_SERVER_PORT)
server_client_address = (SERVER_IP, SERVER_CLIENT_PORT)

client_server_address = (CLIENT_IP, CLIENT_SERVER_PORT)
client_client_address = (CLIENT_IP, CLIENT_CLIENT_PORT)

# set Server
server = NeuralNetworkCS(type='server')

def set_server():
    # CS connect
    server.connect(server_server_address, server_client_address, client_server_address, client_client_address)

# set Client
client = NeuralNetworkCS(type='client')

def set_client():
    # CS connect
    client.connect(client_server_address, client_client_address, server_server_address, server_client_address)

server_thread = threading.Thread(target=set_server)
client_thread = threading.Thread(target=set_client)

server_thread.start()
client_thread.start()
client_thread.join()
server_thread.join()

TCPServer waiting for connection ......
successfully connect to server 127.0.0.1:30000
TCPServer waiting for connection ......
successfully connect to server 127.0.0.1:20000
TCPServer successfully connected by :('127.0.0.1', 20001)
TCPServer successfully connected by :('127.0.0.1', 30001)


The model provider needs to provide and share the model. We can choose whether to save param locally or not. For an example of another choice, see the C/S example in `debug/application/neural_network/2pc`.

In [5]:
net = AlexNet()
net.load_state_dict(torch.load('./data/neural_network/AlexNet/AlexNet_MNIST.pkl'))

share_and_save_model(model=net, save_path=model_file_path)  # share model

The data provider needs to provide data. Take an image in the MNIST dataset as an example.

In [6]:
data = "./data/img/image.png"

Because neural network inference involves matrix multiplication, before starting the prediction, we need to simulate one prediction and generate the required matrix Beaver triples in advance.

In [7]:
def server_dummy_model():
    server.dummy_model(net)

def client_dummy_model():
    client.dummy_model(data)

server_thread = threading.Thread(target=server_dummy_model)
client_thread = threading.Thread(target=client_dummy_model)

server_thread.start()
client_thread.start()
client_thread.join()
server_thread.join()

The above steps are the preparation work. Before starting inference, the data provider needs to share its data. And then, the two parties load their respective model shares and start inference.

In [8]:
def server_predict():
    data_share = server.receive() # receive data share
    net = load_secure_model_from_file(net=AlexNet(), path=model_file_path, party=server)
    server.inference(net, data_share)

    # close party after inference
    server.close()


def client_predict():
    data_shares = share_data(data) # share data
    data_share = data_shares[1]
    client.send(data_shares[0]) # send shares to other party
    net = load_secure_model_from_file(net=AlexNet(), path=model_file_path, party=client)
    res = client.inference(net, data_share)

    _, predicted = torch.max(res, 1)
    # predicted_result
    print('predicted result: ', predicted)

    # close party after inference
    client.close()

server_thread = threading.Thread(target=server_predict)
client_thread = threading.Thread(target=client_predict)

server_thread.start()
client_thread.start()
client_thread.join()
server_thread.join()

predicted result:  tensor([0], device='cuda:0')
Communication costs:
	send rounds: 88		send bytes: 3.078125 KB.
	recv rounds: 112		recv bytes: 1.1753768920898438 MB.
Communication costs:
	send rounds: 112		send bytes: 1.1753768920898438 MB.
	recv rounds: 88		recv bytes: 3.078125 KB.


We can see the prediction results as above, the core statements used by our library for neural network prediction are `server.inference` and `client.inference`. If you wish to perform additional operations on the prediction results, you can process them according to your specific requirements.
In [data/neural_network/AlexNet/](https://github.com/XidianNSS/NssMPClib/tree/main/data/neural_network/AlexNet) and [data/neural_network/ResNet/](https://github.com/XidianNSS/NssMPClib/tree/main/data/neural_network/ResNet), we provide the training code for AlexNet and ResNet50. You can use them to train models according to your specific requirements and perform inference using trained models.