-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New fabric parity tests #16899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
New fabric parity tests #16899
Changes from all commits
Commits
Show all changes
98 commits
Select commit
Hold shift + click to select a range
64ba17e
experimental
awaelchli 98cd00d
wip
awaelchli 0974d89
wip
awaelchli ba307f6
wip
awaelchli ed453d0
wip
awaelchli f6273db
fix
awaelchli 164c994
update
awaelchli 5fd9f5c
update
awaelchli c2ec0d7
update
awaelchli 52c0f3f
update
awaelchli 49313f0
update
awaelchli c08e16c
update
awaelchli d2f6184
update
awaelchli 0cf71fb
update
awaelchli c713106
update
awaelchli b04d381
update
awaelchli 2b47e9c
update
awaelchli bdc3055
update
awaelchli 8747031
update
awaelchli 14bb8d9
update
awaelchli 2445026
update
awaelchli da23916
update
awaelchli 0ea1496
update
awaelchli ba84ba7
update
awaelchli caa7c03
update
awaelchli 2b57493
update
awaelchli c14e2c4
update
awaelchli 43b17e9
refactor
awaelchli 436a5e6
debug
awaelchli 826be20
Revert "debug"
awaelchli c9d5f19
Revert "refactor"
awaelchli ddbb113
update
awaelchli e8b79a5
update
awaelchli c1dff21
update
awaelchli ad369b8
update
awaelchli 0b35ef8
update
awaelchli 68b4888
update
awaelchli 41e7a25
update
awaelchli 929d604
update
awaelchli 71c77ca
update
awaelchli 727515f
update
awaelchli 1771443
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] bc37136
delete
awaelchli 3d8ad31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 76b1676
update
awaelchli 130880f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 11d5099
benchmark
awaelchli 667b174
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli 20c6672
update
awaelchli c5fa2bc
tuning
awaelchli 2d85b0d
run on gpu
awaelchli 72faa64
memory
awaelchli 0de9ba2
tolerance
awaelchli 905c5d6
memory
awaelchli 719088b
refactor
awaelchli 0bbe2cd
refactor
awaelchli 6f41053
safer check
awaelchli 1f6e987
reset peak
awaelchli 33d7c01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8ade03a
Merge branch 'master' into fabric/framework-overhead
awaelchli 5462c24
empty cache
awaelchli d7d3739
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli cbf24c1
Update tests/tests_fabric/parity/test_parity_simple.py
awaelchli 42facb7
Update tests/tests_fabric/parity/test_parity_simple.py
awaelchli d6e5227
cuda
awaelchli 80d5919
Experiment with tracking mode by @carmocca
awaelchli a703991
Revert "Experiment with tracking mode by @carmocca"
awaelchli af7a7e4
move assertions top
awaelchli 3162108
reset cuda memory stats before test
awaelchli 24c9f1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b59f51d
assertions across all devices
awaelchli 94e2777
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli d42ab34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 726d462
slow cpu
awaelchli b7a82b5
add requirement
awaelchli 1697904
tolerance
awaelchli 0c75321
bf16 skip windows
awaelchli 3a8b048
Merge branch 'master' into fabric/framework-overhead
awaelchli 603ec15
Merge branch 'master' into fabric/framework-overhead
awaelchli e5c836a
parity on cpu
awaelchli ea67af7
Merge branch 'master' into fabric/framework-overhead
awaelchli 8201fed
Merge branch 'master' into fabric/framework-overhead
Borda a52061f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1dd94b6
Merge branch 'master' into fabric/framework-overhead
awaelchli 5970e8c
update
awaelchli c7f865b
Update tests/tests_fabric/parity/test_parity_ddp.py
awaelchli 1b82ed5
Update tests/tests_fabric/conftest.py
awaelchli 6221b28
Update tests/tests_fabric/parity/test_parity_ddp.py
awaelchli f38c95d
parametrize backend
awaelchli c6a45c8
use equality
awaelchli 62a28de
add barrier
awaelchli f2ce109
comment about reusing processes
awaelchli f0edeba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 71342dc
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli 14bfc37
Merge branch 'master' into fabric/framework-overhead
awaelchli 0a42768
Merge branch 'master' into fabric/framework-overhead
awaelchli 4964d75
use the utility to clear cuda cache
awaelchli be6eda7
guard
awaelchli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Copyright The Lightning AI team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from abc import ABC, abstractmethod | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.optim import Optimizer | ||
| from torch.utils.data import DataLoader, TensorDataset | ||
|
|
||
|
|
||
| class ParityModel(ABC, nn.Module): | ||
| """Defines the interface for a model in a Fabric-PyTorch parity test.""" | ||
|
|
||
| # Benchmarking parameters that should be model-specific | ||
| batch_size = 1 | ||
| num_steps = 1 | ||
|
|
||
| @abstractmethod | ||
| def get_optimizer(self, *args, **kwargs) -> Optimizer: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def get_dataloader(self, *args, **kwargs) -> DataLoader: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def get_loss_function(self) -> Callable: | ||
| pass | ||
|
|
||
|
|
||
| class ConvNet(ParityModel): | ||
| batch_size = 4 | ||
| num_steps = 1000 | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.conv1 = nn.Conv2d(3, 6, 5) | ||
| self.pool = nn.MaxPool2d(2, 2) | ||
| self.conv2 = nn.Conv2d(6, 16, 5) | ||
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
| self.fc2 = nn.Linear(120, 84) | ||
| self.fc3 = nn.Linear(84, 10) | ||
|
|
||
| def forward(self, x): | ||
| x = self.pool(F.relu(self.conv1(x))) | ||
| x = self.pool(F.relu(self.conv2(x))) | ||
| x = torch.flatten(x, 1) # flatten all dimensions except batch | ||
| x = F.relu(self.fc1(x)) | ||
| x = F.relu(self.fc2(x)) | ||
| x = self.fc3(x) | ||
| return x | ||
|
|
||
| def get_optimizer(self): | ||
| return torch.optim.SGD(self.parameters(), lr=0.0001) | ||
|
|
||
| def get_dataloader(self): | ||
| # multiply * 8 just in case world size is larger than 1 | ||
| dataset_size = self.num_steps * self.batch_size * 8 | ||
| inputs = torch.rand(dataset_size, 3, 32, 32) | ||
| labels = torch.randint(0, 10, (dataset_size,)) | ||
| dataset = TensorDataset(inputs, labels) | ||
| dataloader = DataLoader( | ||
| dataset, | ||
| batch_size=self.batch_size, | ||
| num_workers=2, | ||
| ) | ||
| return dataloader | ||
|
|
||
| def get_loss_function(self): | ||
| return F.cross_entropy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Copyright The Lightning AI team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| import time | ||
| from copy import deepcopy | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed | ||
| import torch.nn.functional | ||
| from torch.nn.parallel.distributed import DistributedDataParallel | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data.distributed import DistributedSampler | ||
|
|
||
| from lightning.fabric.fabric import Fabric | ||
| from tests_fabric.helpers.runif import RunIf | ||
| from tests_fabric.parity.models import ConvNet | ||
| from tests_fabric.parity.utils import ( | ||
| cuda_reset, | ||
| is_cuda_memory_close, | ||
| is_state_dict_equal, | ||
| is_timing_close, | ||
| make_deterministic, | ||
| ) | ||
|
|
||
|
|
||
| def train_torch_ddp( | ||
| rank, | ||
| world_size, | ||
| device=torch.device("cpu"), | ||
| backend="nccl", | ||
| ): | ||
| make_deterministic() | ||
| memory_stats = {} | ||
|
|
||
| os.environ["LOCAL_RANK"] = str(rank) | ||
| torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) | ||
|
|
||
| model = ConvNet().to(device) | ||
| initial_state_dict = deepcopy(model.state_dict()) | ||
|
|
||
| ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None)) | ||
|
|
||
| dataloader = model.get_dataloader() | ||
| sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) | ||
| dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size) | ||
| optimizer = model.get_optimizer() | ||
| loss_fn = model.get_loss_function() | ||
|
|
||
| memory_stats["start"] = torch.cuda.memory_stats() | ||
|
|
||
| ddp_model.train() | ||
| iteration_timings = [] | ||
| iterator = iter(dataloader) | ||
| for _ in range(model.num_steps): | ||
| t0 = time.perf_counter() | ||
|
|
||
| inputs, labels = next(iterator) | ||
| inputs, labels = inputs.to(device), labels.to(device) | ||
| optimizer.zero_grad() | ||
| outputs = ddp_model(inputs) | ||
| loss = loss_fn(outputs, labels) | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| t1 = time.perf_counter() | ||
| iteration_timings.append(t1 - t0) | ||
|
|
||
| memory_stats["end"] = torch.cuda.memory_stats() | ||
|
|
||
| # check that the model has changed | ||
| assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) | ||
|
|
||
| return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats | ||
|
|
||
|
|
||
| def train_fabric_ddp(fabric): | ||
| make_deterministic() | ||
| memory_stats = {} | ||
|
|
||
| model = ConvNet() | ||
| initial_state_dict = deepcopy(model.state_dict()) | ||
|
|
||
| optimizer = model.get_optimizer() | ||
| model, optimizer = fabric.setup(model, optimizer) | ||
|
|
||
| dataloader = model.get_dataloader() | ||
| dataloader = fabric.setup_dataloaders(dataloader) | ||
| loss_fn = model.get_loss_function() | ||
|
|
||
| memory_stats["start"] = torch.cuda.memory_stats() | ||
|
|
||
| model.train() | ||
| iteration_timings = [] | ||
| iterator = iter(dataloader) | ||
| for _ in range(model.num_steps): | ||
| t0 = time.perf_counter() | ||
|
|
||
| inputs, labels = next(iterator) | ||
| optimizer.zero_grad() | ||
| outputs = model(inputs) | ||
| loss = loss_fn(outputs, labels) | ||
| fabric.backward(loss) | ||
| optimizer.step() | ||
|
|
||
| t1 = time.perf_counter() | ||
| iteration_timings.append(t1 - t0) | ||
|
|
||
| memory_stats["end"] = torch.cuda.memory_stats() | ||
|
|
||
| # check that the model has changed | ||
| assert not is_state_dict_equal(initial_state_dict, model.state_dict()) | ||
|
|
||
| return model.state_dict(), torch.tensor(iteration_timings), memory_stats | ||
|
|
||
|
|
||
| @pytest.mark.flaky(reruns=3) | ||
| @RunIf(standalone=True) | ||
| @pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") | ||
| @pytest.mark.parametrize( | ||
| "accelerator, devices, tolerance", | ||
| [ | ||
| ("cpu", 2, 0.01), | ||
| pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)), | ||
| ], | ||
| ) | ||
| def test_parity_ddp(accelerator, devices, tolerance): | ||
| cuda_reset() | ||
|
|
||
| # Launch processes with Fabric and re-use them for the PyTorch training for convenience | ||
| fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) | ||
| fabric.launch() | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Train with Fabric | ||
| state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) | ||
|
|
||
| fabric.barrier() | ||
| cuda_reset() | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
| # Train with raw PyTorch | ||
| state_dict_torch, timings_torch, memory_torch = train_torch_ddp( | ||
| rank=fabric.global_rank, | ||
| world_size=fabric.world_size, | ||
| device=fabric.device, | ||
| backend=fabric.strategy._process_group_backend, | ||
| ) | ||
|
|
||
| # Compare the final weights | ||
| assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric))) | ||
|
|
||
| # Compare the time per iteration | ||
| assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance))) | ||
|
|
||
| # Compare memory usage | ||
| if accelerator == "cuda": | ||
| assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]))) | ||
| assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]))) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.