<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Federated_Averaging_Algorithm_(FedAvg).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset

# Example dataset (let's create a synthetic dataset for demonstration)
data = torch.randn(1000, 10)  # 1000 samples, 10 features
labels = torch.randint(0, 2, (1000,))  # Binary classification

# Create a TensorDataset
dataset = TensorDataset(data, labels)

# Split the dataset into 5 local data splits
num_devices = 5
local_data_splits = random_split(dataset, [len(dataset) // num_devices] * num_devices)

# Define the model architecture
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)  # Example model with 10 input features and 2 output classes

    def forward(self, x):
        return self.fc(x)

# Local training function
def train_local_model(data, model, epochs=5, lr=0.01):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    data_loader = DataLoader(data, batch_size=32, shuffle=True)

    for epoch in range(epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# Function to average weights
def average_weights(local_models):
    global_weights = local_models[0]
    for key in global_weights.keys():
        for i in range(1, len(local_models)):
            global_weights[key] += local_models[i][key]
        global_weights[key] = torch.div(global_weights[key], len(local_models))
    return global_weights

# Assuming 'initial_model' is the same model architecture used on each device
initial_model = SimpleModel

# Train local models and collect their weights
local_models = []
for data in local_data_splits:
    model = initial_model()  # Initialize the model
    local_model_weights = train_local_model(data, model)
    local_models.append(local_model_weights)

# Server computes the average of the models
global_weights = average_weights(local_models)
global_model = initial_model()
global_model.load_state_dict(global_weights)

print("Federated learning process completed!")