<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning_(FL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Define simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten input
        return self.fc(x)

# Client update function
def client_update(client_model, optimizer, train_loader, epoch):
    client_model.train()
    for _ in range(epoch):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()
    return client_model.state_dict()

# Server aggregation function
def server_aggregate(global_model, client_models):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict:
        global_state_dict[key] = torch.mean(torch.stack([client[key] for client in client_models]), dim=0)
    global_model.load_state_dict(global_state_dict)
    return global_model

# Evaluation function
def evaluate(global_model, test_loader):
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = global_model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    accuracy = 100.0 * correct / total
    print(f"Global Model Accuracy: {accuracy:.2f}%")

# Initialize data loaders
num_clients = 5
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
client_datasets = random_split(dataset, [len(dataset) // num_clients] * num_clients)
train_loaders = [DataLoader(ds, batch_size=32, shuffle=True) for ds in client_datasets]

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Initialize models and optimizers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = SimpleNN().to(device)
client_models = [SimpleNN().to(device) for _ in range(num_clients)]
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in client_models]

# Simulated federated learning loop
for round in range(10):  # Simulate 10 rounds of federated learning
    print(f"Round {round + 1}")
    client_updates = []

    # Distribute the global model to clients
    for client_model in client_models:
        client_model.load_state_dict(global_model.state_dict())

    # Each client performs local training
    for client_model, optimizer, train_loader in zip(client_models, client_optimizers, train_loaders):
        client_update_dict = client_update(client_model, optimizer, train_loader, epoch=1)
        client_updates.append(client_update_dict)

    # Aggregate updates at the server
    global_model = server_aggregate(global_model, client_updates)

    # Evaluate global model
    evaluate(global_model, test_loader)