In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from federated_learning.model.LeNet import LeNet_Small, LeNet_Small_Quant
import matplotlib.pyplot as plt
import os
from collections import defaultdict
from PIL import Image
from torchvision.utils import save_image
from federated_learning.attacks import ModelInversion
import federated_learning
import blockchain
from network import POFLNetWork
from copy import deepcopy

In [2]:
def plot(img):
    # plot the cifar10 image
    img = img / 255
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img)
    plt.show()

In [18]:
network = POFLNetWork(num_clients=5,
                          global_rounds=10,
                          local_rounds=5,
                          frac_malicous=0.2,
                          dataset='cifar10',
                          model='lenet')
network.init_network(clear_path=True)
network.add_clients()

Files already downloaded and verified
Files already downloaded and verified
Block 1 stored.


In [19]:
latest_block = network.blockchain.last_block
latest_model = federated_learning.numpy_dict_to_model(numpy_model=latest_block.global_params, model_struct=network.model)
network_model = network.model

In [20]:
for epoch in range(1, network.global_rounds+1):
    print(f"Global round {epoch}")
    network.local_train_update(num_epochs=network.local_rounds)
    network.eval_update()
    if not network.blockchain.valid_chain():
        print("Invalid chain")
        break
# network.local_train_update(num_epochs=network.local_rounds)

Global round 1
Client 1 Epoch 1/5 Loss: 2.1440975666046143
Client 4 Epoch 1/5 Loss: 2.2867560386657715
Client 5 Epoch 1/5 Loss: 2.2628633975982666
Client 3 Epoch 1/5 Loss: 1.7075458765029907
Client 2 Epoch 1/5 Loss: 2.2111964225769043
local updates sent
Sart verification
transaction from worker 4 verified by worker 1
transaction from worker 4 verified by worker 4
transaction from worker 4 verified by worker 5
transaction from worker 4 verified by worker 3
transaction from worker 4 verified by worker 2
transaction from worker 4 verified by all workers
worker 4 is the leader
Local models aggregated.
Global model accuracy: 0.15
Block 2 stored.
2
Global round 2
Client 1 Epoch 1/5 Loss: 2.214967966079712
Client 4 Epoch 1/5 Loss: 2.321989059448242
Client 5 Epoch 1/5 Loss: 2.278322219848633
Client 3 Epoch 1/5 Loss: 2.0864694118499756
Client 2 Epoch 1/5 Loss: 2.332695722579956
local updates sent
Sart verification
transaction from worker 1 verified by worker 1
transaction from worker 1 verified

In [16]:
# network.eval_update()

local updates sent
Sart verification
transaction from worker 5 verified by worker 1
transaction from worker 5 verified by worker 5
transaction from worker 5 verified by worker 2
transaction from worker 5 verified by worker 3
transaction from worker 5 verified by worker 4
transaction from worker 5 verified by all workers
worker 5 is the leader
Local models aggregated.
Global model accuracy: 0.28
Block 7 stored.


In [7]:
def models_have_same_params(model1, model2):
    # Ensure the models' state_dicts have the same keys
    if model1.state_dict().keys() != model2.state_dict().keys():
        return False

    for (key1, param1), (key2, param2) in zip(model1.state_dict().items(), model2.state_dict().items()):
        if key1 != key2:
            return False  # Keys should match
        if not torch.equal(param1, param2):
            return False  # Parameters should be identical

    return True

In [22]:
model1 = network.workers[0].model
model2 = network.workers[1].model
models_have_same_params(model1, model2)

False

In [28]:
global_model = federated_learning.numpy_dict_to_model(numpy_model=network.blockchain.last_block.global_params, model_struct=network.model)
global_model2 = federated_learning.numpy_dict_to_model(numpy_model=network.blockchain.chain[-2].global_params, model_struct=network.model)
global_model3 = federated_learning.numpy_dict_to_model(numpy_model=network.blockchain.chain[-3].global_params, model_struct=network.model)

In [30]:
global_model2.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[1.3254e+01, 1.3300e+01, 1.3443e+01, 1.3421e+01, 1.3318e+01],
                        [1.3368e+01, 1.3276e+01, 1.3286e+01, 1.3332e+01, 1.3282e+01],
                        [1.3237e+01, 1.3190e+01, 1.3207e+01, 1.3289e+01, 1.3224e+01],
                        [1.3215e+01, 1.3221e+01, 1.3144e+01, 1.3164e+01, 1.3138e+01],
                        [1.3174e+01, 1.3148e+01, 1.3218e+01, 1.3278e+01, 1.3111e+01]],
              
                       [[1.2857e+01, 1.2765e+01, 1.2734e+01, 1.2821e+01, 1.2713e+01],
                        [1.2748e+01, 1.2806e+01, 1.2582e+01, 1.2632e+01, 1.2693e+01],
                        [1.2587e+01, 1.2600e+01, 1.2598e+01, 1.2653e+01, 1.2568e+01],
                        [1.2610e+01, 1.2553e+01, 1.2487e+01, 1.2564e+01, 1.2566e+01],
                        [1.2656e+01, 1.2547e+01, 1.2554e+01, 1.2509e+01, 1.2554e+01]],
              
                       [[1.1698e+01, 1.1700e+01, 1.1660e+01, 1.1527e+01, 1.166