In [1]:
import torch
import torch.nn as nn

from federated_learning.model import LeNet_Small_Quant
from federated_learning.client import Worker
import federated_learning

import numpy as np

import warnings
warnings.simplefilter("ignore")

In [2]:
(X_train, y_train), (X_test, y_test) = federated_learning.load_cifar10(num_users=2, n_class=10, n_samples=100, rate_unbalance=1.0)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
worker1 = Worker(index=1, X_train=X_train[0], y_train=y_train[0], X_test=X_test[:1000], y_test=y_test[:1000], model=LeNet_Small_Quant())
worker2 = Worker(index=2, X_train=X_train[1], y_train=y_train[1], X_test=X_test[:1000], y_test=y_test[:1000], model=LeNet_Small_Quant())

In [4]:
worker1.model.set_optimizer(torch.optim.Adam(worker1.model.parameters(), lr=0.001))
worker1.train_step_dp(model=worker1.model, K=30, B=64, eps=1.0, delta=1.0)
worker1.evaluate(worker1.model, worker1.X_test, worker1.y_test, B=64)

Worker 1 local epoch 10: loss 1.8428709134459496
Worker 1 local epoch 20: loss 1.7087349891662598
Worker 1 local epoch 30: loss 1.8336205780506134


(4.518858656287193, 0.24960937537252903)

In [5]:
worker2.model.set_optimizer(torch.optim.Adam(worker2.model.parameters(), lr=0.001))
worker2.train_step_dp(model=worker2.model, K=30, B=64, eps=1.0, delta=1.0)
worker2.evaluate(worker2.model, worker2.X_test, worker2.y_test, B=64)

Worker 2 local epoch 10: loss 1.7892474830150604
Worker 2 local epoch 20: loss 1.8785505145788193
Worker 2 local epoch 30: loss 1.969981499016285


(8.07858082652092, 0.16835937555879354)

In [6]:
global_model = LeNet_Small_Quant()
agg = federated_learning.FedAvg(global_model=global_model, beta=0.0, lr=0.1)
local_params = [w.model.get_params() for w in [worker1, worker2]]
global_model = agg.aggregate(local_params)

In [8]:
worker1.evaluate(global_model, worker1.X_test, worker1.y_test, B=64)

(2.5865061581134796, 0.10234375018626451)