<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning_(FL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define the SimpleNN class
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Apply ReLU activation
        return self.fc2(x)

# Define the train_model function
def train_model(model, data_loader, optimizer, criterion):
    model.train()
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)  # Compute loss
        loss.backward()  # Backpropagate loss
        optimizer.step()  # Update model parameters

# Define the federated_learning function
def federated_learning(models, global_model, data_loaders, epochs):
    for epoch in range(epochs):
        for model, data_loader in zip(models, data_loaders):
            train_model(model, data_loader, optim.SGD(model.parameters(), lr=0.01), nn.MSELoss())  # Train each local model

        global_model_state = global_model.state_dict()
        for key in global_model_state.keys():
            # Average parameters of all local models
            global_model_state[key] = torch.mean(torch.stack([model.state_dict()[key] for model in models]), dim=0)

        global_model.load_state_dict(global_model_state)  # Update global model
        for model in models:
            model.load_state_dict(global_model_state)  # Update local models with global model parameters

        print(f'Epoch {epoch + 1}/{epochs} complete.')

# Example usage
input_dim = 10
output_dim = 1
models = [SimpleNN(input_dim, output_dim) for _ in range(3)]  # Create three local models
global_model = SimpleNN(input_dim, output_dim)  # Create a global model

# Create data loaders with synthetic data for each local model
data_loaders = [DataLoader(TensorDataset(torch.randn(100, input_dim), torch.randn(100, output_dim)), batch_size=10) for _ in range(3)]

# Perform federated learning
federated_learning(models, global_model, data_loaders, epochs=5)