In [70]:
import tenseal as ts
def get_context():
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 60]
    )
    context.generate_galois_keys()
    context.global_scale = 2**40
    return context

context = get_context()

In [71]:
import flwr as fl
import numpy as np

# byteparam is of Paramter with tensor of size 1 holding a bytes object, e.g., Parameters(tensors=[bytes_obj], tensor_type="numpy.ndarray")
# def get_params_encrypted(model, context):
#     import tenseal as ts
#     import numpy as np

#     params = []
#     for _, val in model.state_dict().items():
#         np_val = val.cpu().numpy().flatten().astype(np.float32)
#         ckks_vector = ts.ckks_vector(context, np_val)
#         serialized = ckks_vector.serialize()
#         params.append(serialized)  # List of bytes

#     # This will be passed to Flower return and become part of fl.common.Parameters
#     return params

def encrypt_params_np_array(np_array, context):
    import tenseal as ts

    ckks_vector = ts.ckks_vector(context, np_array.astype(np.float32))
    serialized = ckks_vector.serialize()
    print(serialized)
    return [serialized]  # return list of 1 encrypted byte object

def byteparam_to_ndarrays(byteparam: fl.common.Parameters) -> fl.common.NDArrays:
    ndarrays = fl.common.parameters_to_ndarrays(byteparam)
    return np.frombuffer(ndarrays[0], dtype=np.float32)
    

def ndarrays_to_byteparam(ndarrays: fl.common.NDArrays) -> fl.common.Parameters:
    serialized_ndarrays = bytes(np.array(ndarrays, dtype=np.float32))
    return fl.common.ndarrays_to_parameters([serialized_ndarrays])

In [72]:
import flwr as fl
import numpy as np
from flwr.common import Parameters


# Load the shared CKKS context with the secret key
class FlowerClient(fl.client.NumPyClient):

    def __init__(self, context):
        super().__init__()
        self.context = context  # Save encryption context

    def fit(self, parameters, config):
        # 1. Convert byte to numpy
        param_np = np.frombuffer(parameters[0], dtype=np.float32)

        # 2. Simulate local update
        local_params = param_np + np.array([0.01] * 10)
        print('Local params (before encryption):', local_params)

        # 3. Encrypt
        encrypted_params = encrypt_params_np_array(local_params, self.context)

        # 4. Return encrypted params
        return encrypted_params, 10, {}

def client_fn(ctx: fl.common.Context) -> fl.client.Client:
    context = get_context()  # You already have this defined
    return FlowerClient(context).to_client()

In [73]:
import flwr as fl


class BytesStrategy(fl.server.strategy.FedAvg):
    def __init__(self, context, **kwargs):
        super().__init__(**kwargs)
        self.context = context  
    def aggregate_fit(self, server_round, results, failures):
        
        if failures:
            print(f"[Round {server_round}] {len(failures)} client(s) failed:")
            for i, failure in enumerate(failures):
                print(f"  Failure {i+1}: {repr(failure)}")
        

        all_arrs = []
        encrypted_vecs =[]
        for client, fit_res in results:
            # Deserialize encrypted CKKS vector from byte tensor
            serialized = fit_res.parameters.tensors[0]
            ckks_vec = ts.ckks_vector_from(self.context, serialized)
            encrypted_vecs.append(ckks_vec)

        avg_vec = np.average(np.array(all_arrs), axis=0)+np.array([0.05]*10)
        print('New global params', avg_vec)

        return ndarrays_to_byteparam(avg_vec), {}
    

# Workflow:
# 1) Set initial params to [.2, .2, ...] (10 elements), which is converted into Parameters with 1 tensor being the byte representation of [.2, .2, ...]
# 2) Each client receives the byte representation of params in "fit" function. Assuming that there's only one tensor, it accesses parameters[0] 
# to get this byte representation and convert it to NDArray (numpy array), adds each element with 0.01, converts the resulting numpy array into a bytes object
# return an array of this bytes object
# 3) Server receives Parameters from each client in aggregate_fit function. Each Parameters object should hold one tensor storing a bytes object representing a numpy array.
# It then reconstructs the original arrays and taking the average across all arrays and adds the average vector with 0.05 on each element, finally convert the resulting vector
# into a Parameters object where each object holds one tensor storing a bytes object.
init_params = [.2]*10
init_params_bytes = bytes(np.array(init_params, dtype=np.float32))
init_params = fl.common.ndarrays_to_parameters([init_params_bytes])
context = get_context()
strategy = BytesStrategy(
    context = context,
    fraction_fit=1,
    fraction_evaluate=1,
    initial_parameters = init_params   
)


def server_fn(ctx: fl.common.Context) -> fl.server.ServerAppComponents:
    # Configure the server for 5 rounds of training
    config = fl.server.ServerConfig(num_rounds=5)

    return fl.server.ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = fl.server.ServerApp(server_fn=server_fn)
client = fl.client.ClientApp(client_fn=client_fn)

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

fl.simulation.run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=3,
    backend_config=backend_config,
)


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=5, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[91mERROR [0m:     ServerApp thread raised an exception: failed to parse CKKS stream
[91mERROR [0m:     Traceback (most recent call last):
  File "/Users/creamy/Research/.venv/lib/python3.9/site-packages/flwr/simulation/run_simulation.py", line 268, in server_th_with_start_checks
    updated_context = _run(
  File "/Users/creamy/Research/.venv/lib/python3.9/site-packages/flwr/server/run_serverapp.py", line 62, in run
    server_app(grid=grid, context=co

[36m(ClientAppActor pid=7194)[0m Local params (before encryption): [0.21 0.21 0.21 0.21 0.21 0.21 0.21 0.21 0.21 0.21]
[36m(ClientAppActor pid=7194)[0m b'\n\x01\n\x12\x81\xaf\x0e^\xa1\x10\x04\x01\x02\x00\x00\x81\x97\x03\x00\x00\x00\x00\x00(\xb5/\xfd\xa0a\x00\x04\x00\x1c\\\x0en\xfd\x9f\xdbr-\x10\xa0\xbcf\x962h\xfa#\xca\x9b\xde\x8cCcY\n\xf8s\x12\x08U+N\xcd<\xec\x83Kux7Rc\xec\x18\xba\xbb\xbb\xbb\xbb\xbb\xbb{\xa7f\x84\xa2\x84\x1ea_XE\xb4\xa6H\x19\xf4\xa5BQ(M\n,8\xd26\x97\x0e!\x8f\xbc\xa0D\x8d\xf3\x1b \xd5\xd4DM\x9a\xa8D\xe7\xa0\xa5-\xdaUK\x02k\x03\xc3\x99\x9cU\xd0\x93\x1b\x1ca^\xaaU\xb8BLd~\x08\xcc\xac5\x1eTaT\x04F\x93\xa8\xe2&\xd5\xdf\xa9\x18\xafT\x07\x91K\x98\xbe\xa0m\r\x89\xa9R\t+\x03Z\x9d\xa2\xc5\xc9\xa2\xceR\x82\xd5ItQ\x014`\x0eP\r\xc9\x90\x01`\xd3T\xa3\xda\x83\xf3G\x92L\xc6!\xd3l\x1a\x05\x06,A\xaa\x9a\x11\x1a\x8dJx\xa2\x98\xbe\x12\xbc\x04\x92\xdb\xa3\xc7e\x9d\xc7\x19.b\x15LT7\nmI\xea\x98\xb8\x99\x92\xf5\x90\x14\x89O\xe4\x15\xb5G\x19\x06\x9eb\x1f\x12y\x13tgd\x06H\xb4a\xbb\xc6\xa3\



RuntimeError: Exception in ServerApp thread

In [None]:
import tenseal as ts
import numpy as np

# === Dummy context
def get_context():
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 60]
    )
    context.generate_galois_keys()
    context.global_scale = 2**40
    return context

context = get_context()

# === Simulated model weight
dummy_weights = np.array([0.2]*10, dtype=np.float32)
print(dummy_weights)

# === Encrypt like get_params_encrypted would
ckks_vector = ts.ckks_vector(context, dummy_weights)
serialized = ckks_vector.serialize()

# === Print test
print("Serialized CKKS Vector (first 50 bytes):", serialized[:50])
print("Type:", type(serialized))

# === Optionally test decryption
vec = ts.ckks_vector_from(context, serialized)
print("Decrypted values:", vec.decrypt())

[0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]
Serialized CKKS Vector (first 50 bytes): b'\n\x01\n\x12\xde\xae\x0e^\xa1\x10\x04\x01\x02\x00\x00^\x97\x03\x00\x00\x00\x00\x00(\xb5/\xfd\xa0a\x00\x04\x00\xf4[\x0en\xfd_\xdar.\x10\xa0\xbcf1\xce\xc4\xc1\x02'
Type: <class 'bytes'>
Decrypted values: [0.2000000030608814, 0.2000000027310915, 0.20000000213226707, 0.20000000307115887, 0.20000000301534662, 0.20000000400282722, 0.2000000007019671, 0.2000000029790933, 0.20000000305733795, 0.2000000036828271]
