# Comparision of Swarm Learning and Federated Learning

## Import Library

In [13]:
import hashlib
import json
from datetime import datetime
import time
import torch
import threading
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random

## Define Blockchain

In [14]:
class Block:
    def __init__(self, round_num, leader_id, model_hash, previous_hash=""):
        self.round_num = round_num
        self.leader_id = leader_id
        self.model_hash = model_hash
        self.previous_hash = previous_hash
        self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.hash = self.calculate_hash()

    def calculate_hash(self):
        block_string = f"{self.round_num}{self.leader_id}{self.model_hash}{self.previous_hash}{self.timestamp}"
        return hashlib.sha256(block_string.encode()).hexdigest()

class Blockchain:
    def __init__(self):
        self.chain = [self.create_genesis_block()]

    def create_genesis_block(self):
        return Block(0, "Genesis", "0", "0")

    def get_latest_block(self):
        return self.chain[-1]

    def add_block(self, new_block):
        new_block.previous_hash = self.get_latest_block().hash
        new_block.hash = new_block.calculate_hash()
        self.chain.append(new_block)

    def is_chain_valid(self):
        for i in range(1, len(self.chain)):
            current = self.chain[i]
            previous = self.chain[i - 1]
            if current.hash != current.calculate_hash() or current.previous_hash != previous.hash:
                return False
        return True

    def display_chain(self, show_sample_params=False):
        for block in self.chain:
            print("\n-----------------")
            print(f"Round: {block.round_num}")
            print(f"Leader ID: {block.leader_id}")
            print(f"Model Parameter Hash: {block.model_hash}")
            print(f"Previous Hash: {block.previous_hash}")
            print(f"Timestamp: {block.timestamp}")
            print(f"Block Hash: {block.hash}")

## Create Node

In [15]:
class Node:
    def __init__(self, node_id, model, local_data, nodes, blockchain=None):
        self.node_id = node_id
        self.model = model
        self.local_data = local_data
        self.nodes = nodes
        self.is_leader = False
        self.active = True
        self.blockchain = blockchain

    def start_election(self):
        leader_node = random.choice(self.nodes)
        leader_node.become_leader()
        return leader_node

    def become_leader(self):
        print(f"Node {self.node_id} is now the leader.")
        self.is_leader = True
        for node in self.nodes:
            node.is_leader = node.node_id == self.node_id

    def gather_parameters(self):
        return train_local_model(self.local_data, self.model, epochs=1)

    def verify_and_add_block(self, round_num, model_params):
        model_params_serializable = {k: v.tolist() for k, v in model_params.items()}
        model_hash = hashlib.sha256(json.dumps(model_params_serializable, sort_keys=True).encode()).hexdigest()
        new_block = Block(round_num, self.node_id, model_hash, self.blockchain.get_latest_block().hash)
        self.blockchain.add_block(new_block)

## CNN Model

In [16]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 24 * 24, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [17]:
def train_local_model(local_data, model, epochs=1):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    for epoch in range(epochs):
        for data, target in local_data:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.state_dict()

In [18]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [19]:
def aggregate_parameters(local_params_list):
    global_params = {}
    for key in local_params_list[0].keys():
        global_params[key] = sum([params[key] for params in local_params_list]) / len(local_params_list)
    return global_params

## Prepare data

In [20]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def split_data(dataset, num_nodes=3):
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    split_size = len(indices) // num_nodes
    node_data_loaders = []
    for i in range(num_nodes):
        node_indices = indices[i * split_size: (i + 1) * split_size]
        subset = Subset(dataset, node_indices)
        loader = DataLoader(subset, batch_size=32, shuffle=True)
        node_data_loaders.append(loader)
    return node_data_loaders

In [21]:
num_nodes = 3
node_data_loaders = split_data(mnist_dataset, num_nodes=num_nodes)
blockchain = Blockchain()

## Swarming Learning

In [22]:
def swarm_learning_with_blockchain(swarm_nodes, blockchain, test_loader, num_rounds=5, epochs_per_round=3):
    start_time = time.time()
    accuracies = []
    for round_num in range(num_rounds):
        print(f"\n========== Swarm Round {round_num + 1}/{num_rounds} ==========")
        local_models = []
        leader = swarm_nodes[0].start_election()

        for node in swarm_nodes:
            print(f"Node {node.node_id} is training...")
            local_params = node.gather_parameters()
            local_models.append(local_params)

        if leader.is_leader and leader.active:
            global_params = aggregate_parameters(local_models)
            leader.verify_and_add_block(round_num, global_params)
            for node in swarm_nodes:
                node.model.load_state_dict(global_params)
            accuracy = evaluate_model(leader.model, test_loader)
            accuracies.append(accuracy)
            print(f"Round {round_num + 1} - Swarm Model Accuracy: {accuracy:.2f}%")

    end_time = time.time()
    total_time = end_time - start_time
    avg_accuracy = sum(accuracies) / len(accuracies)
    return total_time, avg_accuracy


In [23]:
swarm_nodes = []
for i, data_loader in enumerate(node_data_loaders):
    node = Node(i, CNNModel(), data_loader, [], blockchain)
    swarm_nodes.append(node)
for node in swarm_nodes:
    node.nodes = swarm_nodes

In [24]:
print("\nRunning Swarm Learning...")
swarm_time, swarm_accuracy = swarm_learning_with_blockchain(swarm_nodes, blockchain, test_loader, num_rounds=10)


Running Swarm Learning...

Node 2 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 1 - Swarm Model Accuracy: 74.79%

Node 0 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 2 - Swarm Model Accuracy: 94.69%

Node 1 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 3 - Swarm Model Accuracy: 95.69%

Node 0 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 4 - Swarm Model Accuracy: 96.61%

Node 0 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 5 - Swarm Model Accuracy: 96.95%

Node 2 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 6 - Swarm Model Accuracy: 97.42%

Node 0 is now the leader.
Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 7 - Swarm Model Accuracy: 97.61%

Node 2 is now the leader.
Node 0 is training...


## Federated Learning

In [25]:
def federated_learning(nodes, test_loader, num_rounds=5, epochs_per_round=3):
    start_time = time.time()
    accuracies = []
    global_model = CNNModel()

    for round_num in range(num_rounds):
        print(f"\n========== Federated Round {round_num + 1}/{num_rounds} ==========")
        local_params_list = []

        # Mỗi node huấn luyện cục bộ
        for node in nodes:
            print(f"Node {node.node_id} is training...")
            local_model = CNNModel()
            local_model.load_state_dict(global_model.state_dict())
            local_params = train_local_model(node.local_data, local_model, epochs=epochs_per_round)
            local_params_list.append(local_params)

        # Tổng hợp tham số tại server trung tâm
        global_params = aggregate_parameters(local_params_list)
        global_model.load_state_dict(global_params)

        # Đánh giá mô hình toàn cục
        accuracy = evaluate_model(global_model, test_loader)
        accuracies.append(accuracy)
        print(f"Round {round_num + 1} - Federated Model Accuracy: {accuracy:.2f}%")

    end_time = time.time()
    total_time = end_time - start_time
    avg_accuracy = sum(accuracies) / len(accuracies)
    return total_time, avg_accuracy


In [26]:
federated_nodes = [Node(i, CNNModel(), data_loader, []) for i, data_loader in enumerate(node_data_loaders)]

print("\nRunning Federated Learning...")
federated_time, federated_accuracy = federated_learning(federated_nodes, test_loader, num_rounds=10)


Running Federated Learning...

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 1 - Federated Model Accuracy: 96.37%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 2 - Federated Model Accuracy: 97.80%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 3 - Federated Model Accuracy: 98.23%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 4 - Federated Model Accuracy: 98.36%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 5 - Federated Model Accuracy: 98.45%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 6 - Federated Model Accuracy: 98.53%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 7 - Federated Model Accuracy: 98.54%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 8 - Federated Model Accuracy: 98.59%

Node 0 is training...
Node 1 is training...
Node 2 is training...
Round 9 - Federated Mo

## Comparision 

In [27]:
print("\n========== Comparison ==========")
print(f"Swarm Learning - Total Time: {swarm_time:.2f}s, Average Accuracy: {swarm_accuracy:.2f}%")
print(f"Federated Learning - Total Time: {federated_time:.2f}s, Average Accuracy: {federated_accuracy:.2f}%")
print(f"Time Difference (Swarm - Federated): {swarm_time - federated_time:.2f}s")
print(f"Accuracy Difference (Swarm - Federated): {swarm_accuracy - federated_accuracy:.2f}%")


Swarm Learning - Total Time: 324.29s, Average Accuracy: 94.71%
Federated Learning - Total Time: 799.31s, Average Accuracy: 98.23%
Time Difference (Swarm - Federated): -475.01s
Accuracy Difference (Swarm - Federated): -3.52%
