From 64ba17e89fbce450bc604d6ca06df770af37a70c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 27 Feb 2023 16:12:41 +0100 Subject: [PATCH 01/86] experimental --- benchmark/train.py | 89 +++++++++++++++++++++++++++++++++++++ benchmark/train_fabric.py | 93 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 benchmark/train.py create mode 100644 benchmark/train_fabric.py diff --git a/benchmark/train.py b/benchmark/train.py new file mode 100644 index 0000000000000..a74352289c7f1 --- /dev/null +++ b/benchmark/train.py @@ -0,0 +1,89 @@ +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms + + +def main(): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + batch_size = 4 + + trainset = torchvision.datasets.CIFAR10( + root="~/data", train=True, download=True, transform=transform + ) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=batch_size, shuffle=True, num_workers=2 + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + net = Net() + net.to(DEVICE) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + iteration_timings = [] + iterator = iter(trainloader) + while True: + t0 = time.perf_counter() + try: + data = next(iterator) + except StopIteration: + break + + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data + inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + """ + median tensor(0.0013) + mean tensor(0.0015) + std tensor(0.0018) + """ + print("median", torch.median(torch.tensor(iteration_timings))) + print("mean", torch.mean(torch.tensor(iteration_timings))) + print("std", torch.std(torch.tensor(iteration_timings))) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/train_fabric.py b/benchmark/train_fabric.py new file mode 100644 index 0000000000000..332bc68180f9f --- /dev/null +++ b/benchmark/train_fabric.py @@ -0,0 +1,93 @@ +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html +import time + +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms + + +def main(): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + batch_size = 4 + + trainset = torchvision.datasets.CIFAR10( + root="~/data", train=True, download=True, transform=transform + ) + trainloader = torch.utils.data.DataLoader( + trainset, batch_size=batch_size, shuffle=True, num_workers=2 + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + net = Net() + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + fabric = L.Fabric(accelerator="cpu") + + setup_t0 = time.perf_counter() + net, optimizer = fabric.setup(net, optimizer) + trainloader = fabric.setup_dataloaders(trainloader) + setup_t1 = time.perf_counter() + print(f"setup time: {setup_t1-setup_t0} sec") + + iteration_timings = [] + iterator = iter(trainloader) + while True: + t0 = time.perf_counter() + try: + data = next(iterator) + except StopIteration: + break + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + """ + median tensor(0.0014) + mean tensor(0.0020) + std tensor(0.0502) + """ + print("median", torch.median(torch.tensor(iteration_timings))) + print("mean", torch.mean(torch.tensor(iteration_timings))) + print("std", torch.std(torch.tensor(iteration_timings))) + + +if __name__ == "__main__": + main() From 98cd00d61eae75696aa4a264b4a60f00cfe79cba Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 27 Feb 2023 23:56:03 +0100 Subject: [PATCH 02/86] wip --- benchmark/train.py | 89 ------------- benchmark/train_fabric.py | 147 +++++++++++++--------- tests/tests_fabric/parity/train_fabric.py | 120 ++++++++++++++++++ tests/tests_fabric/parity/utils.py | 22 ++++ tests/tests_fabric/test_parity.py | 22 +--- 5 files changed, 232 insertions(+), 168 deletions(-) delete mode 100644 benchmark/train.py create mode 100644 tests/tests_fabric/parity/train_fabric.py create mode 100644 tests/tests_fabric/parity/utils.py diff --git a/benchmark/train.py b/benchmark/train.py deleted file mode 100644 index a74352289c7f1..0000000000000 --- a/benchmark/train.py +++ /dev/null @@ -1,89 +0,0 @@ -# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html -import time - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torchvision -import torchvision.transforms as transforms - - -def main(): - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) - - batch_size = 4 - - trainset = torchvision.datasets.CIFAR10( - root="~/data", train=True, download=True, transform=transform - ) - trainloader = torch.utils.data.DataLoader( - trainset, batch_size=batch_size, shuffle=True, num_workers=2 - ) - - class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - net = Net() - net.to(DEVICE) - - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - iteration_timings = [] - iterator = iter(trainloader) - while True: - t0 = time.perf_counter() - try: - data = next(iterator) - except StopIteration: - break - - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data - inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) - - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = net(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - """ - median tensor(0.0013) - mean tensor(0.0015) - std tensor(0.0018) - """ - print("median", torch.median(torch.tensor(iteration_timings))) - print("mean", torch.mean(torch.tensor(iteration_timings))) - print("std", torch.std(torch.tensor(iteration_timings))) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/benchmark/train_fabric.py b/benchmark/train_fabric.py index 332bc68180f9f..984919e2ffd14 100644 --- a/benchmark/train_fabric.py +++ b/benchmark/train_fabric.py @@ -1,4 +1,3 @@ -# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html import time import lightning as L @@ -6,71 +5,94 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import torchvision -import torchvision.transforms as transforms - - -def main(): - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] +from torch.utils.data import TensorDataset, DataLoader + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_dataloader(dataset_size=100, batch_size=4): + inputs = torch.rand(dataset_size, 3, 32, 32) + labels = torch.randint(0, 10, (dataset_size, )) + dataset = TensorDataset(inputs, labels) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2, ) + return dataloader - batch_size = 4 - trainset = torchvision.datasets.CIFAR10( - root="~/data", train=True, download=True, transform=transform - ) - trainloader = torch.utils.data.DataLoader( - trainset, batch_size=batch_size, shuffle=True, num_workers=2 - ) +def make_deterministic(): + torch.use_deterministic_algorithms(True) + torch.manual_seed(1) + torch.cuda.manual_seed(1) - class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - net = Net() +def train_torch(steps=100, batch_size=4): + make_deterministic() + device = "cuda" if torch.cuda.is_available() else "cpu" + net = Net().to(device) + dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def train_fabric(steps=100, batch_size=4): + make_deterministic() fabric = L.Fabric(accelerator="cpu") - setup_t0 = time.perf_counter() + net = Net() + dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net, optimizer = fabric.setup(net, optimizer) - trainloader = fabric.setup_dataloaders(trainloader) - setup_t1 = time.perf_counter() - print(f"setup time: {setup_t1-setup_t0} sec") + dataloader = fabric.setup_dataloaders(dataloader) iteration_timings = [] - iterator = iter(trainloader) - while True: + iterator = iter(dataloader) + for _ in range(steps): t0 = time.perf_counter() - try: - data = next(iterator) - except StopIteration: - break - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data - - # zero the parameter gradients - optimizer.zero_grad() - # forward + backward + optimize + inputs, labels = next(iterator) + optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) fabric.backward(loss) @@ -79,15 +101,20 @@ def forward(self, x): t1 = time.perf_counter() iteration_timings.append(t1 - t0) - """ - median tensor(0.0014) - mean tensor(0.0020) - std tensor(0.0502) - """ - print("median", torch.median(torch.tensor(iteration_timings))) - print("mean", torch.mean(torch.tensor(iteration_timings))) - print("std", torch.std(torch.tensor(iteration_timings))) + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def compare(): + outputs_torch = train_torch(steps=2000) + outputs_fabric = train_fabric(steps=2000) + + # 3.5009579733014107e-06 + # 3.5009579733014107e-06 + median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) + mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) + print("median", median.abs().item()) + print("mean", mean.abs().item()) if __name__ == "__main__": - main() + compare() diff --git a/tests/tests_fabric/parity/train_fabric.py b/tests/tests_fabric/parity/train_fabric.py new file mode 100644 index 0000000000000..984919e2ffd14 --- /dev/null +++ b/tests/tests_fabric/parity/train_fabric.py @@ -0,0 +1,120 @@ +import time + +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import TensorDataset, DataLoader + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_dataloader(dataset_size=100, batch_size=4): + inputs = torch.rand(dataset_size, 3, 32, 32) + labels = torch.randint(0, 10, (dataset_size, )) + dataset = TensorDataset(inputs, labels) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2, + ) + return dataloader + + +def make_deterministic(): + torch.use_deterministic_algorithms(True) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + + +def train_torch(steps=100, batch_size=4): + make_deterministic() + device = "cuda" if torch.cuda.is_available() else "cpu" + net = Net().to(device) + dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def train_fabric(steps=100, batch_size=4): + make_deterministic() + fabric = L.Fabric(accelerator="cpu") + + net = Net() + dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + net, optimizer = fabric.setup(net, optimizer) + dataloader = fabric.setup_dataloaders(dataloader) + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def compare(): + outputs_torch = train_torch(steps=2000) + outputs_fabric = train_fabric(steps=2000) + + # 3.5009579733014107e-06 + # 3.5009579733014107e-06 + median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) + mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) + print("median", median.abs().item()) + print("mean", mean.abs().item()) + + +if __name__ == "__main__": + compare() diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py new file mode 100644 index 0000000000000..b37f43afb90cf --- /dev/null +++ b/tests/tests_fabric/parity/utils.py @@ -0,0 +1,22 @@ +from contextlib import contextmanager +from typing import Generator + +import torch +from torch import nn + + +def configure_optimizers(module: nn.Module): + return torch.optim.SGD(module.parameters(), lr=0.0001) + + +@contextmanager +def precision_context(precision, accelerator) -> Generator[None, None, None]: + if precision == 32: + yield + return + if accelerator == "gpu": + with torch.cuda.amp.autocast(): + yield + elif accelerator == "cpu": + with torch.cpu.amp.autocast(): + yield diff --git a/tests/tests_fabric/test_parity.py b/tests/tests_fabric/test_parity.py index a31419d6a0c2e..69b50e5dc840d 100644 --- a/tests/tests_fabric/test_parity.py +++ b/tests/tests_fabric/test_parity.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from contextlib import contextmanager from copy import deepcopy from functools import partial -from typing import Callable, Generator +from typing import Callable import pytest import torch @@ -36,6 +35,8 @@ from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.cloud_io import _atomic_save +from tests_fabric.parity.utils import configure_optimizers, precision_context + class BoringModel(nn.Module): def __init__(self): @@ -47,10 +48,6 @@ def forward(self, x): return torch.nn.functional.mse_loss(x, torch.ones_like(x)) -def configure_optimizers(module: nn.Module): - return torch.optim.SGD(module.parameters(), lr=0.0001) - - def main( move_to_device: Callable, model: nn.Module, @@ -93,19 +90,6 @@ def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = return checkpoint_path -@contextmanager -def precision_context(precision, accelerator) -> Generator[None, None, None]: - if precision == 32: - yield - return - if accelerator == "gpu": - with torch.cuda.amp.autocast(): - yield - elif accelerator == "cpu": - with torch.cpu.amp.autocast(): - yield - - @pytest.mark.parametrize( "precision, accelerator", [ From 0974d8933f3e41b1852d0fc68b6d197e3cb2dd04 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 00:18:43 +0100 Subject: [PATCH 03/86] wip --- tests/tests_fabric/parity/__init__.py | 0 tests/tests_fabric/parity/models.py | 75 ++++++++++++++ tests/tests_fabric/parity/test_parity.py | 93 +++++++++++++++++ tests/tests_fabric/parity/train_fabric.py | 120 ---------------------- tests/tests_fabric/parity/utils.py | 23 ++++- 5 files changed, 189 insertions(+), 122 deletions(-) create mode 100644 tests/tests_fabric/parity/__init__.py create mode 100644 tests/tests_fabric/parity/models.py create mode 100644 tests/tests_fabric/parity/test_parity.py delete mode 100644 tests/tests_fabric/parity/train_fabric.py diff --git a/tests/tests_fabric/parity/__init__.py b/tests/tests_fabric/parity/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py new file mode 100644 index 0000000000000..d15b4aa0a7b7e --- /dev/null +++ b/tests/tests_fabric/parity/models.py @@ -0,0 +1,75 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Callable + +import torch.nn as nn +from torch.optim import Optimizer +import torch +import torch.nn.functional as F +from torch.utils.data import TensorDataset, DataLoader + + +class ParityModel(ABC, nn.Module): + """Defines the interface for a model in a Fabric-PyTorch parity test.""" + + @abstractmethod + def get_optimizer(self, *args, **kwargs) -> Optimizer: + pass + + @abstractmethod + def get_dataloader(self, *args, **kwargs) -> DataLoader: + pass + + @abstractmethod + def get_loss_function(self) -> Callable: + pass + + +class ConvNet(ParityModel): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def get_optimizer(self): + return torch.optim.SGD(self.parameters(), lr=0.0001) + + def get_dataloader(self, dataset_size=100, batch_size=4): + inputs = torch.rand(dataset_size, 3, 32, 32) + labels = torch.randint(0, 10, (dataset_size, )) + dataset = TensorDataset(inputs, labels) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2, + ) + return dataloader + + def get_loss_function(self): + return F.cross_entropy diff --git a/tests/tests_fabric/parity/test_parity.py b/tests/tests_fabric/parity/test_parity.py new file mode 100644 index 0000000000000..350473b2340a4 --- /dev/null +++ b/tests/tests_fabric/parity/test_parity.py @@ -0,0 +1,93 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import lightning as L +import torch +import torch.nn as nn +from tests_fabric.parity.utils import make_deterministic +from tests_fabric.parity.models import ConvNet + + +def train_torch(steps=100, batch_size=4): + make_deterministic() + device = "cuda" if torch.cuda.is_available() else "cpu" + model = ConvNet().to(device) + dataloader = model.get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + loss_fn = model.get_loss_function() + optimizer = model.get_optimizer() + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def train_fabric(steps=100, batch_size=4): + make_deterministic() + fabric = L.Fabric(accelerator="cpu") + + model = ConvNet() + dataloader = model.get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + loss_fn = model.get_loss_function() + optimizer = model.get_optimizer() + + model, optimizer = fabric.setup(model, optimizer) + dataloader = fabric.setup_dataloaders(dataloader) + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return dict(iteration_timings=torch.tensor(iteration_timings)) + + +def test_compare(): + outputs_torch = train_torch(steps=2000) + outputs_fabric = train_fabric(steps=2000) + + # 3.5009579733014107e-06 + # 3.5009579733014107e-06 + median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) + mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) + print("median", median.abs().item()) + print("mean", mean.abs().item()) + + +if __name__ == "__main__": + compare() diff --git a/tests/tests_fabric/parity/train_fabric.py b/tests/tests_fabric/parity/train_fabric.py deleted file mode 100644 index 984919e2ffd14..0000000000000 --- a/tests/tests_fabric/parity/train_fabric.py +++ /dev/null @@ -1,120 +0,0 @@ -import time - -import lightning as L -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.utils.data import TensorDataset, DataLoader - - -class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -def get_dataloader(dataset_size=100, batch_size=4): - inputs = torch.rand(dataset_size, 3, 32, 32) - labels = torch.randint(0, 10, (dataset_size, )) - dataset = TensorDataset(inputs, labels) - dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - num_workers=2, - ) - return dataloader - - -def make_deterministic(): - torch.use_deterministic_algorithms(True) - torch.manual_seed(1) - torch.cuda.manual_seed(1) - - -def train_torch(steps=100, batch_size=4): - make_deterministic() - device = "cuda" if torch.cuda.is_available() else "cpu" - net = Net().to(device) - dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return dict(iteration_timings=torch.tensor(iteration_timings)) - - -def train_fabric(steps=100, batch_size=4): - make_deterministic() - fabric = L.Fabric(accelerator="cpu") - - net = Net() - dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - net, optimizer = fabric.setup(net, optimizer) - dataloader = fabric.setup_dataloaders(dataloader) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, labels) - fabric.backward(loss) - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return dict(iteration_timings=torch.tensor(iteration_timings)) - - -def compare(): - outputs_torch = train_torch(steps=2000) - outputs_fabric = train_fabric(steps=2000) - - # 3.5009579733014107e-06 - # 3.5009579733014107e-06 - median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) - mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) - print("median", median.abs().item()) - print("mean", mean.abs().item()) - - -if __name__ == "__main__": - compare() diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index b37f43afb90cf..cb1e34524bf9c 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from contextlib import contextmanager from typing import Generator @@ -5,8 +18,14 @@ from torch import nn -def configure_optimizers(module: nn.Module): - return torch.optim.SGD(module.parameters(), lr=0.0001) +def make_deterministic(): + torch.use_deterministic_algorithms(True) + torch.manual_seed(1) + torch.cuda.manual_seed(1) + + +# def configure_optimizers(module: nn.Module): +# return torch.optim.SGD(module.parameters(), lr=0.0001) @contextmanager From ba307f67575216217f7eecbf321725166d9d07ad Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 00:35:41 +0100 Subject: [PATCH 04/86] wip --- benchmark/train_fabric.py | 120 ------------------ .../{test_parity.py => test_timings.py} | 35 ++--- 2 files changed, 13 insertions(+), 142 deletions(-) delete mode 100644 benchmark/train_fabric.py rename tests/tests_fabric/parity/{test_parity.py => test_timings.py} (66%) diff --git a/benchmark/train_fabric.py b/benchmark/train_fabric.py deleted file mode 100644 index 984919e2ffd14..0000000000000 --- a/benchmark/train_fabric.py +++ /dev/null @@ -1,120 +0,0 @@ -import time - -import lightning as L -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.utils.data import TensorDataset, DataLoader - - -class Net(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) # flatten all dimensions except batch - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -def get_dataloader(dataset_size=100, batch_size=4): - inputs = torch.rand(dataset_size, 3, 32, 32) - labels = torch.randint(0, 10, (dataset_size, )) - dataset = TensorDataset(inputs, labels) - dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - num_workers=2, - ) - return dataloader - - -def make_deterministic(): - torch.use_deterministic_algorithms(True) - torch.manual_seed(1) - torch.cuda.manual_seed(1) - - -def train_torch(steps=100, batch_size=4): - make_deterministic() - device = "cuda" if torch.cuda.is_available() else "cpu" - net = Net().to(device) - dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return dict(iteration_timings=torch.tensor(iteration_timings)) - - -def train_fabric(steps=100, batch_size=4): - make_deterministic() - fabric = L.Fabric(accelerator="cpu") - - net = Net() - dataloader = get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - net, optimizer = fabric.setup(net, optimizer) - dataloader = fabric.setup_dataloaders(dataloader) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = net(inputs) - loss = criterion(outputs, labels) - fabric.backward(loss) - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return dict(iteration_timings=torch.tensor(iteration_timings)) - - -def compare(): - outputs_torch = train_torch(steps=2000) - outputs_fabric = train_fabric(steps=2000) - - # 3.5009579733014107e-06 - # 3.5009579733014107e-06 - median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) - mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) - print("median", median.abs().item()) - print("mean", mean.abs().item()) - - -if __name__ == "__main__": - compare() diff --git a/tests/tests_fabric/parity/test_parity.py b/tests/tests_fabric/parity/test_timings.py similarity index 66% rename from tests/tests_fabric/parity/test_parity.py rename to tests/tests_fabric/parity/test_timings.py index 350473b2340a4..db339af0b77f5 100644 --- a/tests/tests_fabric/parity/test_parity.py +++ b/tests/tests_fabric/parity/test_timings.py @@ -15,22 +15,21 @@ import lightning as L import torch -import torch.nn as nn from tests_fabric.parity.utils import make_deterministic from tests_fabric.parity.models import ConvNet -def train_torch(steps=100, batch_size=4): +def train_torch(num_steps=100, batch_size=4): make_deterministic() device = "cuda" if torch.cuda.is_available() else "cpu" model = ConvNet().to(device) - dataloader = model.get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) loss_fn = model.get_loss_function() optimizer = model.get_optimizer() iteration_timings = [] iterator = iter(dataloader) - for _ in range(steps): + for _ in range(num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) @@ -44,15 +43,15 @@ def train_torch(steps=100, batch_size=4): t1 = time.perf_counter() iteration_timings.append(t1 - t0) - return dict(iteration_timings=torch.tensor(iteration_timings)) + return torch.tensor(iteration_timings) -def train_fabric(steps=100, batch_size=4): +def train_fabric(num_steps=100, batch_size=4): make_deterministic() fabric = L.Fabric(accelerator="cpu") model = ConvNet() - dataloader = model.get_dataloader(dataset_size=(steps * batch_size), batch_size=batch_size) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) loss_fn = model.get_loss_function() optimizer = model.get_optimizer() @@ -61,7 +60,7 @@ def train_fabric(steps=100, batch_size=4): iteration_timings = [] iterator = iter(dataloader) - for _ in range(steps): + for _ in range(num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) @@ -74,20 +73,12 @@ def train_fabric(steps=100, batch_size=4): t1 = time.perf_counter() iteration_timings.append(t1 - t0) - return dict(iteration_timings=torch.tensor(iteration_timings)) + return torch.tensor(iteration_timings) -def test_compare(): - outputs_torch = train_torch(steps=2000) - outputs_fabric = train_fabric(steps=2000) +def test_parity_cpu(): + timings_torch = train_torch(num_steps=2000) + timings_fabric = train_fabric(num_steps=2000) - # 3.5009579733014107e-06 - # 3.5009579733014107e-06 - median = torch.median(outputs_fabric["iteration_timings"]) - torch.median(outputs_torch["iteration_timings"]) - mean = torch.mean(outputs_fabric["iteration_timings"]) - torch.mean(outputs_torch["iteration_timings"]) - print("median", median.abs().item()) - print("mean", mean.abs().item()) - - -if __name__ == "__main__": - compare() + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) From ed453d005a297e0326eef5adf6b544a43a3c19ac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 04:03:53 +0100 Subject: [PATCH 05/86] wip --- tests/tests_fabric/parity/models.py | 22 +- tests/tests_fabric/parity/test_correctness.py | 169 ++++++++++++++ tests/tests_fabric/parity/test_timings.py | 14 +- tests/tests_fabric/parity/utils.py | 10 +- tests/tests_fabric/test_parity.py | 210 ------------------ 5 files changed, 205 insertions(+), 220 deletions(-) create mode 100644 tests/tests_fabric/parity/test_correctness.py delete mode 100644 tests/tests_fabric/test_parity.py diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index d15b4aa0a7b7e..2968471f26817 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader +from tests_fabric.helpers.models import RandomDataset class ParityModel(ABC, nn.Module): @@ -37,6 +38,25 @@ def get_loss_function(self) -> Callable: pass +class BoringModel(ParityModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + def get_optimizer(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + def get_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(RandomDataset(32, 4), shuffle=True) + + def get_loss_function(self) -> Callable: + pass + + class ConvNet(ParityModel): def __init__(self): super().__init__() @@ -57,7 +77,7 @@ def forward(self, x): return x def get_optimizer(self): - return torch.optim.SGD(self.parameters(), lr=0.0001) + return torch.optim.SGD(module.parameters(), lr=0.0001) def get_dataloader(self, dataset_size=100, batch_size=4): inputs = torch.rand(dataset_size, 3, 32, 32) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py new file mode 100644 index 0000000000000..c5eb2927a6282 --- /dev/null +++ b/tests/tests_fabric/parity/test_correctness.py @@ -0,0 +1,169 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy +from functools import partial +from typing import Callable + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +import torch.nn.functional +from lightning_utilities.core.apply_func import apply_to_collection +from tests_fabric.helpers.runif import RunIf +from torch import nn, Tensor +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from lightning.fabric.fabric import Fabric +from lightning.fabric.plugins.environments.lightning import find_free_network_port +from lightning.fabric.strategies.ddp import DDPStrategy +from lightning.fabric.utilities.apply_func import move_data_to_device +from lightning.fabric.utilities.cloud_io import _atomic_save + +from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic +from tests_fabric.parity.models import BoringModel + + + +def train_torch( + move_to_device: Callable, + num_epochs: int = 1, + checkpoint_dir = ".", +): + make_deterministic() + model = BoringModel() + model = move_to_device(model) + train_dataloader = model.get_dataloader() + optimizer = model.get_optimizer() + + for _ in range(num_epochs): + model.train() + for batch in train_dataloader: + batch = move_to_device(batch) + optimizer.zero_grad() + loss = model(batch) + loss.backward() + optimizer.step() + + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + + +def train_torch_ddp( + rank, + world_size, + device = torch.device("cpu"), + num_epochs = 1, + checkpoint_dir = ".", +): + make_deterministic() + + os.environ["LOCAL_RANK"] = str(rank) + if torch.distributed.is_available() and not torch.distributed.is_initialized(): + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + model = BoringModel() + ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) + + train_dataloader = model.get_dataloader() + sampler = DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=True) + train_dataloader = DataLoader(train_dataloader.dataset, sampler=sampler) + optimizer = model.get_optimizer() + + for epoch in range(num_epochs): + sampler.set_epoch(epoch) + ddp_model.train() + for batch in train_dataloader: + batch = batch.to(device) + optimizer.zero_grad() + loss = ddp_model(batch) + loss.backward() + optimizer.step() + + if rank == 0: + _atomic_save(ddp_model.module.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + + +class FabricRunner(Fabric): + def run(self, num_epochs: int = 1, checkpoint_dir = "."): + make_deterministic() + model = BoringModel() + initial_state_dict = deepcopy(model.state_dict()) + optimizer = model.get_optimizer() + model, optimizer = self.setup(model, optimizer) + train_dataloader = self.setup_dataloaders(model.get_dataloader()) + + model.train() + for _ in range(num_epochs): + for batch in train_dataloader: + batch = self.to_device(batch) + optimizer.zero_grad() + loss = model(batch) + self.backward(loss) + optimizer.step() + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + + if self.global_rank == 0: + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) + + +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): + fabric = FabricRunner(precision=precision, accelerator=accelerator) + fabric.run(checkpoint_dir=tmpdir) + + with precision_context(precision, accelerator): + train_torch(fabric.to_device, checkpoint_dir=tmpdir) + + fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + + + +# @RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + (32, "ddp", 2, "cpu"), + # (32, "ddp", 2, "gpu"), + ], +) +def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): + fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + fabric.run(checkpoint_dir=tmpdir) + + with precision_context(precision, accelerator): + train_torch_ddp(rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir) + + tmpdir = fabric.broadcast(tmpdir) + + fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + # for w_pure, w_fabric in zip(pure_model_state_dict.values(), fabric_model_state_dict.values()): + # torch.testing.assert_close(w_pure.cpu(), w_fabric.cpu()) diff --git a/tests/tests_fabric/parity/test_timings.py b/tests/tests_fabric/parity/test_timings.py index db339af0b77f5..3de4a729fe9b4 100644 --- a/tests/tests_fabric/parity/test_timings.py +++ b/tests/tests_fabric/parity/test_timings.py @@ -13,15 +13,15 @@ # limitations under the License. import time -import lightning as L +from lightning.fabric import Fabric import torch from tests_fabric.parity.utils import make_deterministic from tests_fabric.parity.models import ConvNet -def train_torch(num_steps=100, batch_size=4): +def train_torch(rank=0, accelerator="cpu", devices=1, num_steps=100, batch_size=4): make_deterministic() - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device("cuda" if accelerator == "cuda" else "cpu", rank) model = ConvNet().to(device) dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) loss_fn = model.get_loss_function() @@ -48,7 +48,8 @@ def train_torch(num_steps=100, batch_size=4): def train_fabric(num_steps=100, batch_size=4): make_deterministic() - fabric = L.Fabric(accelerator="cpu") + fabric = Fabric(accelerator="cpu") + fabric.launch() model = ConvNet() dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) @@ -76,6 +77,11 @@ def train_fabric(num_steps=100, batch_size=4): return torch.tensor(iteration_timings) +def launch_fabric(): + fabric = Fabric() + fabric.launch(train_fabric, **kwargs) + + def test_parity_cpu(): timings_torch = train_torch(num_steps=2000) timings_fabric = train_fabric(num_steps=2000) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index cb1e34524bf9c..3c61de816a816 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -15,7 +15,6 @@ from typing import Generator import torch -from torch import nn def make_deterministic(): @@ -24,10 +23,6 @@ def make_deterministic(): torch.cuda.manual_seed(1) -# def configure_optimizers(module: nn.Module): -# return torch.optim.SGD(module.parameters(), lr=0.0001) - - @contextmanager def precision_context(precision, accelerator) -> Generator[None, None, None]: if precision == 32: @@ -39,3 +34,8 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: elif accelerator == "cpu": with torch.cpu.amp.autocast(): yield + + +def is_state_dict_equal(state0, state1): + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) diff --git a/tests/tests_fabric/test_parity.py b/tests/tests_fabric/test_parity.py deleted file mode 100644 index 69b50e5dc840d..0000000000000 --- a/tests/tests_fabric/test_parity.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from copy import deepcopy -from functools import partial -from typing import Callable - -import pytest -import torch -import torch.distributed -import torch.multiprocessing as mp -import torch.nn.functional -from lightning_utilities.core.apply_func import apply_to_collection -from tests_fabric.helpers.models import RandomDataset -from tests_fabric.helpers.runif import RunIf -from torch import nn, Tensor -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from lightning.fabric.fabric import Fabric -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.strategies.ddp import DDPStrategy -from lightning.fabric.utilities.apply_func import move_data_to_device -from lightning.fabric.utilities.cloud_io import _atomic_save - -from tests_fabric.parity.utils import configure_optimizers, precision_context - - -class BoringModel(nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2, bias=False) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - -def main( - move_to_device: Callable, - model: nn.Module, - train_dataloader: DataLoader, - num_epochs: int = 10, -): - model = move_to_device(model) - optimizer = configure_optimizers(model) - - for _ in range(num_epochs): - model.train() - for batch in train_dataloader: - batch = move_to_device(batch) - optimizer.zero_grad() - loss = model(batch) - loss.backward() - optimizer.step() - - return model.state_dict() - - -class FabricRunner(Fabric): - def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None): - optimizer = configure_optimizers(model) - model, optimizer = self.setup(model, optimizer) - train_dataloader = self.setup_dataloaders(train_dataloader) - - model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = self.to_device(batch) - optimizer.zero_grad() - loss = model(batch) - self.backward(loss) - optimizer.step() - - if isinstance(self._strategy, DDPStrategy) and tmpdir and self.global_rank == 0: - checkpoint_path = os.path.join(tmpdir, "model.pt") - _atomic_save(model.state_dict(), checkpoint_path) - return checkpoint_path - - -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -def test_boring_fabric_model_single_device(precision, accelerator): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 8)) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, accelerator=accelerator) - fabric.run(model, train_dataloader, num_epochs=num_epochs) - fabric_state_dict = model.state_dict() - - with precision_context(precision, accelerator): - model.load_state_dict(state_dict) - pure_state_dict = main(fabric.to_device, model, train_dataloader, num_epochs=num_epochs) - - state_dict = apply_to_collection(state_dict, Tensor, fabric.to_device) - for w_pure, w_fabric in zip(state_dict.values(), fabric_state_dict.values()): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - assert not torch.allclose(w_pure, w_fabric) - - for w_pure, w_fabric in zip(pure_state_dict.values(), fabric_state_dict.values()): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - assert torch.allclose(w_pure, w_fabric) - - -def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): - os.environ["LOCAL_RANK"] = str(rank) - if torch.distributed.is_available() and not torch.distributed.is_initialized(): - torch.distributed.init_process_group("gloo", rank=rank, world_size=2) - - to_device = partial(move_data_to_device, device=torch.device("cuda", rank)) - model = DistributedDataParallel( - to_device(model), - device_ids=[rank], - ) - train_dataloader = DataLoader( - train_dataloader.dataset, - sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False), - ) - with precision_context(precision, accelerator): - main(to_device, model, train_dataloader, num_epochs=num_epochs) - - if rank == 0: - _atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) - - -@pytest.mark.skip(reason="Skipping as it takes 80 seconds.") -@RunIf(min_cuda_gpus=2) -@pytest.mark.parametrize( - "precision, strategy, devices, accelerator", - [ - (32, "ddp_spawn", 2, "gpu"), - ], -) -def test_boring_fabric_model_ddp_spawn(precision, strategy, devices, accelerator, tmpdir): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 8)) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) - checkpoint_path = fabric.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) - spawn_model_state_dict = torch.load(checkpoint_path) - - for w_pure, w_fabric in zip(state_dict.values(), spawn_model_state_dict.values()): - assert not torch.equal(w_pure.cpu(), w_fabric.cpu()) - - model.load_state_dict(state_dict) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(find_free_network_port()) - mp.spawn(run, args=(model, train_dataloader, num_epochs, precision, accelerator, tmpdir), nprocs=2) - spawn_pure_model_state_dict = torch.load(os.path.join(tmpdir, "model_spawn.pt")) - - for w_pure, w_fabric in zip(spawn_pure_model_state_dict.values(), spawn_model_state_dict.values()): - assert torch.equal(w_pure.cpu(), w_fabric.cpu()) - - -@RunIf(min_cuda_gpus=2, standalone=True) -@pytest.mark.parametrize( - "precision, strategy, devices, accelerator", - [ - (32, "ddp", 2, "gpu"), - ], -) -def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) - fabric.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) - - fabric_model_state_dict = model.state_dict() - - for w_pure, w_fabric in zip(state_dict.values(), fabric_model_state_dict.values()): - assert not torch.allclose(w_pure.cpu(), w_fabric.cpu()) - - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) - model = BoringModel() - run(fabric.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) - pure_model_state_dict = model.state_dict() - - for w_pure, w_fabric in zip(pure_model_state_dict.values(), fabric_model_state_dict.values()): - torch.testing.assert_close(w_pure.cpu(), w_fabric.cpu()) From f6273db5f62dd2fd3a043b345eae1ca10e7dc5c4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 04:27:06 +0100 Subject: [PATCH 06/86] fix --- tests/tests_fabric/parity/models.py | 4 +- tests/tests_fabric/parity/test_correctness.py | 39 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 2968471f26817..72424c10aec53 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -51,7 +51,7 @@ def get_optimizer(self): return torch.optim.SGD(self.parameters(), lr=0.1) def get_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(RandomDataset(32, 4), shuffle=True) + return DataLoader(RandomDataset(32, 4)) def get_loss_function(self) -> Callable: pass @@ -81,7 +81,7 @@ def get_optimizer(self): def get_dataloader(self, dataset_size=100, batch_size=4): inputs = torch.rand(dataset_size, 3, 32, 32) - labels = torch.randint(0, 10, (dataset_size, )) + labels = torch.randint(0, 10, (dataset_size,)) dataset = TensorDataset(inputs, labels) dataloader = DataLoader( dataset, diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index c5eb2927a6282..cb1ad11c7c11e 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -38,11 +38,10 @@ from tests_fabric.parity.models import BoringModel - def train_torch( move_to_device: Callable, num_epochs: int = 1, - checkpoint_dir = ".", + checkpoint_dir=".", ): make_deterministic() model = BoringModel() @@ -50,8 +49,8 @@ def train_torch( train_dataloader = model.get_dataloader() optimizer = model.get_optimizer() + model.train() for _ in range(num_epochs): - model.train() for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() @@ -65,9 +64,9 @@ def train_torch( def train_torch_ddp( rank, world_size, - device = torch.device("cpu"), - num_epochs = 1, - checkpoint_dir = ".", + device=torch.device("cpu"), + num_epochs=1, + checkpoint_dir=".", ): make_deterministic() @@ -76,40 +75,53 @@ def train_torch_ddp( torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) model = BoringModel() + initial_state_dict = deepcopy(model.state_dict()) + ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) train_dataloader = model.get_dataloader() - sampler = DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=True) + sampler = DistributedSampler( + train_dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False + ) train_dataloader = DataLoader(train_dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() + ddp_model.train() for epoch in range(num_epochs): sampler.set_epoch(epoch) - ddp_model.train() for batch in train_dataloader: batch = batch.to(device) + print("torch", batch) optimizer.zero_grad() loss = ddp_model(batch) loss.backward() optimizer.step() + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) + if rank == 0: _atomic_save(ddp_model.module.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) class FabricRunner(Fabric): - def run(self, num_epochs: int = 1, checkpoint_dir = "."): + def run(self, num_epochs=1, checkpoint_dir="."): make_deterministic() + model = BoringModel() initial_state_dict = deepcopy(model.state_dict()) + optimizer = model.get_optimizer() model, optimizer = self.setup(model, optimizer) - train_dataloader = self.setup_dataloaders(model.get_dataloader()) + + dataloader = model.get_dataloader() + train_dataloader = self.setup_dataloaders(dataloader) model.train() for _ in range(num_epochs): for batch in train_dataloader: batch = self.to_device(batch) + print("fabric", batch) optimizer.zero_grad() loss = model(batch) self.backward(loss) @@ -144,7 +156,6 @@ def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): assert is_state_dict_equal(torch_state_dict, fabric_state_dict) - # @RunIf(min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", @@ -158,12 +169,12 @@ def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpd fabric.run(checkpoint_dir=tmpdir) with precision_context(precision, accelerator): - train_torch_ddp(rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir) + train_torch_ddp( + rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir + ) tmpdir = fabric.broadcast(tmpdir) fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) assert is_state_dict_equal(torch_state_dict, fabric_state_dict) - # for w_pure, w_fabric in zip(pure_model_state_dict.values(), fabric_model_state_dict.values()): - # torch.testing.assert_close(w_pure.cpu(), w_fabric.cpu()) From 164c994f42900a3d61af689df5bf9a423abc75a6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 04:30:07 +0100 Subject: [PATCH 07/86] update --- tests/tests_fabric/parity/test_correctness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index cb1ad11c7c11e..fa6774c8ac2c7 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -156,12 +156,12 @@ def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): assert is_state_dict_equal(torch_state_dict, fabric_state_dict) -# @RunIf(min_cuda_gpus=2, standalone=True) +@RunIf(standalone=True) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", [ (32, "ddp", 2, "cpu"), - # (32, "ddp", 2, "gpu"), + pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): From 5fd9f5c06538ae700b60c8ea77cd5cc1be34bd84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 04:33:57 +0100 Subject: [PATCH 08/86] update --- tests/tests_fabric/parity/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 3c61de816a816..922132b4f1826 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -18,6 +18,7 @@ def make_deterministic(): + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) torch.manual_seed(1) torch.cuda.manual_seed(1) From c2ec0d7d88009a67b8ebedadaef29a9b529d4f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 04:34:09 +0100 Subject: [PATCH 09/86] update --- tests/tests_fabric/parity/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 922132b4f1826..75e232365909c 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from contextlib import contextmanager from typing import Generator From 52c0f3ff0e64bc7284170e58cbe5794099f99113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 04:46:02 +0100 Subject: [PATCH 10/86] update --- tests/tests_fabric/parity/test_correctness.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index fa6774c8ac2c7..1620b8de89df7 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -27,6 +27,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from unittest import mock from lightning.fabric.fabric import Fabric from lightning.fabric.plugins.environments.lightning import find_free_network_port @@ -40,7 +41,8 @@ def train_torch( move_to_device: Callable, - num_epochs: int = 1, + precision_context, + num_epochs=1, checkpoint_dir=".", ): make_deterministic() @@ -54,7 +56,8 @@ def train_torch( for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() - loss = model(batch) + with precision_context(): + loss = model(batch) loss.backward() optimizer.step() @@ -144,12 +147,13 @@ def run(self, num_epochs=1, checkpoint_dir="."): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) +@mock.patch.dict(os.environ, {}, clear=True) def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): fabric = FabricRunner(precision=precision, accelerator=accelerator) fabric.run(checkpoint_dir=tmpdir) - with precision_context(precision, accelerator): - train_torch(fabric.to_device, checkpoint_dir=tmpdir) + precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) + train_torch(fabric.to_device, precision_context=precision_ctx, checkpoint_dir=tmpdir) fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) @@ -164,6 +168,7 @@ def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) +@mock.patch.dict(os.environ, {}, clear=True) def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.run(checkpoint_dir=tmpdir) From 49313f08a3c43e463ac7e2ac7eef49543aa309dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 05:44:18 +0100 Subject: [PATCH 11/86] update --- tests/tests_fabric/parity/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 75e232365909c..96e253a8dfa8f 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -19,8 +19,8 @@ def make_deterministic(): - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - torch.use_deterministic_algorithms(True) + # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + # torch.use_deterministic_algorithms(True) torch.manual_seed(1) torch.cuda.manual_seed(1) From c08e16c506d8f4600ac9074311441fe8991fa4c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 05:47:54 +0100 Subject: [PATCH 12/86] update --- tests/tests_fabric/parity/test_correctness.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 1620b8de89df7..3e9403dd20902 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -56,8 +56,8 @@ def train_torch( for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() - with precision_context(): - loss = model(batch) + # with precision_context(): + loss = model(batch) loss.backward() optimizer.step() @@ -124,7 +124,6 @@ def run(self, num_epochs=1, checkpoint_dir="."): for _ in range(num_epochs): for batch in train_dataloader: batch = self.to_device(batch) - print("fabric", batch) optimizer.zero_grad() loss = model(batch) self.backward(loss) From d2f6184208ae567df55a84332316d22d1bff8fbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 05:53:17 +0100 Subject: [PATCH 13/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 3e9403dd20902..b491cc84c65c6 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -148,7 +148,7 @@ def run(self, num_epochs=1, checkpoint_dir="."): ) @mock.patch.dict(os.environ, {}, clear=True) def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, accelerator=accelerator) + fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) fabric.run(checkpoint_dir=tmpdir) precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) From 0cf71fbd385b1f52b077692db3a13c440ffbd6d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 05:53:54 +0100 Subject: [PATCH 14/86] update --- tests/tests_fabric/parity/test_correctness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index b491cc84c65c6..6e27d7fc01606 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -56,8 +56,8 @@ def train_torch( for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() - # with precision_context(): - loss = model(batch) + with precision_context(): + loss = model(batch) loss.backward() optimizer.step() From c713106e534061b4c16d1469b8c6b50d251040c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 05:54:43 +0100 Subject: [PATCH 15/86] update --- tests/tests_fabric/parity/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 96e253a8dfa8f..75e232365909c 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -19,8 +19,8 @@ def make_deterministic(): - # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.use_deterministic_algorithms(True) torch.manual_seed(1) torch.cuda.manual_seed(1) From b04d381fa98cf639e200283dd282403e37dead9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:00:23 +0100 Subject: [PATCH 16/86] update --- tests/tests_fabric/parity/test_correctness.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 6e27d7fc01606..569e108fe87b6 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -48,18 +48,18 @@ def train_torch( make_deterministic() model = BoringModel() model = move_to_device(model) - train_dataloader = model.get_dataloader() - optimizer = model.get_optimizer() + with precision_context(): + train_dataloader = model.get_dataloader() + optimizer = model.get_optimizer() - model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = move_to_device(batch) - optimizer.zero_grad() - with precision_context(): - loss = model(batch) - loss.backward() - optimizer.step() + model.train() + for _ in range(num_epochs): + for batch in train_dataloader: + batch = move_to_device(batch) + optimizer.zero_grad() + loss = model(batch) + loss.backward() + optimizer.step() _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) From 2b47e9cf2aabc12dddb0e3dbb3d9dd01636761d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:00:45 +0100 Subject: [PATCH 17/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 569e108fe87b6..591103acf795b 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -57,7 +57,7 @@ def train_torch( for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() - loss = model(batch) + loss = model(batch) loss.backward() optimizer.step() From bdc30551445987c16f88482f8b67454478fd936d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:04:45 +0100 Subject: [PATCH 18/86] update --- tests/tests_fabric/parity/test_correctness.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 591103acf795b..d5616ba309a13 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -48,18 +48,22 @@ def train_torch( make_deterministic() model = BoringModel() model = move_to_device(model) - with precision_context(): - train_dataloader = model.get_dataloader() - optimizer = model.get_optimizer() + train_dataloader = model.get_dataloader() + optimizer = model.get_optimizer() - model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = move_to_device(batch) - optimizer.zero_grad() + model.train() + for _ in range(num_epochs): + for batch in train_dataloader: + batch = move_to_device(batch) + optimizer.zero_grad() + + precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} + dst_type = precision_to_type["bf16-mixed"] + batch = batch.to(dst_type) + with precision_context(): loss = model(batch) - loss.backward() - optimizer.step() + loss.backward() + optimizer.step() _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) From 8747031e8a6d8ee34b2647950fb9083e59b9050f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:07:33 +0100 Subject: [PATCH 19/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index d5616ba309a13..d04611b45eafa 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -62,6 +62,8 @@ def train_torch( batch = batch.to(dst_type) with precision_context(): loss = model(batch) + + loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor loss.backward() optimizer.step() From 14bb8d9a045fc8d762f22ce1706953e27c4fa687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:08:43 +0100 Subject: [PATCH 20/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index d04611b45eafa..1744312490c5d 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -66,6 +66,7 @@ def train_torch( loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor loss.backward() optimizer.step() + break _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) @@ -134,6 +135,7 @@ def run(self, num_epochs=1, checkpoint_dir="."): loss = model(batch) self.backward(loss) optimizer.step() + break # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) From 24450263bc413bcc4c1976c288071cb5adf3d28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:11:00 +0100 Subject: [PATCH 21/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 1744312490c5d..ae508e41c2a66 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -55,6 +55,7 @@ def train_torch( for _ in range(num_epochs): for batch in train_dataloader: batch = move_to_device(batch) + print("torch", batch) optimizer.zero_grad() precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} @@ -131,6 +132,7 @@ def run(self, num_epochs=1, checkpoint_dir="."): for _ in range(num_epochs): for batch in train_dataloader: batch = self.to_device(batch) + print("fabric", batch) optimizer.zero_grad() loss = model(batch) self.backward(loss) From da239160ed3f8d179faa64aca5704baeb6985c0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:12:28 +0100 Subject: [PATCH 22/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index ae508e41c2a66..10216e352f6bc 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -162,7 +162,7 @@ def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): fabric.run(checkpoint_dir=tmpdir) precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) - train_torch(fabric.to_device, precision_context=precision_ctx, checkpoint_dir=tmpdir) + train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) From 0ea1496effb987a1b41b9f968410978b0650bab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:13:59 +0100 Subject: [PATCH 23/86] update --- tests/tests_fabric/parity/test_correctness.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 10216e352f6bc..1f956ca484b39 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -55,7 +55,6 @@ def train_torch( for _ in range(num_epochs): for batch in train_dataloader: batch = move_to_device(batch) - print("torch", batch) optimizer.zero_grad() precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} @@ -67,7 +66,6 @@ def train_torch( loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor loss.backward() optimizer.step() - break _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) @@ -102,7 +100,6 @@ def train_torch_ddp( sampler.set_epoch(epoch) for batch in train_dataloader: batch = batch.to(device) - print("torch", batch) optimizer.zero_grad() loss = ddp_model(batch) loss.backward() @@ -132,12 +129,10 @@ def run(self, num_epochs=1, checkpoint_dir="."): for _ in range(num_epochs): for batch in train_dataloader: batch = self.to_device(batch) - print("fabric", batch) optimizer.zero_grad() loss = model(batch) self.backward(loss) optimizer.step() - break # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) From ba84ba7f6d09abaa5cfebe03b2f8ff637e1564ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:14:34 +0100 Subject: [PATCH 24/86] update --- tests/tests_fabric/parity/test_correctness.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 1f956ca484b39..37356e13f52e9 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -57,13 +57,13 @@ def train_torch( batch = move_to_device(batch) optimizer.zero_grad() - precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} - dst_type = precision_to_type["bf16-mixed"] - batch = batch.to(dst_type) + # precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} + # dst_type = precision_to_type["bf16-mixed"] + # batch = batch.to(dst_type) with precision_context(): loss = model(batch) - loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor + # loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor loss.backward() optimizer.step() From caa7c03aa02cc295cefa6a6639cc166aff1dcad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:15:57 +0100 Subject: [PATCH 25/86] update --- tests/tests_fabric/parity/test_correctness.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 37356e13f52e9..5f165c460c7ac 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -56,14 +56,8 @@ def train_torch( for batch in train_dataloader: batch = move_to_device(batch) optimizer.zero_grad() - - # precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} - # dst_type = precision_to_type["bf16-mixed"] - # batch = batch.to(dst_type) with precision_context(): loss = model(batch) - - # loss = loss.to(torch.get_default_dtype()) if torch.is_floating_point(loss) else tensor loss.backward() optimizer.step() @@ -146,7 +140,7 @@ def run(self, num_epochs=1, checkpoint_dir="."): [ (32, "cpu"), pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), pytest.param(32, "mps", marks=RunIf(mps=True)), ], From 2b574938ecfcee6a3638d56f3d3748e4714c40df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:38:24 +0100 Subject: [PATCH 26/86] update --- tests/tests_fabric/parity/models.py | 39 +++++---- tests/tests_fabric/parity/test_correctness.py | 81 ++++++++++--------- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 72424c10aec53..746a9b354f6c4 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -37,24 +37,24 @@ def get_dataloader(self, *args, **kwargs) -> DataLoader: def get_loss_function(self) -> Callable: pass - -class BoringModel(ParityModel): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2, bias=False) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - def get_optimizer(self): - return torch.optim.SGD(self.parameters(), lr=0.1) - - def get_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader(RandomDataset(32, 4)) - - def get_loss_function(self) -> Callable: - pass +# +# class BoringModel(ParityModel): +# def __init__(self): +# super().__init__() +# self.layer = torch.nn.Linear(32, 2, bias=False) +# +# def forward(self, x): +# x = self.layer(x) +# return torch.nn.functional.mse_loss(x, torch.ones_like(x)) +# +# def get_optimizer(self): +# return torch.optim.SGD(self.parameters(), lr=0.1) +# +# def get_dataloader(self, *args, **kwargs) -> DataLoader: +# return DataLoader(RandomDataset(32, 4)) +# +# def get_loss_function(self) -> Callable: +# pass class ConvNet(ParityModel): @@ -77,7 +77,7 @@ def forward(self, x): return x def get_optimizer(self): - return torch.optim.SGD(module.parameters(), lr=0.0001) + return torch.optim.SGD(self.parameters(), lr=0.0001) def get_dataloader(self, dataset_size=100, batch_size=4): inputs = torch.rand(dataset_size, 3, 32, 32) @@ -86,7 +86,6 @@ def get_dataloader(self, dataset_size=100, batch_size=4): dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=True, num_workers=2, ) return dataloader diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 5f165c460c7ac..2897f6a3815cb 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -36,30 +36,34 @@ from lightning.fabric.utilities.cloud_io import _atomic_save from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic -from tests_fabric.parity.models import BoringModel +from tests_fabric.parity.models import ConvNet def train_torch( move_to_device: Callable, precision_context, - num_epochs=1, + num_steps=1, + batch_size=4, checkpoint_dir=".", ): make_deterministic() - model = BoringModel() + model = ConvNet() model = move_to_device(model) - train_dataloader = model.get_dataloader() + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = move_to_device(batch) - optimizer.zero_grad() - with precision_context(): - loss = model(batch) - loss.backward() - optimizer.step() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) @@ -68,7 +72,8 @@ def train_torch_ddp( rank, world_size, device=torch.device("cpu"), - num_epochs=1, + num_steps=1, + batch_size=4, checkpoint_dir=".", ): make_deterministic() @@ -77,27 +82,29 @@ def train_torch_ddp( if torch.distributed.is_available() and not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - model = BoringModel() + model = ConvNet().to(device) initial_state_dict = deepcopy(model.state_dict()) ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) - train_dataloader = model.get_dataloader() + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) sampler = DistributedSampler( - train_dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False + dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False ) - train_dataloader = DataLoader(train_dataloader.dataset, sampler=sampler) + dataloader = DataLoader(dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() ddp_model.train() - for epoch in range(num_epochs): - sampler.set_epoch(epoch) - for batch in train_dataloader: - batch = batch.to(device) - optimizer.zero_grad() - loss = ddp_model(batch) - loss.backward() - optimizer.step() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + outputs = ddp_model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) @@ -107,26 +114,28 @@ def train_torch_ddp( class FabricRunner(Fabric): - def run(self, num_epochs=1, checkpoint_dir="."): + def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): make_deterministic() - model = BoringModel() + model = ConvNet() initial_state_dict = deepcopy(model.state_dict()) optimizer = model.get_optimizer() model, optimizer = self.setup(model, optimizer) - dataloader = model.get_dataloader() - train_dataloader = self.setup_dataloaders(dataloader) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = self.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = self.to_device(batch) - optimizer.zero_grad() - loss = model(batch) - self.backward(loss) - optimizer.step() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + self.backward(loss) + optimizer.step() # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) From c14e2c4f0723aa3f927e55919a90b3fc18e60a72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 06:39:18 +0100 Subject: [PATCH 27/86] update --- tests/tests_fabric/parity/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_correctness.py index 2897f6a3815cb..5e67345e430c4 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -89,7 +89,7 @@ def train_torch_ddp( dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) sampler = DistributedSampler( - dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False + dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False ) dataloader = DataLoader(dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() From 43b17e9efdb08e030c136061c02b6c9e1311ff80 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 07:31:05 +0100 Subject: [PATCH 28/86] refactor --- tests/tests_fabric/parity/models.py | 20 --- ...test_correctness.py => test_parity_ddp.py} | 108 +++++--------- .../tests_fabric/parity/test_parity_simple.py | 141 ++++++++++++++++++ tests/tests_fabric/parity/test_timings.py | 90 ----------- 4 files changed, 178 insertions(+), 181 deletions(-) rename tests/tests_fabric/parity/{test_correctness.py => test_parity_ddp.py} (56%) create mode 100644 tests/tests_fabric/parity/test_parity_simple.py delete mode 100644 tests/tests_fabric/parity/test_timings.py diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 746a9b354f6c4..16e42928173b0 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader -from tests_fabric.helpers.models import RandomDataset class ParityModel(ABC, nn.Module): @@ -37,25 +36,6 @@ def get_dataloader(self, *args, **kwargs) -> DataLoader: def get_loss_function(self) -> Callable: pass -# -# class BoringModel(ParityModel): -# def __init__(self): -# super().__init__() -# self.layer = torch.nn.Linear(32, 2, bias=False) -# -# def forward(self, x): -# x = self.layer(x) -# return torch.nn.functional.mse_loss(x, torch.ones_like(x)) -# -# def get_optimizer(self): -# return torch.optim.SGD(self.parameters(), lr=0.1) -# -# def get_dataloader(self, *args, **kwargs) -> DataLoader: -# return DataLoader(RandomDataset(32, 4)) -# -# def get_loss_function(self) -> Callable: -# pass - class ConvNet(ParityModel): def __init__(self): diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_parity_ddp.py similarity index 56% rename from tests/tests_fabric/parity/test_correctness.py rename to tests/tests_fabric/parity/test_parity_ddp.py index 5e67345e430c4..b47effa70ffa4 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import time from copy import deepcopy from functools import partial from typing import Callable @@ -21,7 +22,6 @@ import torch.distributed import torch.multiprocessing as mp import torch.nn.functional -from lightning_utilities.core.apply_func import apply_to_collection from tests_fabric.helpers.runif import RunIf from torch import nn, Tensor from torch.nn.parallel.distributed import DistributedDataParallel @@ -30,49 +30,19 @@ from unittest import mock from lightning.fabric.fabric import Fabric -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.strategies.ddp import DDPStrategy -from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.cloud_io import _atomic_save from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet - -def train_torch( - move_to_device: Callable, - precision_context, - num_steps=1, - batch_size=4, - checkpoint_dir=".", -): - make_deterministic() - model = ConvNet() - model = move_to_device(model) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - optimizer = model.get_optimizer() - loss_fn = model.get_loss_function() - - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) - optimizer.zero_grad() - with precision_context(): - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) +NUM_STEPS_DEFAULT = 2000 def train_torch_ddp( rank, world_size, device=torch.device("cpu"), - num_steps=1, + num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir=".", ): @@ -82,39 +52,47 @@ def train_torch_ddp( if torch.distributed.is_available() and not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - model = ConvNet().to(device) + model = ConvNet() initial_state_dict = deepcopy(model.state_dict()) ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) sampler = DistributedSampler( - dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False + dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False ) dataloader = DataLoader(dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() + iteration_timings = [] + ddp_model.train() iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) + inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) if rank == 0: - _atomic_save(ddp_model.module.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + state = dict(state_dict=ddp_model.module.state_dict(), iteration_timings=torch.tensor(iteration_timings)) + _atomic_save(state, os.path.join(checkpoint_dir, "torch_model.pt")) class FabricRunner(Fabric): - def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): + def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): make_deterministic() model = ConvNet() @@ -127,9 +105,13 @@ def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): dataloader = self.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() + iteration_timings = [] + model.train() iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) optimizer.zero_grad() outputs = model(inputs) @@ -137,34 +119,15 @@ def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): self.backward(loss) optimizer.step() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) if self.global_rank == 0: - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) - - -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -@mock.patch.dict(os.environ, {}, clear=True) -def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) - fabric.run(checkpoint_dir=tmpdir) - - precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) - train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) - - fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) - assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) + _atomic_save(state, os.path.join(checkpoint_dir, "fabric_model.pt")) @RunIf(standalone=True) @@ -175,18 +138,21 @@ def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) -@mock.patch.dict(os.environ, {}, clear=True) -def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): +def test_parity_ddp(precision, strategy, devices, accelerator, tmpdir): fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.run(checkpoint_dir=tmpdir) - with precision_context(precision, accelerator): - train_torch_ddp( - rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir - ) + train_torch_ddp( + rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir + ) tmpdir = fabric.broadcast(tmpdir) - fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) - assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + fabric_results = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_results = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(fabric_results["state_dict"], torch_results["state_dict"]) + + timings_fabric = fabric_results["iteration_timings"] + timings_torch = torch_results["iteration_timings"] + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py new file mode 100644 index 0000000000000..4c9bdf9682085 --- /dev/null +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -0,0 +1,141 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from copy import deepcopy +from functools import partial +from typing import Callable + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +import torch.nn.functional +from tests_fabric.helpers.runif import RunIf +from torch import nn, Tensor +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from unittest import mock + +from lightning.fabric.fabric import Fabric +from lightning.fabric.utilities.cloud_io import _atomic_save + +from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic +from tests_fabric.parity.models import ConvNet + +NUM_STEPS_DEFAULT = 2000 + + +def train_torch( + move_to_device: Callable, + precision_context, + num_steps=NUM_STEPS_DEFAULT, + batch_size=4, + checkpoint_dir=".", +): + make_deterministic() + model = ConvNet() + model = move_to_device(model) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + iteration_timings = [] + + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) + _atomic_save(state, os.path.join(checkpoint_dir, "torch_model.pt")) + + +class FabricRunner(Fabric): + def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): + make_deterministic() + + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) + + optimizer = model.get_optimizer() + model, optimizer = self.setup(model, optimizer) + + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = self.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() + + iteration_timings = [] + + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + self.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + + if self.global_rank == 0: + state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) + _atomic_save(state, os.path.join(checkpoint_dir, "fabric_model.pt")) + + +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_parity_single_device(precision, accelerator, tmpdir): + fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) + fabric.run(checkpoint_dir=tmpdir) + + train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) + + fabric_results = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_results = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(fabric_results["state_dict"], torch_results["state_dict"]) + + timings_fabric = fabric_results["iteration_timings"] + timings_torch = torch_results["iteration_timings"] + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_timings.py b/tests/tests_fabric/parity/test_timings.py deleted file mode 100644 index 3de4a729fe9b4..0000000000000 --- a/tests/tests_fabric/parity/test_timings.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time - -from lightning.fabric import Fabric -import torch -from tests_fabric.parity.utils import make_deterministic -from tests_fabric.parity.models import ConvNet - - -def train_torch(rank=0, accelerator="cpu", devices=1, num_steps=100, batch_size=4): - make_deterministic() - device = torch.device("cuda" if accelerator == "cuda" else "cpu", rank) - model = ConvNet().to(device) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - loss_fn = model.get_loss_function() - optimizer = model.get_optimizer() - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return torch.tensor(iteration_timings) - - -def train_fabric(num_steps=100, batch_size=4): - make_deterministic() - fabric = Fabric(accelerator="cpu") - fabric.launch() - - model = ConvNet() - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - loss_fn = model.get_loss_function() - optimizer = model.get_optimizer() - - model, optimizer = fabric.setup(model, optimizer) - dataloader = fabric.setup_dataloaders(dataloader) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - fabric.backward(loss) - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return torch.tensor(iteration_timings) - - -def launch_fabric(): - fabric = Fabric() - fabric.launch(train_fabric, **kwargs) - - -def test_parity_cpu(): - timings_torch = train_torch(num_steps=2000) - timings_fabric = train_fabric(num_steps=2000) - - # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) From 436a5e6854aada820d9f5404c78c2582f3410469 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 12:11:48 +0100 Subject: [PATCH 29/86] debug --- tests/tests_fabric/conftest.py | 2 ++ tests/tests_fabric/parity/test_parity_simple.py | 7 ------- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 56d87ed960f1e..fb1f09274c3cf 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -55,6 +55,7 @@ def restore_env_variables(): "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 "CRC32C_SW_MODE", # set by tensorboardX + "CUBLAS_WORKSPACE_CONFIG", # handled by `reset_deterministic_algorithm` fixture below } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -72,6 +73,7 @@ def teardown_process_group(): def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) torch.use_deterministic_algorithms(False) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 4c9bdf9682085..c5481851ecbc6 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -14,19 +14,13 @@ import os import time from copy import deepcopy -from functools import partial from typing import Callable import pytest import torch import torch.distributed -import torch.multiprocessing as mp import torch.nn.functional from tests_fabric.helpers.runif import RunIf -from torch import nn, Tensor -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from unittest import mock from lightning.fabric.fabric import Fabric @@ -124,7 +118,6 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) -@mock.patch.dict(os.environ, {}, clear=True) def test_parity_single_device(precision, accelerator, tmpdir): fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) fabric.run(checkpoint_dir=tmpdir) From 826be2037b54a753be1b86bb837e58a7f504f6d7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 12:11:48 +0100 Subject: [PATCH 30/86] Revert "debug" This reverts commit 436a5e6854aada820d9f5404c78c2582f3410469. --- tests/tests_fabric/conftest.py | 2 -- tests/tests_fabric/parity/test_parity_simple.py | 7 +++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index fb1f09274c3cf..56d87ed960f1e 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -55,7 +55,6 @@ def restore_env_variables(): "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 "CRC32C_SW_MODE", # set by tensorboardX - "CUBLAS_WORKSPACE_CONFIG", # handled by `reset_deterministic_algorithm` fixture below } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -73,7 +72,6 @@ def teardown_process_group(): def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield - os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) torch.use_deterministic_algorithms(False) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index c5481851ecbc6..4c9bdf9682085 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -14,13 +14,19 @@ import os import time from copy import deepcopy +from functools import partial from typing import Callable import pytest import torch import torch.distributed +import torch.multiprocessing as mp import torch.nn.functional from tests_fabric.helpers.runif import RunIf +from torch import nn, Tensor +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from unittest import mock from lightning.fabric.fabric import Fabric @@ -118,6 +124,7 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) +@mock.patch.dict(os.environ, {}, clear=True) def test_parity_single_device(precision, accelerator, tmpdir): fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) fabric.run(checkpoint_dir=tmpdir) From c9d5f19fdf5352e9b15a46ef4823d3341945ccb7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 07:31:05 +0100 Subject: [PATCH 31/86] Revert "refactor" This reverts commit 43b17e9efdb08e030c136061c02b6c9e1311ff80. --- tests/tests_fabric/parity/models.py | 20 +++ ...test_parity_ddp.py => test_correctness.py} | 108 +++++++++----- .../tests_fabric/parity/test_parity_simple.py | 141 ------------------ tests/tests_fabric/parity/test_timings.py | 90 +++++++++++ 4 files changed, 181 insertions(+), 178 deletions(-) rename tests/tests_fabric/parity/{test_parity_ddp.py => test_correctness.py} (56%) delete mode 100644 tests/tests_fabric/parity/test_parity_simple.py create mode 100644 tests/tests_fabric/parity/test_timings.py diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 16e42928173b0..746a9b354f6c4 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader +from tests_fabric.helpers.models import RandomDataset class ParityModel(ABC, nn.Module): @@ -36,6 +37,25 @@ def get_dataloader(self, *args, **kwargs) -> DataLoader: def get_loss_function(self) -> Callable: pass +# +# class BoringModel(ParityModel): +# def __init__(self): +# super().__init__() +# self.layer = torch.nn.Linear(32, 2, bias=False) +# +# def forward(self, x): +# x = self.layer(x) +# return torch.nn.functional.mse_loss(x, torch.ones_like(x)) +# +# def get_optimizer(self): +# return torch.optim.SGD(self.parameters(), lr=0.1) +# +# def get_dataloader(self, *args, **kwargs) -> DataLoader: +# return DataLoader(RandomDataset(32, 4)) +# +# def get_loss_function(self) -> Callable: +# pass + class ConvNet(ParityModel): def __init__(self): diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_correctness.py similarity index 56% rename from tests/tests_fabric/parity/test_parity_ddp.py rename to tests/tests_fabric/parity/test_correctness.py index b47effa70ffa4..5e67345e430c4 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_correctness.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import time from copy import deepcopy from functools import partial from typing import Callable @@ -22,6 +21,7 @@ import torch.distributed import torch.multiprocessing as mp import torch.nn.functional +from lightning_utilities.core.apply_func import apply_to_collection from tests_fabric.helpers.runif import RunIf from torch import nn, Tensor from torch.nn.parallel.distributed import DistributedDataParallel @@ -30,19 +30,49 @@ from unittest import mock from lightning.fabric.fabric import Fabric +from lightning.fabric.plugins.environments.lightning import find_free_network_port +from lightning.fabric.strategies.ddp import DDPStrategy +from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.cloud_io import _atomic_save from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet -NUM_STEPS_DEFAULT = 2000 + +def train_torch( + move_to_device: Callable, + precision_context, + num_steps=1, + batch_size=4, + checkpoint_dir=".", +): + make_deterministic() + model = ConvNet() + model = move_to_device(model) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) def train_torch_ddp( rank, world_size, device=torch.device("cpu"), - num_steps=NUM_STEPS_DEFAULT, + num_steps=1, batch_size=4, checkpoint_dir=".", ): @@ -52,47 +82,39 @@ def train_torch_ddp( if torch.distributed.is_available() and not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) - model = ConvNet() + model = ConvNet().to(device) initial_state_dict = deepcopy(model.state_dict()) ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) sampler = DistributedSampler( - dataloader.dataset, rank=rank, num_replicas=world_size, seed=1, drop_last=False, shuffle=False + dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False ) dataloader = DataLoader(dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() - iteration_timings = [] - ddp_model.train() iterator = iter(dataloader) for _ in range(num_steps): - t0 = time.perf_counter() - inputs, labels = next(iterator) - inputs, labels = inputs.to(device), labels.to(device) + inputs, labels = move_to_device(inputs), move_to_device(labels) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) if rank == 0: - state = dict(state_dict=ddp_model.module.state_dict(), iteration_timings=torch.tensor(iteration_timings)) - _atomic_save(state, os.path.join(checkpoint_dir, "torch_model.pt")) + _atomic_save(ddp_model.module.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) class FabricRunner(Fabric): - def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): + def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): make_deterministic() model = ConvNet() @@ -105,13 +127,9 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): dataloader = self.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() - iteration_timings = [] - model.train() iterator = iter(dataloader) for _ in range(num_steps): - t0 = time.perf_counter() - inputs, labels = next(iterator) optimizer.zero_grad() outputs = model(inputs) @@ -119,15 +137,34 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): self.backward(loss) optimizer.step() - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) if self.global_rank == 0: - state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) - _atomic_save(state, os.path.join(checkpoint_dir, "fabric_model.pt")) + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) + + +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): + fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) + fabric.run(checkpoint_dir=tmpdir) + + precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) + train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) + + fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(torch_state_dict, fabric_state_dict) @RunIf(standalone=True) @@ -138,21 +175,18 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) -def test_parity_ddp(precision, strategy, devices, accelerator, tmpdir): +@mock.patch.dict(os.environ, {}, clear=True) +def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.run(checkpoint_dir=tmpdir) - train_torch_ddp( - rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir - ) + with precision_context(precision, accelerator): + train_torch_ddp( + rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir + ) tmpdir = fabric.broadcast(tmpdir) - fabric_results = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_results = torch.load(os.path.join(tmpdir, "torch_model.pt")) - assert is_state_dict_equal(fabric_results["state_dict"], torch_results["state_dict"]) - - timings_fabric = fabric_results["iteration_timings"] - timings_torch = torch_results["iteration_timings"] - # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) + fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(torch_state_dict, fabric_state_dict) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py deleted file mode 100644 index 4c9bdf9682085..0000000000000 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import time -from copy import deepcopy -from functools import partial -from typing import Callable - -import pytest -import torch -import torch.distributed -import torch.multiprocessing as mp -import torch.nn.functional -from tests_fabric.helpers.runif import RunIf -from torch import nn, Tensor -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from unittest import mock - -from lightning.fabric.fabric import Fabric -from lightning.fabric.utilities.cloud_io import _atomic_save - -from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic -from tests_fabric.parity.models import ConvNet - -NUM_STEPS_DEFAULT = 2000 - - -def train_torch( - move_to_device: Callable, - precision_context, - num_steps=NUM_STEPS_DEFAULT, - batch_size=4, - checkpoint_dir=".", -): - make_deterministic() - model = ConvNet() - model = move_to_device(model) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - optimizer = model.get_optimizer() - loss_fn = model.get_loss_function() - - iteration_timings = [] - - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) - optimizer.zero_grad() - with precision_context(): - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) - _atomic_save(state, os.path.join(checkpoint_dir, "torch_model.pt")) - - -class FabricRunner(Fabric): - def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): - make_deterministic() - - model = ConvNet() - initial_state_dict = deepcopy(model.state_dict()) - - optimizer = model.get_optimizer() - model, optimizer = self.setup(model, optimizer) - - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - dataloader = self.setup_dataloaders(dataloader) - loss_fn = model.get_loss_function() - - iteration_timings = [] - - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - self.backward(loss) - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - # check that the model has changed - assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - - if self.global_rank == 0: - state = dict(state_dict=model.state_dict(), iteration_timings=torch.tensor(iteration_timings)) - _atomic_save(state, os.path.join(checkpoint_dir, "fabric_model.pt")) - - -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -@mock.patch.dict(os.environ, {}, clear=True) -def test_parity_single_device(precision, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) - fabric.run(checkpoint_dir=tmpdir) - - train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) - - fabric_results = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_results = torch.load(os.path.join(tmpdir, "torch_model.pt")) - assert is_state_dict_equal(fabric_results["state_dict"], torch_results["state_dict"]) - - timings_fabric = fabric_results["iteration_timings"] - timings_torch = torch_results["iteration_timings"] - # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_timings.py b/tests/tests_fabric/parity/test_timings.py new file mode 100644 index 0000000000000..3de4a729fe9b4 --- /dev/null +++ b/tests/tests_fabric/parity/test_timings.py @@ -0,0 +1,90 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from lightning.fabric import Fabric +import torch +from tests_fabric.parity.utils import make_deterministic +from tests_fabric.parity.models import ConvNet + + +def train_torch(rank=0, accelerator="cpu", devices=1, num_steps=100, batch_size=4): + make_deterministic() + device = torch.device("cuda" if accelerator == "cuda" else "cpu", rank) + model = ConvNet().to(device) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + loss_fn = model.get_loss_function() + optimizer = model.get_optimizer() + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return torch.tensor(iteration_timings) + + +def train_fabric(num_steps=100, batch_size=4): + make_deterministic() + fabric = Fabric(accelerator="cpu") + fabric.launch() + + model = ConvNet() + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + loss_fn = model.get_loss_function() + optimizer = model.get_optimizer() + + model, optimizer = fabric.setup(model, optimizer) + dataloader = fabric.setup_dataloaders(dataloader) + + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return torch.tensor(iteration_timings) + + +def launch_fabric(): + fabric = Fabric() + fabric.launch(train_fabric, **kwargs) + + +def test_parity_cpu(): + timings_torch = train_torch(num_steps=2000) + timings_fabric = train_fabric(num_steps=2000) + + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) From ddbb113a7363bf8be3b372d5077450e01cfe9063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:19:31 +0100 Subject: [PATCH 32/86] update --- .../tests_fabric/parity/test_parity_simple.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/tests_fabric/parity/test_parity_simple.py diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py new file mode 100644 index 0000000000000..fdf78b4d34dd0 --- /dev/null +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -0,0 +1,122 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy +from functools import partial +from typing import Callable + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +import torch.nn.functional +from lightning_utilities.core.apply_func import apply_to_collection +from tests_fabric.helpers.runif import RunIf +from torch import nn, Tensor +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from unittest import mock + +from lightning.fabric.fabric import Fabric +from lightning.fabric.plugins.environments.lightning import find_free_network_port +from lightning.fabric.strategies.ddp import DDPStrategy +from lightning.fabric.utilities.apply_func import move_data_to_device +from lightning.fabric.utilities.cloud_io import _atomic_save + +from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic +from tests_fabric.parity.models import ConvNet + + +def train_torch( + move_to_device: Callable, + precision_context, + num_steps=1, + batch_size=4, + checkpoint_dir=".", +): + make_deterministic() + model = ConvNet() + model = move_to_device(model) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + + +class FabricRunner(Fabric): + def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): + make_deterministic() + + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) + + optimizer = model.get_optimizer() + model, optimizer = self.setup(model, optimizer) + + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = self.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() + + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + self.backward(loss) + optimizer.step() + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + + if self.global_rank == 0: + _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) + + +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): + fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) + fabric.run(checkpoint_dir=tmpdir) + + precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) + train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) + + fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) + torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) + assert is_state_dict_equal(torch_state_dict, fabric_state_dict) From e8b79a5fd5e8dc40b6b0cde9515ab8b5bedccd92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:23:12 +0100 Subject: [PATCH 33/86] update --- tests/tests_fabric/conftest.py | 2 + ...test_correctness.py => test_parity_ddp.py} | 61 ++----------------- .../tests_fabric/parity/test_parity_simple.py | 12 ---- 3 files changed, 6 insertions(+), 69 deletions(-) rename tests/tests_fabric/parity/{test_correctness.py => test_parity_ddp.py} (69%) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 56d87ed960f1e..c1de4566b18a4 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -55,6 +55,7 @@ def restore_env_variables(): "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 "CRC32C_SW_MODE", # set by tensorboardX + "CUBLAS_WORKSPACE_CONFIG", # handled by the `reset_deterministic_algorithm` fixture below } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -72,6 +73,7 @@ def teardown_process_group(): def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) torch.use_deterministic_algorithms(False) diff --git a/tests/tests_fabric/parity/test_correctness.py b/tests/tests_fabric/parity/test_parity_ddp.py similarity index 69% rename from tests/tests_fabric/parity/test_correctness.py rename to tests/tests_fabric/parity/test_parity_ddp.py index 5e67345e430c4..a6649cf64a985 100644 --- a/tests/tests_fabric/parity/test_correctness.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -39,35 +39,6 @@ from tests_fabric.parity.models import ConvNet -def train_torch( - move_to_device: Callable, - precision_context, - num_steps=1, - batch_size=4, - checkpoint_dir=".", -): - make_deterministic() - model = ConvNet() - model = move_to_device(model) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - optimizer = model.get_optimizer() - loss_fn = model.get_loss_function() - - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) - optimizer.zero_grad() - with precision_context(): - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) - - def train_torch_ddp( rank, world_size, @@ -99,7 +70,7 @@ def train_torch_ddp( iterator = iter(dataloader) for _ in range(num_steps): inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) + inputs, labels = inputs.to(device), inputs.to(labels) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, labels) @@ -144,29 +115,6 @@ def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -@mock.patch.dict(os.environ, {}, clear=True) -def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) - fabric.run(checkpoint_dir=tmpdir) - - precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) - train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) - - fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) - assert is_state_dict_equal(torch_state_dict, fabric_state_dict) - - @RunIf(standalone=True) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", @@ -180,10 +128,9 @@ def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpd fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.run(checkpoint_dir=tmpdir) - with precision_context(precision, accelerator): - train_torch_ddp( - rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir - ) + train_torch_ddp( + rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir + ) tmpdir = fabric.broadcast(tmpdir) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index fdf78b4d34dd0..2273a054c4bc7 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -13,26 +13,16 @@ # limitations under the License. import os from copy import deepcopy -from functools import partial from typing import Callable import pytest import torch import torch.distributed -import torch.multiprocessing as mp import torch.nn.functional -from lightning_utilities.core.apply_func import apply_to_collection from tests_fabric.helpers.runif import RunIf -from torch import nn, Tensor -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from unittest import mock from lightning.fabric.fabric import Fabric -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.strategies.ddp import DDPStrategy -from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.cloud_io import _atomic_save from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic @@ -109,12 +99,10 @@ def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) -@mock.patch.dict(os.environ, {}, clear=True) def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) fabric.run(checkpoint_dir=tmpdir) - precision_ctx = partial(precision_context, precision=precision, accelerator=accelerator) train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) From c1dff2123bf5921762ecbd04993b4af97990f1f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:24:32 +0100 Subject: [PATCH 34/86] update --- tests/tests_fabric/parity/test_parity_ddp.py | 6 ++++-- tests/tests_fabric/parity/test_parity_simple.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index a6649cf64a985..383bfc057db86 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -38,12 +38,14 @@ from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet +NUM_STEPS_DEFAULT = 100 + def train_torch_ddp( rank, world_size, device=torch.device("cpu"), - num_steps=1, + num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir=".", ): @@ -85,7 +87,7 @@ def train_torch_ddp( class FabricRunner(Fabric): - def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): + def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): make_deterministic() model = ConvNet() diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 2273a054c4bc7..9181b6fc62832 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -20,7 +20,6 @@ import torch.distributed import torch.nn.functional from tests_fabric.helpers.runif import RunIf -from unittest import mock from lightning.fabric.fabric import Fabric from lightning.fabric.utilities.cloud_io import _atomic_save @@ -28,11 +27,13 @@ from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet +NUM_STEPS_DEFAULT = 100 + def train_torch( move_to_device: Callable, precision_context, - num_steps=1, + num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir=".", ): @@ -59,7 +60,7 @@ def train_torch( class FabricRunner(Fabric): - def run(self, num_steps=1, batch_size=4, checkpoint_dir="."): + def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): make_deterministic() model = ConvNet() From ad369b842a0551b42c6368c98f95996a2f9ce938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:39:01 +0100 Subject: [PATCH 35/86] update --- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 9181b6fc62832..10707007b4965 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -27,7 +27,7 @@ from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet -NUM_STEPS_DEFAULT = 100 +NUM_STEPS_DEFAULT = 32 def train_torch( From 0b35ef81453ba5b50555bf2e4f419e3cee20c556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:39:51 +0100 Subject: [PATCH 36/86] update --- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 10707007b4965..f44b35337132c 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -27,7 +27,7 @@ from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet -NUM_STEPS_DEFAULT = 32 +NUM_STEPS_DEFAULT = 2 def train_torch( From 68b488874416316e1df6d013173858e6babb70fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 12:57:56 +0100 Subject: [PATCH 37/86] update --- tests/tests_fabric/conftest.py | 2 +- tests/tests_fabric/parity/test_parity_ddp.py | 66 ++++++++----------- .../tests_fabric/parity/test_parity_simple.py | 61 ++++++++--------- 3 files changed, 57 insertions(+), 72 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index c1de4566b18a4..42a364ea4f9a3 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -55,7 +55,7 @@ def restore_env_variables(): "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 "CRC32C_SW_MODE", # set by tensorboardX - "CUBLAS_WORKSPACE_CONFIG", # handled by the `reset_deterministic_algorithm` fixture below + "CUBLAS_WORKSPACE_CONFIG", # handled by the `reset_deterministic_algorithm` fixture below } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 383bfc057db86..eb445338a1a8f 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -47,7 +47,6 @@ def train_torch_ddp( device=torch.device("cpu"), num_steps=NUM_STEPS_DEFAULT, batch_size=4, - checkpoint_dir=".", ): make_deterministic() @@ -60,7 +59,7 @@ def train_torch_ddp( ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * world_size), batch_size=batch_size) sampler = DistributedSampler( dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False ) @@ -82,39 +81,36 @@ def train_torch_ddp( # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) - if rank == 0: - _atomic_save(ddp_model.module.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + return ddp_model.module.state_dict() -class FabricRunner(Fabric): - def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): - make_deterministic() +def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): + make_deterministic() - model = ConvNet() - initial_state_dict = deepcopy(model.state_dict()) + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) - optimizer = model.get_optimizer() - model, optimizer = self.setup(model, optimizer) + optimizer = model.get_optimizer() + model, optimizer = fabric.setup(model, optimizer) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - dataloader = self.setup_dataloaders(dataloader) - loss_fn = model.get_loss_function() + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * fabric.world_size), batch_size=batch_size) + dataloader = fabric.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - self.backward(loss) - optimizer.step() + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() - # check that the model has changed - assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - if self.global_rank == 0: - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) + return model.state_dict() @RunIf(standalone=True) @@ -125,17 +121,13 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) -@mock.patch.dict(os.environ, {}, clear=True) -def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) - fabric.run(checkpoint_dir=tmpdir) +def test_parity_ddp(precision, strategy, devices, accelerator, tmpdir): + fabric = Fabric(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + fabric.launch() - train_torch_ddp( - rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, checkpoint_dir=tmpdir + fabric_state_dict = train_fabric_ddp(fabric) + torch_state_dict = train_torch_ddp( + rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device ) - tmpdir = fabric.broadcast(tmpdir) - - fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) assert is_state_dict_equal(torch_state_dict, fabric_state_dict) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index f44b35337132c..0897f73d7843c 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -22,12 +22,10 @@ from tests_fabric.helpers.runif import RunIf from lightning.fabric.fabric import Fabric -from lightning.fabric.utilities.cloud_io import _atomic_save - from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet -NUM_STEPS_DEFAULT = 2 +NUM_STEPS_DEFAULT = 100 def train_torch( @@ -35,7 +33,6 @@ def train_torch( precision_context, num_steps=NUM_STEPS_DEFAULT, batch_size=4, - checkpoint_dir=".", ): make_deterministic() model = ConvNet() @@ -56,38 +53,36 @@ def train_torch( loss.backward() optimizer.step() - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "torch_model.pt")) + return model.state_dict() -class FabricRunner(Fabric): - def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): - make_deterministic() +def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): + make_deterministic() - model = ConvNet() - initial_state_dict = deepcopy(model.state_dict()) + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) - optimizer = model.get_optimizer() - model, optimizer = self.setup(model, optimizer) + optimizer = model.get_optimizer() + model, optimizer = fabric.setup(model, optimizer) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - dataloader = self.setup_dataloaders(dataloader) - loss_fn = model.get_loss_function() + dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = fabric.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() - model.train() - iterator = iter(dataloader) - for _ in range(num_steps): - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - self.backward(loss) - optimizer.step() + model.train() + iterator = iter(dataloader) + for _ in range(num_steps): + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() - # check that the model has changed - assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - if self.global_rank == 0: - _atomic_save(model.state_dict(), os.path.join(checkpoint_dir, "fabric_model.pt")) + return model.state_dict() @pytest.mark.parametrize( @@ -100,12 +95,10 @@ def run(self, num_steps=NUM_STEPS_DEFAULT, batch_size=4, checkpoint_dir="."): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) -def test_boring_fabric_model_single_device(precision, accelerator, tmpdir): - fabric = FabricRunner(precision=precision, accelerator=accelerator, devices=1) - fabric.run(checkpoint_dir=tmpdir) +def test_parity_single_device(precision, accelerator, tmpdir): + fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) - train_torch(fabric.to_device, precision_context=fabric.autocast, checkpoint_dir=tmpdir) + fabric_state_dict = train_fabric(fabric) + torch_state_dict = train_torch(fabric.to_device, precision_context=fabric.autocast) - fabric_state_dict = torch.load(os.path.join(tmpdir, "fabric_model.pt")) - torch_state_dict = torch.load(os.path.join(tmpdir, "torch_model.pt")) assert is_state_dict_equal(torch_state_dict, fabric_state_dict) From 41e7a25455fba8e5c7153f3fb060673a0e9f4319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 13:01:02 +0100 Subject: [PATCH 38/86] update --- tests/tests_fabric/parity/test_parity_ddp.py | 1 - .../tests_fabric/parity/test_parity_simple.py | 24 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index eb445338a1a8f..1b7758a5b0979 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -33,7 +33,6 @@ from lightning.fabric.plugins.environments.lightning import find_free_network_port from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.utilities.apply_func import move_data_to_device -from lightning.fabric.utilities.cloud_io import _atomic_save from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 0897f73d7843c..1a5545d3b23ce 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import time from copy import deepcopy from typing import Callable @@ -42,8 +43,11 @@ def train_torch( loss_fn = model.get_loss_function() model.train() + iteration_timings = [] iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) inputs, labels = move_to_device(inputs), move_to_device(labels) optimizer.zero_grad() @@ -53,7 +57,10 @@ def train_torch( loss.backward() optimizer.step() - return model.state_dict() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + return model.state_dict(), torch.tensor(iteration_timings) def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): @@ -70,8 +77,11 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): loss_fn = model.get_loss_function() model.train() + iteration_timings = [] iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) optimizer.zero_grad() outputs = model(inputs) @@ -79,10 +89,13 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): fabric.backward(loss) optimizer.step() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - return model.state_dict() + return model.state_dict(), torch.tensor(iteration_timings) @pytest.mark.parametrize( @@ -98,7 +111,10 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): def test_parity_single_device(precision, accelerator, tmpdir): fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) - fabric_state_dict = train_fabric(fabric) - torch_state_dict = train_torch(fabric.to_device, precision_context=fabric.autocast) + fabric_state_dict, timings_fabric = train_fabric(fabric) + torch_state_dict, timings_torch = train_torch(fabric.to_device, precision_context=fabric.autocast) assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) From 929d6045ffdb24066bf60551a1e68326175a505d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 13:07:57 +0100 Subject: [PATCH 39/86] update --- tests/tests_fabric/parity/test_parity_ddp.py | 33 +++++-- .../tests_fabric/parity/test_parity_simple.py | 9 +- tests/tests_fabric/parity/test_timings.py | 90 ------------------- 3 files changed, 35 insertions(+), 97 deletions(-) delete mode 100644 tests/tests_fabric/parity/test_timings.py diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 1b7758a5b0979..823ce9511c7bf 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import time from copy import deepcopy from functools import partial from typing import Callable @@ -67,8 +68,11 @@ def train_torch_ddp( loss_fn = model.get_loss_function() ddp_model.train() + iteration_timings = [] iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) inputs, labels = inputs.to(device), inputs.to(labels) optimizer.zero_grad() @@ -77,10 +81,13 @@ def train_torch_ddp( loss.backward() optimizer.step() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) - return ddp_model.module.state_dict() + return ddp_model.module.state_dict(), torch.tensor(iteration_timings) def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): @@ -97,8 +104,11 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): loss_fn = model.get_loss_function() model.train() + iteration_timings = [] iterator = iter(dataloader) for _ in range(num_steps): + t0 = time.perf_counter() + inputs, labels = next(iterator) optimizer.zero_grad() outputs = model(inputs) @@ -106,13 +116,17 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): fabric.backward(loss) optimizer.step() + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - return model.state_dict() + return model.state_dict(), torch.tensor(iteration_timings) @RunIf(standalone=True) +# @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", [ @@ -121,12 +135,21 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): ], ) def test_parity_ddp(precision, strategy, devices, accelerator, tmpdir): + # Train with Fabric fabric = Fabric(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.launch() + fabric_state_dict, timings_fabric = train_fabric_ddp(fabric) - fabric_state_dict = train_fabric_ddp(fabric) - torch_state_dict = train_torch_ddp( - rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device + # Train with raw PyTorch + torch_state_dict, timings_torch = train_torch_ddp( + rank=fabric.global_rank, + world_size=fabric.world_size, + device=fabric.device, ) + # Compare the final weights assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + + # Compare the time per iteration + # The median is more robust to outliers than the mean + assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 1a5545d3b23ce..2218e19e777d8 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -26,7 +26,7 @@ from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic from tests_fabric.parity.models import ConvNet -NUM_STEPS_DEFAULT = 100 +NUM_STEPS_DEFAULT = 1000 def train_torch( @@ -98,6 +98,7 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): return model.state_dict(), torch.tensor(iteration_timings) +@pytest.mark.flaky(reruns=3) @pytest.mark.parametrize( "precision, accelerator", [ @@ -109,12 +110,16 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): ], ) def test_parity_single_device(precision, accelerator, tmpdir): + # Train with Fabric fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) - fabric_state_dict, timings_fabric = train_fabric(fabric) + + # Train with raw PyTorch torch_state_dict, timings_torch = train_torch(fabric.to_device, precision_context=fabric.autocast) + # Compare the final weights assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + # Compare the time per iteration # The median is more robust to outliers than the mean assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_timings.py b/tests/tests_fabric/parity/test_timings.py deleted file mode 100644 index 3de4a729fe9b4..0000000000000 --- a/tests/tests_fabric/parity/test_timings.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time - -from lightning.fabric import Fabric -import torch -from tests_fabric.parity.utils import make_deterministic -from tests_fabric.parity.models import ConvNet - - -def train_torch(rank=0, accelerator="cpu", devices=1, num_steps=100, batch_size=4): - make_deterministic() - device = torch.device("cuda" if accelerator == "cuda" else "cpu", rank) - model = ConvNet().to(device) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - loss_fn = model.get_loss_function() - optimizer = model.get_optimizer() - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - loss.backward() - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return torch.tensor(iteration_timings) - - -def train_fabric(num_steps=100, batch_size=4): - make_deterministic() - fabric = Fabric(accelerator="cpu") - fabric.launch() - - model = ConvNet() - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) - loss_fn = model.get_loss_function() - optimizer = model.get_optimizer() - - model, optimizer = fabric.setup(model, optimizer) - dataloader = fabric.setup_dataloaders(dataloader) - - iteration_timings = [] - iterator = iter(dataloader) - for _ in range(num_steps): - t0 = time.perf_counter() - - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - fabric.backward(loss) - optimizer.step() - - t1 = time.perf_counter() - iteration_timings.append(t1 - t0) - - return torch.tensor(iteration_timings) - - -def launch_fabric(): - fabric = Fabric() - fabric.launch(train_fabric, **kwargs) - - -def test_parity_cpu(): - timings_torch = train_torch(num_steps=2000) - timings_fabric = train_fabric(num_steps=2000) - - # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) From 71c77ca50d588a1024110a5f4159567d9b9818e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 13:08:18 +0100 Subject: [PATCH 40/86] update --- tests/tests_fabric/parity/models.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 746a9b354f6c4..e1b9030753cfb 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -37,25 +37,6 @@ def get_dataloader(self, *args, **kwargs) -> DataLoader: def get_loss_function(self) -> Callable: pass -# -# class BoringModel(ParityModel): -# def __init__(self): -# super().__init__() -# self.layer = torch.nn.Linear(32, 2, bias=False) -# -# def forward(self, x): -# x = self.layer(x) -# return torch.nn.functional.mse_loss(x, torch.ones_like(x)) -# -# def get_optimizer(self): -# return torch.optim.SGD(self.parameters(), lr=0.1) -# -# def get_dataloader(self, *args, **kwargs) -> DataLoader: -# return DataLoader(RandomDataset(32, 4)) -# -# def get_loss_function(self) -> Callable: -# pass - class ConvNet(ParityModel): def __init__(self): From 727515f0ab006ea40c1258dce26c4cc9233eab9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Feb 2023 13:10:31 +0100 Subject: [PATCH 41/86] update --- tests/tests_fabric/parity/models.py | 1 - tests/tests_fabric/parity/test_parity_ddp.py | 2 +- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index e1b9030753cfb..16e42928173b0 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader -from tests_fabric.helpers.models import RandomDataset class ParityModel(ABC, nn.Module): diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 823ce9511c7bf..a504965b01b72 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -134,7 +134,7 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) -def test_parity_ddp(precision, strategy, devices, accelerator, tmpdir): +def test_parity_ddp(precision, strategy, devices, accelerator): # Train with Fabric fabric = Fabric(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) fabric.launch() diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 2218e19e777d8..eb4e8efc22fd0 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -109,7 +109,7 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) -def test_parity_single_device(precision, accelerator, tmpdir): +def test_parity_single_device(precision, accelerator): # Train with Fabric fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) fabric_state_dict, timings_fabric = train_fabric(fabric) From 1771443ec77869c43e9f224244d6d7d396a18f2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Feb 2023 12:12:12 +0000 Subject: [PATCH 42/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/models.py | 6 +++--- tests/tests_fabric/parity/test_parity_ddp.py | 18 +++--------------- .../tests_fabric/parity/test_parity_simple.py | 5 ++--- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 16e42928173b0..e24c697ab2582 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -14,11 +14,11 @@ from abc import ABC, abstractmethod from typing import Callable -import torch.nn as nn -from torch.optim import Optimizer import torch +import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import TensorDataset, DataLoader +from torch.optim import Optimizer +from torch.utils.data import DataLoader, TensorDataset class ParityModel(ABC, nn.Module): diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index a504965b01b72..5aee8086729ca 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -14,29 +14,19 @@ import os import time from copy import deepcopy -from functools import partial -from typing import Callable import pytest import torch import torch.distributed -import torch.multiprocessing as mp import torch.nn.functional -from lightning_utilities.core.apply_func import apply_to_collection from tests_fabric.helpers.runif import RunIf -from torch import nn, Tensor +from tests_fabric.parity.models import ConvNet +from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from unittest import mock from lightning.fabric.fabric import Fabric -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.strategies.ddp import DDPStrategy -from lightning.fabric.utilities.apply_func import move_data_to_device - -from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic -from tests_fabric.parity.models import ConvNet NUM_STEPS_DEFAULT = 100 @@ -60,9 +50,7 @@ def train_torch_ddp( ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * world_size), batch_size=batch_size) - sampler = DistributedSampler( - dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False - ) + sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) dataloader = DataLoader(dataloader.dataset, sampler=sampler) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index eb4e8efc22fd0..da26928041888 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import time from copy import deepcopy from typing import Callable @@ -21,10 +20,10 @@ import torch.distributed import torch.nn.functional from tests_fabric.helpers.runif import RunIf +from tests_fabric.parity.models import ConvNet +from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic from lightning.fabric.fabric import Fabric -from tests_fabric.parity.utils import precision_context, is_state_dict_equal, make_deterministic -from tests_fabric.parity.models import ConvNet NUM_STEPS_DEFAULT = 1000 From bc37136363c8df9b5f5359a960204c3d2bc9fb57 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 28 Feb 2023 13:16:08 +0100 Subject: [PATCH 43/86] delete --- tests/tests_fabric/parity/utils.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 75e232365909c..dc06bd1e8ee0b 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from contextlib import contextmanager -from typing import Generator - import torch @@ -25,19 +22,6 @@ def make_deterministic(): torch.cuda.manual_seed(1) -@contextmanager -def precision_context(precision, accelerator) -> Generator[None, None, None]: - if precision == 32: - yield - return - if accelerator == "gpu": - with torch.cuda.amp.autocast(): - yield - elif accelerator == "cpu": - with torch.cpu.amp.autocast(): - yield - - def is_state_dict_equal(state0, state1): # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) From 3d8ad31d832a03dfded6afa21f220c06e6e948fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Mar 2023 03:59:20 +0000 Subject: [PATCH 44/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index dc06bd1e8ee0b..cd1383881d797 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + import torch From 76b167600b9a8be40798088da278d7db8a198fc1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 05:19:26 +0100 Subject: [PATCH 45/86] update --- tests/tests_fabric/parity/test_parity_ddp.py | 3 ++- tests/tests_fabric/parity/test_parity_simple.py | 16 ++++++++++++---- tests/tests_fabric/parity/utils.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 5aee8086729ca..c70f643af4452 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -139,5 +139,6 @@ def test_parity_ddp(precision, strategy, devices, accelerator): assert is_state_dict_equal(torch_state_dict, fabric_state_dict) # Compare the time per iteration + # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) + assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index da26928041888..932e557b9f4eb 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -24,6 +24,7 @@ from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic from lightning.fabric.fabric import Fabric +from tests_fabric.parity.utils import get_model_input_dtype NUM_STEPS_DEFAULT = 1000 @@ -31,6 +32,7 @@ def train_torch( move_to_device: Callable, precision_context, + input_dtype=torch.float32, num_steps=NUM_STEPS_DEFAULT, batch_size=4, ): @@ -51,8 +53,8 @@ def train_torch( inputs, labels = move_to_device(inputs), move_to_device(labels) optimizer.zero_grad() with precision_context(): - outputs = model(inputs) - loss = loss_fn(outputs, labels) + outputs = model(inputs.to(input_dtype)) + loss = loss_fn(outputs.float(), labels) loss.backward() optimizer.step() @@ -104,21 +106,27 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): (32, "cpu"), pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "cpu"), pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) def test_parity_single_device(precision, accelerator): + input_dtype = get_model_input_dtype(precision) + # Train with Fabric fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) fabric_state_dict, timings_fabric = train_fabric(fabric) # Train with raw PyTorch - torch_state_dict, timings_torch = train_torch(fabric.to_device, precision_context=fabric.autocast) + torch_state_dict, timings_torch = train_torch( + fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype + ) # Compare the final weights assert is_state_dict_equal(torch_state_dict, fabric_state_dict) # Compare the time per iteration + # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch), torch.median(timings_fabric), rtol=1e-4, atol=1e-4) + assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index cd1383881d797..c23de92a8c408 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -23,6 +23,16 @@ def make_deterministic(): torch.cuda.manual_seed(1) +def get_model_input_dtype(precision): + if precision in ("16-mixed", "16", 16): + return torch.float16 + elif precision in ("bf16-mixed", "bf16"): + return torch.bfloat16 + elif precision in ("64-true", "64", 64): + return torch.double + return torch.float32 + + def is_state_dict_equal(state0, state1): # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) From 130880f327a270aae18b7640e61a5adce5c2b292 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Mar 2023 04:20:26 +0000 Subject: [PATCH 46/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/test_parity_simple.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 932e557b9f4eb..4121a8aee4bb8 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -21,10 +21,9 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic +from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic from lightning.fabric.fabric import Fabric -from tests_fabric.parity.utils import get_model_input_dtype NUM_STEPS_DEFAULT = 1000 From 11d5099a1cba5032b5aeb83ff2bdd1b57e4fb9a8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 05:26:56 +0100 Subject: [PATCH 47/86] benchmark --- tests/tests_fabric/parity/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index c23de92a8c408..a79e6e7c90706 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -19,6 +19,7 @@ def make_deterministic(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False torch.manual_seed(1) torch.cuda.manual_seed(1) From 20c66721ace88cfc250a8a04e0fcd0d6c418114d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 05:51:17 +0100 Subject: [PATCH 48/86] update --- tests/tests_fabric/conftest.py | 8 ++++++++ tests/tests_fabric/parity/test_parity_ddp.py | 17 +++++++++-------- tests/tests_fabric/parity/test_parity_simple.py | 1 + 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 42a364ea4f9a3..f7dc91126250b 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -77,6 +77,14 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) +@pytest.fixture +def reset_cudnn_benchmark(): + """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" + benchmark = torch.backends.cudnn.benchmark + yield + torch.backends.cudnn.benchmark = benchmark + + def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index c70f643af4452..78179a0251fd3 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -28,7 +28,7 @@ from lightning.fabric.fabric import Fabric -NUM_STEPS_DEFAULT = 100 +NUM_STEPS_DEFAULT = 1000 def train_torch_ddp( @@ -51,7 +51,7 @@ def train_torch_ddp( dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * world_size), batch_size=batch_size) sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) - dataloader = DataLoader(dataloader.dataset, sampler=sampler) + dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=batch_size) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() @@ -62,7 +62,7 @@ def train_torch_ddp( t0 = time.perf_counter() inputs, labels = next(iterator) - inputs, labels = inputs.to(device), inputs.to(labels) + inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, labels) @@ -115,16 +115,17 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): @RunIf(standalone=True) # @pytest.mark.flaky(reruns=3) +@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( - "precision, strategy, devices, accelerator", + "accelerator, devices", [ - (32, "ddp", 2, "cpu"), - pytest.param(32, "ddp", 2, "gpu", marks=RunIf(min_cuda_gpus=2)), + ("cpu", 2), + # pytest.param("gpu", 2, marks=RunIf(min_cuda_gpus=2)), ], ) -def test_parity_ddp(precision, strategy, devices, accelerator): +def test_parity_ddp(accelerator, devices): # Train with Fabric - fabric = Fabric(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) fabric.launch() fabric_state_dict, timings_fabric = train_fabric_ddp(fabric) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 4121a8aee4bb8..a9c909292d21b 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -99,6 +99,7 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): @pytest.mark.flaky(reruns=3) +@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( "precision, accelerator", [ From c5fa2bcc6d43713c4bf09709280a85965762f1ea Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 06:10:23 +0100 Subject: [PATCH 49/86] tuning --- tests/tests_fabric/parity/test_parity_ddp.py | 3 ++- tests/tests_fabric/parity/test_parity_simple.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 78179a0251fd3..9f62bbd501da7 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -142,4 +142,5 @@ def test_parity_ddp(accelerator, devices): # Compare the time per iteration # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean - assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL + assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-4) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index a9c909292d21b..3fed45a0db149 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -129,4 +129,5 @@ def test_parity_single_device(precision, accelerator): # Compare the time per iteration # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) From 2d85b0d635a1c5dd1a3fc9dfa121bf3b64fc5f02 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 06:24:44 +0100 Subject: [PATCH 50/86] run on gpu --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 9f62bbd501da7..9b062ebd2ecc8 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -120,7 +120,7 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): "accelerator, devices", [ ("cpu", 2), - # pytest.param("gpu", 2, marks=RunIf(min_cuda_gpus=2)), + pytest.param("gpu", 2, marks=RunIf(min_cuda_gpus=2)), ], ) def test_parity_ddp(accelerator, devices): From 72faa64150c5d0c04b977e66a3f8faf6695c0c0b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:16:33 +0100 Subject: [PATCH 51/86] memory --- tests/tests_fabric/parity/test_parity_ddp.py | 6 ++--- .../tests_fabric/parity/test_parity_simple.py | 26 +++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 9b062ebd2ecc8..9002cfa0b0d5e 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -127,17 +127,17 @@ def test_parity_ddp(accelerator, devices): # Train with Fabric fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) fabric.launch() - fabric_state_dict, timings_fabric = train_fabric_ddp(fabric) + state_dict_fabric, timings_fabric = train_fabric_ddp(fabric) # Train with raw PyTorch - torch_state_dict, timings_torch = train_torch_ddp( + state_dict_torch, timings_torch = train_torch_ddp( rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, ) # Compare the final weights - assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + assert is_state_dict_equal(state_dict_torch, state_dict_fabric) # Compare the time per iteration # Drop measurements of the first iterations, as they may be slower than others diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 3fed45a0db149..7e3c72dbe3d08 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -36,12 +36,16 @@ def train_torch( batch_size=4, ): make_deterministic() + memory_stats = {} + model = ConvNet() model = move_to_device(model) dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() + memory_stats["start"] = torch.cuda.memory_stats() + model.train() iteration_timings = [] iterator = iter(dataloader) @@ -60,11 +64,14 @@ def train_torch( t1 = time.perf_counter() iteration_timings.append(t1 - t0) - return model.state_dict(), torch.tensor(iteration_timings) + memory_stats["end"] = torch.cuda.memory_stats() + + return model.state_dict(), torch.tensor(iteration_timings), memory_stats def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): make_deterministic() + memory_stats = {} model = ConvNet() initial_state_dict = deepcopy(model.state_dict()) @@ -76,6 +83,8 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): dataloader = fabric.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() + memory_stats["start"] = torch.cuda.memory_stats() + model.train() iteration_timings = [] iterator = iter(dataloader) @@ -92,10 +101,12 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): t1 = time.perf_counter() iteration_timings.append(t1 - t0) + memory_stats["end"] = torch.cuda.memory_stats() + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - return model.state_dict(), torch.tensor(iteration_timings) + return model.state_dict(), torch.tensor(iteration_timings), memory_stats @pytest.mark.flaky(reruns=3) @@ -116,18 +127,23 @@ def test_parity_single_device(precision, accelerator): # Train with Fabric fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) - fabric_state_dict, timings_fabric = train_fabric(fabric) + state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) # Train with raw PyTorch - torch_state_dict, timings_torch = train_torch( + state_dict_torch, timings_torch, memory_torch = train_torch( fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype ) # Compare the final weights - assert is_state_dict_equal(torch_state_dict, fabric_state_dict) + assert is_state_dict_equal(state_dict_torch, state_dict_fabric) # Compare the time per iteration # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) + + # Compare peak CUDA memory usage + if memory_torch["start"]: + assert memory_torch["start"]["allocated_bytes.all.peak"] == memory_fabric["start"]["allocated_bytes.all.peak"] + assert memory_torch["end"]["allocated_bytes.all.peak"] == memory_fabric["end"]["allocated_bytes.all.peak"] From 0de9ba279ecd94fb12ac657f5625eed0a180dc7a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:25:18 +0100 Subject: [PATCH 52/86] tolerance --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 9002cfa0b0d5e..40bd83cd2ec31 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -143,4 +143,4 @@ def test_parity_ddp(accelerator, devices): # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL - assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-4) + assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-3) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 7e3c72dbe3d08..7f44070feb2bc 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -141,7 +141,7 @@ def test_parity_single_device(precision, accelerator): # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL - assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-4, atol=1e-4) + assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-3) # Compare peak CUDA memory usage if memory_torch["start"]: From 905c5d6ee78bb7faae46e24115cb7a90188e1f41 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:25:53 +0100 Subject: [PATCH 53/86] memory --- tests/tests_fabric/parity/test_parity_simple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 7f44070feb2bc..449c8a60f5b25 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -145,5 +145,5 @@ def test_parity_single_device(precision, accelerator): # Compare peak CUDA memory usage if memory_torch["start"]: - assert memory_torch["start"]["allocated_bytes.all.peak"] == memory_fabric["start"]["allocated_bytes.all.peak"] - assert memory_torch["end"]["allocated_bytes.all.peak"] == memory_fabric["end"]["allocated_bytes.all.peak"] + assert memory_torch["start"]["allocated_bytes.all.peak"] >= memory_fabric["start"]["allocated_bytes.all.peak"] + assert memory_torch["end"]["allocated_bytes.all.peak"] >= memory_fabric["end"]["allocated_bytes.all.peak"] From 719088b25f4bdde80c9c1eb3cef3b7f198974773 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:42:08 +0100 Subject: [PATCH 54/86] refactor --- tests/tests_fabric/parity/test_parity_ddp.py | 29 +++++++++++++------ .../tests_fabric/parity/test_parity_simple.py | 16 ++++------ tests/tests_fabric/parity/utils.py | 12 ++++++++ 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 40bd83cd2ec31..779a825f43021 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic +from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic, is_timing_close, is_memory_close from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -39,6 +39,7 @@ def train_torch_ddp( batch_size=4, ): make_deterministic() + memory_stats = {} os.environ["LOCAL_RANK"] = str(rank) if torch.distributed.is_available() and not torch.distributed.is_initialized(): @@ -55,6 +56,8 @@ def train_torch_ddp( optimizer = model.get_optimizer() loss_fn = model.get_loss_function() + memory_stats["start"] = torch.cuda.memory_stats() + ddp_model.train() iteration_timings = [] iterator = iter(dataloader) @@ -72,14 +75,17 @@ def train_torch_ddp( t1 = time.perf_counter() iteration_timings.append(t1 - t0) + memory_stats["end"] = torch.cuda.memory_stats() + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) - return ddp_model.module.state_dict(), torch.tensor(iteration_timings) + return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): make_deterministic() + memory_stats = {} model = ConvNet() initial_state_dict = deepcopy(model.state_dict()) @@ -91,6 +97,8 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): dataloader = fabric.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() + memory_stats["start"] = torch.cuda.memory_stats() + model.train() iteration_timings = [] iterator = iter(dataloader) @@ -107,10 +115,12 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): t1 = time.perf_counter() iteration_timings.append(t1 - t0) + memory_stats["end"] = torch.cuda.memory_stats() + # check that the model has changed assert not is_state_dict_equal(initial_state_dict, model.state_dict()) - return model.state_dict(), torch.tensor(iteration_timings) + return model.state_dict(), torch.tensor(iteration_timings), memory_stats @RunIf(standalone=True) @@ -127,10 +137,10 @@ def test_parity_ddp(accelerator, devices): # Train with Fabric fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) fabric.launch() - state_dict_fabric, timings_fabric = train_fabric_ddp(fabric) + state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) # Train with raw PyTorch - state_dict_torch, timings_torch = train_torch_ddp( + state_dict_torch, timings_torch, memory_torch = train_torch_ddp( rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, @@ -140,7 +150,8 @@ def test_parity_ddp(accelerator, devices): assert is_state_dict_equal(state_dict_torch, state_dict_fabric) # Compare the time per iteration - # Drop measurements of the first iterations, as they may be slower than others - # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL - assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-3) + assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) + + # Compare memory usage + assert is_memory_close(memory_torch["start"], memory_fabric["start"]) + assert is_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 449c8a60f5b25..bab1a3f31c22b 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic +from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic, is_timing_close, is_memory_close from lightning.fabric.fabric import Fabric @@ -138,12 +138,8 @@ def test_parity_single_device(precision, accelerator): assert is_state_dict_equal(state_dict_torch, state_dict_fabric) # Compare the time per iteration - # Drop measurements of the first iterations, as they may be slower than others - # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL - assert torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=1e-3, atol=1e-3) - - # Compare peak CUDA memory usage - if memory_torch["start"]: - assert memory_torch["start"]["allocated_bytes.all.peak"] >= memory_fabric["start"]["allocated_bytes.all.peak"] - assert memory_torch["end"]["allocated_bytes.all.peak"] >= memory_fabric["end"]["allocated_bytes.all.peak"] + assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) + + # Compare memory usage + assert is_memory_close(memory_torch["start"], memory_fabric["start"]) + assert is_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index a79e6e7c90706..f9191a8a99d72 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -37,3 +37,15 @@ def get_model_input_dtype(precision): def is_state_dict_equal(state0, state1): # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + + +def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): + # Drop measurements of the first iterations, as they may be slower than others + # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL + return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) + + +def is_memory_close(memory_stats_torch, memory_stats_fabric): + # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch + return memory_stats_torch.get("allocated_bytes.all.peak", 0) >= memory_stats_fabric.get("allocated_bytes.all.peak", 0) From 0bbe2cdb43f7b75aaa66de901b76ca778b89c56f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:50:06 +0100 Subject: [PATCH 55/86] refactor --- tests/tests_fabric/parity/models.py | 13 +++++++++++-- tests/tests_fabric/parity/test_parity_ddp.py | 16 ++++++---------- tests/tests_fabric/parity/test_parity_simple.py | 14 +++++--------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index e24c697ab2582..16616b255acc2 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -24,6 +24,10 @@ class ParityModel(ABC, nn.Module): """Defines the interface for a model in a Fabric-PyTorch parity test.""" + # Benchmarking parameters that should be model-specific + batch_size = 1 + num_steps = 1 + @abstractmethod def get_optimizer(self, *args, **kwargs) -> Optimizer: pass @@ -38,6 +42,9 @@ def get_loss_function(self) -> Callable: class ConvNet(ParityModel): + batch_size = 4 + num_steps = 1000 + def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) @@ -59,13 +66,15 @@ def forward(self, x): def get_optimizer(self): return torch.optim.SGD(self.parameters(), lr=0.0001) - def get_dataloader(self, dataset_size=100, batch_size=4): + def get_dataloader(self): + # multiply * 8 just in case world size is larger than 1 + dataset_size = self.num_steps * self.batch_size * 8 inputs = torch.rand(dataset_size, 3, 32, 32) labels = torch.randint(0, 10, (dataset_size,)) dataset = TensorDataset(inputs, labels) dataloader = DataLoader( dataset, - batch_size=batch_size, + batch_size=self.batch_size, num_workers=2, ) return dataloader diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 779a825f43021..d2c7fcedcf64e 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -28,15 +28,11 @@ from lightning.fabric.fabric import Fabric -NUM_STEPS_DEFAULT = 1000 - def train_torch_ddp( rank, world_size, device=torch.device("cpu"), - num_steps=NUM_STEPS_DEFAULT, - batch_size=4, ): make_deterministic() memory_stats = {} @@ -50,9 +46,9 @@ def train_torch_ddp( ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * world_size), batch_size=batch_size) + dataloader = model.get_dataloader() sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) - dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=batch_size) + dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size) optimizer = model.get_optimizer() loss_fn = model.get_loss_function() @@ -61,7 +57,7 @@ def train_torch_ddp( ddp_model.train() iteration_timings = [] iterator = iter(dataloader) - for _ in range(num_steps): + for _ in range(model.num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) @@ -83,7 +79,7 @@ def train_torch_ddp( return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats -def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): +def train_fabric_ddp(fabric): make_deterministic() memory_stats = {} @@ -93,7 +89,7 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): optimizer = model.get_optimizer() model, optimizer = fabric.setup(model, optimizer) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size * fabric.world_size), batch_size=batch_size) + dataloader = model.get_dataloader() dataloader = fabric.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() @@ -102,7 +98,7 @@ def train_fabric_ddp(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): model.train() iteration_timings = [] iterator = iter(dataloader) - for _ in range(num_steps): + for _ in range(model.num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index bab1a3f31c22b..318af46dd2f8b 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -25,22 +25,18 @@ from lightning.fabric.fabric import Fabric -NUM_STEPS_DEFAULT = 1000 - def train_torch( move_to_device: Callable, precision_context, input_dtype=torch.float32, - num_steps=NUM_STEPS_DEFAULT, - batch_size=4, ): make_deterministic() memory_stats = {} model = ConvNet() model = move_to_device(model) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = model.get_dataloader() optimizer = model.get_optimizer() loss_fn = model.get_loss_function() @@ -49,7 +45,7 @@ def train_torch( model.train() iteration_timings = [] iterator = iter(dataloader) - for _ in range(num_steps): + for _ in range(model.num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) @@ -69,7 +65,7 @@ def train_torch( return model.state_dict(), torch.tensor(iteration_timings), memory_stats -def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): +def train_fabric(fabric): make_deterministic() memory_stats = {} @@ -79,7 +75,7 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): optimizer = model.get_optimizer() model, optimizer = fabric.setup(model, optimizer) - dataloader = model.get_dataloader(dataset_size=(num_steps * batch_size), batch_size=batch_size) + dataloader = model.get_dataloader() dataloader = fabric.setup_dataloaders(dataloader) loss_fn = model.get_loss_function() @@ -88,7 +84,7 @@ def train_fabric(fabric, num_steps=NUM_STEPS_DEFAULT, batch_size=4): model.train() iteration_timings = [] iterator = iter(dataloader) - for _ in range(num_steps): + for _ in range(model.num_steps): t0 = time.perf_counter() inputs, labels = next(iterator) From 6f41053b1458a40f0333203d34306aafd134516e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:53:32 +0100 Subject: [PATCH 56/86] safer check --- tests/tests_fabric/parity/test_parity_ddp.py | 7 ++++--- tests/tests_fabric/parity/test_parity_simple.py | 7 ++++--- tests/tests_fabric/parity/utils.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index d2c7fcedcf64e..9d5f943093639 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic, is_timing_close, is_memory_close +from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic, is_timing_close, is_cuda_memory_close from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -149,5 +149,6 @@ def test_parity_ddp(accelerator, devices): assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) # Compare memory usage - assert is_memory_close(memory_torch["start"], memory_fabric["start"]) - assert is_memory_close(memory_torch["end"], memory_fabric["end"]) + if accelerator == "gpu": + assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) + assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 318af46dd2f8b..53a5cb5902cce 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic, is_timing_close, is_memory_close +from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic, is_timing_close, is_cuda_memory_close from lightning.fabric.fabric import Fabric @@ -137,5 +137,6 @@ def test_parity_single_device(precision, accelerator): assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) # Compare memory usage - assert is_memory_close(memory_torch["start"], memory_fabric["start"]) - assert is_memory_close(memory_torch["end"], memory_fabric["end"]) + if accelerator == "gpu": + assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) + assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index f9191a8a99d72..bbf70647d6f09 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -46,6 +46,6 @@ def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) -def is_memory_close(memory_stats_torch, memory_stats_fabric): +def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch - return memory_stats_torch.get("allocated_bytes.all.peak", 0) >= memory_stats_fabric.get("allocated_bytes.all.peak", 0) + return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] From 1f6e98764d1ecfabfaa574a38a01335652a9f9a3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 15:54:55 +0100 Subject: [PATCH 57/86] reset peak --- tests/tests_fabric/parity/test_parity_ddp.py | 2 ++ tests/tests_fabric/parity/test_parity_simple.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 9d5f943093639..64d572b7d8d97 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -135,6 +135,8 @@ def test_parity_ddp(accelerator, devices): fabric.launch() state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) + torch.cuda.reset_peak_memory_stats() + # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch_ddp( rank=fabric.global_rank, diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 53a5cb5902cce..ee8439d70ca19 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -125,6 +125,8 @@ def test_parity_single_device(precision, accelerator): fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) + torch.cuda.reset_peak_memory_stats() + # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch( fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype From 33d7c01eab1b01b000405b2b3d5916d99ed41b71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Mar 2023 14:56:09 +0000 Subject: [PATCH 58/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- tests/tests_fabric/parity/test_parity_simple.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 64d572b7d8d97..3c2ce8437783f 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_state_dict_equal, make_deterministic, is_timing_close, is_cuda_memory_close +from tests_fabric.parity.utils import is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index ee8439d70ca19..c5d9a4f20f195 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -21,7 +21,13 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import get_model_input_dtype, is_state_dict_equal, make_deterministic, is_timing_close, is_cuda_memory_close +from tests_fabric.parity.utils import ( + get_model_input_dtype, + is_cuda_memory_close, + is_state_dict_equal, + is_timing_close, + make_deterministic, +) from lightning.fabric.fabric import Fabric From 5462c24d606adfaae43dab79eeb93b7067635a56 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 16:12:35 +0100 Subject: [PATCH 59/86] empty cache --- tests/tests_fabric/parity/test_parity_ddp.py | 1 + tests/tests_fabric/parity/test_parity_simple.py | 1 + tests/tests_fabric/parity/utils.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 64d572b7d8d97..ba778c937119c 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -135,6 +135,7 @@ def test_parity_ddp(accelerator, devices): fabric.launch() state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Train with raw PyTorch diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index ee8439d70ca19..131b63a67fca5 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -125,6 +125,7 @@ def test_parity_single_device(precision, accelerator): fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Train with raw PyTorch diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index bbf70647d6f09..3672ae1adf6b6 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -42,7 +42,7 @@ def is_state_dict_equal(state0, state1): def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * |torch| + ATOL + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) From cbf24c102dc8000f141316a9b117d29ecd678bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Mar 2023 10:16:51 -0500 Subject: [PATCH 60/86] Update tests/tests_fabric/parity/test_parity_simple.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/tests_fabric/parity/test_parity_simple.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index e92a311eb7151..8b637f5f30b26 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -132,7 +132,8 @@ def test_parity_single_device(precision, accelerator): state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + if accelerator == "gpu": + torch.cuda.reset_peak_memory_stats() # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch( From 42facb7666916945afc45722b282a0c58457fdf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Mar 2023 10:16:58 -0500 Subject: [PATCH 61/86] Update tests/tests_fabric/parity/test_parity_simple.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/tests_fabric/parity/test_parity_simple.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 8b637f5f30b26..8be73ad419aa5 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -117,10 +117,10 @@ def train_fabric(fabric): "precision, accelerator", [ (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - # pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler pytest.param("bf16", "cpu"), - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), pytest.param(32, "mps", marks=RunIf(mps=True)), ], ) From d6e52275119c032494093a6566e046010510b6d4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 16:19:19 +0100 Subject: [PATCH 62/86] cuda --- tests/tests_fabric/parity/test_parity_ddp.py | 10 +++++----- tests/tests_fabric/parity/test_parity_simple.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index fae8b59d0473b..830fd278c9961 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -120,13 +120,12 @@ def train_fabric_ddp(fabric): @RunIf(standalone=True) -# @pytest.mark.flaky(reruns=3) @pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( "accelerator, devices", [ ("cpu", 2), - pytest.param("gpu", 2, marks=RunIf(min_cuda_gpus=2)), + pytest.param("cuda", 2, marks=RunIf(min_cuda_gpus=2)), ], ) def test_parity_ddp(accelerator, devices): @@ -135,8 +134,9 @@ def test_parity_ddp(accelerator, devices): fabric.launch() state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + if accelerator == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch_ddp( @@ -152,6 +152,6 @@ def test_parity_ddp(accelerator, devices): assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) # Compare memory usage - if accelerator == "gpu": + if accelerator == "cuda": assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 8be73ad419aa5..92d4510efee76 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -131,8 +131,8 @@ def test_parity_single_device(precision, accelerator): fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) - torch.cuda.empty_cache() - if accelerator == "gpu": + if accelerator == "cuda": + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Train with raw PyTorch @@ -147,6 +147,6 @@ def test_parity_single_device(precision, accelerator): assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) # Compare memory usage - if accelerator == "gpu": + if accelerator == "cuda": assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) From 80d5919130611e0abec786c6375351589f1437f5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 16:57:54 +0100 Subject: [PATCH 63/86] Experiment with tracking mode by @carmocca MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí commit for future reference --- tests/tests_fabric/parity/models.py | 29 +++++ .../tests_fabric/parity/test_parity_simple.py | 1 - .../parity/test_parity_torch_calls.py | 112 ++++++++++++++++++ tests/tests_fabric/parity/utils.py | 47 +++++--- 4 files changed, 171 insertions(+), 18 deletions(-) create mode 100644 tests/tests_fabric/parity/test_parity_torch_calls.py diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 16616b255acc2..46c9aeaf555a9 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -41,6 +41,35 @@ def get_loss_function(self) -> Callable: pass +class TinyModel(ParityModel): + batch_size = 2 + num_steps = 3 + + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 2) + + def forward(self, x): + return self.layer(x) + + def get_optimizer(self): + return torch.optim.SGD(self.parameters(), lr=0.0001) + + def get_dataloader(self): + inputs = torch.rand(32, 10) + labels = torch.rand(32, 2) + dataset = TensorDataset(inputs, labels) + dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=2, + ) + return dataloader + + def get_loss_function(self): + return F.mse_loss + + class ConvNet(ParityModel): batch_size = 4 num_steps = 1000 diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 92d4510efee76..022dc606afd86 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -67,7 +67,6 @@ def train_torch( iteration_timings.append(t1 - t0) memory_stats["end"] = torch.cuda.memory_stats() - return model.state_dict(), torch.tensor(iteration_timings), memory_stats diff --git a/tests/tests_fabric/parity/test_parity_torch_calls.py b/tests/tests_fabric/parity/test_parity_torch_calls.py new file mode 100644 index 0000000000000..9a403c3320607 --- /dev/null +++ b/tests/tests_fabric/parity/test_parity_torch_calls.py @@ -0,0 +1,112 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from copy import deepcopy +from typing import Callable + +import pytest +import torch +import torch.distributed +import torch.nn.functional +from tests_fabric.helpers.runif import RunIf +from tests_fabric.parity.models import TinyModel +from tests_fabric.parity.utils import ( + get_model_input_dtype, + make_deterministic, + TrackingMode, +) + +from lightning.fabric.fabric import Fabric + + +def train_torch( + move_to_device: Callable, + precision_context, + input_dtype=torch.float32, +): + make_deterministic() + + model = TinyModel() + model = move_to_device(model) + dataloader = model.get_dataloader() + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + model.train() + with TrackingMode() as tracked_calls: + iterator = iter(dataloader) + for _ in range(model.num_steps): + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs.to(input_dtype)) + loss = loss_fn(outputs.float(), labels) + loss.backward() + optimizer.step() + + return tracked_calls.calls + + +def train_fabric(fabric): + make_deterministic() + + model = TinyModel() + optimizer = model.get_optimizer() + model, optimizer = fabric.setup(model, optimizer) + + dataloader = model.get_dataloader() + dataloader = fabric.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() + + model.train() + with TrackingMode() as tracked_calls: + iterator = iter(dataloader) + for _ in range(model.num_steps): + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() + + return tracked_calls.calls + + +@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + # pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "cpu"), + # pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +def test_parity_torch_calls(precision, accelerator): + input_dtype = get_model_input_dtype(precision) + + # Train with Fabric + fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) + calls_fabric = train_fabric(fabric) + + # Train with raw PyTorch + calls_torch = train_torch( + fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype + ) + + # Compare the calls made to ATen + assert calls_torch == calls_fabric diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 3672ae1adf6b6..0db9ebc2e7ca8 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -14,6 +14,36 @@ import os import torch +from torch.utils._python_dispatch import TorchDispatchMode + + +def is_state_dict_equal(state0, state1): + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + + +def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): + # Drop measurements of the first iterations, as they may be slower than others + # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL + return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) + + +def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): + # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch + return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] + + +class TrackingMode(TorchDispatchMode): + """Tracks the calls made on all tensor operations.""" + + def __init__(self): + super().__init__() + self.calls = [] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.calls.append(f"{func.__module__}.{func.__name__}") + return func(*args, **kwargs) def make_deterministic(): @@ -32,20 +62,3 @@ def get_model_input_dtype(precision): elif precision in ("64-true", "64", 64): return torch.double return torch.float32 - - -def is_state_dict_equal(state0, state1): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) - - -def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): - # Drop measurements of the first iterations, as they may be slower than others - # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL - return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) - - -def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): - # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch - return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] From a703991766892b88d5e2d3cb3dc7ac3d3d3d5d94 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 16:57:54 +0100 Subject: [PATCH 64/86] Revert "Experiment with tracking mode by @carmocca" This reverts commit 80d5919130611e0abec786c6375351589f1437f5. --- tests/tests_fabric/parity/models.py | 29 ----- .../tests_fabric/parity/test_parity_simple.py | 1 + .../parity/test_parity_torch_calls.py | 112 ------------------ tests/tests_fabric/parity/utils.py | 47 +++----- 4 files changed, 18 insertions(+), 171 deletions(-) delete mode 100644 tests/tests_fabric/parity/test_parity_torch_calls.py diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py index 46c9aeaf555a9..16616b255acc2 100644 --- a/tests/tests_fabric/parity/models.py +++ b/tests/tests_fabric/parity/models.py @@ -41,35 +41,6 @@ def get_loss_function(self) -> Callable: pass -class TinyModel(ParityModel): - batch_size = 2 - num_steps = 3 - - def __init__(self): - super().__init__() - self.layer = nn.Linear(10, 2) - - def forward(self, x): - return self.layer(x) - - def get_optimizer(self): - return torch.optim.SGD(self.parameters(), lr=0.0001) - - def get_dataloader(self): - inputs = torch.rand(32, 10) - labels = torch.rand(32, 2) - dataset = TensorDataset(inputs, labels) - dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - num_workers=2, - ) - return dataloader - - def get_loss_function(self): - return F.mse_loss - - class ConvNet(ParityModel): batch_size = 4 num_steps = 1000 diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 022dc606afd86..92d4510efee76 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -67,6 +67,7 @@ def train_torch( iteration_timings.append(t1 - t0) memory_stats["end"] = torch.cuda.memory_stats() + return model.state_dict(), torch.tensor(iteration_timings), memory_stats diff --git a/tests/tests_fabric/parity/test_parity_torch_calls.py b/tests/tests_fabric/parity/test_parity_torch_calls.py deleted file mode 100644 index 9a403c3320607..0000000000000 --- a/tests/tests_fabric/parity/test_parity_torch_calls.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -from copy import deepcopy -from typing import Callable - -import pytest -import torch -import torch.distributed -import torch.nn.functional -from tests_fabric.helpers.runif import RunIf -from tests_fabric.parity.models import TinyModel -from tests_fabric.parity.utils import ( - get_model_input_dtype, - make_deterministic, - TrackingMode, -) - -from lightning.fabric.fabric import Fabric - - -def train_torch( - move_to_device: Callable, - precision_context, - input_dtype=torch.float32, -): - make_deterministic() - - model = TinyModel() - model = move_to_device(model) - dataloader = model.get_dataloader() - optimizer = model.get_optimizer() - loss_fn = model.get_loss_function() - - model.train() - with TrackingMode() as tracked_calls: - iterator = iter(dataloader) - for _ in range(model.num_steps): - inputs, labels = next(iterator) - inputs, labels = move_to_device(inputs), move_to_device(labels) - optimizer.zero_grad() - with precision_context(): - outputs = model(inputs.to(input_dtype)) - loss = loss_fn(outputs.float(), labels) - loss.backward() - optimizer.step() - - return tracked_calls.calls - - -def train_fabric(fabric): - make_deterministic() - - model = TinyModel() - optimizer = model.get_optimizer() - model, optimizer = fabric.setup(model, optimizer) - - dataloader = model.get_dataloader() - dataloader = fabric.setup_dataloaders(dataloader) - loss_fn = model.get_loss_function() - - model.train() - with TrackingMode() as tracked_calls: - iterator = iter(dataloader) - for _ in range(model.num_steps): - inputs, labels = next(iterator) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_fn(outputs, labels) - fabric.backward(loss) - optimizer.step() - - return tracked_calls.calls - - -@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - # pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)), - # pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler - pytest.param("bf16", "cpu"), - # pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -def test_parity_torch_calls(precision, accelerator): - input_dtype = get_model_input_dtype(precision) - - # Train with Fabric - fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) - calls_fabric = train_fabric(fabric) - - # Train with raw PyTorch - calls_torch = train_torch( - fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype - ) - - # Compare the calls made to ATen - assert calls_torch == calls_fabric diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 0db9ebc2e7ca8..3672ae1adf6b6 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -14,36 +14,6 @@ import os import torch -from torch.utils._python_dispatch import TorchDispatchMode - - -def is_state_dict_equal(state0, state1): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) - - -def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): - # Drop measurements of the first iterations, as they may be slower than others - # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL - return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) - - -def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): - # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch - return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] - - -class TrackingMode(TorchDispatchMode): - """Tracks the calls made on all tensor operations.""" - - def __init__(self): - super().__init__() - self.calls = [] - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - self.calls.append(f"{func.__module__}.{func.__name__}") - return func(*args, **kwargs) def make_deterministic(): @@ -62,3 +32,20 @@ def get_model_input_dtype(precision): elif precision in ("64-true", "64", 64): return torch.double return torch.float32 + + +def is_state_dict_equal(state0, state1): + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + + +def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): + # Drop measurements of the first iterations, as they may be slower than others + # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL + return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) + + +def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): + # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch + return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] From af7a7e4ad01635ab2ce486101fd36e114cfb6c44 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 16:58:46 +0100 Subject: [PATCH 65/86] move assertions top --- tests/tests_fabric/parity/utils.py | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 3672ae1adf6b6..3c310c4142b83 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -16,6 +16,23 @@ import torch +def is_state_dict_equal(state0, state1): + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + + +def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): + # Drop measurements of the first iterations, as they may be slower than others + # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL + return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) + + +def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): + # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch + return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] + + def make_deterministic(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) @@ -32,20 +49,3 @@ def get_model_input_dtype(precision): elif precision in ("64-true", "64", 64): return torch.double return torch.float32 - - -def is_state_dict_equal(state0, state1): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) - - -def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): - # Drop measurements of the first iterations, as they may be slower than others - # The median is more robust to outliers than the mean - # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL - return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) - - -def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): - # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch - return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] From 3162108fadca0b444f00daa075f31edc2682ce1b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 22:46:45 +0100 Subject: [PATCH 66/86] reset cuda memory stats before test --- tests/tests_fabric/parity/test_parity_ddp.py | 8 ++++---- tests/tests_fabric/parity/test_parity_simple.py | 7 ++++--- tests/tests_fabric/parity/utils.py | 6 ++++++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 830fd278c9961..4321aaf443e28 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -21,7 +21,7 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic +from tests_fabric.parity.utils import is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic, cuda_reset from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -129,14 +129,14 @@ def train_fabric_ddp(fabric): ], ) def test_parity_ddp(accelerator, devices): + cuda_reset() + # Train with Fabric fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) fabric.launch() state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) - if accelerator == "cuda": - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + cuda_reset() # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch_ddp( diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 92d4510efee76..5ed3a19ee2b86 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -27,6 +27,7 @@ is_state_dict_equal, is_timing_close, make_deterministic, + cuda_reset, ) from lightning.fabric.fabric import Fabric @@ -127,13 +128,13 @@ def train_fabric(fabric): def test_parity_single_device(precision, accelerator): input_dtype = get_model_input_dtype(precision) + cuda_reset() + # Train with Fabric fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) - if accelerator == "cuda": - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + cuda_reset() # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch( diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 3c310c4142b83..83975e9c87743 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -49,3 +49,9 @@ def get_model_input_dtype(precision): elif precision in ("64-true", "64", 64): return torch.double return torch.float32 + + +def cuda_reset(): + torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() From 24c9f1fb0e209029adfe26bfe3f74ee336150f23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Mar 2023 21:47:45 +0000 Subject: [PATCH 67/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/test_parity_ddp.py | 8 +++++++- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 4321aaf443e28..4d106ae637327 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -21,7 +21,13 @@ import torch.nn.functional from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet -from tests_fabric.parity.utils import is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic, cuda_reset +from tests_fabric.parity.utils import ( + cuda_reset, + is_cuda_memory_close, + is_state_dict_equal, + is_timing_close, + make_deterministic, +) from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index 5ed3a19ee2b86..c4593d4aaec9b 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -22,12 +22,12 @@ from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet from tests_fabric.parity.utils import ( + cuda_reset, get_model_input_dtype, is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic, - cuda_reset, ) from lightning.fabric.fabric import Fabric From b59f51daf63a23b4eca7a3d5a22403859a63e792 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 23:09:12 +0100 Subject: [PATCH 68/86] assertions across all devices --- tests/tests_fabric/parity/test_parity_ddp.py | 10 ++++++---- tests/tests_fabric/parity/utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 4321aaf443e28..cba9e8cf1c1be 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -119,6 +119,7 @@ def train_fabric_ddp(fabric): return model.state_dict(), torch.tensor(iteration_timings), memory_stats +@pytest.mark.flaky(reruns=3) @RunIf(standalone=True) @pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( @@ -146,12 +147,13 @@ def test_parity_ddp(accelerator, devices): ) # Compare the final weights - assert is_state_dict_equal(state_dict_torch, state_dict_fabric) + assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric))) # Compare the time per iteration - assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) + assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3))) # Compare memory usage if accelerator == "cuda": - assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) - assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) + assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]))) + assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]))) + diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 83975e9c87743..2dc48ded812ce 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -25,7 +25,7 @@ def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): # Drop measurements of the first iterations, as they may be slower than others # The median is more robust to outliers than the mean # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL - return torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol) + return bool(torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol)) def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): From d42ab340043e73498884cb614b319589959d907a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Mar 2023 22:11:51 +0000 Subject: [PATCH 69/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/test_parity_ddp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 4cf7bfb53abfa..d1ff83b3e2db5 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -162,4 +162,3 @@ def test_parity_ddp(accelerator, devices): if accelerator == "cuda": assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]))) assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]))) - From 726d46288ad0236ffd66df075b300b35fb0f12e4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 1 Mar 2023 23:41:28 +0100 Subject: [PATCH 70/86] slow cpu --- tests/tests_fabric/parity/test_parity_ddp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index d1ff83b3e2db5..49952133af583 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -126,16 +126,16 @@ def train_fabric_ddp(fabric): @pytest.mark.flaky(reruns=3) -@RunIf(standalone=True) +# @RunIf(standalone=True) @pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( - "accelerator, devices", + "accelerator, devices, tolerance", [ - ("cpu", 2), - pytest.param("cuda", 2, marks=RunIf(min_cuda_gpus=2)), + ("cpu", 2, 0.005), + pytest.param("cuda", 2, 0.001, marks=RunIf(min_cuda_gpus=2)), ], ) -def test_parity_ddp(accelerator, devices): +def test_parity_ddp(accelerator, devices, tolerance): cuda_reset() # Train with Fabric @@ -156,7 +156,7 @@ def test_parity_ddp(accelerator, devices): assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric))) # Compare the time per iteration - assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3))) + assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance))) # Compare memory usage if accelerator == "cuda": From b7a82b5036800138da62ab3b7e7823a313279be3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 2 Mar 2023 00:16:33 +0100 Subject: [PATCH 71/86] add requirement standalone --- requirements/fabric/test.txt | 1 + tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index 26395bbdb958f..09bf112d03fcb 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -2,6 +2,7 @@ coverage==6.5.0 codecov==2.1.12 pytest==7.2.0 pytest-cov==4.0.0 +pytest-rerunfailures==10.3 pre-commit==2.20.0 click==8.1.3 tensorboardX>=2.2, <=2.5.1 # min version is set by torch.onnx missing attribute diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 49952133af583..4d52e9ea1c1c3 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -126,7 +126,7 @@ def train_fabric_ddp(fabric): @pytest.mark.flaky(reruns=3) -# @RunIf(standalone=True) +@RunIf(standalone=True) @pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") @pytest.mark.parametrize( "accelerator, devices, tolerance", From 1697904855db5aff60f66f472e4819d57a901d45 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 2 Mar 2023 00:17:53 +0100 Subject: [PATCH 72/86] tolerance --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 4d52e9ea1c1c3..10bad8f77a7be 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -132,7 +132,7 @@ def train_fabric_ddp(fabric): "accelerator, devices, tolerance", [ ("cpu", 2, 0.005), - pytest.param("cuda", 2, 0.001, marks=RunIf(min_cuda_gpus=2)), + pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)), ], ) def test_parity_ddp(accelerator, devices, tolerance): From 0c75321233c6a9ca19dae356851569c74a01ce3f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 2 Mar 2023 00:30:31 +0100 Subject: [PATCH 73/86] bf16 skip windows --- tests/tests_fabric/parity/test_parity_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index c4593d4aaec9b..d89eecefe8fb1 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -120,7 +120,7 @@ def train_fabric(fabric): (32, "cpu"), pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)), # pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler - pytest.param("bf16", "cpu"), + pytest.param("bf16", "cpu", marks=RunIf(skip_windows=True)), pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), pytest.param(32, "mps", marks=RunIf(mps=True)), ], From e5c836ae94f5f31652a5648568e8845b93d17a86 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 3 Mar 2023 10:57:19 +0100 Subject: [PATCH 74/86] parity on cpu --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 10bad8f77a7be..acc9371bd5055 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -131,7 +131,7 @@ def train_fabric_ddp(fabric): @pytest.mark.parametrize( "accelerator, devices, tolerance", [ - ("cpu", 2, 0.005), + ("cpu", 2, 0.01), pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)), ], ) From a52061f8ea7c07c7a1b2f627d4e6db03a123b54f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Mar 2023 17:17:55 +0000 Subject: [PATCH 75/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/test_parity_ddp.py | 10 +++++----- tests/tests_fabric/parity/test_parity_simple.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index acc9371bd5055..fc89a2159c70d 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -19,6 +19,11 @@ import torch import torch.distributed import torch.nn.functional +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from lightning.fabric.fabric import Fabric from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet from tests_fabric.parity.utils import ( @@ -28,11 +33,6 @@ is_timing_close, make_deterministic, ) -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from lightning.fabric.fabric import Fabric def train_torch_ddp( diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py index d89eecefe8fb1..1e2d0ac6d52dd 100644 --- a/tests/tests_fabric/parity/test_parity_simple.py +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -19,6 +19,8 @@ import torch import torch.distributed import torch.nn.functional + +from lightning.fabric.fabric import Fabric from tests_fabric.helpers.runif import RunIf from tests_fabric.parity.models import ConvNet from tests_fabric.parity.utils import ( @@ -30,8 +32,6 @@ make_deterministic, ) -from lightning.fabric.fabric import Fabric - def train_torch( move_to_device: Callable, From 5970e8c874871d779e5a65bc56d625b6da104f0b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 5 Mar 2023 00:46:29 +0100 Subject: [PATCH 76/86] update --- tests/tests_fabric/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index f7dc91126250b..9b9c98de4dfeb 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -55,7 +55,6 @@ def restore_env_variables(): "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy "CUDA_MODULE_LOADING", # leaked since PyTorch 1.13 "CRC32C_SW_MODE", # set by tensorboardX - "CUBLAS_WORKSPACE_CONFIG", # handled by the `reset_deterministic_algorithm` fixture below } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" From c7f865b73dea6f24b48dea01251940df68620885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Mar 2023 08:59:51 -0500 Subject: [PATCH 77/86] Update tests/tests_fabric/parity/test_parity_ddp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index fc89a2159c70d..1b89b7f124057 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -44,7 +44,7 @@ def train_torch_ddp( memory_stats = {} os.environ["LOCAL_RANK"] = str(rank) - if torch.distributed.is_available() and not torch.distributed.is_initialized(): + if not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) model = ConvNet().to(device) From 1b82ed5197c9bd3fdcacb0a6ed2c87a934fc23c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Mar 2023 09:00:35 -0500 Subject: [PATCH 78/86] Update tests/tests_fabric/conftest.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/tests_fabric/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 9b9c98de4dfeb..bafc5a8e84fc4 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -79,9 +79,8 @@ def reset_deterministic_algorithm(): @pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" - benchmark = torch.backends.cudnn.benchmark yield - torch.backends.cudnn.benchmark = benchmark + torch.backends.cudnn.benchmark = False def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: From 6221b28f066cd490245c3ec2d64f3439b92a6812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Mar 2023 09:04:11 -0500 Subject: [PATCH 79/86] Update tests/tests_fabric/parity/test_parity_ddp.py --- tests/tests_fabric/parity/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 1b89b7f124057..6427491a5fbd9 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -50,7 +50,7 @@ def train_torch_ddp( model = ConvNet().to(device) initial_state_dict = deepcopy(model.state_dict()) - ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None)) + ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None)) dataloader = model.get_dataloader() sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) From f38c95d1df1ce73532bb6ec261ce7d284ed91ba6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:13:03 +0100 Subject: [PATCH 80/86] parametrize backend --- tests/tests_fabric/parity/test_parity_ddp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 6427491a5fbd9..49326e6310316 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -39,13 +39,13 @@ def train_torch_ddp( rank, world_size, device=torch.device("cpu"), + backend="nccl", ): make_deterministic() memory_stats = {} os.environ["LOCAL_RANK"] = str(rank) - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) model = ConvNet().to(device) initial_state_dict = deepcopy(model.state_dict()) @@ -144,12 +144,14 @@ def test_parity_ddp(accelerator, devices, tolerance): state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) cuda_reset() + torch.distributed.destroy_process_group() # Train with raw PyTorch state_dict_torch, timings_torch, memory_torch = train_torch_ddp( rank=fabric.global_rank, world_size=fabric.world_size, device=fabric.device, + backend=fabric.strategy._process_group_backend, ) # Compare the final weights From c6a45c87407ece91a6d8762154cd2f5e1ee34d30 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:16:02 +0100 Subject: [PATCH 81/86] use equality --- tests/tests_fabric/parity/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 2dc48ded812ce..07b52b14c7360 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -14,11 +14,12 @@ import os import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 def is_state_dict_equal(state0, state1): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - return all(torch.allclose(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + eq_fn = torch.equal if _TORCH_GREATER_EQUAL_1_12 else torch.allclose + return all(eq_fn(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): From 62a28ded9c054741e050135723a6f6fca12cd7ef Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:19:00 +0100 Subject: [PATCH 82/86] add barrier --- tests/tests_fabric/parity/test_parity_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index 49326e6310316..dbc1b0153a234 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -143,6 +143,7 @@ def test_parity_ddp(accelerator, devices, tolerance): fabric.launch() state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) + fabric.barrier() cuda_reset() torch.distributed.destroy_process_group() From f2ce10990991fa540fbbd5d1058591cfb023faa6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:21:42 +0100 Subject: [PATCH 83/86] comment about reusing processes --- tests/tests_fabric/parity/test_parity_ddp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py index dbc1b0153a234..73933742ca069 100644 --- a/tests/tests_fabric/parity/test_parity_ddp.py +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -138,9 +138,11 @@ def train_fabric_ddp(fabric): def test_parity_ddp(accelerator, devices, tolerance): cuda_reset() - # Train with Fabric + # Launch processes with Fabric and re-use them for the PyTorch training for convenience fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) fabric.launch() + + # Train with Fabric state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) fabric.barrier() From f0edebaac74797fe683d5f59ea1686e574ebc17e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:21:51 +0000 Subject: [PATCH 84/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/parity/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 07b52b14c7360..9f45032d2da57 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -14,6 +14,7 @@ import os import torch + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 From 4964d75cc6082746ce1a692261734c46fd87e8b7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 18:17:52 +0100 Subject: [PATCH 85/86] use the utility to clear cuda cache --- tests/tests_fabric/parity/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 9f45032d2da57..4726b45e75c73 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -15,6 +15,7 @@ import torch +from lightning.fabric.accelerators.cuda import _clear_cuda_memory from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 @@ -54,6 +55,6 @@ def get_model_input_dtype(precision): def cuda_reset(): - torch.cuda.empty_cache() + _clear_cuda_memory() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() From be6eda7663306787c257ad6e8ab04a0fb75dc7ca Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 18:37:16 +0100 Subject: [PATCH 86/86] guard --- tests/tests_fabric/parity/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py index 4726b45e75c73..0248c036f76f0 100644 --- a/tests/tests_fabric/parity/utils.py +++ b/tests/tests_fabric/parity/utils.py @@ -55,6 +55,6 @@ def get_model_input_dtype(precision): def cuda_reset(): - _clear_cuda_memory() if torch.cuda.is_available(): + _clear_cuda_memory() torch.cuda.reset_peak_memory_stats()