<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Transfer_Learning_with_Federated_Aggregation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def federated_transfer(global_model, client_data, epochs=1, lr=0.01):
    for client in client_data:
        # Copy global model to local model
        local_model = SimpleNN()
        local_model.load_state_dict(global_model.state_dict())

        # Define loss and optimizer
        criterion = nn.MSELoss()
        optimizer = optim.SGD(local_model.parameters(), lr=lr)

        # Training loop
        local_model.train()
        for epoch in range(epochs):
            for data, target in client_data[client]:
                optimizer.zero_grad()
                output = local_model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

        # Get the local update
        local_update = {k: v - global_model.state_dict()[k] for k, v in local_model.state_dict().items()}

        # Aggregate updates into the global model
        global_model = aggregate_updates(global_model, local_update)

    return global_model

def aggregate_updates(global_model, local_update):
    # Example of simple averaging for aggregation
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] += local_update[k] / len(local_update)
    global_model.load_state_dict(global_dict)
    return global_model

# Example usage
global_model = SimpleNN()
client_data = {
    'client1': [(torch.randn(10), torch.tensor([1.0]))],
    'client2': [(torch.randn(10), torch.tensor([2.0]))]
}

updated_global_model = federated_transfer(global_model, client_data)