<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

def federated_training(models, data_loaders, global_model, num_clients, epochs=5):
    for epoch in range(epochs):
        client_weights = []
        for client_id in range(num_clients):
            model = models[client_id]
            data_loader = data_loaders[client_id]
            optimizer = optim.SGD(model.parameters(), lr=0.01)

            model.train()
            for data, target in data_loader:
                optimizer.zero_grad()
                output = model(data)
                loss = nn.CrossEntropyLoss()(output, target)
                loss.backward()
                optimizer.step()

            client_weights.append(model.state_dict())

        avg_weights = {key: torch.mean(torch.stack([w[key] for w in client_weights]), dim=0) for key in global_model.state_dict().keys()}
        global_model.load_state_dict(avg_weights)

# Example usage
input_dim = 10
output_dim = 3
num_clients = 5
models = [SimpleModel(input_dim, output_dim) for _ in range(num_clients)]
data_loaders = [DataLoader(TensorDataset(torch.randn(100, input_dim), torch.randint(0, output_dim, (100,))), batch_size=10) for _ in range(num_clients)]
global_model = SimpleModel(input_dim, output_dim)

federated_training(models, data_loaders, global_model, num_clients)
print("Federated Learning completed")