In [None]:
import tenseal as ts
import torch
from model import SimpleNN


def encrypted_inference(model, X_samples, scaler):
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60]
    )
    context.generate_galois_keys()
    context.global_scale = 2**40

    model.eval()
    for i, sample in enumerate(X_samples):
        encrypted_sample = ts.ckks_vector(context, sample)
        decrypted = torch.tensor(encrypted_sample.decrypt())
        output = model(decrypted.unsqueeze(0))
        pred = torch.argmax(output, dim=1).item()
        print(f"Sample {i+1} ➜ Predicted: {pred}")
