In [2]:
%pip install torch torchvision syft
%pip install pysyft

Defaulting to user installation because normal site-packages is not writeable
[0mNote: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import syft as sy
import copy


In [4]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [5]:
transform = transforms.Compose([transforms.ToTensor()])

# Load dataset once
full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split for 2 clients (simulate decentralized data)
client1_data = Subset(full_dataset, range(0, 30000))
client2_data = Subset(full_dataset, range(30000, 60000))

client1_loader = DataLoader(client1_data, batch_size=64, shuffle=True)
client2_loader = DataLoader(client2_data, batch_size=64, shuffle=True)


In [6]:
def local_train(model, loader, epochs=1, lr=0.01):
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for data, target in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.state_dict()


In [7]:
def average_weights(weights_list):
    avg_weights = copy.deepcopy(weights_list[0])
    for key in avg_weights:
        for i in range(1, len(weights_list)):
            avg_weights[key] += weights_list[i][key]
        avg_weights[key] = avg_weights[key] / len(weights_list)
    return avg_weights


In [8]:
global_model = SimpleNet()

rounds = 5
for r in range(rounds):
    print(f"\n📦 Round {r + 1}")

    # Copy global model for local training
    local_model1 = SimpleNet()
    local_model1.load_state_dict(global_model.state_dict())

    local_model2 = SimpleNet()
    local_model2.load_state_dict(global_model.state_dict())

    # Local training
    w1 = local_train(local_model1, client1_loader)
    w2 = local_train(local_model2, client2_loader)

    # Federated averaging
    new_weights = average_weights([w1, w2])
    global_model.load_state_dict(new_weights)
    print(f"✅ Round {r + 1} complete")
print("\n🎉 Federated Learning Training Complete!")


📦 Round 1
✅ Round 1 complete

📦 Round 2
✅ Round 2 complete

📦 Round 3
✅ Round 3 complete

📦 Round 4
✅ Round 4 complete

📦 Round 5
✅ Round 5 complete

🎉 Federated Learning Training Complete!


In [9]:
test_loader = DataLoader(datasets.MNIST(root='./data', train=False, transform=transform), batch_size=64)

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
    accuracy = correct / len(loader.dataset)
    print(f"✅ Global model accuracy: {accuracy:.4f}")

evaluate(global_model, test_loader)


✅ Global model accuracy: 0.9002
