<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning_(FL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

# Client-side update
def client_update(client_model, optimizer, train_loader, device, epoch=1):
    client_model.train()
    for _ in range(epoch):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()
    return client_model.state_dict()

# Server aggregation
def server_aggregate(global_model, client_states):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_states[i][k].float() for i in range(len(client_states))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    return global_model

# Evaluate global model
def evaluate(global_model, test_loader, device):
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = global_model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total

# Example usage
if __name__ == "__main__":
    # Simulated datasets for 3 clients
    x_train = torch.randn(600, 784)
    y_train = torch.randint(0, 10, (600,))
    client_datasets = [
        TensorDataset(x_train[i:i+200], y_train[i:i+200]) for i in range(0, 600, 200)
    ]
    clients = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in client_datasets]

    # Test dataset
    x_test = torch.randn(100, 784)
    y_test = torch.randint(0, 10, (100,))
    test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=32, shuffle=False)

    # Initialize global model and clients
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    global_model = SimpleNN(input_dim=784, output_dim=10).to(device)
    client_models = [SimpleNN(input_dim=784, output_dim=10).to(device) for _ in range(len(clients))]
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

    client_optimizers = [optim.SGD(model.parameters(), lr=0.01) for model in client_models]

    # Federated learning rounds
    num_rounds = 5
    for round in range(num_rounds):
        client_states = []
        for i, client_loader in enumerate(clients):
            client_state = client_update(client_models[i], client_optimizers[i], client_loader, device)
            client_states.append(client_state)
        global_model = server_aggregate(global_model, client_states)

        # Evaluate global model
        accuracy = evaluate(global_model, test_loader, device)
        print(f"Round {round + 1}, Accuracy: {accuracy:.4f}")