In [176]:
import time
import numpy as np
from collections import Counter
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_openml
from pyope.ope import OPE, ValueRange
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
from tqdm.auto import tqdm 
import matplotlib.pyplot as plt
import json
from umbral import encrypt, decrypt_reencrypted, reencrypt, generate_kfrags
from umbral.keys import SecretKey, PublicKey
from umbral.signing import Signer
from collections import Counter
from umbral.capsule_frag import VerifiedCapsuleFrag
from joblib import Parallel, delayed
from concurrent.futures import ThreadPoolExecutor, as_completed



### Key Initialization and Loading

This section loads all necessary cryptographic keys from a `keys.json` file:

- **OPE Keys**: Loaded for each decision tree, used to encrypt feature thresholds.
- **PRE Keys**: Loaded for a fixed number of groups, used to encrypt and re-encrypt class labels.
- **Voting Key**: A global key used for decrypting the final prediction after proxy re-encryption.

The OPE keys are used per-tree, while PRE keys are grouped across trees to enable multi-key encryption. The voting key is used in the final decryption phase after

In [177]:
# Load keys from JSON file
with open("keys.json", "r") as f:
    key_data = json.load(f)

in_range = ValueRange(0, 2550)
out_range = ValueRange(0, 2**32 - 1)

num_estimators = 1  # Number of trees in the RF
num_keys = 1    # Number of PRE keys/groups

if num_keys > num_estimators:
    raise ValueError("The number of keys cannot exceed the number of estimators.")

# Index used to choose the keys from the JSON file
start_idx = 0
end_idx_ope = start_idx + num_estimators
end_idx_pre = start_idx + num_keys  # Only take num_keys PRE keys

# Load OPE keys (each tree still uses a unique OPE key)
ope_keys = [OPE(key.encode(), in_range=in_range, out_range=out_range)
            for key in key_data["ope_keys"][start_idx:end_idx_ope]]

# Load PRE keys (only num_keys for grouped encryption)
pre_secret_keys = [
    SecretKey.from_bytes(bytes.fromhex(k))
    for k in key_data["pre_keys"]["private_keys"][start_idx:end_idx_pre]
]
pre_public_keys = [
    PublicKey.from_bytes(bytes.fromhex(k))
    for k in key_data["pre_keys"]["public_keys"][start_idx:end_idx_pre]
]

# Load voting key
vote_secret_key = SecretKey.from_bytes(bytes.fromhex(key_data["vote_key"]["private"]))
vote_public_key = PublicKey.from_bytes(bytes.fromhex(key_data["vote_key"]["public"]))


### Key Management and Tree Grouping

This section performs the following tasks:

- Groups decision trees evenly across multiple PRE keys.
- Generates Proxy Re-Encryption (PRE) key pairs for each group.
- Generates kfrags (key fragments) that allow re-encryption from each group key to a common voting key.

The grouping ensures that each PRE key encrypts approximately the same number of decision trees.

In [178]:
# Divide the total number of estimators into groups based on the number of PRE keys
def generate_tree_groups(num_estimators, num_keys):
    base = num_estimators // num_keys            # Minimum number of trees per group
    remainder = num_estimators % num_keys        # Distribute the remaining trees

    tree_groups = []
    start = 0
    for i in range(num_keys):
        group_size = base + (1 if i < remainder else 0)  # Distribute remainder across first few groups
        group = list(range(start, start + group_size))   # Assign tree indices to this group
        tree_groups.append(group)
        start += group_size
    return tree_groups

# Generate tree-to-key groups based on the number of estimators and keys
tree_groups = generate_tree_groups(num_estimators, num_keys)


# Generate a list of random PRE secret keys and their corresponding public keys
def generate_group_keys(num_keys):
    group_secret_keys = [SecretKey.random() for _ in range(num_keys)]
    group_public_keys = [sk.public_key() for sk in group_secret_keys]
    return group_secret_keys, group_public_keys


# Generate kfrags (key fragments) for each group to allow re-encryption to the voting key
def generate_kfrags_for_groups(group_secret_keys, vote_public_key):
    group_kfrags = []
    for sk in group_secret_keys:
        signer = Signer(sk)  # Signer is needed to verify kfrags
        kfrags = generate_kfrags(
            delegating_sk=sk,
            receiving_pk=vote_public_key,
            signer=signer,
            threshold=1,
            shares=1,
            sign_delegating_key=True,
            sign_receiving_key=True
        )
        group_kfrags.append(kfrags)
    return group_kfrags


### PRE Decryption Function

This function handles the decryption of a Proxy Re-Encrypted (PRE) ciphertext. It performs the following steps:

- Extracts the original `cfrags` from verified `kfrags`.
- Uses the PRE `decrypt_reencrypted` function with the voting secret key and delegation public key.
- Measures and returns the decryption time alongside the decrypted result.

In [179]:
# Decrypt a re-encrypted ciphertext using PRE and measure the decryption time
def decrypt_pre(ciphertext, capsule, kfrags, vote_secret_key, delegating_pk):
    start = time.perf_counter()  # Start timing the decryption

    # Extract raw capsule fragments from the verified capsule fragments
    cfrags = [vcfrag.cfrag for vcfrag in kfrags]

    # Perform decryption using the PRE library's re-encryption decryption function
    decrypted = decrypt_reencrypted(
        receiving_sk=vote_secret_key,   # Final secret key that receives the re-encrypted data
        delegating_pk=delegating_pk,    # Public key of the delegator
        capsule=capsule,                # Capsule containing encrypted metadata
        verified_cfrags=cfrags,         # Verified fragments used to reconstruct access
        ciphertext=ciphertext           # Encrypted label
    )

    end = time.perf_counter()  # End timing
    decryption_time = end - start

    return decrypted, decryption_time  # Return both result and time taken


### Grouped Majority Voting with PRE

This function performs secure majority voting using grouped encrypted predictions. For each group:
- It determines the local majority vote.
- Re-encrypts the vote using a kfrag.
- Decrypts the re-encrypted vote using the voting secret key.

The final result is computed by majority voting over all group-level decrypted predictions.

In [180]:
# Perform grouped majority voting using PRE-decrypted labels
def grouped_majority_voting_with_pre(grouped_votes, grouped_capsules, group_keys, group_kfrags, vote_secret_key):
    final_votes = []

    for i in range(len(grouped_votes)):
        # Get the most common vote in the current group
        group_vote = Counter(grouped_votes[i]).most_common(1)[0][0]

        # Retrieve the capsule and kfrag for the group
        capsule = grouped_capsules[i][0]
        kfrag = group_kfrags[i][0]

        # Re-encrypt the capsule using the provided kfrag
        cfrag = reencrypt(capsule, kfrag)
        verified_cfrag = VerifiedCapsuleFrag(cfrag)

        # Decrypt the re-encrypted ciphertext
        decrypted, _ = decrypt_pre(
            ciphertext=group_vote,
            capsule=capsule,
            kfrags=[verified_cfrag],
            vote_secret_key=vote_secret_key,
            delegating_pk=group_keys[i]
        )

        # Append the decoded decrypted vote
        final_votes.append(decrypted.decode())

    # Return the overall majority vote from all groups
    return Counter(final_votes).most_common(1)[0][0]


### Dataset Loading and Timing

This section loads the MNIST dataset using `fetch_openml` and measures the time required to complete the operation. The dataset will later be used for training and testing the Random Forest classifier.


In [181]:
start_total_time = time.perf_counter()  # Start total execution time

# Load Dataset Used for Testing and Training
start_dataset_load_time = time.perf_counter()

# Fetch MNIST dataset from OpenML and convert to NumPy arrays
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
X, y = mnist.data.astype("float32"), mnist.target.astype("int")

end_dataset_load_time = time.perf_counter()

# Calculate dataset load time
dataset_load_time = end_dataset_load_time - start_dataset_load_time
print(f"Dataset Loading Time: {dataset_load_time:.4f} seconds")

Dataset Loading Time: 4.6908 seconds


### Dataset Scaling

This section normalizes the dataset to ensure all pixel values fall within the required encryption range. 
The raw MNIST images are first scaled to the [0, 1] range, then rescaled to the [0, 2550] integer range to match the input range expected by the OPE encryption scheme.


In [182]:
# Normalize pixel values to [0, 1]
X = X / 255.0

if X.max() <= 1:
    # Rescale dataset from original range to [0, 255]
    X = (X - X.min()) / (X.max() - X.min()) * 255
    X = (X * 10).astype(int)  # Scale to 0–2550
else:
    X = X.astype(int)

### Dataset Splitting

The dataset is split into training and testing sets. The total number of samples in each set is recorded. 
These values are used to control the scope of encryption and evaluation throughout the experiment.

In [183]:
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

#Get the number of samples to be encrypted
num_samples_training = len(X_train)
num_samples_testing = len(X_test)

print(f"Number of training samples: {num_samples_training}")
print(f"Number of testing samples: {num_samples_testing}")

Number of training samples: 56000
Number of testing samples: 14000


### Grouped Dataset Encryption with OPE

This function encrypts the input dataset `X_data` using Order-Preserving Encryption (OPE) on a per-group basis. Each group uses a shared OPE key, and encryption is parallelized across chunks of rows to improve performance.

If cached encrypted data exists and matches the expected shape, it is loaded from disk. Otherwise, the data is encrypted in chunks, stored, and reused for all trees in the corresponding group.


In [184]:
# Encrypt a single chunk of rows using OPE
def encrypt_chunk(chunk, ope):
    return [[ope.encrypt(int(val)) for val in row] for row in chunk]

# Encrypt dataset grouped by shared OPE keys across decision tree groups
def load_or_encrypt_dataset_grouped(X_data, tree_groups, ope_keys, key_data, chunk_size=500):
    encrypted_versions = []
    encryption_times = []

    for group_idx, tree_indices in enumerate(tree_groups):
        key_name = key_data["ope_keys"][group_idx]
        ope = ope_keys[group_idx]
        file_path = f"Encrypted_Dataset/X_test_encrypted_{key_name}.npy"

        # Use cached encrypted dataset if available and matches input size
        if os.path.exists(file_path):
            encrypted_X = np.load(file_path)
            if encrypted_X.shape[0] != X_data.shape[0]:
                print(f"Mismatch shape in cached file for key: {key_name}. Re-encrypting...")
                os.remove(file_path)
            else:
                print(f"[Group {group_idx + 1}] Loaded cached data for key: {key_name[:8]}...")
                encryption_times.append(0)
                for _ in tree_indices:
                    encrypted_versions.append(encrypted_X)
                continue

        print(f"[Group {group_idx + 1}] Encrypting test data with key: {key_name[:8]}...")

        start = time.perf_counter()

        # Split dataset into smaller chunks for parallel encryption
        num_samples = X_data.shape[0]
        chunks = [X_data[i:i + chunk_size] for i in range(0, num_samples, chunk_size)]

        encrypted_chunks = []
        for chunk in tqdm(chunks, desc=f"Encrypting (Group {group_idx + 1})", unit="chunk"):
            encrypted = Parallel(n_jobs=-1)(
                delayed(encrypt_chunk)([row], ope) for row in chunk
            )
            # Flatten nested lists to a single list of rows
            encrypted_chunks.extend([item[0] for item in encrypted])

        encrypted_X = np.array(encrypted_chunks)

        end = time.perf_counter()

        # Save encrypted data for future reuse
        os.makedirs("Encrypted_Dataset", exist_ok=True)
        np.save(file_path, encrypted_X)
        print(f"[Group {group_idx + 1}] Saved encrypted data to: {file_path}")
        encryption_times.append(end - start)

        # Share the same encrypted dataset with all trees in the current group
        for _ in tree_indices:
            encrypted_versions.append(encrypted_X)

    return encrypted_versions, encryption_times


### Encrypt Test Dataset Using Grouped OPE

This block performs encryption of the test dataset using the grouped OPE approach. It measures the total encryption time and prints the shape of the encrypted dataset. If cached versions are used, the encryption time is reported as zero.


In [185]:
# Display the shape of the test dataset
print(f"X_test size: {X_test.shape}")

# Start timing the encryption process
start_test_data_encryption_time = time.perf_counter()

# Encrypt or load encrypted test dataset for each group of trees
X_test_encrypted_per_tree, encryption_times = load_or_encrypt_dataset_grouped(
    X_test, tree_groups, ope_keys, key_data, chunk_size=50
)

end_test_data_encryption_time = time.perf_counter()

# Determine encryption time: use total duration if loaded from cache, or sum of all group encryption times
test_data_encryption_time = (
    end_test_data_encryption_time - start_test_data_encryption_time
    if all(t == 0 for t in encryption_times)
    else sum(encryption_times)
)

# Print final encryption time
print(f"\nTotal Dataset Encryption Time (Grouped OPE + Chunked): {test_data_encryption_time:.4f} seconds")


X_test size: (14000, 784)
[Group 1] Loaded cached data for key: 5f5e8a3f...

Total Dataset Encryption Time (Grouped OPE + Chunked): 0.0413 seconds


### Random Forest Training

A Random Forest classifier is trained on the plaintext training data to construct the model used for encrypted inference.
The number of decision trees and maximum depth are defined in the initialization. 

The training time is measured to compare against the full encryption and classification pipeline later.

In [186]:
start_training_time = time.perf_counter()

clf_ope = RandomForestClassifier(n_estimators=num_estimators, max_depth=20, random_state=42, min_samples_split=2)
clf_ope.fit(X_train, y_train)

end_training_time = time.perf_counter()
training_time = end_training_time - start_training_time

print(f"Random Forest Training Time: {training_time:.4f} seconds")

Random Forest Training Time: 0.4559 seconds


### Label Encryption with Proxy Re-Encryption (PRE)

This section encrypts the predicted class labels at the leaf nodes of each decision tree using Proxy Re-Encryption (PRE). It consists of two parts:

- `encrypt_single_tree`: Encrypts the labels of a single decision tree using a shared group PRE public key.
- `encrypt_tree_labels_with_pre`: Runs the label encryption in parallel across all trees in the Random Forest using multithreading.

The result includes encrypted labels and capsules for each tree, along with the total encryption time.


In [187]:
# Encrypt the leaf node labels of a single decision tree using a given PRE key
def encrypt_single_tree(tree_idx, feature, value, group_idx, pre_key_bytes, max_labels_per_tree=None):
    tree_values = {}
    tree_capsules = {}

    # Convert the PRE public key from bytes
    pre_key = PublicKey.from_bytes(bytes.fromhex(pre_key_bytes))

    # Identify leaf nodes in the decision tree
    leaf_nodes = [i for i in range(len(feature)) if feature[i] == -2]
    if max_labels_per_tree is not None:
        leaf_nodes = leaf_nodes[:max_labels_per_tree]  # Limit encryption to a subset if specified

    # Encrypt the predicted label at each leaf node
    for node in leaf_nodes:
        label = str(np.argmax(value[node][0]))
        capsule, ciphertext = encrypt(pre_key, label.encode())
        tree_values[node] = ciphertext
        tree_capsules[node] = capsule

    return tree_values, tree_capsules


# Encrypt the labels of all trees using PRE, in parallel
def encrypt_tree_labels_with_pre(clf, pre_public_keys, tree_groups, max_labels_per_tree=None):
    start_label_encryption_time = time.perf_counter()
    print("Encrypting leaf labels using PRE (Parallel)...")

    jobs = []

    # Prepare encryption jobs for each tree based on its group
    for idx, tree in enumerate(clf.estimators_):
        group_idx = next(i for i, group in enumerate(tree_groups) if idx in group)
        pre_key_hex = bytes(pre_public_keys[group_idx]).hex()

        feature = tree.tree_.feature.tolist()
        value = tree.tree_.value.tolist()

        jobs.append((idx, feature, value, group_idx, pre_key_hex))

    # Encrypt labels in parallel using multithreading
    results = Parallel(n_jobs=-1, backend="threading")(
        delayed(encrypt_single_tree)(idx, feature, value, group_idx, pre_key_hex, max_labels_per_tree)
        for idx, feature, value, group_idx, pre_key_hex in jobs
    )

    encrypted_leaf_values, leaf_capsules = zip(*results)

    end_label_encryption_time = time.perf_counter()
    label_encryption_time = end_label_encryption_time - start_label_encryption_time

    print(f"PRE Label Encryption completed for {len(clf.estimators_)} trees.")
    print(f"Total PRE Encryption Time (Parallel): {label_encryption_time:.4f} seconds")

    return list(encrypted_leaf_values), list(leaf_capsules), label_encryption_time


### PRE Label Encryption with ThreadPoolExecutor

This block performs parallel label encryption for each decision tree using `ThreadPoolExecutor`. The encrypted labels are computed using the `encrypt_single_tree` function, which uses a group's PRE public key. This approach is a multithreaded alternative to joblib and is especially useful when working with shared memory operations.


In [188]:
# Prepare encryption jobs: one per tree with associated group and PRE key
jobs = []
for idx, tree in enumerate(clf_ope.estimators_):
    group_idx = next(i for i, group in enumerate(tree_groups) if idx in group)
    pre_key_hex = bytes(pre_public_keys[group_idx]).hex()

    feature = tree.tree_.feature.tolist()
    value = tree.tree_.value.tolist()

    jobs.append((idx, feature, value, group_idx, pre_key_hex))

# Start timing the encryption process
start_label_encryption_time = time.perf_counter()

results = []

# Run encryption tasks using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
    futures = [
        executor.submit(encrypt_single_tree, idx, feature, value, group_idx, pre_key_hex)
        for idx, feature, value, group_idx, pre_key_hex in jobs
    ]

    # Collect results as they complete
    for future in tqdm(as_completed(futures), total=len(futures), desc="Encrypting Trees with PRE", unit="tree"):
        results.append(future.result())

end_label_encryption_time = time.perf_counter()
label_encryption_time = end_label_encryption_time - start_label_encryption_time

# Report encryption summary
print(f"PRE Label Encryption completed for {len(results)} trees.")
print(f"Total PRE Encryption Time (Threaded): {label_encryption_time:.4f} seconds")


Encrypting Trees with PRE:   0%|          | 0/1 [00:00<?, ?tree/s]

PRE Label Encryption completed for 1 trees.
Total PRE Encryption Time (Threaded): 8.5723 seconds


### Threshold Encryption with OPE (Parallel)

This section encrypts the split thresholds of each decision tree using Order-Preserving Encryption (OPE). It:

- Checks if encrypted thresholds are already cached for each tree.
- If not cached, encrypts and stores them.
- Executes the process in parallel across all decision trees using `joblib`.

The encrypted thresholds are saved per-tree to avoid redundant encryption on future runs.


In [189]:
# Make sure the output directory exists
os.makedirs("Encrypted_Thresholds", exist_ok=True)

start_threshold_encryption_time = time.perf_counter()

# Encrypt thresholds for a single tree and save to file
def encrypt_thresholds_for_tree(idx, threshold_list, ope, key_name):
    file_path = f"Encrypted_Thresholds/encrypted_thresholds_{key_name}.npy"

    # If cached file exists, load it
    if os.path.exists(file_path):
        thresholds = np.load(file_path, allow_pickle=True).tolist()
        print(f"Loaded cached thresholds for key {key_name[:8]}...")
        return thresholds
    else:
        print(f"Encrypting thresholds for key {key_name[:8]}...")

        # Encrypt each threshold (skip value -2, which marks a leaf)
        tree_thresholds = [
            ope.encrypt(int(th)) if th != -2 else None
            for th in threshold_list
        ]

        # Save encrypted thresholds
        np.save(file_path, tree_thresholds)
        print(f"Saved thresholds to {file_path}")
        return tree_thresholds

# Prepare data for parallel execution: (tree index, thresholds, OPE key, key name)
tree_threshold_data = [
    (idx, tree.tree_.threshold.tolist(), ope_keys[idx], key_data["ope_keys"][idx])
    for idx, tree in enumerate(clf_ope.estimators_)
]

# Encrypt all thresholds in parallel
encrypted_thresholds = Parallel(n_jobs=-1)(
    delayed(encrypt_thresholds_for_tree)(idx, threshold_list, ope, key_name)
    for idx, threshold_list, ope, key_name in tree_threshold_data
)

end_threshold_encryption_time = time.perf_counter()
threshold_encryption_time = end_threshold_encryption_time - start_threshold_encryption_time

print(f"\nEncrypted Thresholds for {len(encrypted_thresholds)} trees")
print(f"Threshold Encryption Time (Parallel): {threshold_encryption_time:.4f} seconds")



Encrypted Thresholds for 1 trees
Threshold Encryption Time (Parallel): 0.9041 seconds


### Dataset Encryption with OPE Per Tree

This block encrypts the input dataset using Order-Preserving Encryption (OPE), generating one encrypted version of the dataset for each OPE key (typically one per decision tree).

- `encrypt_image`: Encrypts a single image (row of pixel values).
- `encrypt_dataset_with_ope`: Iterates over all OPE keys and applies encryption to the entire dataset, storing each result for later use in tree-specific classification.


In [190]:
# Encrypt a single image using an OPE key
def encrypt_image(image, ope_key):
    return [ope_key.encrypt(int(pixel)) for pixel in image]

# Encrypt the entire dataset once for each OPE key (e.g., one per decision tree)
def encrypt_dataset_with_ope(X, ope_keys):
    encrypted_versions = []

    for idx, ope_key in enumerate(ope_keys):
        print(f"\nEncrypting dataset for Tree {idx + 1} using OPE Key {idx + 1}")
        start_tree_time = time.time()

        # Encrypt each image in the dataset using the current OPE key
        encrypted_X = [
            encrypt_image(image, ope_key)
            for image in tqdm(X, desc=f"Encrypting for Tree {idx + 1}")
        ]

        total_time = time.time() - start_tree_time
        print(f"Done: {len(encrypted_X)} images encrypted for Tree {idx + 1} in {total_time:.2f} sec")

        encrypted_versions.append(np.array(encrypted_X))

    return encrypted_versions


### Grouped Majority Voting with PRE

This function performs majority voting across encrypted prediction groups, with added error handling:

- It identifies the majority vote in each group.
- Re-encrypts and decrypts the vote using PRE.
- Catches and logs any errors during re-encryption, decryption, or decoding.

If decryption or decoding fails, a fallback label of `'0'` is used to ensure continuity in the final vote.


In [191]:
# Perform grouped majority voting using PRE with error handling for decoding and decryption
def grouped_majority_voting_with_pre(grouped_votes, grouped_capsules, group_keys, group_kfrags, vote_secret_key):
    final_votes = []

    for i in range(len(grouped_votes)):
        try:
            # Get the majority vote within the current group
            group_vote = Counter(grouped_votes[i]).most_common(1)[0][0]
            capsule = grouped_capsules[i][0]
            kfrag = group_kfrags[i][0]

            # Re-encrypt the capsule and verify
            cfrag = reencrypt(capsule, kfrag)
            verified_cfrag = VerifiedCapsuleFrag(cfrag)

            # Decrypt the re-encrypted ciphertext
            decrypted, _ = decrypt_pre(
                ciphertext=group_vote,
                capsule=capsule,
                kfrags=[verified_cfrag],
                vote_secret_key=vote_secret_key,
                delegating_pk=group_keys[i]
            )

            # Try decoding the decrypted label; fallback to "0" if decoding fails
            try:
                label = decrypted.decode()
            except Exception as e:
                print(f"[WARNING] Decode failed for group {i}, using fallback label '0'. Error: {e}")
                label = "0"

            final_votes.append(label)

        except Exception as e:
            # Catch any error during voting for this group and fallback
            print(f"[ERROR] Failed voting for group {i}: {e}")
            final_votes.append("0")

    # Perform overall majority voting from all group results
    return Counter(final_votes).most_common(1)[0][0]


### Secure Classification with Encrypted Random Forest

This function performs secure classification using a multi-key encrypted Random Forest model. It:

- Traverses each tree to find the leaf node for the input sample.
- Retrieves the encrypted label and associated capsule.
- Groups votes by encryption key for grouped decryption.
- Performs secure grouped voting using Proxy Re-Encryption (PRE).
- Measures and returns traversal time, label access time, voting time, and decryption time (if tracked).

This function is designed for performance evaluation under full encryption.


In [196]:
# Perform secure classification using PRE-encrypted labels and grouped voting
def secure_classify(model, encrypted_X_per_tree, encrypted_thresholds, encrypted_leaf_values,
                    leaf_capsules, tree_public_keys, kfrags_list, vote_secret_key, tree_groups):
    
    traversal_time = 0
    label_access_time = 0
    voting_time = 0

    encrypted_votes = []
    capsules = []

    # Traverse all trees and collect encrypted labels and capsules
    for tree_idx, tree in enumerate(model.estimators_):
        start_traversal = time.perf_counter()
        node = 0
        tree_thresholds = encrypted_thresholds[tree_idx]
        encrypted_X = encrypted_X_per_tree[tree_idx]

        # Traverse the tree using encrypted thresholds
        while tree.tree_.feature[node] != -2:
            feature_idx = tree.tree_.feature[node]
            encrypted_threshold = tree_thresholds[node]

            if encrypted_X[feature_idx] < encrypted_threshold:
                node = tree.tree_.children_left[node]
            else:
                node = tree.tree_.children_right[node]

        traversal_time += time.perf_counter() - start_traversal

        # Retrieve the encrypted label and capsule for the reached leaf node
        start_label = time.perf_counter()
        encrypted_label = encrypted_leaf_values[tree_idx][node]
        capsule = leaf_capsules[tree_idx][node]
        label_access_time += time.perf_counter() - start_label

        encrypted_votes.append(encrypted_label)
        capsules.append(capsule)

    # Group votes and capsules by encryption group (based on PRE keys)
    grouped_votes = []
    grouped_capsules = []
    group_keys = []
    group_kfrags = []

    for group_idx, group in enumerate(tree_groups):
        group_votes = [encrypted_votes[i] for i in group]
        group_caps = [capsules[i] for i in group]

        grouped_votes.append(group_votes)
        grouped_capsules.append(group_caps)
        group_keys.append(tree_public_keys[group_idx])
        group_kfrags.append(kfrags_list[group_idx])

    # Perform grouped majority voting and measure its time
    start_voting = time.perf_counter()
    result = grouped_majority_voting_with_pre(
        grouped_votes,
        grouped_capsules,
        group_keys,
        group_kfrags,
        vote_secret_key
    )
    voting_time += time.perf_counter() - start_voting


    return result, traversal_time, label_access_time, voting_time


### Secure Classification for a Single Encrypted Sample

This function performs secure classification on a single encrypted input sample using the multi-key Random Forest model.

Steps involved:
- Retrieves the encrypted sample for each tree.
- Passes the sample through the `secure_classify` function for tree traversal, label decryption, and grouped voting.
- Measures the total time taken and returns timing details for:
  - Tree traversal
  - Label access
  - Grouped voting
  - Total execution


In [197]:
# Secure classification for a single encrypted sample
def secure_classify_sample(sample_idx, model, X_encrypted_per_tree, encrypted_thresholds,
                           encrypted_leaf_values, leaf_capsules,
                           tree_public_keys, kfrags_list, vote_secret_key, tree_groups):
    # Get the encrypted sample for each tree
    sample_per_tree = [X_encrypted_per_tree[tree_idx][sample_idx] for tree_idx in range(len(model.estimators_))]

    start = time.perf_counter()

    # Perform secure classification
    pred, t_traversal, t_label, t_vote = secure_classify(
        model=model,
        encrypted_X_per_tree=sample_per_tree,
        encrypted_thresholds=encrypted_thresholds,
        encrypted_leaf_values=encrypted_leaf_values,
        leaf_capsules=leaf_capsules,
        tree_public_keys=tree_public_keys,
        kfrags_list=kfrags_list,
        vote_secret_key=vote_secret_key,
        tree_groups=tree_groups
    )

    end = time.perf_counter()
    total_time = end - start

    return pred, t_traversal, t_label, t_vote, total_time


### Secure Classification for Full Dataset (Parallel)

This function performs secure classification for an entire encrypted dataset using a multi-key Random Forest model with Proxy Re-Encryption.

- It classifies each sample in parallel using `secure_classify_sample`.
- Aggregates timing metrics for:
  - Tree traversal
  - Label access
  - Grouped voting
  - Total time
- Returns both the predicted labels and a dictionary of performance metrics.


In [198]:
# Securely classify an entire encrypted dataset in parallel
def secure_classify_dataset(model, X_encrypted_per_tree, encrypted_thresholds, encrypted_leaf_values,
                             leaf_capsules, tree_public_keys, kfrags_list, vote_secret_key, tree_groups):
    num_samples = len(X_encrypted_per_tree[0])
    print("Starting parallel classification...")

    # Classify each sample in parallel
    results = Parallel(n_jobs=-1, prefer="threads")(
        delayed(secure_classify_sample)(
            sample_idx, model, X_encrypted_per_tree, encrypted_thresholds,
            encrypted_leaf_values, leaf_capsules,
            tree_public_keys, kfrags_list, vote_secret_key, tree_groups
        ) for sample_idx in tqdm(range(num_samples))
    )

    # Initialize accumulators for timing and predictions
    predictions = []
    total_traversal = total_label_access = total_voting = total_time = 0

    # Aggregate results from all samples
    for pred, t_traversal, t_label, t_vote, t_total in results:
        predictions.append(pred)
        total_traversal += t_traversal
        total_label_access += t_label
        total_voting += t_vote
        total_time += t_total

    overall_time = total_time

    # Return predictions and detailed timing information
    timings = {
        "samples": num_samples,
        "classification_time": total_time,
        "overall_time": total_time,
        "avg_per_sample": total_time / num_samples,
        "traversal_time": total_traversal,
        "label_access_time": total_label_access,
        "voting_time": total_voting
    }

    return np.array(predictions), timings


### Grouped kfrag Generation and Secure Classification Execution

This section performs the final secure classification using Proxy Re-Encryption (PRE):

- Generates one `kfrag` set per encryption group (not per tree).
- Builds a simplified list of PRE public keys for each group.
- Runs secure classification across the encrypted test set using `secure_classify_dataset`.
- Prints predictions, true labels, and a detailed breakdown of timing metrics, including:
  - Total classification time
  - Tree traversal time
  - Label access time
  - Voting time
  - Average time per sample


In [199]:
# Generate kfrags for each group (not per tree)
signers = [Signer(sk) for sk in pre_secret_keys]

kfrags_list = [
    generate_kfrags(
        delegating_sk=pre_secret_keys[group_idx],
        receiving_pk=vote_public_key,
        signer=signers[group_idx],
        threshold=1,
        shares=1,
        sign_delegating_key=True,
        sign_receiving_key=True
    )
    for group_idx in range(len(tree_groups))
]

# Optionally prepare the list of PRE public keys used per group
tree_public_keys_points = [pre_public_keys[group_idx] for group_idx in range(len(tree_groups))]

start_time = time.time()
print("Performing Secure Classification with PRE (Grouped)...")

# Run secure classification for the entire encrypted test dataset
y_pred_encrypted, timing_report = secure_classify_dataset(
    model=clf_ope,
    X_encrypted_per_tree=X_test_encrypted_per_tree,
    encrypted_thresholds=encrypted_thresholds,
    encrypted_leaf_values=encrypted_leaf_values,
    leaf_capsules=leaf_capsules,
    tree_public_keys=tree_public_keys_points,
    kfrags_list=kfrags_list,
    vote_secret_key=vote_secret_key,
    tree_groups=tree_groups
)

# Display predictions and true labels
print("Predictions:", y_pred_encrypted)
print("True Labels:", y_test[:num_samples_testing])

# Display total classification time
classification_time = time.time() - start_time
print(f"Secure Classification Time: {classification_time:.4f} seconds")

# Display detailed classification timing report
print("\n===== Classification Timing Summary =====")
for key, value in timing_report.items():
    if isinstance(value, float):
        print(f"{key.replace('_', ' ').capitalize()}: {value:.4f} seconds")
    else:
        print(f"{key}: {value}")


Performing Secure Classification with PRE (Grouped)...
Starting parallel classification...


  0%|          | 0/14000 [00:00<?, ?it/s]

Predictions: ['8' '4' '3' ... '2' '7' '1']
True Labels: [8 4 8 ... 2 7 1]
Secure Classification Time: 55.9697 seconds

===== Classification Timing Summary =====
samples: 14000
Classification time: 893.5567 seconds
Overall time: 893.5567 seconds
Avg per sample: 0.0638 seconds
Traversal time: 1.1407 seconds
Label access time: 0.0425 seconds
Voting time: 892.1942 seconds


### Final Accuracy Evaluation

This block verifies that the number of predictions matches the number of ground truth labels. If they match, it converts the predictions to integers and calculates the final accuracy score on the encrypted dataset.

In [None]:
# Sanity check: ensure predictions match ground truth size
print("Length of predictions:", len(y_pred_encrypted))
print("Length of ground truth:", len(y_test))

# Use only the number of samples classified (in case test set was sliced earlier)
y_true = y_test[:len(y_pred_encrypted)]

if len(y_pred_encrypted) == len(y_true):
    # Convert predictions to integers
    y_pred_encrypted = [int(p) for p in y_pred_encrypted]

    # Calculate and print accuracy
    secure_accuracy = accuracy_score(y_true, y_pred_encrypted)
    print(f"Secure Random Forest Accuracy on Encrypted Dataset: {secure_accuracy:.2f}")
else:
    print("Error: Prediction length does not match test labels. Cannot compute accuracy.")


Length of predictions: 14000
Length of ground truth: 14000
Secure Random Forest Accuracy on Encrypted Dataset: 0.82


In [202]:
import multiprocessing

# Recompute actual total from all components to ensure percentages are meaningful
effective_total_time = (
    dataset_load_time +
    test_data_encryption_time +
    label_encryption_time +
    training_time +
    threshold_encryption_time +
    classification_time
)

# Compute percentages
dataset_load_percentage = (dataset_load_time / effective_total_time) * 100
test_data_encryption_percentage = (test_data_encryption_time / effective_total_time) * 100
label_encryption_percentage = (label_encryption_time / effective_total_time) * 100
rf_training_percentage = (training_time / effective_total_time) * 100
threshold_encryption_percentage = (threshold_encryption_time / effective_total_time) * 100
classification_percentage = (classification_time / effective_total_time) * 100

# Combine encryption and classification times
total_throughput_time = test_data_encryption_time + classification_time
encryption_percentage_throughput = (test_data_encryption_time / total_throughput_time) * 100
classification_percentage_throughput = (classification_time / total_throughput_time) * 100

# Full secure pipeline time
total_secure_pipeline_time = (
    test_data_encryption_time +
    label_encryption_time +
    threshold_encryption_time +
    classification_time
)
throughput_secure_pipeline = len(X_test) / total_secure_pipeline_time

# Start output
print("\n===== Optimized Multi-Key Framework (Chunked OPE + Threaded PRE) =====")
print(f"Available CPU Cores: {multiprocessing.cpu_count()}")

# Timing summary
print("\n===== Execution Time Summary =====")
print(f"Total Execution Time: {effective_total_time:.4f} seconds")
print(f"Dataset Load Time: {dataset_load_time:.4f} seconds ({dataset_load_percentage:.2f}%)")
print(f"Test Data Encryption Time: {test_data_encryption_time:.4f} seconds ({test_data_encryption_percentage:.2f}%)")
print(f"Label Encryption Time (PRE): {label_encryption_time:.4f} seconds ({label_encryption_percentage:.2f}%)")
print(f"Random Forest Training Time: {training_time:.4f} seconds ({rf_training_percentage:.2f}%)")
print(f"Threshold Encryption Time: {threshold_encryption_time:.4f} seconds ({threshold_encryption_percentage:.2f}%)")
print(f"Secure Classification Time: {classification_time:.4f} seconds ({classification_percentage:.2f}%)")

# Classification internals (if recorded)
if "timing_report" in locals():
    print("\n----- Secure Classification Breakdown -----")
    print(f"Tree Traversal Time: {timing_report['traversal_time']:.4f} seconds")
    print(f"Label Access Time: {timing_report['label_access_time']:.4f} seconds")
    print(f"Voting Time: {timing_report['voting_time']:.4f} seconds")
    print(f"Average Time per Sample: {timing_report['avg_per_sample']:.4f} seconds")

# Accuracy & test/train data
print("\n===== Secure Classification Results =====")
print(f"Secure Random Forest Accuracy on Encrypted MNIST: {secure_accuracy:.4f}")
print(f"Number of Decision Trees: {num_estimators}")
print(f"Training Set Size: {len(X_train)} samples")
print(f"Testing Set Size: {len(X_test)} samples")

# Throughput (Inference Only)
print("\n===== Throughput (Inference Only) =====")
print(f"Total Throughput Time (Test Data Encryption + Classification): {total_throughput_time:.4f} seconds")
print(f"Throughput: {len(X_test) / total_throughput_time:.2f} samples/second")
print(f"Encryption % of Throughput Time: {encryption_percentage_throughput:.2f}%")
print(f"Classification % of Throughput Time: {classification_percentage_throughput:.2f}%")

# Full Secure Pipeline Throughput
print("\n===== Full Secure Pipeline Throughput =====")
print(f"Total Time (Test Data + Label + Threshold + Classification): {total_secure_pipeline_time:.4f} seconds")
print(f"Throughput (Secure Pipeline): {throughput_secure_pipeline:.2f} samples/second")

# Key assignment summary
print("\n===== Key Management Summary =====")
print(f"Number of OPE Keys Used: {len(ope_keys)}")
print(f"Number of PRE Public Keys Used: {len(pre_public_keys)}")
print(f"Key Assignment: {len(tree_groups)} OPE/PRE key groups used across {num_estimators} trees")

print("\nTree Group Structure:")
for i, group in enumerate(tree_groups):
    print(f"  Group {i + 1}: {len(group)} trees")



===== Optimized Multi-Key Framework (Chunked OPE + Threaded PRE) =====
Available CPU Cores: 16

===== Execution Time Summary =====
Total Execution Time: 70.6341 seconds
Dataset Load Time: 4.6908 seconds (6.64%)
Test Data Encryption Time: 0.0413 seconds (0.06%)
Label Encryption Time (PRE): 8.5723 seconds (12.14%)
Random Forest Training Time: 0.4559 seconds (0.65%)
Threshold Encryption Time: 0.9041 seconds (1.28%)
Secure Classification Time: 55.9697 seconds (79.24%)

----- Secure Classification Breakdown -----
Tree Traversal Time: 1.1407 seconds
Label Access Time: 0.0425 seconds
Voting Time: 892.1942 seconds
Average Time per Sample: 0.0638 seconds

===== Secure Classification Results =====
Secure Random Forest Accuracy on Encrypted MNIST: 0.8161
Number of Decision Trees: 1
Training Set Size: 56000 samples
Testing Set Size: 14000 samples

===== Throughput (Inference Only) =====
Total Throughput Time (Test Data Encryption + Classification): 56.0111 seconds
Throughput: 249.95 samples/secon