<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning_(FL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )

    def forward(self, x):
        return self.fc(x)

# Generate synthetic non-IID client data
def generate_non_iid_client_data(num_clients, num_samples_per_client):
    datasets = []
    for i in range(num_clients):
        x = torch.randn(num_samples_per_client, 2) + i  # Offset for non-IID distribution
        y = x.sum(dim=1, keepdim=True) + torch.randn(num_samples_per_client, 1)
        datasets.append(TensorDataset(x, y))
    return datasets

# Federated learning setup
num_clients = 5
num_samples_per_client = 100
client_datasets = generate_non_iid_client_data(num_clients, num_samples_per_client)
global_model = SimpleNN()

global_losses = []
for round in range(10):  # Training rounds
    client_models = []
    round_loss = 0  # Track loss for each round

    # Local training on each client
    for client_data in client_datasets:
        local_model = SimpleNN()
        local_model.load_state_dict(global_model.state_dict())  # Load global model parameters
        optimizer = optim.SGD(local_model.parameters(), lr=0.01)
        dataloader = DataLoader(client_data, batch_size=10, shuffle=True)

        for epoch in range(5):  # Local epochs
            for x_batch, y_batch in dataloader:
                optimizer.zero_grad()
                outputs = local_model(x_batch)
                loss = nn.MSELoss()(outputs, y_batch)
                loss.backward()
                optimizer.step()

        client_models.append(local_model.state_dict())
        round_loss += loss.item()  # Add final batch loss

    # Federated averaging
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        global_state_dict[key] = torch.stack([client_model[key] for client_model in client_models]).mean(dim=0)
    global_model.load_state_dict(global_state_dict)

    global_losses.append(round_loss / num_clients)
    print(f"Round {round + 1}, Average Loss: {global_losses[-1]:.4f}")