# K-means Inference Using FHE

expected RAM usage: less than 1 GB  
expected runtime: less than 30 seconds.

## Introduction
K-means clustering is one of the simplest and most popular unsupervised approaches when you have unlabeled data. Typically, unsupervised algorithms make inferences from datasets using only input vectors without referring to known, or labelled, outcomes. The K-means inference performed is used to compute the nearest centroid for a set of input vectors. Now, we are able to perform K-means clustering in a fully encrypted fashion.

## Use case
One potential FHE use case using K-means is secure anomaly detection and can be applied to supply chain use cases in a multitude of industries ranging from automotive to energy to defense. Customer mandates for improved services are the catalyst behind the speed required for all processes. Further complicating the issue is the growing complexity of global supply chains, which has greatly increased transaction flows and the volume and variety of data. These drivers have in turn increased the need for more efficient, timely and automated processing to manage the volume and contain the costs. Rising costs put pressure on logistics proviers to ensure accuracy and minimize the effort to track, analyze and report at all levels. Knowing the true cost-to-serve for supply chain leaders is the basis for many supplier sourcing decisions. 

With FHE, third-party logistics (3PL) providers can securely detect anomalies in a shipment cost, volume, weight, etc. in seconds and provide visibility for quick analysis while preserving the privacy of the shipment contents. For example, generally, you would expect if the weight or volume of a shipment is high, then the associated cost would also be high. But, if the package is extremely light and has a very high cost, it could be potentially anomalous and might need to be monitored in order to keep costs contained and improve service.

We want to use FHE here because we are tracking metrics that may contain sensitive information like the price that companies are paying to the vendors, the source and destination of the shipment, or the current shipment details like volume and weight. The whole notion is based on the fact that the current shipment information is sensitive and historical information is not.

## Step 1. Client side preparations
### 1.1. Imports and some setup

In [None]:
import numpy as np
import h5py
import os

import utils

utils.verify_memory()

np.set_printoptions(threshold=6,floatmode='maxprec',precision=3)

import pyhelayers
print("misc. init ready")

### 1.2. Define the features (the number of features and centroids)

Each centroid represents a cluster. Data samples around that centroid will be labelled as the centroid. These are standard K-means parameters.

In [None]:
dims=4
numCentroids=6
centroids=np.zeros([0,dims])
for i in range(numCentroids):
    centroids=np.concatenate((centroids,np.random.randn(1,dims)*0.1+i))
print(centroids)

### 1.3. Write the centorids coordinates to a .csv file

In [None]:
data_dir = os.path.join('data', 'kmeans')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# write CSV file
f = open(os.path.join(data_dir,'model.csv'),"w")
for x in range(numCentroids):
    for y in range(dims):
        if (y>0):
            f.write(",")
        f.write(str(centroids[x,y]))
    f.write("\n")
f.close()
print("wrote csv file")

### 1.4. Load the model

We now load the model into helayers. It's still in plaintext (not yet encrypted).

In [None]:
hyper_params = pyhelayers.PlainModelHyperParams()
plain = pyhelayers.KMeansPlain()
plain.init_from_files(hyper_params, [os.path.join(data_dir,"model.csv")])
print("loaded plain model")

### 1.5. Compile the plain model

Now we take the plain model and run a process called compilation. This runs internally an Optimizer that finds the best parameters for this model. Not only does the Optimizer find the best parameters for you, but it also gives you estimations on the time it would take to predict using a single core, the precision, the memory, the time it would take to encrypt/decrypt, etc. 

The input to the compilation process are some preferences that we have. In this demo:
* We choose to optimize for the DefaultContext (SEAL)
* We choose the batch size, how many samples would you provide each time for the inference model to do the classification

This step doesn't yet encrypt the model, but prepares a 'profile' object that we can later use.

In [None]:
he_run_req = pyhelayers.HeRunRequirements()
he_run_req.set_he_context_options([pyhelayers.DefaultContext()])
he_run_req.optimize_for_batch_size(8192)

profile = pyhelayers.HeModel.compile(plain, he_run_req)

batch_size = profile.get_optimal_batch_size()
print(profile.to_string())

print("He profile ready")
print("Batch size: ",batch_size)

### 1.6. Initialize the context

Here we initialize the FHE library based on the paramaters chosen for us in the profile object.

In [None]:
client_context = pyhelayers.HeModel.create_context(profile)
print('Crypto-library ready')

### 1.7. Encrypt the resulting K-means centroids in preparation for inferencing to find the nearest cluster for a sample

Now we can encrypt our model, again using parameters chosen for us in the profile.

In [None]:

client_kmeans = pyhelayers.KMeans(client_context)
print('\rencrypting . . .\r',flush=True)
client_kmeans.encode_encrypt(plain, profile)
print('encrypted KMeans ready')

### 1.8. Provide labels to each of the data samples according to the proximity of a centroid 

In [None]:
test_size=batch_size

labels=np.random.randint(0,numCentroids,size=(test_size))

test_data=np.zeros([0,dims])
for i in range(test_size):
    test_data=np.concatenate((test_data,np.random.randn(1,dims)*0.1+labels[i]))
    
print(test_data)
print('labels',labels)

### 1.9. Encrypt the data samples

To encrypt the data we first create an io processor (iop for short).
The iop object is a lightweight object that knows the model's metadata and can be used to encrypt data for it, and later decrypt the output it sends.

In [None]:
iop=client_kmeans.create_io_processor()
client_samples = pyhelayers.EncryptedData(client_context)
iop.encode_encrypt_inputs_for_predict(client_samples, [test_data])
print('Batch encrypted')

### 1.10. Save and send
We save the encrypted model, the context, and the samples in preparation for sending them to the server

In [None]:
kmeans_buffer = client_kmeans.save_to_buffer()
samples_buffer = client_samples.save_to_buffer()
context_buffer = client_context.save_to_buffer() # with no secret key
print('Context, model, and samples saved')

## Step 2. Server side
### 2.1. Load data
We first load all the data sent from the client

In [None]:
server_context = pyhelayers.load_he_context(context_buffer)
server_kmeans = pyhelayers.load_he_model(server_context,kmeans_buffer)
server_samples = pyhelayers.load_encrypted_data(server_context,samples_buffer)
print('server ready')

### 2.1. Run prediction

With the inputs and centroids both encrypted, we find the distance between each input and each centroid.

The results are saved to a buffer and sent back to the client.

In [None]:
utils.start_timer()

server_predictions = pyhelayers.EncryptedData(server_context)
server_kmeans.predict(server_predictions, server_samples)

duration=utils.end_timer('predict')
utils.report_duration('predict per sample',duration/test_size)

predictions_buffer = server_predictions.save_to_buffer()
print('predictions saved')

## Step 3. Assess results on the client side

We first load the data and decrypt it, again using the 'iop'.

Then we compare it with the ground truth labels

In [None]:
client_predictions = pyhelayers.load_encrypted_data(client_context,predictions_buffer)
print('predictions loaded')

# Decrypting results
plain_predictions = iop.decrypt_decode_output(client_predictions)

print('HE predictions:',plain_predictions)
print('True labels:',labels)
allOk=(plain_predictions==labels).all()
if (allOk):
    print('All predictions match')
else:
    raise Exception("mismatching labels. Demo failed")