# Multi Party FHE
Expected RAM usage: 7.3 GB
Expected runtime: 2 minutes

## Introduction

This example demonstrates a use of a multi-party FHE setting. In this setting, a set of parties wish to compute some function over their secret data while not revealing it to any of the other parties. Using a regular public-key setting will not be secure, since it requires the parties to trust the holder of the secret key (whether it is one of the parties or a "trusted" third party). In the multi-party FHE setting, none of the parties has a hold on the secret key. Instead, each party has its own secret key (therefore, it will also be called a "key-owner" later on). The public keys (which includes the encryption key and the evaluation keys) are generated in a initialization protocol (a.k.a InitProtocol) between the parties. To decrypt a ciphertext, each of the parties (key-owners) needs to give its consent and to take part in a decryption protocol (a.k.a. DecryptProtocol).

In the following example we consider the case of 2 data owners - Alice and Bob - and a server. The security model requires that the server is not colluding with any of the data owners. The example demonstrates how an encrypted linear regression model can be trained with encrypted data of multiples data owners. After the training process is complete the model is decrypted and shared between the data owners. The model used in this example is a linear regression model.

The input of each of the data owners (Alice and Bob) is a dataset of 100000
samples and 1 feature contains fabricated data.

## Step 1. Framework Setup

### 1.1. Start with some imports

We import general pyhelayers classes (e.g., HeConfigRequirement; see details in the basic example notebooks) and multi-party specific classes:
* MultiPartyConfig - Contains the configurations of the multi-party setting. It can be set directly from the python code or loaded from a json file (as we do in this example).
* InitProtocol - Each protocol in the multi-party setting is represented as a protocol object. Each participant initializes its own protocol object and uses it to run the protocol and to get information about its status. The InitProtocol object is used to the run initialization protocol where the participants generate the public keys.
* DecryptProtocol - The DecryptProtocol is used to the run decryption protocol where the participants decrypts a CTile / CTileTensor / EncryptedData object.
* ProtocolMessage - During the run of the protocols, the protocol objects generate messages to send to other participants. These messages are represented as ProtocolMessage objects. These objects can be serialized and saved to memory or to a file. In addition, the ProtocolMessage object allows to get information about the message (e.g., whom is it destined for, on which round it should be received, etc.).
* ProtocolMessageVector - When running a protocol, each participant collects at each round all the messages that are relevant to it. To feed these messages to the protocol object, the participants wrap their messages in a ProtocolMessageVector object.

We also specify the path of two directories:
* MULTI_PARTY_PATH - This directory contains the pre-prepared multi-party configurations files of Alice, Bob and the server.
* COMMUNICATION_DIR_PATH - This is a temporary directory that will be used for saving and loading protocol messages and encrypted data.

In [None]:
import numpy as np
import os
import json
import shutil
import threading
from pyhelayers import OpenFheCkksContext, MockupContext, HeConfigRequirement, RotationSetType
from pyhelayers import PlainModelHyperParams, PlainModel, HeModel, HeRunRequirements, EncryptedData, LogisticRegressionPlain, LRActivation, LRDistribution, ModelIoEncoder
from pyhelayers import MultiPartyConfig, InitProtocol, DecryptProtocol, ProtocolMessage, ProtocolMessageVector
from utils import get_data_sets_dir, get_temp_output_data_dir, create_clean_dir

MULTI_PARTY_PATH = os.path.join(get_data_sets_dir(), 'multi_party_fhe')
COMMUNICATION_DIR_PATH = os.path.join(get_temp_output_data_dir(), 'multi_party_fhe')
create_clean_dir(get_temp_output_data_dir())

### 1.2. Set shared parameters

Here we set the shared model hyper-parameters to be used by the parties. See other AI related notebooks for information about the different parameters.

In [None]:
# Whether to use mockup HE context in the example or to use a secure HE context.
use_mockup = False

req = HeConfigRequirement.insecure(
    num_slots = 2 ** 15,
    multiplication_depth = 18,
    fractional_part_precision = 52,
    integer_part_precision = 8) if use_mockup else HeConfigRequirement(
        num_slots = 2 ** 15,
        multiplication_depth = 18,
        fractional_part_precision = 52,
        integer_part_precision = 8)
req.public_functions.rotate = RotationSetType.CUSTOM_ROTATIONS
req.public_functions.set_rotation_steps([1, 16, 256, 4096])

mean_x = 10.0
std_x = 1.0
std_x_noise_for_fabricated_data = 0.1
phi0 = 50.0
phi1 = -0.4
num_samples_each_party = 100000

hyper_params = PlainModelHyperParams()
hyper_params.number_of_features = 1
hyper_params.logistic_regression_activation = LRActivation.NONE
hyper_params.linear_regression_distribution_x = LRDistribution.LR_NORMAL_DISTRIBUTION
hyper_params.linear_regression_mean_x = mean_x
hyper_params.linear_regression_std_x = std_x
hyper_params.inverse_approximation_precision = 2
hyper_params.trainable = True

# This should equal the number of data owners. In our case there are two data owners (Alice and Bob).
hyper_params.fit_hyper_params.number_of_iterations = 2

# This should equal the maximal number of samples for each party. In our case there are at most num_samples_each_party samples for each party.
hyper_params.fit_hyper_params.fit_batch_size = num_samples_each_party

### 1.3. Define helper functions

We define two helper functions to be used by the participants.

* setup_participant gets the HeContext object, the name of the participant (a string) and the base HeConfigRequirement object. It loads into the HeConfigRequirement object the multi-party config of the current participant and initializes the HeContext object.
* setup_he_model creates an empty linear regression model for the parties to train.
* generate_encrypt_and_save_inputs generates fabricated input for Alice and Bob to encrypt and send to the server.
* read_messages_and_execute_round gets the HeContext object and the Protocol object (either InitProtocol or DecryptProtocol in this example). It traverse all files in the communication directory, loads each of them into a ProtocolMessage object, and checks, using the is_input_message_valid_for_current_round method, if this message is relevant to them. All relevant message are collected into a ProtocolMessageVector object which is fed into the protocol object using the execute_next_round method. This method also generates the output message of the round, and returns a boolean value which equals True if the round was executed successfully. Finally, it saves all output message to the communication directory, using the get_metadata_as_string method to name each file (this method outputs the metadata of the ProtocolMessage into a human readable string).

In [None]:
def setup_participant(he, name, req):
    # Read multi-party configuration file
    f = open(os.path.join(MULTI_PARTY_PATH, 'multi_party_config_' + name + '.json'))
    data = json.load(f)
    mp_config = MultiPartyConfig()
    mp_config.participant_id = data['participant_id']
    mp_config.set_key_owners_ids(data['key_owners_ids'])
    mp_config.initiator_id = data['initiator_id']
    mp_config.aggregator_id = data['aggregator_id']
    req.set_multi_party_config(mp_config)

    # Initialize context
    he.init(req)

def setup_he_model(he, hyper_params):
    # Unlike all other demos, in this demo we had to initialize an HeContext before initializing the HeModel.
    # This is because the HeContext must be generated using a multi-party initialization protocol. This requires us to use a lower-level API than the simpler API shown in most other demos.
    lrp = PlainModel.create(hyper_params)
    he_run_req = HeRunRequirements()
    
    if use_mockup:
        he_run_req.set_he_context_options([OpenFheCkksContext()])
    else:
        he_run_req.set_he_context_options([he])

    he_run_req.set_explicit_he_config_requirement(he.get_he_config_requirement())
    profile = HeModel.compile(lrp, he_run_req)
    lr = lrp.get_empty_he_model(he)
    lr.encode_encrypt(lrp, profile)
    return lr

def generate_encrypt_and_save_inputs(he, name, lr):
    ioe = ModelIoEncoder(lr)
    encrypted_inputs = EncryptedData(he)

    # Create fabricated data
    error = np.random.normal(0, 1, num_samples_each_party)
    x = np.random.normal(mean_x, std_x + std_x_noise_for_fabricated_data, num_samples_each_party) 
    y = phi0 + phi1 * x + error

    ioe.encode_encrypt(encrypted_inputs, [np.expand_dims(x, axis=1), np.expand_dims(y, axis=1)])
    encrypted_inputs.save_to_file(os.path.join(COMMUNICATION_DIR_PATH, name + '_encrypted_data'))

def read_messages_and_execute_round(he, protocol, exceptions):
    try:
        # Clear input messages queue
        input_messages = ProtocolMessageVector()
        output_messages = ProtocolMessageVector()

        # Read messages from directory (here we load every message and check its metadata from the message object in memory. In other implementation we can keep the metadata in the file name and save the loading of unneeded messages).
        for file in os.listdir(os.fsencode(COMMUNICATION_DIR_PATH)):
            filename = os.fsdecode(file)

            # Skip irrelevant messages
            mp_config = he.get_he_config_requirement().get_multi_party_config()
            if "round_" + str(protocol.get_current_round()) not in filename \
                or "source_id_" + str(mp_config.participant_id) in filename \
                or ("dest_role_AGGREGATOR" in filename and not mp_config.is_aggregator()) \
                or ("dest_role_KEY-OWNER" in filename and not mp_config.is_key_owner()):
                continue

            message = ProtocolMessage(he)
            message.load_from_file(os.path.join(COMMUNICATION_DIR_PATH, filename))
            if protocol.is_input_message_valid_for_current_round(message):
                input_messages.append(message)

        # Execute round
        result = protocol.execute_next_round(output_messages, input_messages)
        assert result is True

        # Upload messages to directory
        for i, message in enumerate(output_messages):
            message.save_to_file(os.path.join(COMMUNICATION_DIR_PATH, message.get_metadata_as_string(True) + '_' + str(i)))
    except Exception as e:
        exceptions[mp_config.participant_id] = e


### 1.4. Participants setup

The setup phase includes loading the multi-party configuration objects from files and initializing the HeContext objects of the participants.

In [None]:
print('*** Starting multi-party FHE demo ***')

# Alice setup
he_alice = MockupContext() if use_mockup else OpenFheCkksContext()
setup_participant(he_alice, 'alice', req)

# Bob setup
he_bob = MockupContext() if use_mockup else OpenFheCkksContext()
setup_participant(he_bob, 'bob', req)

# Server setup
he_server = MockupContext() if use_mockup else OpenFheCkksContext()
setup_participant(he_server, 'server', req)

## Step 2. Initialization protocol

In the initialization protocol the participants generate the public keys. We first initialize the protocol object for each participant and then run the protocol using the helper function (see information above). The method needs_another_round returns True if the protocol needs another round. In the example below with call it on Alice's protocol object, but in real world use each participant will use its own object.

In [None]:
print('*** Initialization protocol ***')

create_clean_dir(COMMUNICATION_DIR_PATH)

# Alice side
init_protocol_alice = InitProtocol(he_alice)

# Bob side
init_protocol_bob = InitProtocol(he_bob)

# Server side
init_protocol_server = InitProtocol(he_server)

exceptions = {}
while init_protocol_alice.needs_another_round():
    
    th_alice = threading.Thread(target=read_messages_and_execute_round, args=(he_alice, init_protocol_alice, exceptions))
    th_bob = threading.Thread(target=read_messages_and_execute_round, args=(he_bob, init_protocol_bob, exceptions))
    th_server = threading.Thread(target=read_messages_and_execute_round, args=(he_server, init_protocol_server, exceptions))
 
    th_alice.start()
    th_bob.start()
    th_server.start()

    th_alice.join()
    th_bob.join()
    th_server.join()

    if len(exceptions) > 0:
        raise Exception(exceptions)

## Step 3. Homomorphic computation

At this point all parties have the public keys, and they can encrypt and perform homomorphic computation. Alice and Bob will each initialize an empty HE model and use it to encrypt their inputs and send them to the server.

In [None]:
print('*** Encrypt inputs ***')

create_clean_dir(COMMUNICATION_DIR_PATH)

# Alice side
lr_alice = setup_he_model(he_alice, hyper_params)
generate_encrypt_and_save_inputs(he_alice, "alice", lr_alice)

# Bob side
lr_bob = setup_he_model(he_bob, hyper_params)
generate_encrypt_and_save_inputs(he_bob, "bob", lr_bob)

# The server initializes an empty HE model, loads Alice's and Bob's encrypted inputs and trains the model.

print('*** Train encrypted model ***')

# Server side
ed0 = EncryptedData(he_server)
ed1 = EncryptedData(he_server)
ed0.load_from_file(os.path.join(COMMUNICATION_DIR_PATH, 'alice_encrypted_data'))
ed1.load_from_file(os.path.join(COMMUNICATION_DIR_PATH, 'bob_encrypted_data'))

# This merges the two EncryptedData elements into one element.
ed0.add_encrypted_data(ed1)

lr_server = setup_he_model(he_server, hyper_params)
lr_server.fit(ed0)

# The server extracts the encrypted internals of the trained model.
encrypted_model_internals = lr_server.get_encrypted_internals()

# Now, Alice and Bob wish to decrypt the model. Alice will be the
# plaintext-aggregator (i.e., the one who gets the decrypted model first),
# and she will share the result with Bob.

## Step 4. Decrypt protocol

There are two important roles in the decryption protocol:
* Plaintext aggregator - this is the participant that will get the decrypted result. The plaintext aggregator in this example is Alice, so each participant uses the set_plaintext_aggregator_id method to set the plaintext aggregator ID to be Alice's ID (we assume that the IDs are public and known in advance by all participants). After the run of the protocol, Alice uses the get_output_vector_double method to get the decrypted data. When running the decryption protocol with input of type CTileTensor or EncryptedData, the plaintext aggregator should use the get_output_double_tensor method or the get_output_vector_double_tensor method respectively.
* Ciphertext holder - this is the participant who loads the ciphertext to decrypt into the DecryptProtocol object. In this example it is the server which loads a CTile object. The decryption protocol can also run with input of type CTileTensor and EncryptedData.

NOTE: in this example, the plaintext aggregator is a key-owner, i.e., it has one of the secret keys. Therefore, it is safe for all other participants to publish their messages in a shared communication directory. When the plaintext aggregator is not a key-owner, all participant MUST send their messages directly to it in a secure channel. Otherwise, the ciphertext could be decrypted by anyone who have access to the communication directory.

In [None]:
print('*** Decryption protocol ***')

create_clean_dir(COMMUNICATION_DIR_PATH)

# The IDs are known by each of the participants
alice_id = he_alice.get_he_config_requirement().get_multi_party_config().participant_id

# Alice side
decrypt_protocol_alice = DecryptProtocol(he_alice)
decrypt_protocol_alice.set_plaintext_aggregator_id(alice_id)

# Bob side
decrypt_protocol_bob = DecryptProtocol(he_bob)
decrypt_protocol_bob.set_plaintext_aggregator_id(alice_id)

# Server side
decrypt_protocol_server = DecryptProtocol(he_server)
decrypt_protocol_server.set_plaintext_aggregator_id(alice_id)

# The server also needs to load the ciphertext
decrypt_protocol_server.set_input(encrypted_model_internals)

exceptions = {}
while decrypt_protocol_alice.needs_another_round():
    th_alice = threading.Thread(target=read_messages_and_execute_round, args=(he_alice, decrypt_protocol_alice, exceptions))
    th_bob = threading.Thread(target=read_messages_and_execute_round, args=(he_bob, decrypt_protocol_bob, exceptions))
    th_server = threading.Thread(target=read_messages_and_execute_round, args=(he_server, decrypt_protocol_server, exceptions))
 
    th_alice.start()
    th_bob.start()
    th_server.start()

    th_alice.join()
    th_bob.join()
    th_server.join()

    if len(exceptions) > 0:
        raise Exception(exceptions)
        break

# Alice gets the output
decoded_model_internals = decrypt_protocol_alice.get_output_vector_double_tensor()

In [None]:
# Alice uses the decoded model internals to build a plain model.
trained_plain_model = lr_alice.get_plain_model_from_decoded_internals(decoded_model_internals)
trained_plain_model.__class__ = LogisticRegressionPlain

# Check result
trained_weights = trained_plain_model.get_weights()[0]
trained_bias = trained_plain_model.get_bias()[0]
np.testing.assert_almost_equal(phi0, trained_weights, 1)
np.testing.assert_almost_equal(phi1, trained_bias, 1)
print('*** Checking results... OK ***')

shutil.rmtree(COMMUNICATION_DIR_PATH)