# Blockchain Federated Learning

In [None]:
import hashlib
import torch
import torch.nn as nn
import torch.optim as optim
import random

## Blockchain Implementation

In [None]:
class Block:
    def __init__(self, index, data, previous_hash):
        self.index = index
        self.data = data  # Model parameters or metadata
        self.previous_hash = previous_hash
        self.hash = self.compute_hash()

    def compute_hash(self):
        block_string = f"{self.index}{self.data}{self.previous_hash}"
        return hashlib.sha256(block_string.encode()).hexdigest()


class Blockchain:
    def __init__(self):
        self.chain = []
        self.create_genesis_block()

    def create_genesis_block(self):
        genesis_block = Block(0, "Genesis Block", "0")
        self.chain.append(genesis_block)

    def add_block(self, data):
        previous_block = self.chain[-1]
        new_block = Block(len(self.chain), data, previous_block.hash)
        self.chain.append(new_block)

    def is_chain_valid(self):
        for i in range(1, len(self.chain)):
            current = self.chain[i]
            previous = self.chain[i - 1]

            if current.hash != current.compute_hash():
                return False
            if current.previous_hash != previous.hash:
                return False
        return True

    def print_chain(self):
        for block in self.chain:
            print(f"Block {block.index}: {block.data}, Hash: {block.hash}")

## Federated Learning Implementation

In [None]:
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class Peer:
    def __init__(self, peer_id, model, dataset, lr=0.01):
        self.peer_id = peer_id
        self.model = model
        self.dataset = dataset
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()

    def train_on_local_data(self, epochs=1):
        self.model.train()
        for epoch in range(epochs):
            for x, y in self.dataset:
                x, y = x.float(), y.long()
                self.optimizer.zero_grad()
                outputs = self.model(x)
                loss = self.criterion(outputs, y)
                loss.backward()
                self.optimizer.step()

    def get_model_state(self):
        return {k: v.clone() for k, v in self.model.state_dict().items()}

    def load_model_state(self, state_dict):
        self.model.load_state_dict(state_dict)


# Generate synthetic data for peers
def generate_synthetic_data(num_samples, input_size, num_classes):
    x = torch.rand(num_samples, input_size)
    y = torch.randint(0, num_classes, (num_samples,))
    return [(x[i], y[i]) for i in range(num_samples)]


## Blockchain Federated Learning Loop

In [None]:
def blockchain_federated_learning(num_peers=5, input_size=10, hidden_size=20, output_size=3, epochs=5, rounds=10):
    # Initialize blockchain
    blockchain = Blockchain()

    # Create peers
    peers = []
    for i in range(num_peers):
        model = SimpleNN(input_size, hidden_size, output_size)
        dataset = generate_synthetic_data(100, input_size, output_size)
        peers.append(Peer(peer_id=i, model=model, dataset=dataset))

    # Centralized global model for aggregation
    global_model = SimpleNN(input_size, hidden_size, output_size)

    for round_num in range(rounds):
        print(f"Round {round_num + 1}/{rounds}")

        # Step 1: Local Training
        for peer in peers:
            peer.train_on_local_data(epochs=epochs)

        # Step 2: Blockchain Recording
        peer_updates = []
        for peer in peers:
            model_state = peer.get_model_state()
            peer_updates.append(model_state)
            blockchain.add_block(f"Peer {peer.peer_id} updates: {hashlib.sha256(str(model_state).encode()).hexdigest()}")

        # Step 3: Aggregation
        global_state = {}
        for k in global_model.state_dict().keys():
            global_state[k] = sum([peer_update[k] for peer_update in peer_updates]) / len(peer_updates)

        global_model.load_state_dict(global_state)

        # Step 4: Share Global Model
        for peer in peers:
            peer.load_model_state(global_model.state_dict())

    # Final Blockchain Validation and Display
    print("Blockchain Valid:", blockchain.is_chain_valid())
    blockchain.print_chain()

    # Final Global Model Evaluation (e.g., on a test set)
    print("Training complete. The final global model is ready for evaluation.")


if __name__ == "__main__":
    blockchain_federated_learning()
