# Consensus-Driven FedAvg Implementation

# 1. Local Model Training
## Each CLient trains a local model on it's private data.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.relu(self.fc1(x))
        return self.fc2

def train_client(model, data_loader, epochs=1, lr=0.01):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    return {name: param.data.clone().tolist() for name, param in model.state_dict().items()}


# 2. Peer-to-Peer Communication
## Clients exchange their model updates via a peer-to-peer communication layer.
### Example Communication Layer:

In [2]:
from flask import Flask, request, jsonify
from threading import Thread

app = Flask(__name__)
peer_updates = []

@app.route('/send_update', methods=['POST'])
def receive_update():
    update = request.json
    peer_updates.append(update)
    return jsonify({"message": "Update received"}), 200

@app.route('/get_updates', methods=['GET'])
def get_updates():
    return jsonify(peer_updates), 200

def run_peer_server(port):
    app.run(port=port)

# Start a peer server
server_thread = Thread(target=run_peer_server, args=(5000,))
server_thread.start()


# 3. Consensus Mechanism
## A consensus protocol ensures that all nodes agree on the set of updates to use for aggregation. For simplicity, we’ll use a basic voting mechanism where the majority determines valid updates.
### Voting Consensus Example:

In [3]:
import random

def consensus_voting(peer_updates, threshold=0.6):
    # Count votes for each update
    votes = {}
    for update in peer_updates:
        update_hash = str(update)  # Simplified unique identifier
        votes[update_hash] = votes.get(update_hash, 0) + 1

    # Select updates with sufficient votes
    valid_updates = [update for update, count in votes.items() if count / len(peer_updates) >= threshold]
    return valid_updates


# 4. Federated Averaging (FedAvg)
## Aggregate the selected updates into a global model.




In [4]:
def federated_averaging(valid_updates):
    aggregated_model = {}
    num_updates = len(valid_updates)

    # Average updates
    for key in valid_updates[0]:
        aggregated_model[key] = sum([torch.tensor(update[key]) for update in valid_updates]) / num_updates

    return aggregated_model


# 5. CDFA Workflow
## Integrate all components into a full pipeline

In [5]:
def consensus_driven_federated_averaging(num_rounds=5, num_clients=10, epochs=1):
    # Initialize global model
    global_model = SimpleNN()
    clients_data = []

    # Simulate data for clients
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    clients_data = random_split(dataset, [6000] * num_clients)

    for round_num in range(num_rounds):
        print(f"Round {round_num + 1}/{num_rounds}")
        peer_updates = []

        # Simulate clients training locally
        for client_id in range(num_clients):
            client_loader = DataLoader(clients_data[client_id], batch_size=32, shuffle=True)
            local_model = SimpleNN()
            local_model.load_state_dict(global_model.state_dict())
            client_update = train_client(local_model, client_loader, epochs=epochs)
            peer_updates.append(client_update)

        # Consensus mechanism to validate updates
        valid_updates = consensus_voting(peer_updates)

        # Aggregate valid updates
        aggregated_params = federated_averaging(valid_updates)

        # Update global model
        global_model.load_state_dict(aggregated_params)

    print("Consensus-Driven Federated Learning Complete")
    return global_model


# 6. Evaluate the Global Model
## Evaluate the aggregated model using a test dataset.

In [6]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Global Model Accuracy: {accuracy:.2f}%")


# 7. Run the Full Workflow


In [7]:
if __name__ == "__main__":
    global_model = consensus_driven_federated_averaging()

    # Load test data
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Evaluate the global model
    evaluate_model(global_model, test_loader)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [05:26<00:00, 30.4kB/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 119kB/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.07MB/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 1.52MB/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Round 1/5


TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not Linear