In [1]:
import torch
import torch.nn as nn

from federated_learning.model import LeNet_Small_Quant
from federated_learning.client import Worker
import federated_learning
from network import ZKFLChain

import numpy as np

import warnings
warnings.simplefilter("ignore")

In [2]:
net = ZKFLChain(num_clients=3, global_rounds=3, local_rounds=3, frac_malicous=0.0, dataset='cifar10', model='lenet')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
workers = []
for i in range(net.num_clients):
    worker = Worker(index=i+1,
                        X_train=net.X_train[i],
                        y_train=net.y_train[i],
                        X_test=None,
                        y_test=None,
                        model=LeNet_Small_Quant(),
                        malicious=False)
    workers.append(worker)

In [4]:
local_params = []
for worker in workers:
    worker.set_optimizer(optimizer=torch.optim.Adam(worker.model.parameters(), lr=0.001))   
    worker.train_step_dp(model=worker.model, K=3, B=128, norm=1.2, eps=50.0, delta=1e-5)
    local_params.append(worker.get_params())    

In [5]:
X_test, y_test = net.X_test, net.y_test
for w in workers:
    _, acc = w.evaluate(w.model, x=X_test, y=y_test, B=128)
    print(f"Worker {w.index} accuracy: {acc}")

Worker 1 accuracy: 0.3169501582278481
Worker 2 accuracy: 0.26582278481012656
Worker 3 accuracy: 0.234375


In [8]:
workers2 = []
local_params2 = []
for i in range(net.num_clients):
    worker = Worker(index=i+1,
                        X_train=net.X_train[i],
                        y_train=net.y_train[i],
                        X_test=None,
                        y_test=None,
                        model=LeNet_Small_Quant(),
                        malicious=False)
    workers2.append(worker)
for w in workers2:
    w.set_optimizer(optimizer=torch.optim.Adam(w.model.parameters(), lr=0.001))
    w.train_step_dp(model=w.model, K=5, B=128, eps=1.2, delta=1.0)
    local_params2.append(w.get_params())

In [9]:
for w in workers2:
    _, acc = w.evaluate(w.model, x=X_test, y=y_test, B=128)
    print(f"Worker {w.index} accuracy: {acc}")

Worker 1 accuracy: 0.19284018987341772
Worker 2 accuracy: 0.25860363924050633
Worker 3 accuracy: 0.23833069620253164
