# Milestone 3 - Client side Encryption

This notebook demonstrates the **client-side pipeline** for the ZK-LLMS project:

- Generate **synthetic embeddings** shaped like TinyLlama outputs (e.g. `(seq_len=10, dim=2048)`).
- Create a **CKKS context** with TenSEAL (same parameters as the client).
- **Encrypt** embeddings and build the same JSON **payload** the client sends.
- **Simulate** the server-side `/api/infer` handler:
  - Deserialize and decrypt one encrypted vector.
  - Generate a dummy encrypted result.
- **Decrypt** the server's encrypted response on the client.
- Measure **latency** for each step.

This notebook uses **no external model downloads** and **no network calls**, so it’s safe for public execution.

### Dependencies

In [1]:
# If running on Colab or a fresh environment, you may need:
# !pip install tenseal numpy pyyaml

import time
import base64
from dataclasses import dataclass
from typing import List, Dict, Any

import numpy as np
import tenseal as ts
import yaml


### Configuration & Helpers

In [2]:
# CKKS parameters — aligned with client_config.yaml / server_config.yaml
ckks_config = {
    "poly_modulus_degree": 8192,
    "coeff_mod_bit_sizes": [60, 40, 40, 60],
    "global_scale": 2**40,
}

# Synthetic embedding shape (like TinyLlama last hidden state)
SEQ_LEN = 10       # number of tokens
HIDDEN_DIM = 2048  # embedding dimension


def create_ckks_context(config: Dict[str, Any]) -> ts.Context:
    """Create a TenSEAL CKKS context using the given configuration."""
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=config["poly_modulus_degree"],
        coeff_mod_bit_sizes=config["coeff_mod_bit_sizes"],
    )
    context.global_scale = config["global_scale"]
    context.generate_galois_keys()
    # NOTE: we keep the secret key for the demo (client + server in same notebook)
    return context


### Generate Synthetic Embeddings

In [3]:
np.random.seed(42)  # for reproducibility

start = time.perf_counter()
embeddings = np.random.normal(loc=0.0, scale=0.02, size=(SEQ_LEN, HIDDEN_DIM))
elapsed_ms = (time.perf_counter() - start) * 1000

print(f"Synthetic embeddings shape: {embeddings.shape}")
print(f"Generation time: {elapsed_ms:.2f} ms")


Synthetic embeddings shape: (10, 2048)
Generation time: 0.94 ms


### Create CKKS Context

In [4]:
start = time.perf_counter()
context = create_ckks_context(ckks_config)
elapsed_ms = (time.perf_counter() - start) * 1000

print("CKKS context created.")
print(f"poly_modulus_degree: {ckks_config['poly_modulus_degree']}")
print(f"coeff_mod_bit_sizes: {ckks_config['coeff_mod_bit_sizes']}")
print(f"Context creation time: {elapsed_ms:.2f} ms")


CKKS context created.
poly_modulus_degree: 8192
coeff_mod_bit_sizes: [60, 40, 40, 60]
Context creation time: 112.87 ms


### Encrypt Embeddings & Serialize
Here we mimic your encrypt_embeddings() + serialize_encrypted_vectors() logic.

In [5]:
def encrypt_embeddings(emb: np.ndarray, context: ts.Context) -> List[ts.CKKSVector]:
    """
    Encrypt each token embedding row as a separate CKKS vector.
    emb shape: (seq_len, hidden_dim)
    """
    encrypted = []
    for i in range(emb.shape[0]):
        vec = ts.ckks_vector(context, emb[i].tolist())
        encrypted.append(vec)
    return encrypted


def serialize_encrypted_vectors(vectors: List[ts.CKKSVector]) -> List[str]:
    """
    Serialize CKKS vectors and encode them as base64 strings.
    """
    serialized = []
    for v in vectors:
        raw = v.serialize()
        b64 = base64.b64encode(raw).decode("utf-8")
        serialized.append(b64)
    return serialized


start = time.perf_counter()
encrypted_vectors = encrypt_embeddings(embeddings, context)
serialized_vectors = serialize_encrypted_vectors(encrypted_vectors)
elapsed_ms = (time.perf_counter() - start) * 1000

print(f"Encrypted {len(encrypted_vectors)} token embeddings.")
print(f"Encryption + serialization time: {elapsed_ms:.2f} ms")
print(f"Example serialized ciphertext length: {len(serialized_vectors[0])} characters")


Encrypted 10 token embeddings.
Encryption + serialization time: 46.04 ms
Example serialized ciphertext length: 445884 characters


### Build Client Payload
This mimics your build_payload() function and attaches metadata.

In [6]:
def build_payload(serialized_vecs: List[str],
                  emb_shape: tuple,
                  ckks_cfg: Dict[str, Any],
                  cid: str) -> Dict[str, Any]:
    """
    Build JSON-ready payload, similar to the client implementation.
    """
    payload = {
        "encrypted_embeddings": serialized_vecs,
        "metadata": {
            "cid": cid,
            "embedding_shape": list(emb_shape),
            "ckks": {
                "poly_modulus_degree": ckks_cfg["poly_modulus_degree"],
                "coeff_mod_bit_sizes": ckks_cfg["coeff_mod_bit_sizes"],
                "global_scale": ckks_cfg["global_scale"],
            },
        },
    }
    return payload


import uuid, json

cid = str(uuid.uuid4())
start = time.perf_counter()
payload = build_payload(serialized_vectors, embeddings.shape, ckks_config, cid)
payload_bytes = len(json.dumps(payload).encode("utf-8"))
elapsed_ms = (time.perf_counter() - start) * 1000

print(f"Payload built with CID: {cid}")
print(f"Payload size: {payload_bytes} bytes")
print(f"Build time: {elapsed_ms:.2f} ms")


Payload built with CID: c14942de-3d41-48f7-a5f1-2d7b98db2317
Payload size: 4457845 bytes
Build time: 11.07 ms


### Mock Server /api/infer Logic
This simulates your FastAPI infer() handler (but locally).

In [7]:
def mock_server_infer(payload: Dict[str, Any], context: ts.Context) -> Dict[str, Any]:
    """
    Simulated server-side handler for /api/infer.

    - Deserializes the first encrypted embedding.
    - Decrypts it (for debugging / testing).
    - Prints diagnostics.
    - Returns a dummy encrypted result as the server would.
    """
    start = time.perf_counter()

    encrypted_embeddings = payload.get("encrypted_embeddings", [])
    metadata = payload.get("metadata", {})

    # Deserialize & decrypt just the first embedding for demonstration
    if not encrypted_embeddings:
        raise ValueError("No encrypted embeddings in payload")

    first_b64 = encrypted_embeddings[0]
    enc_bytes = base64.b64decode(first_b64)
    enc_vec = ts.ckks_vector_from(context, enc_bytes)
    decrypted = enc_vec.decrypt()

    print(f"[SERVER] Received {len(encrypted_embeddings)} encrypted embeddings.")
    print(f"[SERVER] Example decrypted slice (first 5 values): {decrypted[:5]}")

    # Dummy encrypted result (like your server)
    dummy_result = [0.1, 0.2, 0.3]
    enc_result = ts.ckks_vector(context, dummy_result)
    result_b64 = base64.b64encode(enc_result.serialize()).decode("utf-8")

    elapsed_ms = (time.perf_counter() - start) * 1000
    print(f"[SERVER] Inference (mock) time: {elapsed_ms:.2f} ms")

    return {"encrypted_result": result_b64}


### Simulate Client → Server → Client Roundtrip

In [8]:
# Simulate the "HTTP POST" by calling mock_server_infer directly
start_total = time.perf_counter()

start = time.perf_counter()
server_response = mock_server_infer(payload, context)
server_time_ms = (time.perf_counter() - start) * 1000

# Client decrypts server's encrypted_result
enc_result_b64 = server_response["encrypted_result"]
enc_bytes = base64.b64decode(enc_result_b64)
enc_vec = ts.ckks_vector_from(context, enc_bytes)
decrypted_result = enc_vec.decrypt()
roundtrip_time_ms = (time.perf_counter() - start_total) * 1000

print(f"\n[CLIENT] Decrypted server result: {decrypted_result}")
print(f"[CLIENT] Mock server time: {server_time_ms:.2f} ms")
print(f"[CLIENT] End-to-end (client→server→client) time: {roundtrip_time_ms:.2f} ms")


[SERVER] Received 10 encrypted embeddings.
[SERVER] Example decrypted slice (first 5 values): [0.009934283129673415, -0.0027652855974043228, 0.01295377051659382, 0.030460597502334052, -0.004683065634977897]
[SERVER] Inference (mock) time: 9.57 ms

[CLIENT] Decrypted server result: [0.10000000046121574, 0.19999999664931548, 0.3000000005207714]
[CLIENT] Mock server time: 9.76 ms
[CLIENT] End-to-end (client→server→client) time: 12.45 ms


### Summary & Notes

In this notebook, we demonstrated:

- Generating **synthetic embeddings** shaped like TinyLlama outputs (`(10, 2048)`).
- Creating a **CKKS context** with the same parameters used by the client and server.
- Encrypting each token embedding row as a **CKKS vector** and serializing it to base64.
- Building a **JSON payload** with:
  - `encrypted_embeddings`: list of ciphertexts
  - `metadata`: `cid`, embedding shape, CKKS parameters
- Simulating the **server-side `/api/infer`**:
  - Deserializing and decrypting an embedding.
  - Returning a dummy encrypted result.
- Performing **client-side decryption** of the server's response.
- Measuring **latency** for encryption, payload building, and mock inference.

This notebook mirrors the **client-side logic** from the Phase 2 implementation
but removes external dependencies (TinyLlama model, network calls),
making it suitable as a **public, reproducible milestone deliverable**.
