In [1]:
print("Checking if Python is working")

Checking if Python is working


In [16]:
import numpy as np
from Pyfhel import Pyfhel
print("1. Import Pyfhel class, and numpy for the inputs to encrypt.")

HE = Pyfhel()           # Creating empty Pyfhel object
HE.contextGen(scheme='bfv', n=2**14, t_bits=20)  # Generate context for 'bfv'/'ckks' scheme
                        # The n defines the number of plaintext slots.
                        #  There are many configurable parameters on this step
                        #  More info in Demo_2, Demo_3, and Pyfhel.contextGen()
HE.keyGen()             # Key Generation: generates a pair of public/secret keys

print("2. Context and key setup")
print(HE)

integer1 = np.array([127], dtype=np.int64)
integer2 = np.array([-2], dtype=np.int64)
ctxt1 = HE.encryptInt(integer1) # Encryption makes use of the public key
ctxt2 = HE.encryptInt(integer2) # For integers, encryptInt function is used.
print("3. Integer Encryption, ")
print("    int ",integer1,'-> ctxt1 ', type(ctxt1))
print("    int ",integer2,'-> ctxt2 ', type(ctxt2))

print(ctxt1)
print(ctxt2)


ctxtSum = ctxt1 + ctxt2         # ctxt1 += ctxt2 for inplace operation
ctxtSub = ctxt1 - ctxt2         # ctxt1 -= ctxt2 for inplace operation
ctxtMul = ctxt1 * ctxt2         # ctxt1 *= ctxt2 for inplace operation
print("4. Operating with encrypted integers")
print(f"Sum: {ctxtSum}")
print(f"Sub: {ctxtSub}")
print(f"Mult:{ctxtMul}")


resSum = HE.decryptInt(ctxtSum) # Decryption must use the corresponding function
                                #  decryptInt.
resSub = HE.decryptInt(ctxtSub)
resMul = HE.decryptInt(ctxtMul)
print("#. Decrypting result:")
print("addition: decrypt(ctxt1 + ctxt2) =  ", resSum)
print(" substraction:   decrypt(ctxt1 - ctxt2) =  ", resSub)
print("Multiplication: decrypt(ctxt1 - ctxt2) =  ", resMul)


1. Import Pyfhel class, and numpy for the inputs to encrypt.
2. Context and key setup
<bfv Pyfhel obj at 0x77aeea762790, [pk:Y, sk:Y, rtk:-, rlk:-, contx(n=16384, t=786433, sec=128, qi=[], scale=1.0, )]>
3. Integer Encryption, 
    int  [127] -> ctxt1  <class 'Pyfhel.PyCtxt.PyCtxt'>
    int  [-2] -> ctxt2  <class 'Pyfhel.PyCtxt.PyCtxt'>
<Pyfhel Ciphertext at 0x77aeea7918f0, scheme=bfv, size=2/2, noiseBudget=361>
<Pyfhel Ciphertext at 0x77aeea791940, scheme=bfv, size=2/2, noiseBudget=361>
4. Operating with encrypted integers
Sum: <Pyfhel Ciphertext at 0x77aeea760cc0, scheme=bfv, size=2/2, noiseBudget=360>
Sub: <Pyfhel Ciphertext at 0x77aeea763f60, scheme=bfv, size=2/2, noiseBudget=360>
Mult:<Pyfhel Ciphertext at 0x77aeea793420, scheme=bfv, size=3/3, noiseBudget=328>
#. Decrypting result:
addition: decrypt(ctxt1 + ctxt2) =   [125   0   0 ...   0   0   0]
 substraction:   decrypt(ctxt1 - ctxt2) =   [129   0   0 ...   0   0   0]
Multiplication: decrypt(ctxt1 - ctxt2) =   [-254    0    0 ..

In [18]:
from Pyfhel import Pyfhel
import numpy as np

HE = Pyfhel()
HE.contextGen(scheme="BFV", n=8192, t=65537, sec=128)
HE.keyGen()

plaintext1 = 123
plaintext2 = 456

ciphertext1 = HE.encryptInt(np.array([plaintext1], dtype=np.int64))
ciphertext2 = HE.encryptInt(np.array([plaintext2], dtype=np.int64))

print("Encrypted ciphertexts created successfully.")
ciphertext_sum = ciphertext1 + ciphertext2
ciphertext_sub = ciphertext1 - ciphertext2
ciphertext_mul = ciphertext1 * ciphertext2
print("Homomorphic operations performed successfully.")

Encrypted ciphertexts created successfully.
Homomorphic operations performed successfully.


In [19]:
print("Decrypting results...")
resSum = HE.decryptInt(ciphertext_sum)[0]
resSub = HE.decryptInt(ciphertext_sub)[0]
resMul = HE.decryptInt(ciphertext_mul)[0]
print("Decryption completed.")

Decrypting results...
Decryption completed.


In [17]:
import os, sys, time, pickle, math, random
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
import phe
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image

# Homomorphic Encryption (Paillier)
from phe import paillier

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---- Configure your dataset root here ----
# Example Windows path from your screenshot (adjust exactly to your folder):
DATA_DIR = r"/home/amint/HEP/Brain Tumor MRI/Training"
# If running on Linux/Mac, set e.g.:
# DATA_DIR = "/path/to/Brain Tumor MRI/Training"

assert os.path.exists(DATA_DIR), f"DATA_DIR not found: {DATA_DIR}"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [18]:

IMG_SIZE = 128
BATCH_SIZE = 32

tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

full_ds = datasets.ImageFolder(DATA_DIR, transform=tfm)
classes = full_ds.classes
n_classes = len(classes)

idxs = list(range(len(full_ds)))
labels = [full_ds[i][1] for i in idxs]

# 70/15/15 split
X_tr, X_tmp, y_tr, y_tmp = train_test_split(idxs, labels, test_size=0.30, stratify=labels, random_state=SEED)
X_val, X_te, y_val, y_te = train_test_split(X_tmp, y_tmp, test_size=0.50, stratify=y_tmp, random_state=SEED)

def split_into_hospitals(X_tr, y_tr, k=4):
    # round-robin per class for a balanced-ish split
    per_cls = {}
    for i, y in zip(X_tr, y_tr):
        per_cls.setdefault(y, []).append(i)
    for v in per_cls.values():
        random.shuffle(v)
    shards = [[] for _ in range(k)]
    i = 0
    for _, lst in per_cls.items():
        for idx in lst:
            shards[i % k].append(idx); i += 1
    return shards

shards = split_into_hospitals(X_tr, y_tr, k=4)

def make_loader(indices, shuffle, bs=BATCH_SIZE):
    return DataLoader(Subset(full_ds, indices), batch_size=bs, shuffle=shuffle, num_workers=2, pin_memory=True)

client_loaders = [make_loader(shards[i], True) for i in range(4)]
val_loader = make_loader(X_val, False)
test_loader = make_loader(X_te, False)

[len(s) for s in shards], len(val_loader.dataset), len(test_loader.dataset), classes


([1000, 1000, 999, 999],
 857,
 857,
 ['glioma', 'meningioma', 'notumor', 'pituitary'])

In [20]:
print(X_tr, X_tmp, y_tr, y_tmp)
print(X_val, X_te, y_val, y_te)

[568, 2052, 1861, 4371, 2338, 3733, 2032, 3282, 2596, 5249, 535, 5291, 2911, 4377, 3808, 89, 3008, 250, 2560, 4198, 1778, 2108, 1876, 1158, 375, 415, 2585, 29, 583, 1974, 4771, 5225, 3878, 1919, 2475, 4335, 5499, 4168, 5175, 2843, 3685, 3396, 4730, 2679, 5328, 179, 4138, 4846, 1737, 1766, 428, 284, 497, 5701, 1987, 4153, 5678, 881, 4097, 3471, 5630, 610, 3150, 521, 2759, 3276, 3190, 22, 2996, 1137, 4300, 1975, 321, 1168, 953, 893, 3539, 2871, 5633, 4966, 4416, 2500, 4072, 4721, 1089, 5540, 1572, 3723, 1978, 2706, 1647, 2715, 2408, 3936, 3457, 2883, 346, 1920, 1751, 1702, 2646, 5148, 3259, 5103, 331, 5056, 3378, 3182, 1509, 4366, 367, 3472, 3611, 3239, 5573, 3797, 5346, 2630, 3130, 444, 1854, 2447, 211, 4757, 3510, 4784, 43, 4293, 109, 2618, 290, 3406, 1314, 1872, 3783, 915, 1663, 4878, 5031, 5644, 449, 4203, 2174, 4866, 3266, 4610, 3939, 657, 2184, 5622, 1461, 4675, 2119, 431, 1459, 937, 213, 1481, 108, 1041, 2453, 368, 3614, 5258, 4057, 2920, 4042, 3653, 3018, 3708, 5181, 1730, 4488, 

In [5]:
class TinyCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),   # 128->64
            nn.Conv2d(16,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),  # 64->32
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # -> (64,1,1)
        )
        self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(64, num_classes))
    def forward(self, x): return self.classifier(self.features(x))

def train_one_epoch(m, loader, opt, crit):
    m.train(); tot=0.0; cor=0; n=0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(); lg = m(xb); loss = crit(lg, yb); loss.backward(); opt.step()
        tot += loss.item()*xb.size(0); cor += (lg.argmax(1)==yb).sum().item(); n += xb.size(0)
    return tot/n, cor/n

@torch.no_grad()
def evaluate(m, loader, crit=None):
    m.eval(); tot=0.0; cor=0; n=0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        lg = m(xb)
        if crit is not None: tot += crit(lg, yb).item()*xb.size(0)
        cor += (lg.argmax(1)==yb).sum().item(); n += xb.size(0)
    return (tot/n if crit is not None else None), cor/n

def get_params(m): return [v.detach().cpu().numpy() for _, v in m.state_dict().items()]

def set_params(m, nds):
    sd = m.state_dict(); new_sd = {}
    for (k, _), arr in zip(sd.items(), nds): new_sd[k] = torch.tensor(arr)
    m.load_state_dict(new_sd, strict=True)

def w2v(nds): return np.concatenate([w.reshape(-1) for w in nds]).astype(np.float32)

def v2w(vec, template):
    out, p = [], 0
    for w in template:
        sz = w.size; out.append(vec[p:p+sz].reshape(w.shape).astype(np.float32)); p += sz
    return out


In [6]:
# !pip install Pyfhel

In [7]:
# from Pyfhel import Pyfhel, PyCtxt
import Pyfhel
from Pyfhel import PyCtxt
          # Creating empty Pyfhel object

In [8]:
def he_setup(scheme="CKKS"):
    he = Pyfhel()
    if scheme.upper() == "CKKS":
        # 128-bit-ish security; tune as needed
        he.contextGen(scheme="CKKS", n=2**14, scale=2**30, qi_sizes=[60,40,40,60])
        he.keyGen(); he.relinKeyGen(); he.rotateKeyGen()
        slots = (2**14)//2   # CKKS packs n/2 complex slots
    elif scheme.upper() == "BFV":
        he.contextGen(scheme="BFV", n=2**14, t_bits=20)  # t_bits controls plaintext modulus
        he.keyGen(); he.relinKeyGen(); he.rotateKeyGen()
        slots = 2**14        # BFV batching slots
    else:
        raise ValueError("scheme must be 'CKKS' or 'BFV'")
    return he, slots

def he_encrypt_update(he, vec: np.ndarray, scheme="CKKS", scale=1e6, slots=4096):
    """Pack vector into ciphertext chunks. BFV quantizes to int."""
    cts = []; byte_count = 0
    if scheme.upper() == "CKKS":
        for i in range(0, len(vec), slots):
            chunk = vec[i:i+slots].astype(np.float64)
            c = he.encryptFrac(chunk)          # encode+encrypt
            cts.append(c); byte_count += len(c.to_bytes())
    else:
        ints = np.round(vec*scale).astype(np.int64)
        for i in range(0, len(ints), slots):
            c = he.encryptInt(ints[i:i+slots]) # encode (batch)+encrypt
            cts.append(c); byte_count += len(c.to_bytes())
    return cts, byte_count

def he_sum_ciphertexts(ct_lists: List[List[PyCtxt]]):
    """Element-wise sum across clients (assumes equal chunking)."""
    agg = [ct_lists[0][j] for j in range(len(ct_lists[0]))]
    for k in range(1, len(ct_lists)):
        for j in range(len(agg)): agg[j] = agg[j] + ct_lists[k][j]
    return agg

def he_decrypt_sum(he, agg_cts: List[PyCtxt], scheme="CKKS", scale=1e6, total_len=None):
    outs = []
    if scheme.upper() == "CKKS":
        for c in agg_cts: outs.append(he.decryptFrac(c))
        vec = np.concatenate(outs).astype(np.float64)
    else:
        for c in agg_cts: outs.append(he.decryptInt(c).astype(np.int64))
        vec = (np.concatenate(outs).astype(np.float64))/scale
    if total_len is not None and len(vec) > total_len: vec = vec[:total_len]
    return vec.astype(np.float32)


In [9]:
# Device + quick sanity check (run once)
import time, numpy as np, pandas as pd
import torch
import Pyfhel
from Pyfhel import Pyfhel, PyCtxt  # ensure the CLASS is imported (not the module)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", DEVICE)

# Optional: check that earlier cells were run
needed = ["client_loaders","test_loader","TinyCNN","evaluate","train_one_epoch",
          "get_params","set_params","w2v","v2w","he_setup","he_encrypt_update",
          "he_sum","he_decrypt","n_classes"]
missing = [n for n in needed if n not in globals()]
print("Missing definitions:", missing)  # should print []


Using cuda
Missing definitions: []


In [10]:
# === Homomorphic Encryption helpers (CKKS/BFV) — run once ===
from Pyfhel import Pyfhel, PyCtxt
import numpy as np

def he_setup(scheme="CKKS", n=2**14, scale=2**30, t_bits=20, qi_sizes=[60,40,40,60]):
    """Create a Pyfhel context + keys. Returns (he, slots)."""
    he = Pyfhel()
    if scheme.upper() == "CKKS":
        he.contextGen(scheme="CKKS", n=n, scale=scale, qi_sizes=qi_sizes)
        he.keyGen()  # public + secret
        slots = n // 2            # CKKS packs n/2 complex slots
    elif scheme.upper() == "BFV":
        he.contextGen(scheme="BFV", n=n, t_bits=t_bits)
        he.keyGen()
        slots = n                 # BFV batching slots ~ n
    else:
        raise ValueError("scheme must be 'CKKS' or 'BFV'")
    return he, slots

def he_encrypt_update(he, vec: np.ndarray, scheme="CKKS", scale=1e6, slots=4096):
    """Chunk-pack and encrypt a model-update vector. Returns (ciphertexts, total_bytes)."""
    cts = []; bytes_ = 0
    if scheme.upper() == "CKKS":
        for i in range(0, len(vec), slots):
            chunk = vec[i:i+slots].astype(np.float64)
            c = he.encryptFrac(chunk)         # encode + encrypt
            cts.append(c); bytes_ += len(c.to_bytes())
    else:  # BFV: quantize to ints
        ints = np.round(vec * scale).astype(np.int64)
        for i in range(0, len(ints), slots):
            c = he.encryptInt(ints[i:i+slots])
            cts.append(c); bytes_ += len(c.to_bytes())
    return cts, bytes_

def he_sum(ct_lists):
    """Element-wise homomorphic sum across clients’ ciphertext lists."""
    assert len(ct_lists) > 0, "he_sum: empty ciphertext list"
    agg = [ct_lists[0][j] for j in range(len(ct_lists[0]))]
    for k in range(1, len(ct_lists)):
        for j in range(len(agg)):
            agg[j] = agg[j] + ct_lists[k][j]
    return agg

def he_decrypt(he, agg_cts, scheme="CKKS", scale=1e6, total_len=None):
    """Decrypt aggregated ciphertexts back to a float vector."""
    outs = []
    if scheme.upper() == "CKKS":
        for c in agg_cts:
            outs.append(he.decryptFrac(c))
        vec = np.concatenate(outs).astype(np.float64)
    else:  # BFV → ints → dequantize
        ints = []
        for c in agg_cts:
            ints.append(he.decryptInt(c).astype(np.int64))
        vec = (np.concatenate(ints).astype(np.float64)) / scale
    if total_len is not None and len(vec) > total_len:
        vec = vec[:total_len]
    return vec.astype(np.float32)


In [12]:
from Pyfhel import Pyfhel, PyCtxt

# CKKS setup: deeper qi_sizes to avoid the warning; still addition-only
def he_setup_ckks(n=2**14, scale=2**30, qi_sizes=[60, 40, 40, 40, 60]):
    he = Pyfhel()
    he.contextGen(scheme="CKKS", n=n, scale=scale, qi_sizes=qi_sizes)
    he.keyGen()  # public + secret
    slots = n // 2
    return he, slots

# BFV setup (no rescaling in BFV)
def he_setup_bfv(n=2**14, t_bits=20):
    he = Pyfhel()
    he.contextGen(scheme="BFV", n=n, t_bits=t_bits)
    he.keyGen()
    slots = n
    return he, slots

# Choose based on HE_SCHEME
if HE_SCHEME.upper() == "CKKS":
    he_srv, slots = he_setup_ckks()
else:
    he_srv, slots = he_setup_bfv()

# Export ONLY public materials (no relin/rotate needed for sum)
pub_ctx = he_srv.to_bytes_context()
pub_pk  = he_srv.to_bytes_public_key()

def make_client_he_from_server():
    he = Pyfhel()
    he.from_bytes_context(pub_ctx)
    he.from_bytes_public_key(pub_pk)
    return he


  he.contextGen(scheme="CKKS", n=n, scale=scale, qi_sizes=qi_sizes)


In [15]:
# import numpy as np, pandas as pd, time, torch, torch.nn as nn
# import Pyfhel
# from Pyfhel import Pyfhel
# # from pyfhel import Pyfhel, PyCtxt

# HE_SCHEME = "CKKS"   # or "BFV"
# BFV_SCALE = 1e6
# ROUNDS = 5; LOCAL_EPOCHS = 1; LR = 1e-3

# # ---------------------------
# # 1) Server HE context + keys
# # ---------------------------
# he_srv, slots = he_setup(HE_SCHEME)  # has BOTH public+secret keys

# # Export a "public-only" bundle for clients (no secret key)
# pub_ctx  = he_srv.to_bytes_context()
# pub_pk   = he_srv.to_bytes_public_key()
# relin_k  = he_srv.to_bytes_relin_key()
# rot_k    = he_srv.to_bytes_rotate_key()

# def make_client_he_from_server():
#     """Client loads ONLY public materials; can encrypt but cannot decrypt."""
#     he = Pyfhel()
#     he.from_bytes_context(pub_ctx)
#     he.from_bytes_public_key(pub_pk)
#     # relin/rotate not strictly needed for sum, but keep for completeness
#     if relin_k is not None:
#         he.from_bytes_relin_key(relin_k)
#     if rot_k is not None:
#         he.from_bytes_rotate_key(rot_k)
#     return he

# # ---------------------------
# # 2) Global model state
# # ---------------------------
# g_model = TinyCNN(n_classes).to(DEVICE)
# template = get_params(g_model)
# g_vec = w2v(template)
# crit = nn.CrossEntropyLoss()

# rows = []

# for rnd in range(1, ROUNDS+1):
#     enc_updates = []
#     client_times = []
#     uplink = 0

#     # ---------------------------
#     # 3) Four hospitals (clients)
#     # ---------------------------
#     for ci in range(4):
#         # Init client model from current global
#         m = TinyCNN(n_classes).to(DEVICE)
#         set_params(m, v2w(g_vec, template))
#         opt = torch.optim.Adam(m.parameters(), lr=LR)

#         t0 = time.perf_counter()
#         for _ in range(LOCAL_EPOCHS):
#             train_one_epoch(m, client_loaders[ci], opt, crit)
#         t1 = time.perf_counter()

#         # Compute local update
#         l_vec = w2v(get_params(m))
#         upd = (l_vec - g_vec).astype(np.float32)

#         # Encrypt update with SERVER'S PUBLIC KEY
#         he_cli = make_client_he_from_server()
#         cts, bytes_ = he_encrypt_update(
#             he_cli, upd, scheme=HE_SCHEME, scale=BFV_SCALE, slots=min(slots, 8192)
#         )
#         enc_updates.append(cts)
#         client_times.append(t1 - t0)
#         uplink += bytes_

#     # ---------------------------
#     # 4) Server: aggregate + decrypt
#     # ---------------------------
#     s0 = time.perf_counter()
#     agg_cts = he_sum(enc_updates)  # homomorphic sum (same chunking across clients)
#     upd_sum = he_decrypt(
#         he_srv, agg_cts, scheme=HE_SCHEME, scale=BFV_SCALE, total_len=len(g_vec)
#     )
#     upd_avg = upd_sum / 4.0
#     g_vec = (g_vec + upd_avg).astype(np.float32)
#     s1 = time.perf_counter()

#     # ---------------------------
#     # 5) Evaluate and account comm
#     # ---------------------------
#     set_params(g_model, v2w(g_vec, template))
#     loss, acc = evaluate(g_model, test_loader, crit)

#     # Approx downlink: broadcast plaintext global weights to 4 clients
#     down_per = sum(arr.nbytes for arr in get_params(g_model))
#     down_bytes = down_per * 4

#     row = {
#         "round": rnd,
#         "accuracy": float(acc),
#         "client_latency_avg_s": float(np.mean(client_times)),
#         "server_latency_s": float(s1 - s0),
#         "uplink_bytes": float(uplink),
#         "downlink_bytes": float(down_bytes),
#         "total_comm_bytes": float(uplink + down_bytes),
#         "scheme": HE_SCHEME,
#     }
#     print(
#         f"[Round {rnd}] acc={acc:.3f} | client={row['client_latency_avg_s']:.2f}s "
#         f"| server={row['server_latency_s']:.2f}s | uplink={uplink/1e6:.2f} MB | "
#         f"down={down_bytes/1e6:.2f} MB | total={(row['total_comm_bytes']/1e6):.2f} MB"
#     )
#     rows.append(row)

# df = pd.DataFrame(rows)
# df.to_csv(f"pure_python_he_metrics_{HE_SCHEME.lower()}.csv", index=False)
# df


In [None]:
# ---- FIXED federated loop (CKKS/BFV) with one server keypair ----
import time, numpy as np, pandas as pd, torch, torch.nn as nn
from Pyfhel import Pyfhel, PyCtxt   # <-- make sure this is the CLASS, not the module

# # ----- user-tunable constants -----
# NUM_CLIENTS   = 3
# LOCAL_EPOCHS  = 10
# ROUNDS        = 10
# SCHEME        = "CKKS"   # or "BFV"
# BFV_SCALE     = 1e6      # used only if SCHEME == "BFV"

HE_SCHEME = "CKKS"   # or "BFV"
BFV_SCALE = 1e6
ROUNDS = 5; LOCAL_EPOCHS = 1; LR = 1e-3

# ---------------------------
# HE helpers (addition-only)
# ---------------------------
def he_setup_ckks(n=2**14, scale=2**20, qi_sizes=[60,40,40,40,60]):
    he = Pyfhel()
    he.contextGen(scheme="CKKS", n=n, scale=scale, qi_sizes=qi_sizes)
    he.keyGen()                 # public + secret
    return he, n//2             # slots

def he_setup_bfv(n=2**14, t_bits=20):
    he = Pyfhel()
    he.contextGen(scheme="BFV", n=n, t_bits=t_bits)
    he.keyGen()
    return he, n

def he_encrypt_update(he, vec: np.ndarray, scheme="CKKS", scale=1e6, slots=4096):
    cts=[]; bytes_=0
    if scheme.upper()=="CKKS":
        for i in range(0, len(vec), slots):
            c = he.encryptFrac(vec[i:i+slots].astype(np.float64))
            cts.append(c); bytes_ += len(c.to_bytes())
    else:  # BFV: quantize to ints
        ints = np.round(vec*scale).astype(np.int64)
        for i in range(0, len(ints), slots):
            c = he.encryptInt(ints[i:i+slots])
            cts.append(c); bytes_ += len(c.to_bytes())
    return cts, bytes_

def he_sum(ct_lists):
    assert len(ct_lists)>0, "he_sum: empty list"
    agg = [ct_lists[0][j] for j in range(len(ct_lists[0]))]
    for k in range(1, len(ct_lists)):
        for j in range(len(agg)):
            agg[j] = agg[j] + ct_lists[k][j]
    return agg

def he_decrypt(he, agg_cts, scheme="CKKS", scale=1e6, total_len=None):
    outs=[]
    if scheme.upper()=="CKKS":
        for c in agg_cts: outs.append(he.decryptFrac(c))
        vec = np.concatenate(outs).astype(np.float64)
    else:
        ints=[]
        for c in agg_cts: ints.append(he.decryptInt(c).astype(np.int64))
        vec = (np.concatenate(ints).astype(np.float64))/scale
    if total_len is not None and len(vec)>total_len:
        vec = vec[:total_len]
    return vec.astype(np.float32)

# ---------------------------
# 1) Server HE context + PUBLIC bundle
# ---------------------------
if HE_SCHEME.upper()=="CKKS":
    he_srv, slots = he_setup_ckks()
else:
    he_srv, slots = he_setup_bfv()

pub_ctx = he_srv.to_bytes_context()
pub_pk  = he_srv.to_bytes_public_key()
# NOTE: no relin/rotate keys needed for addition-only; do NOT call to_bytes_relin_key/rotate_key

def make_client_he_from_server():
    he = Pyfhel()
    he.from_bytes_context(pub_ctx)
    he.from_bytes_public_key(pub_pk)
    return he

# ---------------------------
# 2) Global model state
# ---------------------------
g_model = TinyCNN(n_classes).to(DEVICE)
template = get_params(g_model)
g_vec = w2v(template)
crit = nn.CrossEntropyLoss()

rows = []

# ---------------------------
# 3) Train for ROUNDS with 4 clients
# ---------------------------
for rnd in range(1, ROUNDS+1):
    enc_updates=[]; client_times=[]; uplink=0

    for ci in range(4):
        # Start client from global
        m = TinyCNN(n_classes).to(DEVICE)
        set_params(m, v2w(g_vec, template))
        opt = torch.optim.Adam(m.parameters(), lr=LR)

        t0 = time.perf_counter()
        for _ in range(LOCAL_EPOCHS):
            train_one_epoch(m, client_loaders[ci], opt, crit)
        t1 = time.perf_counter()

        # Local update
        l_vec = w2v(get_params(m))
        upd = (l_vec - g_vec).astype(np.float32)

        # Encrypt with server PUBLIC key
        he_cli = make_client_he_from_server()
        cts, bytes_ = he_encrypt_update(he_cli, upd, scheme=HE_SCHEME, scale=BFV_SCALE, slots=min(slots, 8192))
        enc_updates.append(cts)
        client_times.append(t1 - t0)
        uplink += bytes_

    # Server: homomorphic sum -> decrypt -> average
    s0 = time.perf_counter()
    agg_cts = he_sum(enc_updates)
    upd_sum = he_decrypt(he_srv, agg_cts, scheme=HE_SCHEME, scale=BFV_SCALE, total_len=len(g_vec))
    upd_avg = upd_sum / 4.0
    g_vec = (g_vec + upd_avg).astype(np.float32)
    s1 = time.perf_counter()

    # Evaluate and comm accounting
    set_params(g_model, v2w(g_vec, template))
    loss, acc = evaluate(g_model, test_loader, crit)

    down_per = sum(arr.nbytes for arr in get_params(g_model))  # broadcast plaintext weights
    down_bytes = down_per * 4

    row = {
        "round": rnd,
        "accuracy": float(acc),
        "client_latency_avg_s": float(np.mean(client_times)),
        "server_latency_s": float(s1 - s0),
        "uplink_bytes": float(uplink),
        "downlink_bytes": float(down_bytes),
        "total_comm_bytes": float(uplink + down_bytes),
        "scheme": HE_SCHEME,
    }
    print(f"[Round {rnd}] acc={acc:.3f} | client={row['client_latency_avg_s']:.2f}s | server={row['server_latency_s']:.2f}s "
          f"| uplink={uplink/1e6:.2f} MB | down={down_bytes/1e6:.2f} MB | total={(row['total_comm_bytes']/1e6):.2f} MB")
    rows.append(row)

df = pd.DataFrame(rows)
df.to_csv(f"pure_python_he_metrics_{HE_SCHEME.lower()}.csv", index=False)
df


  he.contextGen(scheme="CKKS", n=n, scale=scale, qi_sizes=qi_sizes)


[Round 1] acc=0.413 | client=0.90s | server=0.01s | uplink=12.58 MB | down=0.38 MB | total=12.97 MB
[Round 2] acc=0.529 | client=0.79s | server=0.01s | uplink=12.58 MB | down=0.38 MB | total=12.97 MB
[Round 3] acc=0.609 | client=0.79s | server=0.01s | uplink=12.58 MB | down=0.38 MB | total=12.97 MB
[Round 4] acc=0.639 | client=0.81s | server=0.01s | uplink=12.58 MB | down=0.38 MB | total=12.97 MB
[Round 5] acc=0.611 | client=0.80s | server=0.01s | uplink=12.58 MB | down=0.38 MB | total=12.97 MB


Unnamed: 0,round,accuracy,client_latency_avg_s,server_latency_s,uplink_bytes,downlink_bytes,total_comm_bytes,scheme
0,1,0.413069,0.899206,0.009559,12584268.0,381504.0,12965772.0,CKKS
1,2,0.528588,0.793544,0.010029,12584268.0,381504.0,12965772.0,CKKS
2,3,0.609102,0.791728,0.00952,12584268.0,381504.0,12965772.0,CKKS
3,4,0.63944,0.810174,0.009441,12584268.0,381504.0,12965772.0,CKKS
4,5,0.611435,0.802996,0.008572,12584268.0,381504.0,12965772.0,CKKS
