# Federated Learning with Differential Privacy
Using PyTorch, Opacus, and Flower

In [None]:
!pip install torch torchvision opacus flwr

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from opacus import PrivacyEngine
import flwr as fl

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(5408, 10)
        )

    def forward(self, x):
        return self.net(x)

def load_data():
    transform = transforms.Compose([transforms.ToTensor()])
    train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=False)
    return train_loader, test_loader

def train(model, train_loader, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    privacy_engine = PrivacyEngine()
    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=1.0,
        max_grad_norm=1.0,
    )
    for epoch in range(1):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model.state_dict()

class FLClient(fl.client.NumPyClient):
    def __init__(self):
        self.device = "cpu"
        self.model = CNN().to(self.device)
        self.train_loader, self.test_loader = load_data()

    def get_parameters(self, config):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def fit(self, parameters, config):
        params_dict = dict(zip(self.model.state_dict().keys(), parameters))
        self.model.load_state_dict(params_dict, strict=True)
        updated_params = train(self.model, self.train_loader, device=self.device)
        return [val.cpu().numpy() for val in updated_params.values()], len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        return 0.0, len(self.test_loader.dataset), {}

fl.client.start_numpy_client(server_address="localhost:8080", client=FLClient())

In [None]:
import flwr as fl

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    min_fit_clients=2,
    min_available_clients=2,
    min_evaluate_clients=2,
    on_fit_config_fn=lambda rnd: {"rnd": rnd},
)

if __name__ == "__main__":
    fl.server.start_server(
        server_address="localhost:8080",
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=strategy,
    )