<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Example model: simple neural network for classification
class FederatedNN(nn.Module):
    def __init__(self):
        super(FederatedNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Federated learning function (simplified version)
def federated_training(global_model, local_models, data_loaders, num_rounds=5, device='cpu'):
    global_model.to(device)
    for local_model in local_models:
        local_model.to(device)

    for round in range(num_rounds):
        print(f"Starting Round {round+1}...")

        # Each device trains locally
        for i, local_model in enumerate(local_models):
            local_model.load_state_dict(global_model.state_dict())  # Synchronize with global model
            local_optimizer = optim.SGD(local_model.parameters(), lr=0.01)
            criterion = nn.CrossEntropyLoss()

            # Simulate local training
            local_model.train()
            for images, labels in data_loaders[i]:
                images = images.view(-1, 784).to(device)
                labels = labels.to(device)

                local_optimizer.zero_grad()
                outputs = local_model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                local_optimizer.step()

        # Aggregation of model weights (simplified FedAvg)
        with torch.no_grad():
            global_weights = global_model.state_dict()
            for key in global_weights:
                # Average the weights across all local models
                global_weights[key] = torch.mean(
                    torch.stack([local_model.state_dict()[key] for local_model in local_models]), dim=0
                )
            global_model.load_state_dict(global_weights)

        print(f"Round {round+1} completed.")
    return global_model

# Dataset and DataLoaders for each simulated client
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Partition dataset for 3 clients
client_datasets = random_split(full_dataset, [20000, 20000, len(full_dataset) - 40000])
data_loaders = [DataLoader(dataset, batch_size=64, shuffle=True) for dataset in client_datasets]

# Example setup: global and local models
global_model = FederatedNN()
local_models = [FederatedNN() for _ in range(3)]  # Three local models

# Federated training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
federated_model = federated_training(global_model, local_models, data_loaders, num_rounds=5, device=device)