# Install and Import Dependencies


In [None]:
!pip install tenseal syft pennylane
!pip install protobuf==3.20.3

In [None]:
import os

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import syft as sy
import pickle
import time
from collections import OrderedDict
from typing import List, Tuple, Dict, Optional, Callable, Union, cast
import tenseal as ts
from io import BytesIO
import numpy as np
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch
from torch import nn
import torch.nn.functional as F
import syft as sy
from logging import WARNING
import pennylane as qml
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix
import seaborn as sn
import pandas as pd
from functools import reduce

# Utility Functions


In [None]:
def choice_device(device):
    if torch.cuda.is_available() and device != "cpu":
        device = "cuda:0"
    elif (
        torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
        and device != "cpu"
    ):
        device = "mps"
    else:
        device = "cpu"
    return device


def classes_string(name_dataset):
    if name_dataset == "cifar":
        return (
            "plane",
            "car",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
        )
    elif name_dataset == "MRI":
        return ("glioma", "meningioma", "notumor", "pituitary")
    else:
        print("Warning: unspecified dataset")
        return ()


def save_matrix(y_true, y_pred, path, classes):
    y_true_mapped = [classes[label] for label in y_true]
    y_pred_mapped = [classes[label] for label in y_pred]
    cf_matrix_normalized = confusion_matrix(
        y_true_mapped, y_pred_mapped, labels=classes, normalize="all"
    )
    cf_matrix_round = np.round(cf_matrix_normalized, 2)
    df_cm = pd.DataFrame(
        cf_matrix_round, index=[i for i in classes], columns=[i for i in classes]
    )
    plt.figure(figsize=(12, 7))
    sn.heatmap(df_cm, annot=True)
    plt.xlabel("Predicted label", fontsize=13)
    plt.ylabel("True label", fontsize=13)
    plt.title("Confusion Matrix", fontsize=15)
    plt.savefig(path)
    plt.close()


def save_roc(targets, y_proba, path, nbr_classes):
    y_true = np.zeros(shape=(len(targets), nbr_classes))
    for i in range(len(targets)):
        y_true[i, targets[i]] = 1
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(nbr_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_proba[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_proba.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(nbr_classes)]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(nbr_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= nbr_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    plt.figure()
    plt.plot(
        fpr["micro"],
        tpr["micro"],
        label=f"micro-average ROC curve (area = {roc_auc['micro']:.2f})",
        color="deeppink",
        linestyle=":",
        linewidth=4,
    )
    plt.plot(
        fpr["macro"],
        tpr["macro"],
        label=f"macro-average ROC curve (area = {roc_auc['macro']:.2f})",
        color="navy",
        linestyle=":",
        linewidth=4,
    )
    lw = 2
    for i in range(nbr_classes):
        plt.plot(
            fpr[i],
            tpr[i],
            lw=lw,
            label=f"ROC curve of class {i} (area = {roc_auc[i]:.2f})",
        )
    plt.plot([0, 1], [0, 1], "k--", lw=lw, label="Worst case")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic (ROC) Curve OvR")
    plt.legend(loc="lower right")
    plt.savefig(path)
    plt.close()


def save_graphs(path_save, local_epoch, results, end_file=""):
    os.makedirs(path_save, exist_ok=True)
    print("Saving graphs in ", path_save)
    plot_graph(
        [[*range(local_epoch)]] * 2,
        [results["train_acc"], results["val_acc"]],
        "Epochs",
        "Accuracy (%)",
        ["Training accuracy", "Validation accuracy"],
        "Accuracy curves",
        path_save + "Accuracy_curves" + end_file,
    )
    plot_graph(
        [[*range(local_epoch)]] * 2,
        [results["train_loss"], results["val_loss"]],
        "Epochs",
        "Loss",
        ["Training loss", "Validation loss"],
        "Loss curves",
        path_save + "Loss_curves" + end_file,
    )


def plot_graph(
    list_xplot, list_yplot, x_label, y_label, curve_labels, title, path=None
):
    lw = 2
    plt.figure()
    for i in range(len(curve_labels)):
        plt.plot(list_xplot[i], list_yplot[i], lw=lw, label=curve_labels[i])
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    if curve_labels:
        plt.legend(loc="lower right")
    if path:
        plt.savefig(path)
    plt.close()


def get_parameters2(net, context_client=None) -> List[np.ndarray]:
    if context_client:
        encrypted_tensor = crypte(net.state_dict(), context_client)
        return [layer.get_weight() for layer in encrypted_tensor]
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray], context_client=None):
    state_dict = net.state_dict()  # Get the model's state dictionary for shapes
    params_dict = zip(state_dict.keys(), parameters)
    if context_client:
        secret_key = context_client.secret_key()
        dico = {k: deserialized_layer(k, v, context_client) for k, v in params_dict}
        new_state_dict = OrderedDict()
        for k, v in dico.items():
            if isinstance(v, CryptedLayer):
                decrypted = v.decrypt(secret_key)  # Returns a flattened list
                shape = state_dict[k].shape  # Get the expected shape
                # Reshape the decrypted list to match the original tensor shape
                new_state_dict[k] = torch.Tensor(np.array(decrypted).reshape(shape))
            else:
                new_state_dict[k] = torch.Tensor(v.get_weight())  # Plain parameters
    else:
        new_state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(new_state_dict, strict=True)
    print("Updated model parameters")

# Security-related classes and functions


In [None]:
class Layer:
    def __init__(self, name_layer, weight):
        self.name = name_layer
        self.weight_array = weight

    def get_name(self):
        return self.name

    def get_weight(self):
        return self.weight_array

    def __add__(self, other):
        weights = other.get_weight() if isinstance(other, Layer) else other
        return Layer(self.name, self.weight_array + weights)

    def __sub__(self, other):
        weights = other.get_weight() if isinstance(other, Layer) else other
        return Layer(self.name, self.weight_array - weights)

    def __mul__(self, other):
        weights = other.get_weight() if isinstance(other, Layer) else other
        return Layer(self.name, self.weight_array * weights)

    def __truediv__(self, other):
        weights = other.get_weight() if isinstance(other, Layer) else other
        weights = self.weight_array * (1 / weights)
        return Layer(self.name, weights)

    def __len__(self):
        somme = 1
        for elem in self.weight_array.shape:
            somme *= elem
        return somme

    def shape(self):
        return self.weight_array.shape

    def sum(self, axis=0):
        return Layer(f"sum_{self.name}", self.weight_array.sum(axis=axis))

    def mean(self, axis=0):
        weights = self.weight_array.sum(axis=axis) * (1 / self.weight_array.shape[axis])
        return Layer(f"sum_{self.name}", weights)

    def decrypt(self, sk=None):
        return self.weight_array.tolist()

    def serialize(self):
        return {self.name: self.weight_array}


class CryptedLayer(Layer):
    def __init__(self, name_layer, weight, contexte=None):
        super(CryptedLayer, self).__init__(name_layer, weight)
        if isinstance(weight, (ts.tensors.CKKSTensor, bytes)):
            self.weight_array = weight
        else:
            self.weight_array = ts.ckks_tensor(contexte, weight.cpu().detach().numpy())

    def __add__(self, other):
        weights = other.get_weight() if isinstance(other, CryptedLayer) else other
        return CryptedLayer(self.name, self.weight_array + weights)

    def __sub__(self, other):
        weights = other.get_weight() if isinstance(other, CryptedLayer) else other
        return CryptedLayer(self.name, self.weight_array - weights)

    def __mul__(self, other):
        weights = other.get_weight() if isinstance(other, CryptedLayer) else other
        return CryptedLayer(self.name, self.weight_array * weights)

    def __truediv__(self, other):
        try:
            weights = other.get_weight() if isinstance(other, CryptedLayer) else other
            weights = self.weight_array * (1 / weights)
        except:
            print("Error: division operator not supported by SEAL")
            weights = []
        return CryptedLayer(self.name, weights)

    def shape(self):
        return self.weight_array.shape

    def sum(self, axis=0):
        return CryptedLayer(f"sum_{self.name}", self.weight_array.sum(axis=axis))

    def mean(self, axis=0):
        weights = self.weight_array.sum(axis=axis) * (1 / self.weight_array.shape[axis])
        return CryptedLayer(f"sum_{self.name}", weights)

    def decrypt(self, sk=None):
        return (
            self.weight_array.decrypt(sk).tolist()
            if sk
            else self.weight_array.decrypt().tolist()
        )

    def serialize(self):
        return {self.name: self.weight_array.serialize()}


def context():
    cont = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60],
    )
    cont.generate_galois_keys()
    cont.global_scale = 2**40
    return cont


def crypte(client_w, context_c):
    encrypted = []
    for name_layer, weight_array in client_w.items():
        if name_layer == "qnn.weights":  # Encrypt quantum layer parameters
            encrypted.append(CryptedLayer(name_layer, weight_array, context_c))
        else:
            encrypted.append(Layer(name_layer, weight_array))
    return encrypted


def read_query(file_path):
    if os.path.exists(file_path):
        with open(file_path, "rb") as file:
            query_str = pickle.load(file)
        contexte = query_str["contexte"]
        del query_str["contexte"]
        return query_str, contexte
    else:
        print(f"File {file_path} does not exist")
        return None, None


def write_query(file_path, client_query):
    with open(file_path, "wb") as file:
        encode_str = pickle.dumps(client_query)
        file.write(encode_str)


def deserialized_layer(name_layer, weight_array, ctx):
    if isinstance(weight_array, bytes):
        return CryptedLayer(name_layer, ts.ckks_tensor_from(ctx, weight_array), ctx)
    elif isinstance(weight_array, ts.tensors.CKKSTensor):
        return CryptedLayer(name_layer, weight_array, ctx)
    else:
        return Layer(name_layer, weight_array)


def serialize_ndarray(ndarray):
    if isinstance(ndarray, ts.tensors.CKKSTensor):
        return ndarray.serialize()
    elif isinstance(ndarray, torch.Tensor):
        # Move tensor to CPU and convert to NumPy array
        return serialize_ndarray(ndarray.cpu().detach().numpy())
    else:
        bytes_io = BytesIO()
        np.save(bytes_io, ndarray, allow_pickle=False)
        return bytes_io.getvalue()


def deserialize_ndarray(tensor, context):
    try:
        return ts.ckks_tensor_from(context, tensor)
    except:
        bytes_io = BytesIO(tensor)
        return np.load(bytes_io, allow_pickle=False)


def serialize_parameters(parameters):
    return [serialize_ndarray(param) for param in parameters]


def deserialize_parameters(serialized_params, context):
    return [deserialize_ndarray(param, context) for param in serialized_params]


def aggregate_serialized(results, context):
    num_examples_total = sum([num_examples for _, num_examples in results])
    weights_results = [
        (deserialize_parameters(serialized_params, context), num_examples)
        for serialized_params, num_examples in results
    ]
    aggregated_params = []
    for layer_idx in range(len(weights_results[0][0])):
        layer_updates = [weights[layer_idx] for weights, _ in weights_results]
        if isinstance(layer_updates[0], ts.tensors.CKKSTensor):
            weighted_sum = sum(
                [
                    layer * num_examples
                    for layer, num_examples in zip(
                        layer_updates, [num for _, num in weights_results]
                    )
                ]
            )
            aggregated_layer = weighted_sum * (1 / num_examples_total)
        else:
            weighted_sum = sum(
                [
                    layer * num_examples
                    for layer, num_examples in zip(
                        layer_updates, [num for _, num in weights_results]
                    )
                ]
            )
            aggregated_layer = weighted_sum / num_examples_total
        aggregated_params.append(aggregated_layer)

# Data setup


In [None]:
NORMALIZE_DICT = {
    "cifar": dict(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    "MRI": dict(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
}


def split_data_client(dataset, num_clients, seed):
    partition_size = len(dataset) // num_clients
    lengths = [partition_size] * (num_clients - 1)
    lengths += [len(dataset) - sum(lengths)]
    ds = random_split(dataset, lengths, torch.Generator().manual_seed(seed))
    return ds


def load_datasets(
    num_clients: int,
    batch_size: int,
    resize: int,
    seed: int,
    num_workers: int,
    splitter=10,
    dataset="cifar",
    data_path="./data/",
    data_path_val="",
):
    list_transforms = [
        transforms.ToTensor(),
        transforms.Normalize(**NORMALIZE_DICT[dataset]),
    ]
    if dataset != "cifar" and resize is not None:
        list_transforms = [transforms.Resize((resize, resize))] + list_transforms
    transformer = transforms.Compose(list_transforms)
    if dataset == "cifar":
        trainset = datasets.CIFAR10(
            data_path + dataset, train=True, download=True, transform=transformer
        )
        testset = datasets.CIFAR10(
            data_path + dataset, train=False, download=True, transform=transformer
        )
    else:
        trainset = datasets.ImageFolder(
            data_path + dataset + "/Training", transform=transformer
        )
        testset = datasets.ImageFolder(
            data_path + dataset + "/Testing", transform=transformer
        )
    datasets_train = split_data_client(trainset, num_clients, seed)
    if data_path_val:
        valset = datasets.ImageFolder(data_path_val, transform=transformer)
        datasets_val = split_data_client(valset, num_clients, seed)
    trainloaders = []
    valloaders = []
    for i in range(num_clients):
        if data_path_val:
            trainloaders.append(
                DataLoader(datasets_train[i], batch_size=batch_size, shuffle=True)
            )
            valloaders.append(DataLoader(datasets_val[i], batch_size=batch_size))
        else:
            len_val = int(len(datasets_train[i]) * splitter / 100)
            len_train = len(datasets_train[i]) - len_val
            lengths = [len_train, len_val]
            ds_train, ds_val = random_split(
                datasets_train[i], lengths, torch.Generator().manual_seed(seed)
            )
            trainloaders.append(
                DataLoader(ds_train, batch_size=batch_size, shuffle=True)
            )
            valloaders.append(DataLoader(ds_val, batch_size=batch_size))
    testloader = DataLoader(testset, batch_size=batch_size)
    return trainloaders, valloaders, testloader

# Training and testing functions


In [None]:
def test(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    loss_fn: Union[torch.nn.Module, Tuple],
    device: torch.device,
):
    model.eval()
    test_loss, test_acc = 0, 0
    y_pred = []
    y_true = []
    y_proba = []
    softmax = nn.Softmax(dim=1)
    with torch.inference_mode():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            probas_output = softmax(output)
            y_proba.extend(probas_output.detach().cpu().numpy())
            loss = loss_fn(output, labels)
            test_loss += loss.item()
            labels = labels.data.cpu().numpy()
            y_true.extend(labels)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            y_pred.extend(preds)
            acc = (preds == labels).mean()
            test_acc += acc
    y_proba = np.array(y_proba)
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc * 100, y_pred, y_true, y_proba


def train_step(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    loss_fn: Union[torch.nn.Module, Tuple],
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> Tuple[float, float]:
    model.train()
    train_loss, train_acc = 0, 0
    for batch, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = loss_fn(output, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        y_pred_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
        train_acc += (y_pred_class == labels).sum().item() / len(output)
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc * 100


def train(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: Union[torch.nn.Module, Tuple],
    epochs: int,
    device: torch.device,
) -> Dict[str, List]:
    results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    for epoch in range(epochs):
        train_loss, train_acc = train_step(
            model, train_dataloader, loss_fn, optimizer, device
        )
        val_loss, val_acc, *_ = test(model, test_dataloader, loss_fn, device)
        print(
            f"\tTrain Epoch: {epoch + 1} \tTrain_loss: {train_loss:.4f} | Train_acc: {train_acc:.4f} % | "
            f"Validation_loss: {val_loss:.4f} | Validation_acc: {val_acc:.4f} %"
        )
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)
    return results

# Main experiment setup


In [None]:
he = True
data_path = "data/"
dataset = "cifar"
yaml_path = "./results/FL/results.yml"
seed = 0
num_workers = 0
max_epochs = 10
batch_size = 32
splitter = 10
device = "gpu"  # Set to 'gpu' for Kaggle GPU compatibility
number_clients = 10
save_results = "results/FL/"
matrix_path = "confusion_matrix.png"
roc_path = "roc.png"
model_save = "cifar_FHE.pt"
min_fit_clients = 10
min_avail_clients = 10
min_eval_clients = 10
rounds = 20
frac_fit = 1.0
frac_eval = 0.5
lr = 1e-3
path_public_key = "server_key.pkl"

DEVICE = torch.device(choice_device(device))
CLASSES = classes_string(dataset)

# Homomorphic encryption setup
secret_path = "secret.pkl"
public_path = path_public_key
if os.path.exists(secret_path):
    with open(secret_path, "rb") as f:
        query = pickle.load(f)
    context_client = ts.context_from(query["contexte"])
else:
    context_client = context()
    with open(secret_path, "wb") as f:
        pickle.dump({"contexte": context_client.serialize(save_secret_key=True)}, f)
    with open(public_path, "wb") as f:
        pickle.dump({"contexte": context_client.serialize()}, f)
context_server = ts.context_from(read_query(public_path)[1])

# Load datasets
trainloaders, valloaders, testloader = load_datasets(
    num_clients=number_clients,
    batch_size=batch_size,
    resize=True,
    seed=seed,
    num_workers=num_workers,
    splitter=splitter,
    dataset=dataset,
    data_path=data_path,
    data_path_val=None,
)

# Define the model architecture
n_qubits = 6
n_layers = 6
weight_shapes = {"weights": (n_layers, n_qubits)}

dev = qml.device("default.qubit", wires=n_qubits)


@qml.qnode(dev, interface="torch")
def quantum_net(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]


class Net(nn.Module):
    def __init__(self, num_classes=10) -> None:
        super(Net, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 64 x 16 x 16
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 128 x 8 x 8
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 256 x 4 x 4
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, n_qubits),
        )
        self.qnn = qml.qnn.TorchLayer(quantum_net, weight_shapes)
        self.fc4 = nn.Linear(n_qubits, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.network(x)
        x = self.qnn(x)
        x = self.fc4(x)
        return x

# Main Experiment


In [None]:
global_model = Net(num_classes=len(CLASSES)).to(DEVICE)
initial_params = get_parameters2(global_model, context_server)
global_serialized_params = serialize_parameters(initial_params)


# Client training function (unchanged)
def client_train(cid, serialized_global_params, local_epochs=max_epochs, lr=lr):
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    local_model = Net(num_classes=len(CLASSES)).to(DEVICE)
    params = deserialize_parameters(serialized_global_params, context_client)
    set_parameters(local_model, params, context_client)
    optimizer = torch.optim.Adam(local_model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    results = train(
        local_model,
        trainloader,
        valloader,
        optimizer,
        criterion,
        epochs=local_epochs,
        device=DEVICE,
    )
    if save_results:
        save_graphs(save_results, local_epochs, results, f"_Client {cid}")
    updated_params = get_parameters2(local_model, context_client)
    serialized_updated_params = serialize_parameters(updated_params)
    num_examples = len(trainloader.dataset)
    return serialized_updated_params, num_examples


# Federated learning simulation with PySyft (updated to include testing)
print(f"Training on {DEVICE}")
start_simulation = time.time()

# List to store test results
test_accuracies = []
test_losses = []

for round_num in range(rounds):
    client_updates = []
    # Simulate client training
    for cid in range(number_clients):
        print(f"[Client {cid}, round {round_num + 1}] training")
        serialized_updated_params, num_examples = client_train(
            str(cid), global_serialized_params
        )
        client_updates.append((serialized_updated_params, num_examples))

    # Aggregate updates
    global_serialized_params = aggregate_serialized(client_updates, context_server)

    # Test the global model after aggregation
    aggregated_params = deserialize_parameters(global_serialized_params, context_client)
    set_parameters(global_model, aggregated_params, context_client)
    test_loss, test_acc, _, _, _ = test(
        global_model, testloader, torch.nn.CrossEntropyLoss(), DEVICE
    )
    test_accuracies.append(test_acc)
    test_losses.append(test_loss)
    print(
        f"Round {round_num + 1} Test Accuracy: {test_acc:.2f}%, Test Loss: {test_loss:.4f}"
    )

    # Remove the re-encryption step - directly use aggregated encrypted parameters
    print(f"Round {round_num + 1} completed")

print(
    f"Federated learning completed. Simulation Time = {time.time() - start_simulation} seconds"
)

# Save the final model


In [None]:
if save_results:
    os.makedirs(save_results, exist_ok=True)
    torch.save({"model_state_dict": global_model.state_dict()}, model_save)