Skip to content

Commit

Permalink
PyTorch Lightning example (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes committed Sep 23, 2021
1 parent 7db5eb3 commit 1889037
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Meet the Flower community on [flower.dev](https://flower.dev)!
* [Installation](https://flower.dev/docs/installation.html)
* [Quickstart (TensorFlow)](https://flower.dev/docs/quickstart_tensorflow.html)
* [Quickstart (PyTorch)](https://flower.dev/docs/quickstart_pytorch.html)
* [Quickstart (PyTorch Lightning [code example])](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch_lightning)
* [Quickstart (MXNet)](https://flower.dev/docs/example-mxnet-walk-through.html)
* [Quickstart (scikit-learn [code example])](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

Expand All @@ -49,7 +50,9 @@ Quickstart examples:

* [Quickstart (TensorFlow)](https://github.com/adap/flower/tree/main/examples/quickstart_tensorflow)
* [Quickstart (PyTorch)](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch)
* [Quickstart (PyTorch Lightning)](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch_lightning)
* [Quickstart (MXNet)](https://github.com/adap/flower/tree/main/examples/quickstart_mxnet)
* [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

Other [examples](https://github.com/adap/flower/tree/main/examples):

Expand All @@ -58,7 +61,6 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
* [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet_from_centralized_to_federated)
* [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced_tensorflow)
* [Single-Machine Simulation of Federated Learning Systems](https://github.com/adap/flower/tree/main/examples/simulation)
* [Federated learning example with scikit-learn](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

## Flower Baselines / Datasets

Expand Down
2 changes: 2 additions & 0 deletions examples/quickstart_pytorch_lightning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lightning_logs
MNIST
70 changes: 70 additions & 0 deletions examples/quickstart_pytorch_lightning/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import flwr as fl
import mnist
import pytorch_lightning as pl
from collections import OrderedDict


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader

def get_parameters(self):
encoder_params = _get_parameters(self.model.encoder)
decoder_params = _get_parameters(self.model.decoder)
return encoder_params + decoder_params

def set_parameters(self, parameters):
_set_parameters(self.model.encoder, parameters[:4])
_set_parameters(self.model.decoder, parameters[4:])

def fit(self, parameters, config):
self.set_parameters(parameters)

trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=0)
trainer.fit(self.model, self.train_loader, self.val_loader)

return self.get_parameters(), 55000, {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)

trainer = pl.Trainer(progress_bar_refresh_rate=0)
results = trainer.test(self.model, self.test_loader)
loss = results[0]["test_loss"]

return loss, 10000, {"loss": loss}


def _get_parameters(model):
return [val.cpu().numpy() for _, val in model.state_dict().items()]


def _set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def main() -> None:
# Model and data
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader)
fl.client.start_numpy_client("[::]:8080", client)


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions examples/quickstart_pytorch_lightning/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Adapted from the PyTorch Lightning quickstart example.
Source: https://pytorchlightning.ai/ (2021/02/04)
"""


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import pytorch_lightning as pl


class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28),
)

def forward(self, x):
embedding = self.encoder(x)
return embedding

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
self._evaluate(batch, "val")

def test_step(self, batch, batch_idx):
self._evaluate(batch, "test")

def _evaluate(self, batch, stage=None):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)


def load_data():
# Training / validation set
trainset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(trainset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=16)
val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=16)

# Test set
testset = MNIST("", train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=16)

return train_loader, val_loader, test_loader


def main() -> None:
"""Centralized training."""

# Load data
train_loader, val_loader, test_loader = load_data()

# Load model
model = LitAutoEncoder()

# Train
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=0)
trainer.fit(model, train_loader, val_loader)

# Test
trainer.test(model, test_loader)


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions examples/quickstart_pytorch_lightning/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[build-system]
requires = [
"poetry==1.1.4",
]
build-backend = "poetry.masonry.api"

[tool.poetry]
name = "quickstart_pytorch_lightning"
version = "0.1.0"
description = "Federated Learning Quickstart with Flower and PyTorch Lightning"
authors = ["The Flower Authors <enquiries@flower.dev>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = "^0.16.0" # For development: { path = "../../", develop = true }
pytorch-lightning = "^1.4.7"
torchvision = "^0.10.0"
14 changes: 14 additions & 0 deletions examples/quickstart_pytorch_lightning/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
20 changes: 20 additions & 0 deletions examples/quickstart_pytorch_lightning/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import flwr as fl


def main() -> None:
# Define strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5,
fraction_eval=0.5,
)

# Start Flower server for three rounds of federated learning
fl.server.start_server(
server_address="[::]:8080",
config={"num_rounds": 10},
strategy=strategy,
)


if __name__ == "__main__":
main()

0 comments on commit 1889037

Please sign in to comment.