Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the sparseml callback #724

Merged
merged 8 commits into from
Sep 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))


- Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724))


### Changed

- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
Expand Down
55 changes: 55 additions & 0 deletions docs/source/callbacks/sparseml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
=================
SparseML Callback
=================

`SparseML <https://docs.neuralmagic.com/sparseml/>`__ allows you to leverage sparsity to improve inference times substantially.

SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse <https://github.com/neuralmagic/deepsparse>`__ engine to exploit the introduced sparsity, resulting in large performance improvements.

.. warning::

The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs.

To use leverage SparseML & DeepSparse follow the below steps:

1. Choose your Sparse Recipe
----------------------------

To choose a recipe, have a look at `recipes <https://docs.neuralmagic.com/sparseml/source/recipes.html>`__ and `Sparse Zoo <https://docs.neuralmagic.com/sparsezoo/>`__.

It may be easier to infer a recipe via the UI dashboard using `Sparsify <https://github.com/neuralmagic/sparsify>`__ which allows you to tweak and configure a recipe.
This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``.

2. Train with SparseMLCallback
------------------------------

.. code-block:: python
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

from pytorch_lightning import LightningModule, Trainer
from pl_bolts.callbacks import SparseMLCallback


model = MyModel()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

trainer = Trainer(
callbacks=SparseMLCallback(recipe_path='recipe.yaml')
)

3. Export to ONNX!
------------------

Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format.
Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below.

.. code-block:: python
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

import torch

# export the onnx model, using the `model.example_input_array`
SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/')
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# export the onnx model, providing a sample batch
SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32))


Once your model has been exported, you can import this into either `Sparsify <https://github.com/neuralmagic/sparsify>`__ or `DeepSparse <https://github.com/neuralmagic/deepsparse>`__.
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.callbacks.printing import PrintTableMetricsCallback
from pl_bolts.callbacks.sparseml import SparseMLCallback
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.callbacks.torch_ort import ORTCallback
from pl_bolts.callbacks.variational import LatentDimInterpolator
Expand All @@ -20,4 +21,5 @@
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
"ORTCallback",
"SparseMLCallback",
]
90 changes: 90 additions & 0 deletions pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.utils import _SPARSEML_AVAILABLE

if _SPARSEML_AVAILABLE:
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import ModuleExporter


class SparseMLCallback(Callback):
"""Enables SparseML aware training. Requires a recipe to run during training.

Args:
recipe_path: Path to a SparseML compatible yaml recipe.
More information at https://docs.neuralmagic.com/sparseml/source/recipes.html
"""

def __init__(self, recipe_path):
if not _SPARSEML_AVAILABLE:
raise MisconfigurationException("SparseML has not be installed, install with pip install sparseml")
self.manager = ScheduledModifierManager.from_yaml(recipe_path)

def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
optimizer = trainer.optimizers

if len(optimizer) > 1:
raise MisconfigurationException("SparseML only supports training with one optimizer.")
optimizer = optimizer[0]
optimizer = self.manager.modify(
pl_module, optimizer, steps_per_epoch=self._num_training_steps_per_epoch(trainer), epoch=0
)
trainer.optimizers = [optimizer]

def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.manager.finalize(pl_module)

def _num_training_steps_per_epoch(self, trainer: Trainer) -> int:
"""Total training steps inferred from the datamodule and devices."""
if isinstance(trainer.limit_train_batches, int) and trainer.limit_train_batches != 0:
dataset_size = trainer.limit_train_batches
elif isinstance(trainer.limit_train_batches, float):
# limit_train_batches is a percentage of batches
dataset_size = len(trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * trainer.limit_train_batches)
else:
dataset_size = len(trainer.datamodule.train_dataloader())

num_devices = max(1, trainer.num_gpus, trainer.num_processes)
if trainer.tpu_cores:
num_devices = max(num_devices, trainer.tpu_cores)

effective_batch_size = trainer.accumulate_grad_batches * num_devices
max_estimated_steps = dataset_size // effective_batch_size

if trainer.max_steps and trainer.max_steps < max_estimated_steps:
return trainer.max_steps
return max_estimated_steps

@staticmethod
def export_to_sparse_onnx(
model: LightningModule, output_dir: str, sample_batch: Optional[torch.Tensor] = None
) -> None:
"""Exports the model to ONNX format."""
with model._prevent_trainer_and_dataloaders_deepcopy():
exporter = ModuleExporter(model, output_dir=output_dir)
sample_batch = sample_batch if sample_batch is not None else model.example_input_array
if sample_batch is None:
raise MisconfigurationException(
"To export the model, a sample batch must be passed via "
"``SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)`` "
"or an ``example_input_array`` property within the LightningModule"
)
exporter.export_onnx(sample_batch=sample_batch)
1 change: 1 addition & 0 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ def _compare_version(package: str, op, version) -> bool:
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1")
_PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
_SPARSEML_AVAILABLE = _module_available("sparseml")

__all__ = ["BatchGradientVerification"]
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ mypy>=0.790

atari-py==0.2.6 # needed for RL
scikit-learn>=0.23
sparseml
104 changes: 104 additions & 0 deletions tests/callbacks/test_sparseml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path

import pytest
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.callbacks import SparseMLCallback
from pl_bolts.utils import _SPARSEML_AVAILABLE
from tests.helpers.boring_model import BoringModel

if _SPARSEML_AVAILABLE:
from sparseml.pytorch.optim import RecipeManagerStepWrapper


@pytest.fixture
def recipe():
return """
version: 0.1.0
modifiers:
- !EpochRangeModifier
start_epoch: 0.0
end_epoch: 1.0

- !LearningRateModifier
start_epoch: 0
end_epoch: -1.0
update_frequency: -1.0
init_lr: 0.005
lr_class: MultiStepLR
lr_kwargs: {'milestones': [43, 60], 'gamma': 0.1}

- !GMPruningModifier
start_epoch: 0
end_epoch: 40
update_frequency: 1.0
init_sparsity: 0.05
final_sparsity: 0.85
mask_type: unstructured
params: __ALL__
"""


@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
def test_train_sparse_ml_callback(tmpdir, recipe):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert isinstance(trainer.optimizers[0], RecipeManagerStepWrapper)

recipe_path = Path(tmpdir) / "recipe.yaml"
recipe_path.write_text(recipe)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=[SparseMLCallback(recipe_path=str(recipe_path)), TestCallback()],
)
trainer.fit(model)

sample_batch = torch.randn(1, 32)
output_dir = Path(tmpdir) / "model_export/"
SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)
assert os.path.exists(output_dir)


@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
def test_fail_if_no_example_input_array_or_sample_batch(tmpdir, recipe):
model = BoringModel()
with pytest.raises(MisconfigurationException, match="To export the model, a sample batch must be passed"):
output_dir = Path(tmpdir) / "model_export/"
SparseMLCallback.export_to_sparse_onnx(model, output_dir)


@pytest.mark.skipif(not _SPARSEML_AVAILABLE, reason="SparseML isn't installed.")
def test_fail_if_multiple_optimizers(tmpdir, recipe):
recipe_path = Path(tmpdir) / "recipe.yaml"
recipe_path.write_text(recipe)

class TestModel(BoringModel):
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())], []

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, callbacks=[SparseMLCallback(recipe_path=str(recipe_path))]
)
with pytest.raises(MisconfigurationException, match="SparseML only supports training with one optimizer."):
trainer.fit(model)