In [136]:
import time
import numpy as np
from collections import Counter
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_openml
from pyope.ope import OPE, ValueRange
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from umbral import encrypt, decrypt_reencrypted, reencrypt, generate_kfrags
from umbral.keys import SecretKey, PublicKey
from umbral.signing import Signer
from collections import Counter
from umbral.capsule_frag import VerifiedCapsuleFrag

### Load Encryption Keys from File

This block loads all necessary cryptographic keys from a `keys.json` file:

- **OPE Keys**: One key per tree, used for encrypting feature thresholds.
- **PRE Keys**: A unique public/private key pair per tree, used for encrypting leaf labels.
- **Voting Key**: A global key pair used to decrypt re-encrypted labels during majority voting.

In [None]:
# Load keys from JSON file
with open("keys.json", "r") as f:
    key_data = json.load(f)

# Define OPE encryption ranges
in_range = ValueRange(0, 2550)
out_range = ValueRange(0, 2**32 - 1)

# Number of decision trees (and thus keys)
num_estimators = 15

# Define index range for slicing keys
start_idx = 0
end_idx = start_idx + num_estimators

# Load OPE keys for each decision tree
ope_keys = [OPE(key.encode(), in_range=in_range, out_range=out_range) 
            for key in key_data["ope_keys"][start_idx:end_idx]]

# Load PRE secret and public keys for each tree
pre_secret_keys = [
    SecretKey.from_bytes(bytes.fromhex(k)) 
    for k in key_data["pre_keys"]["private_keys"][start_idx:end_idx]
]
pre_public_keys = [
    PublicKey.from_bytes(bytes.fromhex(k)) 
    for k in key_data["pre_keys"]["public_keys"][start_idx:end_idx]
]

# Load voting key pair used for final label decryption
vote_secret_key = SecretKey.from_bytes(bytes.fromhex(key_data["vote_key"]["private"]))
vote_public_key = PublicKey.from_bytes(bytes.fromhex(key_data["vote_key"]["public"]))


### Dataset Loading

This section loads the MNIST dataset using `fetch_openml`. The dataset contains 70,000 grayscale images of handwritten digits (28x28 pixels), each labeled from 0 to 9.

The loading time is measured for performance evaluation.

In [138]:
start_total_time = time.perf_counter()  # Start total execution time

#Load Dataset Used for Testing and Training
start_dataset_load_time = time.perf_counter()

mnist = fetch_openml("mnist_784", version=1, as_frame=False)
X, y = mnist.data.astype("float32"), mnist.target.astype("int")

end_dataset_load_time = time.perf_counter()

dataset_load_time = end_dataset_load_time - start_dataset_load_time
print(f"Dataset Loading Time: {dataset_load_time:.4f} seconds")


Dataset Loading Time: 4.4111 seconds


### Dataset Scaling

This section normalizes the dataset to ensure all pixel values fall within the required encryption range. 
The raw MNIST images are first scaled to the [0, 1] range, then rescaled to the [0, 2550] integer range to match the input range expected by the OPE encryption scheme.


In [139]:
# Normalize pixel values to [0, 1]
X = X / 255.0

if X.max() <= 1:
    # Rescale dataset from original range to [0, 255]
    X = (X - X.min()) / (X.max() - X.min()) * 255
    X = (X * 10).astype(int)  # Scale to 0–2550
else:
    X = X.astype(int)

### Dataset Splitting

The dataset is split into training and testing sets. The total number of samples in each set is recorded. 
These values are used to control the scope of encryption and evaluation throughout the experiment.

In [140]:
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

#Get the number of samples to be encrypted
num_samples_training = len(X_train)
num_samples_testing = len(X_test)

X_test = X_test
X_train = X_train

print(f"Number of training samples: {num_samples_training}")
print(f"Number of testing samples: {num_samples_testing}")

Number of training samples: 56000
Number of testing samples: 14000


### Encrypt Dataset Using Order-Preserving Encryption (OPE)

This block defines a function to encrypt the input dataset using OPE:

- Each tree receives its own version of the dataset, encrypted using a unique OPE key.
- Pixel values are scaled (e.g., by 10×) to fit the expected OPE range.
- Returns a list of encrypted datasets (one for each decision tree).


In [None]:
# Define scaling factor for input pixel values
scale_factor = 10
max_pixel_value = 255 * scale_factor  # Maximum value after scaling (used in OPE in_range)

# Function to encrypt the dataset using OPE (one encrypted version per tree)
def encrypt_dataset_with_ope(X, ope_keys):
    encrypted_versions = []

    for idx, ope in enumerate(ope_keys):
        print(f"Encrypting dataset with OPE key {idx + 1}")

        # Encrypt each pixel in each row using the current OPE key
        encrypted_X = np.array([
            [ope.encrypt(int(val)) for val in row]
            for row in tqdm(X, desc=f"Encrypting for Tree {idx + 1}")
        ])

        encrypted_versions.append(encrypted_X)

    # Return a list of encrypted datasets (one per tree)
    return encrypted_versions


### Encrypt Test Dataset Using OPE (with Caching)

This block performs encryption of the test dataset using the per-tree OPE keys:

- Checks if encrypted test data already exists in cache.
- If cached, loads the file; otherwise, encrypts and saves it.
- Measures encryption time for each tree's dataset.
- Outputs total encryption time for reporting or performance analysis.


In [None]:
# Ensure the folder for encrypted test data exists
os.makedirs("Encrypted_Dataset", exist_ok=True)

# Start timing the entire test data encryption phase
start_test_data_encryption_time = time.perf_counter()

print(f"X_test size: {X_test.shape}")

X_test_encrypted_per_tree = []  # Will hold encrypted versions for each tree
encryption_times = []           # Track individual encryption times

# Encrypt or load encrypted test data for each tree
for idx, ope in enumerate(ope_keys):
    key_name = key_data["ope_keys"][idx]  # Raw key string for naming
    file_path = f"Encrypted_Dataset/X_test_encrypted_{key_name}.npy"

    if os.path.exists(file_path):
        print(f"Loaded cached encrypted data for tree with key {key_name}")
        encrypted_X = np.load(file_path)
        encryption_times.append(0)
    else:
        print(f"Encrypting test data for tree {idx + 1} (key: {key_name[:8]}...)")
        start = time.perf_counter()
        
        # Encrypt each test sample using the current OPE key
        encrypted_X = np.array([
            [ope.encrypt(int(val)) for val in row]
            for row in tqdm(X_test, desc=f"Encrypting for Tree {idx + 1}")
        ])
        
        end = time.perf_counter()
        np.save(file_path, encrypted_X)
        print(f"Saved encrypted test data to {file_path}")
        encryption_times.append(end - start)

    X_test_encrypted_per_tree.append(encrypted_X)

end_test_data_encryption_time = time.perf_counter()

# Determine total encryption time
if all(t == 0 for t in encryption_times):
    # All data was loaded from cache
    test_data_encryption_time = end_test_data_encryption_time - start_test_data_encryption_time
else:
    # Sum actual encryption durations
    test_data_encryption_time = sum(encryption_times)

print(f"Total Dataset Encryption Time: {test_data_encryption_time:.4f} seconds")


X_test size: (14000, 784)
Loaded cached encrypted data for tree with key 5f5e8a3fe2ba8ef9979983ebad66c7d7
Loaded cached encrypted data for tree with key 456f782dbde337bd96fb673941f9335a
Loaded cached encrypted data for tree with key 216770f64070844fddf2eadd8aa7b100
Loaded cached encrypted data for tree with key 6ce2504f1c23915e0062045b639b2c7f
Loaded cached encrypted data for tree with key c456d38cab06f85268dfa40620da8eee
Loaded cached encrypted data for tree with key e5af9c5870aeb77a3dd360340d768569
Loaded cached encrypted data for tree with key f161e54ba25699c70be43cb547b1b198
Loaded cached encrypted data for tree with key b29671a1bab218eb7d364a3831cc2b82
Loaded cached encrypted data for tree with key fcc93a3653c266ef4896d2c82e45310a
Loaded cached encrypted data for tree with key faff2ba978176f936b83c6c040388a62
Loaded cached encrypted data for tree with key eb760519812f42f551f02574521fb9e2
Loaded cached encrypted data for tree with key 50b2d950f6c5f4c092b1d6a14dde223a
Loaded cached 

### Train Random Forest Classifier

This block trains a Random Forest classifier on the plaintext (unencrypted) training dataset.

- Uses `n_estimators` to match the number of OPE/PRE key pairs.
- Sets a fixed `max_depth` and `random_state` for reproducibility.
- Measures the training time for benchmarking against the encrypted inference time.


In [None]:
# Measure training time for the Random Forest model
start_training_time = time.perf_counter()

# Train a Random Forest model with specified number of trees
clf_ope = RandomForestClassifier(
    n_estimators=num_estimators,
    max_depth=20,
    random_state=42,
    min_samples_split=2
)
clf_ope.fit(X_train, y_train)

end_training_time = time.perf_counter()
training_time = end_training_time - start_training_time

print(f"Random Forest Training Time: {training_time:.4f} seconds")


Random Forest Training Time: 5.6395 seconds


### Encrypt Leaf Labels with Proxy Re-Encryption (PRE)

This function encrypts the predicted class labels found at each leaf node of a decision tree using the corresponding tree's PRE public key.

- Only the leaf nodes are encrypted.
- The function supports optional limiting via `max_labels_per_tree`.
- Returns encrypted labels, their associated capsules, and the total encryption time.


In [None]:
# Encrypt leaf node labels for all decision trees using PRE
def encrypt_tree_labels_with_pre(clf, pre_public_keys, max_labels_per_tree=None):
    start_label_encryption_time = time.perf_counter()

    encrypted_leaf_values = []  # Holds encrypted class labels for each tree
    leaf_capsules = []          # Holds encryption capsules for each tree

    print("Encrypting leaf labels using PRE...")
    for idx, tree in enumerate(tqdm(clf.estimators_, desc="Encrypting Trees")):
        tree_values = {}
        tree_capsules = {}

        # Identify all leaf nodes in the tree (feature index -2 indicates a leaf)
        leaf_nodes = [node for node in range(tree.tree_.node_count) if tree.tree_.feature[node] == -2]

        # Optional: limit number of encrypted labels per tree
        if max_labels_per_tree is not None:
            leaf_nodes = leaf_nodes[:max_labels_per_tree]

        # Encrypt the predicted class label at each leaf node
        for node in tqdm(leaf_nodes, leave=False, desc=f"Tree {idx+1}"):
            label = str(np.argmax(tree.tree_.value[node][0]))
            capsule, ciphertext = encrypt(pre_public_keys[idx], label.encode())

            tree_values[node] = ciphertext
            tree_capsules[node] = capsule

        encrypted_leaf_values.append(tree_values)
        leaf_capsules.append(tree_capsules)

    end_label_encryption_time = time.perf_counter()
    label_encryption_time = end_label_encryption_time - start_label_encryption_time

    print(f"\nPRE Label Encryption done for {len(clf.estimators_)} trees.")
    print(f"Total PRE Encryption Time: {label_encryption_time:.4f} seconds")
    return encrypted_leaf_values, leaf_capsules, label_encryption_time


### Executing Leaf Label Encryption with PRE

The `encrypt_tree_labels_with_pre` function is called to encrypt all leaf node class labels across the trained Random Forest.
Each label is encrypted using the public key assigned to its corresponding tree. 

The encrypted labels (`encrypted_leaf_values`) and associated PRE capsules (`leaf_capsules`) are stored for use during secure classification and majority voting.

In [None]:
# Encrypt the leaf labels of all decision trees using PRE public keys
encrypted_leaf_values, leaf_capsules, label_encryption_time = encrypt_tree_labels_with_pre(
    clf_ope,
    pre_public_keys
)


Encrypting leaf labels using PRE...


Encrypting Trees: 100%|██████████| 15/15 [02:06<00:00,  8.42s/it]



PRE Label Encryption done for 15 trees.
Total PRE Encryption Time: 126.3462 seconds


### Encrypt Tree Thresholds Using OPE (with Caching)

This block encrypts all internal split thresholds for each decision tree using their respective OPE keys:

- If encrypted thresholds already exist, they are loaded from disk.
- If not, thresholds are encrypted and cached.
- A value of `-2` indicates a leaf node and is skipped (set to `None`).
- The total encryption time is printed for performance analysis.


In [None]:
# Ensure output directory exists
os.makedirs("Encrypted_Thresholds", exist_ok=True)

start_threshold_encryption_time = time.perf_counter()
encrypted_thresholds = []

# Encrypt thresholds for each decision tree using its OPE key
for idx, (tree, ope) in enumerate(zip(clf_ope.estimators_, ope_keys)):
    key_name = key_data["ope_keys"][idx]  # Key string used for caching file name
    file_path = f"Encrypted_Thresholds/encrypted_thresholds_{key_name}.npy"

    if os.path.exists(file_path):
        # Load from cache if already encrypted
        thresholds = np.load(file_path, allow_pickle=True)
        encrypted_thresholds.append(thresholds.tolist())
        print(f"Loaded cached thresholds for key {key_name}")
    else:
        print(f"Encrypting thresholds for key {key_name}")
        tree_thresholds = []

        for th in tree.tree_.threshold:
            if th != -2:
                tree_thresholds.append(ope.encrypt(int(th)))  # Encrypt valid threshold
            else:
                tree_thresholds.append(None)  # Skip leaf indicators

        encrypted_thresholds.append(tree_thresholds)
        np.save(file_path, tree_thresholds)
        print(f"Saved thresholds to {file_path}")

end_threshold_encryption_time = time.perf_counter()
threshold_encryption_time = end_threshold_encryption_time - start_threshold_encryption_time

print(f"Encrypted Thresholds for {len(encrypted_thresholds)} trees")
print(f"Threshold Encryption Time: {threshold_encryption_time:.4f} seconds")


Loaded cached thresholds for key 5f5e8a3fe2ba8ef9979983ebad66c7d7
Loaded cached thresholds for key 456f782dbde337bd96fb673941f9335a
Loaded cached thresholds for key 216770f64070844fddf2eadd8aa7b100
Loaded cached thresholds for key 6ce2504f1c23915e0062045b639b2c7f
Loaded cached thresholds for key c456d38cab06f85268dfa40620da8eee
Loaded cached thresholds for key e5af9c5870aeb77a3dd360340d768569
Loaded cached thresholds for key f161e54ba25699c70be43cb547b1b198
Loaded cached thresholds for key b29671a1bab218eb7d364a3831cc2b82
Loaded cached thresholds for key fcc93a3653c266ef4896d2c82e45310a
Loaded cached thresholds for key faff2ba978176f936b83c6c040388a62
Loaded cached thresholds for key eb760519812f42f551f02574521fb9e2
Loaded cached thresholds for key 50b2d950f6c5f4c092b1d6a14dde223a
Loaded cached thresholds for key 04060b42032e8fe29ad1b90964489c16
Loaded cached thresholds for key c58b1e17778a8f52c36710885394ac9e
Loaded cached thresholds for key 4c3a9c38aba1fcc0c847487c7789092c
Encrypted 

### Encrypt Full Dataset Per Tree Using OPE

This section defines two functions:

- `encrypt_image`: Encrypts a single image using a given OPE key.
- `encrypt_dataset_with_ope`: Encrypts the full dataset for each tree using its specific OPE key.

For each image, the encryption time is measured and printed. This allows for detailed per-image performance analysis.


In [None]:
# Function to encrypt a single image using a given OPE key
def encrypt_image(image, ope_key):
    return [ope_key.encrypt(int(pixel)) for pixel in image]

# Function to encrypt the full dataset separately for each OPE key
def encrypt_dataset_with_ope(X, ope_keys):
    encrypted_versions = []

    for idx, ope_key in enumerate(ope_keys):
        print(f"\nEncrypting dataset with OPE key {idx + 1}")
        encrypted_X = []

        for i, image in enumerate(tqdm(X, desc=f"Encrypting for Tree {idx + 1}"), start=1):
            # Encrypt a single image and time the process
            start_time = time.time()
            encrypted_image = encrypt_image(image, ope_key)
            encryption_time = time.time() - start_time

            print(f"{i}: Image Encryption Time: {encryption_time:.4f} sec")
            encrypted_X.append(encrypted_image)

        print(f"Number of encrypted images for tree {idx + 1}: {len(encrypted_X)}")
        encrypted_versions.append(np.array(encrypted_X))

    # Return a list of encrypted datasets (one per tree)
    return encrypted_versions


### PRE Decryption Function

This function decrypts a Proxy Re-Encrypted (PRE) ciphertext:

- Accepts a re-encrypted ciphertext, capsule, verified kfrags, and the voting secret key.
- Uses the NuCypher `decrypt_reencrypted` method.
- Measures and returns both the decrypted value and the time taken for the decryption.

In [None]:
# Decrypt a re-encrypted ciphertext using Proxy Re-Encryption (PRE)
def decrypt_pre(ciphertext, capsule, kfrags, vote_secret_key, delegating_pk):
    """
    Decrypts a PRE re-encrypted ciphertext and returns both the result and time taken.
    """
    start = time.perf_counter()

    # Extract raw capsule fragments from verified fragments
    cfrags = [vcfrag.cfrag for vcfrag in kfrags]

    # Perform re-encryption decryption
    decrypted = decrypt_reencrypted(
        receiving_sk=vote_secret_key,    # Final recipient's secret key
        delegating_pk=delegating_pk,     # Original public key that encrypted the label
        capsule=capsule,                 # Capsule associated with the ciphertext
        verified_cfrags=cfrags,          # Verified fragments enabling decryption
        ciphertext=ciphertext            # Encrypted label
    )

    end = time.perf_counter()
    decryption_time = end - start

    return decrypted, decryption_time


### Secure Majority Voting Using PRE

This function performs secure majority voting over encrypted class labels:

- Each encrypted vote is re-encrypted using a corresponding `kfrag`.
- The result is decrypted using the voting secret key and the delegating public key.
- The most common decrypted label is selected as the final classification.
- The function also tracks the total time spent on PRE decryption across all votes.


In [None]:
# Perform secure majority voting using PRE-decrypted labels
def majority_voting(encrypted_votes, capsules, tree_public_keys, kfrags_list, vote_secret_key):
    """
    Perform secure majority voting using PRE-decrypted labels.
    Returns the most common label and total decryption time.
    """
    decrypted_votes = []
    total_decryption_time = 0

    for i in range(len(encrypted_votes)):
        ct = encrypted_votes[i]
        cap = capsules[i]
        kfrag = kfrags_list[i][0]  # Use the first key fragment for this tree

        # Re-encrypt the capsule using PRE
        cfrag = reencrypt(cap, kfrag)
        verified_cfrag = VerifiedCapsuleFrag(cfrag)

        # Decrypt using the voting key and the tree's public key
        decrypted, dec_time = decrypt_pre(
            ct, cap, [verified_cfrag], vote_secret_key, tree_public_keys[i]
        )

        # Decode and store the decrypted label
        decrypted_votes.append(decrypted.decode())
        total_decryption_time += dec_time

    # Perform majority voting over the decrypted labels
    majority_label = Counter(decrypted_votes).most_common(1)[0][0]
    return majority_label, total_decryption_time


### Secure Classification of a Single Sample Using PRE

This function performs secure classification for a single encrypted input sample:

- Each decision tree is traversed using OPE-encrypted features and thresholds.
- The corresponding encrypted label at the reached leaf node is retrieved.
- Proxy Re-Encryption (PRE) is used to securely decrypt all leaf labels for voting.
- Measures time taken for traversal, label access, and majority voting.


In [None]:
# Secure classification for one encrypted sample using all trees
def secure_classify(model, encrypted_X_per_tree, encrypted_thresholds, encrypted_leaf_values,
                    leaf_capsules, tree_public_keys, kfrags_list, vote_secret_key):
    """
    Secure classification with PRE-based leaf encryption and voting.
    Timed version: measures traversal, label access, and voting time.
    """
    encrypted_votes = []  # Encrypted class labels from each tree
    capsules = []          # Capsules for corresponding encrypted labels

    traversal_time = 0
    label_access_time = 0
    voting_time = 0

    for tree_idx, tree in enumerate(model.estimators_):
        # Traverse the decision tree using encrypted feature thresholds
        start_traversal = time.perf_counter()
        node = 0
        tree_thresholds = encrypted_thresholds[tree_idx]
        encrypted_X = encrypted_X_per_tree[tree_idx]

        while tree.tree_.feature[node] != -2:
            feature_idx = tree.tree_.feature[node]
            encrypted_threshold = tree_thresholds[node]

            # Compare encrypted feature with encrypted threshold
            if encrypted_X[feature_idx] < encrypted_threshold:
                node = tree.tree_.children_left[node]
            else:
                node = tree.tree_.children_right[node]

        traversal_time += time.perf_counter() - start_traversal

        # Retrieve encrypted label and capsule from leaf node
        start_label = time.perf_counter()
        encrypted_label = encrypted_leaf_values[tree_idx][node]
        capsule = leaf_capsules[tree_idx][node]
        label_access_time += time.perf_counter() - start_label

        encrypted_votes.append(encrypted_label)
        capsules.append(capsule)

    # Perform secure majority voting using PRE
    start_voting = time.perf_counter()
    result, decryption_time = majority_voting(
        encrypted_votes,
        capsules,
        tree_public_keys,
        kfrags_list,
        vote_secret_key
    )
    voting_time += time.perf_counter() - start_voting

    return result, traversal_time, label_access_time, voting_time, decryption_time


### Secure Classification Over Full Encrypted Dataset

This function runs secure classification over the entire encrypted test set using PRE:

- Iterates through each sample and classifies it using `secure_classify`.
- Aggregates timing metrics across:
  - Tree traversal
  - Label access
  - Voting
  - PRE decryption
- Returns both the predictions and a detailed timing report.


In [None]:
# Classify the entire encrypted dataset securely using PRE voting
def secure_classify_dataset(model, X_encrypted_per_tree, encrypted_thresholds, encrypted_leaf_values,
                             leaf_capsules, tree_public_keys, kfrags_list, vote_secret_key):
    """
    Classify an encrypted dataset using PRE-encrypted labels and secure voting.
    Returns predictions and detailed timing information.
    """
    num_samples = len(X_encrypted_per_tree[0])
    predictions = []

    total_traversal = 0
    total_label_access = 0
    total_voting = 0
    total_decryption = 0
    total_time = 0

    start_total = time.perf_counter()

    for sample_idx in tqdm(range(num_samples), desc="Classifying Encrypted Test Samples"):
        # Extract the encrypted sample across all trees
        sample_per_tree = [X_encrypted_per_tree[tree_idx][sample_idx] for tree_idx in range(len(model.estimators_))]

        start = time.perf_counter()

        # Securely classify the current encrypted sample
        pred, t_traversal, t_label, t_vote, t_decrypt = secure_classify(
            model=model,
            encrypted_X_per_tree=sample_per_tree,
            encrypted_thresholds=encrypted_thresholds,
            encrypted_leaf_values=encrypted_leaf_values,
            leaf_capsules=leaf_capsules,
            tree_public_keys=tree_public_keys,
            kfrags_list=kfrags_list,
            vote_secret_key=vote_secret_key
        )

        end = time.perf_counter()

        # Store prediction and accumulate timing stats
        predictions.append(pred)
        total_traversal += t_traversal
        total_label_access += t_label
        total_voting += t_vote
        total_decryption += t_decrypt
        total_time += end - start

    end_total = time.perf_counter()
    overall_time = end_total - start_total

    # Prepare timing report
    timings = {
        "samples": num_samples,
        "classification_time": total_time,
        "overall_time": overall_time,
        "avg_per_sample": total_time / num_samples,
        "traversal_time": total_traversal,
        "label_access_time": total_label_access,
        "voting_time": total_voting,
        "decryption_time": total_decryption
    }

    return np.array(predictions), timings


### Generate kfrags and Perform Secure Classification

This section prepares and executes the full secure classification pipeline:

- **kfrags** are generated for each tree using its PRE private key and the global voting key.
- **Public keys** are listed for re-encryption reference.
- `secure_classify_dataset` is then called to classify the encrypted test set.
- The total classification time and predictions are printed.


In [152]:

# Generate kfrags for each tree
signers = [Signer(sk) for sk in pre_secret_keys]

kfrags_list = [
    generate_kfrags(delegating_sk=pre_secret_keys[i], 
                    receiving_pk=vote_public_key, 
                    signer=signers[i], 
                    threshold=1, 
                    shares=1, 
                    sign_delegating_key=True, 
                    sign_receiving_key=True)
    for i in range(len(pre_secret_keys))
]

start_time = time.time()
print("Performing Secure Classification with PRE...")

# Ensure the correct public key is passed for re-encryption
tree_public_keys_points = [key for key in pre_public_keys]

y_pred_encrypted, timing_report = secure_classify_dataset(
    model=clf_ope,
    X_encrypted_per_tree=X_test_encrypted_per_tree,
    encrypted_thresholds=encrypted_thresholds,
    encrypted_leaf_values=encrypted_leaf_values,
    leaf_capsules=leaf_capsules,
    tree_public_keys=tree_public_keys_points,  
    kfrags_list=kfrags_list,  
    vote_secret_key=vote_secret_key 
)


print("Predictions:", y_pred_encrypted)
print("True Labels:", y_test[:num_samples_testing])

classification_time = time.time() - start_time
print(f"Secure Classification Time: {classification_time:.4f} seconds")


Performing Secure Classification with PRE...


Classifying Encrypted Test Samples: 100%|██████████| 14000/14000 [22:33<00:00, 10.34it/s]

Predictions: ['8' '4' '8' ... '2' '7' '1']
True Labels: [8 4 8 ... 2 7 1]
Secure Classification Time: 1353.6973 seconds





### Accuracy Evaluation

This section evaluates the performance of the secure Random Forest classifier.

- A sanity check ensures that the number of predicted labels matches the number of ground truth labels.
- Predictions are converted to integers (as decrypted labels are initially strings).
- Accuracy is computed using `sklearn`’s `accuracy_score` to quantify how well the encrypted classification pipeline performs compared to plaintext ground truth.

This metric validates the correctness and practicality of the privacy-preserving classification approach.

In [153]:
# Sanity check: make sure prediction length matches test labels
print("Length of predictions:", len(y_pred_encrypted))
print("Length of ground truth:", len(y_test[:num_samples_testing]))

if len(y_pred_encrypted) == len(y_test[:num_samples_testing]):
    # Convert to integers for comparison and accuracy calculation
    y_pred_encrypted = [int(p) for p in y_pred_encrypted]

    # Calculate accuracy
    secure_accuracy = accuracy_score(y_test[:num_samples_testing], y_pred_encrypted)
    print(f"Secure Random Forest Accuracy on Encrypted Dataset: {(secure_accuracy):.2f}")
else:
    print("Error: Prediction length does not match test labels. Cannot compute accuracy.")


Length of predictions: 14000
Length of ground truth: 14000
Secure Random Forest Accuracy on Encrypted Dataset: 0.95


### Execution Time Analysis and Throughput Reporting

This block generates a full performance report for the secure classification pipeline:

- Recalculates the total execution time using all major pipeline components.
- Computes percentage contributions from each stage (loading, encryption, training, classification).
- Breaks down secure classification timing using metrics from `timing_report`.
- Computes throughput in terms of samples per second for:
  - Test data encryption + classification
  - Full secure pipeline (label + threshold encryption included)
- Prints final accuracy and key usage summary for secure inference.

This block is ideal for performance evaluation, reporting, and comparative benchmarking.

In [154]:
# Recompute actual total from all components to ensure percentages are meaningful
effective_total_time = (
    dataset_load_time +
    test_data_encryption_time +
    label_encryption_time +
    training_time +
    threshold_encryption_time +
    classification_time
)

# Compute percentages for each stage
dataset_load_percentage = (dataset_load_time / effective_total_time) * 100
test_data_encryption_percentage = (test_data_encryption_time / effective_total_time) * 100
label_encryption_percentage = (label_encryption_time / effective_total_time) * 100
rf_training_percentage = (training_time / effective_total_time) * 100
threshold_encryption_percentage = (threshold_encryption_time / effective_total_time) * 100
classification_percentage = (classification_time / effective_total_time) * 100

# Compute throughput-specific percentages
total_throughput_time = test_data_encryption_time + classification_time
encryption_percentage_throughput = (test_data_encryption_time / total_throughput_time) * 100
classification_percentage_throughput = (classification_time / total_throughput_time) * 100

# Compute full secure pipeline time and throughput
total_secure_pipeline_time = (
    test_data_encryption_time +
    label_encryption_time +
    threshold_encryption_time +
    classification_time
)
throughput_secure_pipeline = len(X_test) / total_secure_pipeline_time

# Print timing breakdown
print("\n===== Execution Time Summary =====")
print(f"Total Execution Time: {effective_total_time:.4f} seconds")
print(f"Dataset Load Time: {dataset_load_time:.4f} seconds ({dataset_load_percentage:.2f}%)")
print(f"Test Data Encryption Time: {test_data_encryption_time:.4f} seconds ({test_data_encryption_percentage:.2f}%)")
print(f"Label Encryption Time: {label_encryption_time:.4f} seconds ({label_encryption_percentage:.2f}%)")
print(f"Random Forest Training Time: {training_time:.4f} seconds ({rf_training_percentage:.2f}%)")
print(f"Threshold Encryption Time: {threshold_encryption_time:.4f} seconds ({threshold_encryption_percentage:.2f}%)")
print(f"Secure Classification Time: {classification_time:.4f} seconds ({classification_percentage:.2f}%)")

# Breakdown of secure classification if available
if "timing_report" in locals():
    print("\n----- Secure Classification Breakdown -----")
    print(f"Tree Traversal Time: {timing_report['traversal_time']:.4f} seconds")
    print(f"Label Access Time: {timing_report['label_access_time']:.4f} seconds")
    print(f"Voting Time: {timing_report['voting_time']:.4f} seconds")
    print(f"  - PRE Decryption Time: {timing_report['decryption_time']:.4f} seconds")
    print(f"Average Time per Sample: {timing_report['avg_per_sample']:.4f} seconds")

# Print final results
print("\n===== Secure Classification Results =====")
print(f"Secure Random Forest Accuracy on Encrypted MNIST: {secure_accuracy:.4f}")
print(f"Number of Decision Trees (num_estimators): {num_estimators}")
print(f"Number of Images Used for Training: {len(X_train)}")
print(f"Number of Images Used for Testing: {len(X_test)}")

# Throughput metrics for inference-only pipeline
print("\n===== Throughput (Inference Only) =====")
print(f"Total Throughput Time (Encryption + Classification): {total_throughput_time:.4f} seconds")
print(f"Throughput: {len(X_test) / total_throughput_time:.2f} samples/second")
print(f"Percentage of Test Data Encryption Time vs Throughput: {encryption_percentage_throughput:.2f}%")
print(f"Percentage of Classification Time vs Throughput: {classification_percentage_throughput:.2f}%")

# Throughput for full secure pipeline (label + threshold + classification)
print("\n===== Full Secure Pipeline Throughput =====")
print(f"Total Time (Test Data + Label + Threshold + Classification): {total_secure_pipeline_time:.4f} seconds")
print(f"Throughput (Secure Pipeline): {throughput_secure_pipeline:.2f} samples/second")

# Key usage summary
print("\n===== Key Management Summary =====")
print(f"Number of Decision Trees: {num_estimators}")
print(f"Number of OPE Keys Used: {len(ope_keys)}")
print(f"Number of PRE Public Keys Used: {len(pre_public_keys)}")
print(f"Key Assignment Strategy: One unique OPE and PRE key per tree (Multi-Key Encryption)")

# Re-print timing breakdown for completeness
print("\n===== Classification Timing Summary =====")
for key, value in timing_report.items():
    print(f"{key.replace('_', ' ').capitalize()}: {value:.4f} seconds" if isinstance(value, float) else f"{key}: {value}")



===== Execution Time Summary =====
Total Execution Time: 1491.0250 seconds
Dataset Load Time: 4.4111 seconds (0.30%)
Test Data Encryption Time: 0.8931 seconds (0.06%)
Label Encryption Time: 126.3462 seconds (8.47%)
Random Forest Training Time: 5.6395 seconds (0.38%)
Threshold Encryption Time: 0.0378 seconds (0.00%)
Secure Classification Time: 1353.6973 seconds (90.79%)

----- Secure Classification Breakdown -----
Tree Traversal Time: 10.8025 seconds
Label Access Time: 0.5482 seconds
Voting Time: 1332.9901 seconds
  - PRE Decryption Time: 748.6327 seconds
Average Time per Sample: 0.0960 seconds

===== Secure Classification Results =====
Secure Random Forest Accuracy on Encrypted MNIST: 0.9549
Number of Decision Trees (num_estimators): 15
Number of Images Used for Training: 56000
Number of Images Used for Testing: 14000

===== Throughput (Inference Only) =====
Total Throughput Time (Encryption + Classification): 1354.5904 seconds
Throughput: 10.34 samples/second
Percentage of Test Data 