<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Federated_Learning_for_Privacy_Preserving_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define simple neural network for federated learning
class SimpleNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Simulated federated learning function
def federated_learning(num_devices, data, model, num_rounds=5):
    global_model = model
    device_models = [SimpleNet(input_dim=4, output_dim=3) for _ in range(num_devices)]

    for round in range(num_rounds):
        # Simulate training on each device
        for device_model in device_models:
            device_optimizer = optim.SGD(device_model.parameters(), lr=0.01)
            device_model.load_state_dict(global_model.state_dict())  # Sync models

            # Train each model locally on device data
            for X, y in data:
                device_optimizer.zero_grad()
                outputs = device_model(X)
                loss = nn.CrossEntropyLoss()(outputs, y)
                loss.backward()
                device_optimizer.step()

        # Aggregate weights from all devices
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.mean(
                torch.stack([device_model.state_dict()[key] for device_model in device_models]), dim=0
            )
        global_model.load_state_dict(global_dict)

    return global_model

# Example usage (dummy data)
model = SimpleNet(input_dim=4, output_dim=3)
dummy_data = [(torch.randn(32, 4), torch.randint(0, 3, (32,)))]
trained_model = federated_learning(num_devices=5, data=dummy_data, model=model)