# Private Set Intersection for Vertical Federated Learning using FHE
Expected RAM usage: 38 GB  
Expected runtime: 10-20 minutes  

## Introduction
This example demonstrates a Private Set Intersection (PSI) protocol between three parties using a fourth party, called the Aggregator, to be used for Vertical Federated Learning (VFL). The three parties have their database with private samples. Each sample contains a unique identifier (UID) and a number of features, each is a real number in the range [0,1].
 
The objective is that each party will get by the end of the protocol a CTileTensor encrypted under the aggregator's key that contains only the samples that are in the intersection (i.e., samples that appear at the dataset of each of the parties), without disclosing any information about the other parties' samples or about the size of the intersection. Specifically, each party should not learn whether a specific sample is in the intersection or not, or even the size of the intersection. The CTileTensor then can be further used by each of the parties in the federated learning algorithm.

## Step 1. Framework Setup

### 1.1. Start with some imports

In [None]:
import pyhelayers
import numpy as np
import pandas as pd
import random
import utils 

### 1.2. Parties setup
In our framework, the parties need to have a 128-bit shared secret that is hidden from the aggregator. The method by which this secret is generated and distributed between the parties is out of the scope of this demo. This secret will be used to mask the intermediate calculations from the aggregator.

Furthermore, each party should have a public unique ID

In [None]:
# A 128-bit secret that is shared between all the parties and is hidden from the aggregator.
parties_shared_secret = [12345,67890]

num_parties = 3
list_parties = ['party-{}'.format(i+1) for i in range(num_parties)]

# Each party should have a public unique ID
dict_party_rts_id = {v: i+1 for i, v in enumerate(list_parties)}
print(dict_party_rts_id)

### 1.3. Load dataset
This demo uses three fabricated dataset datasets of 5 samples and 2 features each, indexed by unique ID (unsigned int) for each sample. The datasets share common samples (i.e., some IDs appear in more than one dataset), but have different features. This setting corresponds to the Vertical Federated Learning (VFL) model. The combination of the 3 datasets can be viewed as a single virtual dataset of of 10 samples and 6 features.

In [None]:
num_samples = 10
num_features = 6
samples_uids = random.sample(range(2**32), num_samples)
samples_data = np.random.rand(num_samples, num_features)
data = pd.DataFrame(samples_data, index=samples_uids, columns=['feature-{}'.format(i+1) for i in range(num_features)])
data.index.rename('UIDs', inplace=True)
print('samples-{}, features-{}'.format(num_samples, num_features))
data

### 1.4. Split the dataset vertically and assign it to the parties
The dataset is partitioned vertically between the parties, while each party gets a different 2 features of the dataset. Also, to make the PSI non-trivial, we give each party a random 5 samples out of the 10 samples in the dataset.

In [None]:
num_sample_each_party = 5

dict_party_data = {}
for party_id in range(num_parties):
    party_features = num_features//num_parties
    if party_id == num_parties-1:
        party_features = num_features - (num_features//num_parties) * (num_parties-1)

    start_idx = (num_features//num_parties) * party_id
    p_data = data.iloc[:, start_idx: start_idx + party_features]

    p_data = p_data.sample(n=num_sample_each_party)
                
    dict_party_data['party-{}'.format(party_id + 1)] = {
        'uids': p_data.index.values.tolist(),
        'data': p_data
    }

for p, v in dict_party_data.items():
    print('{}-dataset'.format(p))
    print(v['data'], '\n')

### 1.5. Aggregator setup
The aggregator creates Helayers context to be used in the Private Set Intersection protocol and in the Federated Learning algorithm. This context can be of different types, but it can also use the same context for both tasks.
 
Then, it communicates with the parties and sends them the context object (The means by which it does it are out of the scope of this demo). In this demo we will just use the same context object in the parties' code).

Finally, the aggregator initialize a AggregatorPsiManager object to be used later.

In [None]:
he_context = pyhelayers.HeaanContext()
requirements = pyhelayers.HeConfigRequirement(
    num_slots = 2 ** 15,
    multiplication_depth = 9,
    fractional_part_precision = 48,
    integer_part_precision = 12,
    security_level = 128)
requirements.bootstrappable = True

he_context.init(requirements)
he_context.set_automatic_bootstrapping(True)

print(he_context.print_signature())

aggregator = pyhelayers.AggregatorPsiManager(he_context, he_context)

## Step 2 - Run the PSI protocol

### 2.1. Create encrypted hash tables
Each party initializes a RtsPsiManager object with the Helayers context the aggregator sent, with its UIDs and data, and with the shared secret the parties agreed about.

Then, they run the first step of the PSI protocol, which is inserting their UIDs into an encrypted hash table. 

Finally, they send the serialized encrypted hash tables to the other parties, and send a mapping between the original dataset indices to the hash table indices to the aggregator.

In [None]:
dict_party_psi = {}
dict_party_hash = {}
dict_party_mapping = {}


for party in list_parties:
    psi_manager = pyhelayers.RtsPsiManager(
        he_context, he_context, 
        dict_party_rts_id[party], 
        dict_party_data[party]['uids'], 
        dict_party_data[party]['data'], 
        parties_shared_secret)

    with utils.elapsed_timer('{} generate its encrypted hash table'.format(party), 1) as timer:
        dict_party_hash[party] = psi_manager.insert_to_hash().save_to_buffer()

    dict_party_mapping[party] = psi_manager.get_uids_mapping()

    dict_party_psi[party] = psi_manager

### 2.2. Generate indicator vectors
Each party receives the serialized data from the other parties and deserialize it. Then, it proceeds to generate an encrypted indicator vector that indicates which samples are in the intersection, and will be used later in the protocol.

Finally, it sends the serialized result back to the party that sent the original hash table.

In [None]:
dict_party_indicators = {}

# each party generates indicator for the rest parties
for party in list_parties:
    indicators = {}
    psi_manager = dict_party_psi[party]
    for other_party in list_parties:
        if other_party != party:
            rts_id = dict_party_rts_id[other_party]
            hash_table = pyhelayers.CTileTensor(he_context)
            hash_table.load_from_buffer(dict_party_hash[other_party])

            with utils.elapsed_timer('{} generate the indicator vector for {}'.format(party, other_party), 1) as timer:
                indicators[other_party] = psi_manager.generate_indicator_vector(rts_id, hash_table).save_to_buffer()
            
    dict_party_indicators[party] = indicators

### 2.3. The aggregator rearranges the indicator vectors
The aggregator uses the mapping sent by Alice to rearrange the indicator vector, so that the order of the indicators will be the same as the original order of the samples

In [None]:
dict_party_rearranged_indicators = {}

for party in list_parties:
    mapping_party = dict_party_mapping[party]
    rearranged_indicators = {}
    for other_party in list_parties:
        if other_party != party:
            indicator = pyhelayers.CTileTensor(he_context)
            indicator.load_from_buffer(dict_party_indicators[other_party][party])

            with utils.elapsed_timer('aggregator rearrange indicator vector for {}'.format(party), 1) as timer:
                rearranged_indicators[other_party] = aggregator.rearrange_indicator_vector(indicator, mapping_party).save_to_buffer()

    dict_party_rearranged_indicators[party] = rearranged_indicators


### 2.3. Parties compact local datasets
Finally, each party receives the encrypted rearranged indicators vector from the aggregator. they use it to privately sort their data, such that the first rows are records that are in the intersection, and the rest of the rows are encryptions of 0s (note that the relative order of the samples that are in the intersection is the same as the relative order of them in the DoubleTensor given to the c'tor). The resulted CTileTensor will be then used in the learning algorithm.

In [None]:
dict_party_result = {}

for party in list_parties:
    psi_manager = dict_party_psi[party]
    lst_rearranged_indicator = []
    lst_rts_id = []
    for other_party in list_parties:
        if other_party != party:
            rearranged_indicator = pyhelayers.CTileTensor(he_context)
            rearranged_indicator.load_from_buffer(dict_party_rearranged_indicators[party][other_party])
            lst_rearranged_indicator.append(rearranged_indicator)
            lst_rts_id.append(dict_party_rts_id[other_party])
    indicator_party = pyhelayers.CTileTensorVector(lst_rearranged_indicator)
    final_indicator_party = psi_manager.multiply_indicator_vectors(lst_rts_id, indicator_party)

    with utils.elapsed_timer('{} local dataset compaction'.format(party), 1) as timer:
        dict_party_result[party] = psi_manager.compaction(final_indicator_party)

## Final results

In [None]:
from functools import reduce

expected = reduce(lambda  left,right: pd.merge(left,right,on=['UIDs'],
                                            how='inner'), [dict_party_data['party-{}'.format(party_id + 1)]['data'] for party_id in range(num_parties)])

tt_encoder = pyhelayers.TTEncoder(he_context)

for party in list_parties:
    psi_res = tt_encoder.decrypt_decode_double(dict_party_result[party])
    expected_res = expected[dict_party_data[party]['data'].columns].reset_index(drop=True).reindex(range(num_sample_each_party)).fillna(0)

    print('{}-psi-result'.format(party))
    print(pd.DataFrame(psi_res).to_string(header=False, index=False))
    print('\n{}-expected-result'.format(party))
    print(expected_res.to_string(header=False, index=False))
    print('\n')

    assert np.abs(np.subtract(psi_res, expected_res.to_numpy())).max() < 1e-4

In [None]:
print("RAM usage:", utils.get_used_ram(), "MB")