Skip to content

Add tests that run benchmarking/profiling #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 20 additions & 88 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Generic, Literal
from typing import Generic, Literal

import lightning
import optree
@@ -20,6 +20,8 @@


class MeasureSamplesPerSecondCallback(lightning.Callback, Generic[BatchType]):
"""Callback that measures the number of samples processed per second during train/val/test."""

def __init__(self, num_optimizers: int | None = None):
super().__init__()
self.last_step_times: dict[Literal["train", "val", "test"], float] = {}
@@ -28,17 +30,14 @@ def __init__(self, num_optimizers: int | None = None):

@override
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_train_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="train")

@override
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_validation_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="val")

@override
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_test_epoch_start(trainer, pl_module)
self.on_shared_epoch_start(trainer, pl_module, phase="test")

def on_shared_epoch_start(
@@ -59,94 +58,43 @@ def on_shared_epoch_start(
@override
def on_train_batch_end(
self,
trainer: Trainer,
trainer: Trainer | None,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
outputs: STEP_OUTPUT | None,
batch: BatchType,
batch_idx: int,
batch_idx: int | None,
) -> None:
super().on_train_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_idx=batch_idx,
)
self.on_shared_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_index=batch_idx,
phase="train",
)
self.on_shared_batch_end(pl_module=pl_module, batch=batch, phase="train")

@override
def on_validation_batch_end(
self,
trainer: Trainer,
trainer: Trainer | None,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
outputs: STEP_OUTPUT | None,
batch: BatchType,
batch_idx: int,
dataloader_idx: int = 0,
batch_idx: int | None,
dataloader_idx: int | None = 0,
) -> None:
super().on_validation_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs, # type: ignore
batch=batch,
batch_idx=batch_idx,
dataloader_idx=dataloader_idx,
)
self.on_shared_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_index=batch_idx,
phase="val",
dataloader_idx=dataloader_idx,
)
self.on_shared_batch_end(pl_module=pl_module, batch=batch, phase="val")

@override
def on_test_batch_end(
self,
trainer: Trainer,
trainer: Trainer | None,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
outputs: STEP_OUTPUT | None,
batch: BatchType,
batch_idx: int,
dataloader_idx: int = 0,
batch_idx: int | None = None,
dataloader_idx: int | None = 0,
) -> None:
super().on_test_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs, # type: ignore
batch=batch,
batch_idx=batch_idx,
dataloader_idx=dataloader_idx,
)
self.on_shared_batch_end(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_index=batch_idx,
dataloader_idx=dataloader_idx,
phase="test",
)
self.on_shared_batch_end(pl_module=pl_module, batch=batch, phase="test")

def on_shared_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: BatchType,
batch_index: int,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
self, pl_module: LightningModule, batch: BatchType, phase: Literal["train", "val", "test"]
):
# Note: Not using use cuda events here, since we just want a rough throughput estimate,
# and we assume that there's at least one synchronize call at each step.
now = time.perf_counter()
if phase in self.last_step_times:
elapsed = now - self.last_step_times[phase]
@@ -165,22 +113,6 @@ def on_shared_batch_end(
# todo: support other kinds of batches
self.last_step_times[phase] = now

def log(
self,
name: str,
value: Any,
module: LightningModule | Any,
trainer: Trainer | Any,
**kwargs,
):
# Used to possibly customize how the values are logged (e.g. for non-LightningModules).
# By default, uses the LightningModule.log method.
return module.log(
name,
value,
**kwargs,
)

def get_num_samples(self, batch: BatchType) -> int:
if isinstance(batch, Tensor):
return batch.shape[0]
49 changes: 36 additions & 13 deletions project/algorithms/image_classifier.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,9 @@
from torch import Tensor
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from typing_extensions import override

from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback
from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
@@ -35,8 +36,8 @@ class ImageClassifier(LightningModule):
def __init__(
self,
datamodule: ImageClassificationDataModule,
network: HydraConfigFor[torch.nn.Module],
optimizer: HydraConfigFor[functools.partial[Optimizer]],
network: torch.nn.Module | HydraConfigFor[torch.nn.Module],
optimizer: functools.partial[Optimizer] | HydraConfigFor[functools.partial[Optimizer]],
init_seed: int = 42,
):
"""Create a new instance of the algorithm.
@@ -46,10 +47,12 @@ def __init__(
See the lightning docs for [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]
for more info.
network:
The config of the network to instantiate and train.
optimizer: The config for the Optimizer. Instantiating this will return a function \
(a [functools.partial][]) that will create the Optimizer given the hyper-parameters.
init_seed: The seed to use when initializing the weights of the network.
The network to instantiate and train, or a Hydra config that returns a network \
when instantiated.
optimizer: A function that returns an optimizer given parameters, or a Hydra config \
that creates such a function when instantiated.
init_seed: The seed to set while instantiating the network from its config. This only \
has an effect if the network is a Hydra config, and not an already instantiated.
"""
super().__init__()
self.datamodule = datamodule
@@ -58,10 +61,20 @@ def __init__(
self.init_seed = init_seed

# Save hyper-parameters.
self.save_hyperparameters(ignore=["datamodule"])
self.save_hyperparameters(
ignore=["datamodule"]
# Ignore those if they are already instantiated objects, otherwise lightning will try
# to serialize them to yaml, which will be very slow and may fail.
+ (["network"] if isinstance(network, torch.nn.Module) else [])
+ (["optimizer"] if isinstance(optimizer, functools.partial) else [])
)
# Used by Pytorch-Lightning to compute the input/output shapes of the network.

self.network: torch.nn.Module | None = None
self.network: torch.nn.Module | None = (
network if isinstance(network, torch.nn.Module) else None
)
self.logits_pinned: torch.Tensor | None = None # type: Tensor | None
self.labels_pinned: torch.Tensor | None = None # type: Tensor | None

def configure_model(self):
# Save this for PyTorch-Lightning to infer the input/output shapes of the network.
@@ -84,12 +97,15 @@ def forward(self, input: Tensor) -> Tensor:
logits = self.network(input)
return logits

@override
def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
return self.shared_step(batch, batch_index=batch_index, phase="train")

@override
def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
return self.shared_step(batch, batch_index=batch_index, phase="val")

@override
def test_step(self, batch: tuple[Tensor, Tensor], batch_index: int):
return self.shared_step(batch, batch_index=batch_index, phase="test")

@@ -102,8 +118,9 @@ def shared_step(
x, y = batch
logits: torch.Tensor = self(x)
loss = F.cross_entropy(logits, y, reduction="mean")
self.log(f"{phase}/loss", loss.detach().mean())
loss_mean = loss.detach().mean()
acc = logits.detach().argmax(-1).eq(y).float().mean()
self.log(f"{phase}/loss", loss_mean)
self.log(f"{phase}/accuracy", acc)
return {"loss": loss, "logits": logits, "y": y}

@@ -112,8 +129,11 @@ def configure_optimizers(self):

See [`lightning.pytorch.core.LightningModule.configure_optimizers`][] for more information.
"""
# Instantiate the optimizer config into a functools.partial object.
optimizer_partial = hydra_zen.instantiate(self.optimizer_config)
if isinstance(self.optimizer_config, functools.partial):
optimizer_partial = self.optimizer_config
else:
# Instantiate the optimizer config into a functools.partial object.
optimizer_partial = hydra_zen.instantiate(self.optimizer_config)
# Call the functools.partial object, passing the parameters as an argument.
optimizer = optimizer_partial(self.parameters())
# This then returns the optimizer.
@@ -122,5 +142,8 @@ def configure_optimizers(self):
def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Creates callbacks to be used by default during training."""
return [
ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes)
MeasureSamplesPerSecondCallback(),
# Uncomment to log top_k accuracy metrics.
# Note that with small models, this may cause some slowdown during training.
# ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes)
]
69 changes: 69 additions & 0 deletions project/algorithms/image_classifier_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
"""Example showing how the test suite can be used to add tests for a new algorithm."""

import logging

import lightning
import lightning.pytorch
import lightning.pytorch.loggers
import lightning.pytorch.profilers
import pytest
import torch
import wandb
from pytest_benchmark.fixture import BenchmarkFixture

from project.algorithms.lightning_module_tests import LightningModuleTests
from project.configs import Config
@@ -16,6 +24,8 @@

from .image_classifier import ImageClassifier

logger = logging.getLogger(__name__)

experiment_commands_to_test.extend(
[
"experiment=example trainer.fast_dev_run=True",
@@ -99,3 +109,62 @@ class TestImageClassifier(LightningModuleTests[ImageClassifier]):

Take a look at the `LightningModuleTests` class if you want to see the actual test code.
"""

@pytest.mark.slow
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Needs a GPU to run this test quickly.",
)
def test_benchmark_fit_speed(
self,
algorithm: ImageClassifier,
datamodule: ImageClassificationDataModule,
tmp_path_factory: pytest.TempPathFactory,
benchmark: BenchmarkFixture,
):
"""Runs a few training steps a few times to compare wall-clock time between revisions.

This uses [`pytest-benchmark`](https://pytest-benchmark.readthedocs.io/en/latest/index.html) to
run a measure the time it takes to run a few training steps.
"""
# NOTE: Here we run this test will all the datamodules and networks that are parametrized
# on the class. If you wanted to run this test outside of this repo or with a specific
# datamodule or network, you could simply do this directly:
# from torch.optim import Adam # type: ignore
# datamodule = CIFAR10DataModule(data_dir=DATA_DIR, batch_size=64)
# algo = ImageClassifier(
# datamodule=datamodule,
# network=torchvision.models.resnet18(weights=None, num_classes=datamodule.num_classes),
# optimizer=functools.partial(Adam, lr=1e-3, weight_decay=1e-4),
# ).cuda()

if datamodule is not None:
# Do the data preparation ahead of time.
datamodule.prepare_data()

def run_some_training_steps() -> float:
run_dir = tmp_path_factory.mktemp("benchmark_training_speed")
trainer = lightning.Trainer(
max_epochs=2,
limit_train_batches=10,
limit_val_batches=2,
num_sanity_val_steps=0,
log_every_n_steps=2, # Benchmark with or without logging?
logger=[
# lightning.pytorch.loggers.TensorBoardLogger(run_dir),
lightning.pytorch.loggers.WandbLogger(save_dir=run_dir, mode="offline"),
],
devices=1,
accelerator="auto",
default_root_dir=run_dir,
)
logger.info(f"Trainer log dir: {trainer.log_dir}")
trainer.fit(algorithm, datamodule=algorithm.datamodule)
wandb.finish() # just to make sure that the logging happens the same way in all runs.
train_metrics = trainer.logged_metrics
assert isinstance(train_metrics, dict)
train_acc = train_metrics["train/accuracy"]
assert isinstance(train_acc, torch.Tensor)
return train_acc.item()

benchmark(run_some_training_steps)
11 changes: 10 additions & 1 deletion project/algorithms/lightning_module_tests.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@
from typing import Any, Generic, Literal, TypeVar, overload

import lightning
import lightning.pytorch
import lightning.pytorch.profilers
import pytest
import torch
from lightning import LightningModule
@@ -243,10 +245,17 @@ def do_one_step_of_training(
accelerator=accelerator,
callbacks=callbacks,
devices=devices,
fast_dev_run=True,
fast_dev_run=3,
enable_checkpointing=False,
deterministic=True,
default_root_dir=tmp_path,
# todo: include pytorch profiler here?
profiler=lightning.pytorch.profilers.PyTorchProfiler(
profile_memory=True,
record_shapes=True,
record_module_names=True,
schedule=torch.profiler.schedule(wait=0, warmup=1, active=2),
),
)
trainer.fit(algorithm, datamodule=datamodule)
return callbacks
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.