<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def local_training(model, train_loader, epochs, lr):
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def federated_averaging(models):
    global_model = models[0]
    for key in global_model.state_dict().keys():
        global_state = global_model.state_dict()[key]
        for model in models[1:]:
            global_state += model.state_dict()[key]
        global_state /= len(models)
    for model in models:
        model.load_state_dict(global_model.state_dict())

# Example usage
model1 = SimpleNN(10, 2)
model2 = SimpleNN(10, 2)
train_data1 = DataLoader(TensorDataset(torch.randn(100, 10), torch.randint(0, 2, (100,))), batch_size=32)
train_data2 = DataLoader(TensorDataset(torch.randn(100, 10), torch.randint(0, 2, (100,))), batch_size=32)

local_training(model1, train_data1, epochs=5, lr=0.01)
local_training(model2, train_data2, epochs=5, lr=0.01)
federated_averaging([model1, model2])