<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Federated_Averaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(2, 40)
        self.fc2 = nn.Linear(40, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to train model locally on client's data
def train_model_locally(model, data, epochs=5, lr=0.01):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    dataloader = DataLoader(data, batch_size=32, shuffle=True)

    for epoch in range(epochs):
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    return model

# Function to average weights
def average_weights(weights):
    avg_weights = weights[0]
    for key in avg_weights.keys():
        for weight in weights[1:]:
            avg_weights[key] += weight[key]
        avg_weights[key] = avg_weights[key] / len(weights)
    return avg_weights

# Create and initialize the global model
global_model = SimpleNet()

# Simulate data for clients
client_data = [
    TensorDataset(torch.randn(100, 2), torch.randn(100, 1)),  # Client 1
    TensorDataset(torch.randn(100, 2), torch.randn(100, 1)),  # Client 2
    # Add more clients as needed
]

# Federated learning loop
num_rounds = 10
for round in range(num_rounds):
    client_weights = []

    # Train on each client's data
    for data in client_data:
        local_model = SimpleNet()
        local_model.load_state_dict(global_model.state_dict())
        model_client = train_model_locally(local_model, data)
        client_weights.append(model_client.state_dict())

    # Average the weights
    average_weights_dict = average_weights(client_weights)
    global_model.load_state_dict(average_weights_dict)

    print(f"Round {round + 1}/{num_rounds} completed.")

print("Federated learning completed!")