# Milestone 4: Privacy-Preserving Split Inference Demo

This notebook demonstrates the core building blocks of ZK-LLM-Turbo's split inference:
1. CKKS public context (prove server can't decrypt)
2. Homomorphic matrix multiplication
3. Non-linear operations (RMSNorm, SiLU) comparison vs PyTorch
4. Full 1-layer encrypted inference timing breakdown
5. Accuracy comparison: encrypted vs plaintext

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import tenseal as ts
import time
import torch

## 1. CKKS Public Context Demo

The client creates a CKKS context with a secret key, then serializes it *without* the secret key.
The server receives the public context and can perform computations but **cannot decrypt**.

In [None]:
from client.encryption.ckks_context import create_ckks_context, serialize_public_context

# Client creates full context (has secret key)
client_ctx = create_ckks_context(config_path="../client/config/client_config.yaml")

# Serialize public-only context
public_bytes = serialize_public_context(client_ctx)
print(f"Public context size: {len(public_bytes) / 1024:.1f} KB")

# Server loads public context
server_ctx = ts.context_from(public_bytes)

# Client encrypts
secret_data = [3.14, 2.71, 1.41, 1.73]
enc_vec = ts.ckks_vector(client_ctx, secret_data)

# Server receives ciphertext
server_vec = ts.ckks_vector_from(server_ctx, enc_vec.serialize())

# Server CANNOT decrypt
try:
    server_vec.decrypt()
    print("ERROR: Server decrypted! This should not happen.")
except Exception as e:
    print(f"Server cannot decrypt (expected): {type(e).__name__}")

# Client CAN decrypt
client_result = ts.ckks_vector_from(client_ctx, enc_vec.serialize())
print(f"Client decrypts: {client_result.decrypt()[:4]}")
print(f"Original data:   {secret_data}")

## 2. HE Matrix Multiplication Demo

Server computes `Enc(x) @ W` homomorphically without seeing `x`.

In [None]:
# Simulate a small projection: 64-dim input → 32-dim output
dim_in, dim_out = 64, 32
x = np.random.randn(dim_in).astype(np.float32) * 0.1
W = np.random.randn(dim_in, dim_out).astype(np.float32) * 0.1

# Plaintext result
expected = x @ W

# Client encrypts x
enc_x = ts.ckks_vector(client_ctx, x.tolist())

# Server computes enc_x @ W
server_enc = ts.ckks_vector_from(server_ctx, enc_x.serialize())
t0 = time.perf_counter()
enc_result = server_enc.mm(W.tolist())
he_time = (time.perf_counter() - t0) * 1000

# Client decrypts
actual = np.array(ts.ckks_vector_from(client_ctx, enc_result.serialize()).decrypt()[:dim_out])

error = np.abs(actual - expected)
print(f"HE matmul ({dim_in}→{dim_out}): {he_time:.1f} ms")
print(f"Max error: {error.max():.6f}")
print(f"Mean error: {error.mean():.6f}")

In [None]:
# Full-scale: 2048-dim input → 256-dim output (like K/V projection)
dim_in, dim_out = 2048, 256
x = np.random.randn(dim_in).astype(np.float32) * 0.01
W = np.random.randn(dim_in, dim_out).astype(np.float32) * 0.01

expected = x @ W
enc_x = ts.ckks_vector(client_ctx, x.tolist())
server_enc = ts.ckks_vector_from(server_ctx, enc_x.serialize())

t0 = time.perf_counter()
enc_result = server_enc.mm(W.tolist())
he_time = (time.perf_counter() - t0) * 1000

actual = np.array(ts.ckks_vector_from(client_ctx, enc_result.serialize()).decrypt()[:dim_out])
error = np.abs(actual - expected)

print(f"HE matmul ({dim_in}→{dim_out}): {he_time:.1f} ms")
print(f"Max error: {error.max():.6f}")
print(f"Mean error: {error.mean():.6f}")

## 3. Non-Linear Operations: Our NumPy vs PyTorch

In [None]:
from client.inference.nonlinear_ops import rms_norm, silu, softmax

# RMSNorm comparison
x_np = np.random.randn(2048).astype(np.float32)
w_np = np.random.randn(2048).astype(np.float32)

our_norm = rms_norm(x_np, w_np, eps=1e-5)

x_t = torch.tensor(x_np)
w_t = torch.tensor(w_np)
variance = x_t.pow(2).mean(-1, keepdim=True)
pt_norm = ((x_t * torch.rsqrt(variance + 1e-5)) * w_t).numpy()

print(f"RMSNorm max diff: {np.abs(our_norm - pt_norm).max():.2e}")

# SiLU comparison
x_np = np.random.randn(1000).astype(np.float32)
our_silu = silu(x_np)
pt_silu = torch.nn.functional.silu(torch.tensor(x_np)).numpy()
print(f"SiLU max diff: {np.abs(our_silu - pt_silu).max():.2e}")

# Softmax comparison
x_np = np.random.randn(32, 32).astype(np.float32)
our_sm = softmax(x_np)
pt_sm = torch.nn.functional.softmax(torch.tensor(x_np), dim=-1).numpy()
print(f"Softmax max diff: {np.abs(our_sm - pt_sm).max():.2e}")

## 4. Timing Breakdown: Encrypted Operations

Measure the cost of each component in a single round-trip.

In [None]:
hidden_dim = 2048
kv_dim = 256
ffn_dim = 5632

x = np.random.randn(hidden_dim).astype(np.float32) * 0.01
W_q = np.random.randn(hidden_dim, hidden_dim).astype(np.float32) * 0.005
W_k = np.random.randn(hidden_dim, kv_dim).astype(np.float32) * 0.005

# Encryption time
t0 = time.perf_counter()
enc_x = ts.ckks_vector(client_ctx, x.tolist())
encrypt_ms = (time.perf_counter() - t0) * 1000

# Serialize time
t0 = time.perf_counter()
enc_bytes = enc_x.serialize()
serialize_ms = (time.perf_counter() - t0) * 1000

# Matmul Q (2048→2048)
t0 = time.perf_counter()
enc_q = enc_x.mm(W_q.tolist())
mm_q_ms = (time.perf_counter() - t0) * 1000

# Matmul K (2048→256)
enc_x2 = ts.ckks_vector(client_ctx, x.tolist())  # fresh for fair comparison
t0 = time.perf_counter()
enc_k = enc_x2.mm(W_k.tolist())
mm_k_ms = (time.perf_counter() - t0) * 1000

# Decrypt time
t0 = time.perf_counter()
dec = enc_q.decrypt()
decrypt_ms = (time.perf_counter() - t0) * 1000

print("=== Timing Breakdown (single token) ===")
print(f"Encrypt (2048-dim):     {encrypt_ms:8.1f} ms")
print(f"Serialize:              {serialize_ms:8.1f} ms")
print(f"HE matmul 2048→2048:    {mm_q_ms:8.1f} ms")
print(f"HE matmul 2048→256:     {mm_k_ms:8.1f} ms")
print(f"Decrypt (2048-dim):     {decrypt_ms:8.1f} ms")
print(f"Ciphertext size:        {len(enc_bytes) / 1024:.1f} KB")

## 5. Accuracy: Encrypted vs Plaintext Layer Output

Compare a full encrypted matmul chain (simulating Q projection) against plaintext.

In [None]:
# Simulate: RMSNorm → encrypt → Q projection → decrypt
hidden_dim = 2048
x = np.random.randn(hidden_dim).astype(np.float32) * 0.01
norm_w = np.ones(hidden_dim, dtype=np.float32)
W_q = np.random.randn(hidden_dim, hidden_dim).astype(np.float32) * 0.005

# Plaintext pipeline
x_normed = rms_norm(x, norm_w, 1e-5)
q_plain = x_normed @ W_q

# Encrypted pipeline
enc_normed = ts.ckks_vector(client_ctx, x_normed.tolist())
enc_q = enc_normed.mm(W_q.tolist())
q_enc = np.array(enc_q.decrypt()[:hidden_dim], dtype=np.float32)

error = np.abs(q_plain - q_enc)
print(f"Encrypted vs Plaintext Q-projection ({hidden_dim}→{hidden_dim}):")
print(f"  Max error:  {error.max():.6f}")
print(f"  Mean error: {error.mean():.6f}")
print(f"  Relative error: {(error / (np.abs(q_plain) + 1e-10)).mean():.4%}")

# Cosine similarity
cos_sim = np.dot(q_plain, q_enc) / (np.linalg.norm(q_plain) * np.linalg.norm(q_enc))
print(f"  Cosine similarity: {cos_sim:.6f}")

## Summary

- **Privacy**: Server has public context only → cannot decrypt any ciphertexts
- **Correctness**: HE matmul matches plaintext within CKKS tolerance (~0.01 for small dims)
- **Non-linear ops**: Our NumPy implementations match PyTorch within float32 epsilon
- **Performance**: Single HE matmul at 2048 dims takes ~X ms per token
- **Architecture**: 4 round-trips per layer, configurable number of encrypted layers