Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Lightning example #617

Merged
merged 9 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Meet the Flower community on [flower.dev](https://flower.dev)!
* [Installation](https://flower.dev/docs/installation.html)
* [Quickstart (TensorFlow)](https://flower.dev/docs/quickstart_tensorflow.html)
* [Quickstart (PyTorch)](https://flower.dev/docs/quickstart_pytorch.html)
* [Quickstart (PyTorch Lightning [code example])](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch_lightning)
* [Quickstart (MXNet)](https://flower.dev/docs/example-mxnet-walk-through.html)
* [Quickstart (scikit-learn [code example])](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

Expand All @@ -49,7 +50,9 @@ Quickstart examples:

* [Quickstart (TensorFlow)](https://github.com/adap/flower/tree/main/examples/quickstart_tensorflow)
* [Quickstart (PyTorch)](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch)
* [Quickstart (PyTorch Lightning)](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch_lightning)
* [Quickstart (MXNet)](https://github.com/adap/flower/tree/main/examples/quickstart_mxnet)
* [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

Other [examples](https://github.com/adap/flower/tree/main/examples):

Expand All @@ -58,7 +61,6 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
* [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet_from_centralized_to_federated)
* [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced_tensorflow)
* [Single-Machine Simulation of Federated Learning Systems](https://github.com/adap/flower/tree/main/examples/simulation)
* [Federated learning example with scikit-learn](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

## Flower Baselines / Datasets

Expand Down
2 changes: 2 additions & 0 deletions examples/quickstart_pytorch_lightning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lightning_logs
MNIST
70 changes: 70 additions & 0 deletions examples/quickstart_pytorch_lightning/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import flwr as fl
import mnist
import pytorch_lightning as pl
from collections import OrderedDict


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader

def get_parameters(self):
encoder_params = _get_parameters(self.model.encoder)
decoder_params = _get_parameters(self.model.decoder)
return encoder_params + decoder_params

def set_parameters(self, parameters):
_set_parameters(self.model.encoder, parameters[:4])
_set_parameters(self.model.decoder, parameters[4:])

def fit(self, parameters, config):
self.set_parameters(parameters)

trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=0)
trainer.fit(self.model, self.train_loader, self.val_loader)

return self.get_parameters(), 55000, {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)

trainer = pl.Trainer(progress_bar_refresh_rate=0)
results = trainer.test(self.model, self.test_loader)
loss = results[0]["test_loss"]

return loss, 5000, {"loss": loss}
danieljanes marked this conversation as resolved.
Show resolved Hide resolved


def _get_parameters(model):
return [val.cpu().numpy() for _, val in model.state_dict().items()]


def _set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def main() -> None:
# Model and data
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader)
fl.client.start_numpy_client("[::]:8080", client)


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions examples/quickstart_pytorch_lightning/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Adapted from the PyTorch Lightning quickstart example.

Source: https://pytorchlightning.ai/ (2021/02/04)
"""


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import pytorch_lightning as pl


class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28),
)

def forward(self, x):
embedding = self.encoder(x)
return embedding

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
self._evaluate(batch, "val")

def test_step(self, batch, batch_idx):
self._evaluate(batch, "test")

def _evaluate(self, batch, stage=None):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)


def load_data():
# Training / validation set
trainset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(trainset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=16)
val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=16)

# Test set
testset = MNIST("", train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=16)

return train_loader, val_loader, test_loader


def main() -> None:
"""Centralized training."""

# Load data
train_loader, val_loader, test_loader = load_data()

# Load model
model = LitAutoEncoder()

# Train
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=0)
trainer.fit(model, train_loader, val_loader)

# Test
trainer.test(model, test_loader)


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions examples/quickstart_pytorch_lightning/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[build-system]
requires = [
"poetry==1.1.4",
]
build-backend = "poetry.masonry.api"

[tool.poetry]
name = "quickstart_pytorch_lightning"
version = "0.1.0"
description = "Federated Learning Quickstart with Flower and PyTorch Lightning"
authors = ["The Flower Authors <enquiries@flower.dev>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = "^0.16.0" # For development: { path = "../../", develop = true }
pytorch-lightning = "^1.4.7"
torchvision = "^0.10.0"
14 changes: 14 additions & 0 deletions examples/quickstart_pytorch_lightning/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
20 changes: 20 additions & 0 deletions examples/quickstart_pytorch_lightning/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import flwr as fl


def main() -> None:
# Define strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5,
fraction_eval=0.5,
)

# Start Flower server for three rounds of federated learning
fl.server.start_server(
server_address="[::]:8080",
config={"num_rounds": 10},
strategy=strategy,
)


if __name__ == "__main__":
main()