-
Notifications
You must be signed in to change notification settings - Fork 792
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7db5eb3
commit 1889037
Showing
7 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
lightning_logs | ||
MNIST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import flwr as fl | ||
import mnist | ||
import pytorch_lightning as pl | ||
from collections import OrderedDict | ||
|
||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
|
||
|
||
class FlowerClient(fl.client.NumPyClient): | ||
def __init__(self, model, train_loader, val_loader, test_loader): | ||
self.model = model | ||
self.train_loader = train_loader | ||
self.val_loader = val_loader | ||
self.test_loader = test_loader | ||
|
||
def get_parameters(self): | ||
encoder_params = _get_parameters(self.model.encoder) | ||
decoder_params = _get_parameters(self.model.decoder) | ||
return encoder_params + decoder_params | ||
|
||
def set_parameters(self, parameters): | ||
_set_parameters(self.model.encoder, parameters[:4]) | ||
_set_parameters(self.model.decoder, parameters[4:]) | ||
|
||
def fit(self, parameters, config): | ||
self.set_parameters(parameters) | ||
|
||
trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=0) | ||
trainer.fit(self.model, self.train_loader, self.val_loader) | ||
|
||
return self.get_parameters(), 55000, {} | ||
|
||
def evaluate(self, parameters, config): | ||
self.set_parameters(parameters) | ||
|
||
trainer = pl.Trainer(progress_bar_refresh_rate=0) | ||
results = trainer.test(self.model, self.test_loader) | ||
loss = results[0]["test_loss"] | ||
|
||
return loss, 10000, {"loss": loss} | ||
|
||
|
||
def _get_parameters(model): | ||
return [val.cpu().numpy() for _, val in model.state_dict().items()] | ||
|
||
|
||
def _set_parameters(model, parameters): | ||
params_dict = zip(model.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | ||
model.load_state_dict(state_dict, strict=True) | ||
|
||
|
||
def main() -> None: | ||
# Model and data | ||
model = mnist.LitAutoEncoder() | ||
train_loader, val_loader, test_loader = mnist.load_data() | ||
|
||
# Flower client | ||
client = FlowerClient(model, train_loader, val_loader, test_loader) | ||
fl.client.start_numpy_client("[::]:8080", client) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
"""Adapted from the PyTorch Lightning quickstart example. | ||
Source: https://pytorchlightning.ai/ (2021/02/04) | ||
""" | ||
|
||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
import pytorch_lightning as pl | ||
|
||
|
||
class LitAutoEncoder(pl.LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.encoder = nn.Sequential( | ||
nn.Linear(28 * 28, 64), | ||
nn.ReLU(), | ||
nn.Linear(64, 3), | ||
) | ||
self.decoder = nn.Sequential( | ||
nn.Linear(3, 64), | ||
nn.ReLU(), | ||
nn.Linear(64, 28 * 28), | ||
) | ||
|
||
def forward(self, x): | ||
embedding = self.encoder(x) | ||
return embedding | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
def training_step(self, train_batch, batch_idx): | ||
x, y = train_batch | ||
x = x.view(x.size(0), -1) | ||
z = self.encoder(x) | ||
x_hat = self.decoder(z) | ||
loss = F.mse_loss(x_hat, x) | ||
self.log("train_loss", loss) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
self._evaluate(batch, "val") | ||
|
||
def test_step(self, batch, batch_idx): | ||
self._evaluate(batch, "test") | ||
|
||
def _evaluate(self, batch, stage=None): | ||
x, y = batch | ||
x = x.view(x.size(0), -1) | ||
z = self.encoder(x) | ||
x_hat = self.decoder(z) | ||
loss = F.mse_loss(x_hat, x) | ||
if stage: | ||
self.log(f"{stage}_loss", loss, prog_bar=True) | ||
|
||
|
||
def load_data(): | ||
# Training / validation set | ||
trainset = MNIST("", train=True, download=True, transform=transforms.ToTensor()) | ||
mnist_train, mnist_val = random_split(trainset, [55000, 5000]) | ||
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=16) | ||
val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=16) | ||
|
||
# Test set | ||
testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) | ||
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=16) | ||
|
||
return train_loader, val_loader, test_loader | ||
|
||
|
||
def main() -> None: | ||
"""Centralized training.""" | ||
|
||
# Load data | ||
train_loader, val_loader, test_loader = load_data() | ||
|
||
# Load model | ||
model = LitAutoEncoder() | ||
|
||
# Train | ||
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=0) | ||
trainer.fit(model, train_loader, val_loader) | ||
|
||
# Test | ||
trainer.test(model, test_loader) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[build-system] | ||
requires = [ | ||
"poetry==1.1.4", | ||
] | ||
build-backend = "poetry.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "quickstart_pytorch_lightning" | ||
version = "0.1.0" | ||
description = "Federated Learning Quickstart with Flower and PyTorch Lightning" | ||
authors = ["The Flower Authors <enquiries@flower.dev>"] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.7" | ||
flwr = "^0.16.0" # For development: { path = "../../", develop = true } | ||
pytorch-lightning = "^1.4.7" | ||
torchvision = "^0.10.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
|
||
python server.py & | ||
sleep 3 # Sleep for 3s to give the server enough time to start | ||
|
||
for i in `seq 0 1`; do | ||
echo "Starting client $i" | ||
python client.py & | ||
done | ||
|
||
# This will allow you to use CTRL+C to stop all background processes | ||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM | ||
# Wait for all background processes to complete | ||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import flwr as fl | ||
|
||
|
||
def main() -> None: | ||
# Define strategy | ||
strategy = fl.server.strategy.FedAvg( | ||
fraction_fit=0.5, | ||
fraction_eval=0.5, | ||
) | ||
|
||
# Start Flower server for three rounds of federated learning | ||
fl.server.start_server( | ||
server_address="[::]:8080", | ||
config={"num_rounds": 10}, | ||
strategy=strategy, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |