# PPCM-X: Encrypted Inference Demo

This notebook demonstrates the complete pipeline for privacy-preserving CNN inference using homomorphic encryption.

**Based on:** "Enhancing Privacy in Deep Neural Networks: Techniques and Applications" (Raj et al., IEEE INDIACOM 2025)

**Novel Extension:** Adaptive Polynomial Activations (APA)

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# PPCM-X modules
from src.data_loader import get_data_loaders, get_sample_batch, get_dataset_info
from src.model_plain import get_model, count_parameters, PPCM_CNN, PPCM_X_CNN
from src.activations_poly import PolyReLU, AdaptivePolyActivation, PolynomialCoefficients
from src.he_utils import HEContext, benchmark_he_operations
from src.model_encrypted import EncryptedPPCM, HybridEncryptedModel

print("PPCM-X modules loaded successfully!")

## 1. Polynomial Activation Approximations

Visualize how polynomial functions approximate ReLU for HE compatibility.

In [None]:
# Compare polynomial approximations to true ReLU
x = torch.linspace(-3, 3, 200)
true_relu = torch.relu(x)

plt.figure(figsize=(12, 4))

for i, degree in enumerate([2, 3, 4]):
    plt.subplot(1, 3, i+1)
    
    poly_relu = PolyReLU(degree=degree)
    approx = poly_relu(x)
    
    plt.plot(x.numpy(), true_relu.numpy(), 'b-', label='True ReLU', linewidth=2)
    plt.plot(x.numpy(), approx.detach().numpy(), 'r--', label=f'Poly (deg={degree})', linewidth=2)
    
    mse = poly_relu.approximation_error(x).item()
    plt.title(f'Degree {degree} (MSE: {mse:.4f})')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(-1, 4)

plt.tight_layout()
plt.savefig('../experiments/metrics_plots/polynomial_approximations.png', dpi=150)
plt.show()

## 2. Model Architecture Comparison

In [None]:
# Compare model architectures
models = {
    'PPCM (Base)': get_model('ppcm', 'mnist'),
    'PPCM-X (Fixed Poly)': get_model('ppcm_x', 'mnist', adaptive_activation=False, poly_degree=3),
    'PPCM-X (Adaptive)': get_model('ppcm_x', 'mnist', adaptive_activation=True),
}

print("Model Comparison:")
print("="*50)
for name, model in models.items():
    params = count_parameters(model)
    print(f"{name}: {params:,} parameters")

# Test forward pass
x = torch.randn(1, 1, 28, 28)
for name, model in models.items():
    with torch.no_grad():
        out = model(x)
    print(f"{name} output shape: {out.shape}")

## 3. Adaptive Activation Analysis

In [None]:
# Analyze adaptive activation behavior
model = get_model('ppcm_x', 'mnist', adaptive_activation=True)

# Get sample data
train_loader, _, _ = get_data_loaders('mnist', batch_size=32)
x_batch, _ = next(iter(train_loader))

# Get activation statistics
stats = model.get_activation_stats(x_batch)
print("Adaptive Activation Degrees Selected:")
for layer, degree in stats.items():
    print(f"  {layer}: degree {degree}")

## 4. HE Context and Benchmarks

In [None]:
# Benchmark HE operations
presets = ['fast', 'balanced', 'accurate']
benchmark_results = {}

for preset in presets:
    print(f"\nBenchmarking {preset} preset...")
    context = HEContext(preset=preset)
    results = benchmark_he_operations(context, vector_size=784, num_iterations=5)
    benchmark_results[preset] = results
    
    print(f"  Encrypt: {results['encrypt_ms']:.2f} ms")
    print(f"  Decrypt: {results['decrypt_ms']:.2f} ms")
    print(f"  Add: {results['add_ms']:.2f} ms")
    print(f"  Multiply: {results['mul_ms']:.2f} ms")
    print(f"  Poly Eval: {results['poly_eval_ms']:.2f} ms")

In [None]:
# Visualize benchmark results
operations = ['encrypt_ms', 'decrypt_ms', 'add_ms', 'mul_ms', 'poly_eval_ms']
op_labels = ['Encrypt', 'Decrypt', 'Add', 'Multiply', 'Poly Eval']

fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(operations))
width = 0.25

for i, preset in enumerate(presets):
    values = [benchmark_results[preset][op] for op in operations]
    ax.bar(x + i*width, values, width, label=preset.capitalize())

ax.set_ylabel('Time (ms)')
ax.set_title('HE Operation Benchmarks by Preset')
ax.set_xticks(x + width)
ax.set_xticklabels(op_labels)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('../experiments/metrics_plots/he_benchmarks.png', dpi=150)
plt.show()

## 5. Encrypted Inference Demo

In [None]:
# Create and test encrypted model
model = get_model('ppcm_x', 'mnist', adaptive_activation=True)
encrypted_model = EncryptedPPCM(model, he_preset='fast')

# Get test sample
x, y = get_sample_batch('mnist', batch_size=1)

print(f"Input shape: {x.shape}")
print(f"True label: {y.item()}")

# Run encrypted inference
import time
start = time.time()
enc_output = encrypted_model(x)
inference_time = time.time() - start

prediction = enc_output.argmax(dim=-1).item()
print(f"\nEncrypted prediction: {prediction}")
print(f"Inference time: {inference_time*1000:.2f} ms")

# Show timing breakdown
timing = encrypted_model.get_timing_breakdown()
print("\nLayer timing breakdown:")
for layer, t in timing.items():
    print(f"  {layer}: {t*1000:.2f} ms")

## 6. Plaintext vs Encrypted Comparison

In [None]:
# Compare plaintext and encrypted outputs
hybrid = HybridEncryptedModel(model)

# Test on multiple samples
_, _, test_loader = get_data_loaders('mnist', batch_size=1)

comparisons = []
for i, (x, y) in enumerate(test_loader):
    if i >= 20:
        break
    
    result = hybrid.compare_outputs(x)
    comparisons.append({
        'label': y.item(),
        'plain_pred': result['plain_pred'][0],
        'enc_pred': result['encrypted_pred'][0],
        'mse': result['mse'],
        'match': result['predictions_match']
    })

# Summary
match_rate = sum(c['match'] for c in comparisons) / len(comparisons)
avg_mse = np.mean([c['mse'] for c in comparisons])

print(f"Prediction match rate: {match_rate*100:.1f}%")
print(f"Average output MSE: {avg_mse:.6f}")

In [None]:
# Visualize MSE distribution
mse_values = [c['mse'] for c in comparisons]

plt.figure(figsize=(8, 4))
plt.hist(mse_values, bins=20, edgecolor='black', alpha=0.7)
plt.xlabel('Output MSE (Plaintext vs Encrypted)')
plt.ylabel('Count')
plt.title('Distribution of Output Differences')
plt.axvline(np.mean(mse_values), color='r', linestyle='--', label=f'Mean: {np.mean(mse_values):.6f}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../experiments/metrics_plots/mse_distribution.png', dpi=150)
plt.show()

## 7. Visualize Sample Predictions

In [None]:
# Visualize some predictions
fig, axes = plt.subplots(2, 5, figsize=(12, 5))

for i, (x, y) in enumerate(test_loader):
    if i >= 10:
        break
    
    ax = axes[i // 5, i % 5]
    
    # Get predictions
    with torch.no_grad():
        plain_pred = model(x).argmax().item()
    enc_pred = encrypted_model(x).argmax().item()
    
    # Plot
    ax.imshow(x.squeeze().numpy(), cmap='gray')
    ax.set_title(f'True: {y.item()}\nPlain: {plain_pred}, Enc: {enc_pred}')
    ax.axis('off')

plt.tight_layout()
plt.savefig('../experiments/metrics_plots/sample_predictions.png', dpi=150)
plt.show()

## 8. Summary

This demo showed:
1. Polynomial activation approximations for HE compatibility
2. PPCM-X model with adaptive activations
3. HE operation benchmarks across different parameter presets
4. Encrypted inference pipeline
5. Comparison between plaintext and encrypted outputs

The PPCM-X framework enables privacy-preserving inference with minimal accuracy loss.