Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

sparseml integration #197

Merged
merged 74 commits into from Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
a5b499f
initial commit for sparseml callback
mathemusician Sep 14, 2021
c1864de
typos
mathemusician Sep 14, 2021
0ae5818
typo
mathemusician Sep 14, 2021
5c6ce32
path for models
mathemusician Sep 14, 2021
b13aa51
move sparseml as callback
mathemusician Sep 14, 2021
4ceedd6
simplified change to one file for sparseml callback
mathemusician Sep 14, 2021
8d94e63
small bug with import sparseml
mathemusician Sep 14, 2021
1b5fc82
test change
mathemusician Sep 14, 2021
40bf8e3
override callback with sparseml
mathemusician Sep 15, 2021
6e5010b
revert to not overriding callback
mathemusician Sep 15, 2021
2532c8e
added default
mathemusician Sep 15, 2021
487bbe6
added self to defaults
mathemusician Sep 15, 2021
54e7934
still trying to get hydra to play nice
mathemusician Sep 15, 2021
276706b
finally got sparseml to work
mathemusician Sep 15, 2021
061ceba
added sparseml to plugins
mathemusician Sep 15, 2021
2f88b52
sparseml can now appropriately handle large inputs
mathemusician Sep 16, 2021
db4097d
wandb logging
mathemusician Sep 16, 2021
8cf2dd6
logging is a lot more complicated than I thought it would be
mathemusician Sep 16, 2021
a50ab24
finally got loggin to work
mathemusician Sep 16, 2021
7db8793
typo
mathemusician Sep 16, 2021
31244ff
reformatted with black, add export to small model
mathemusician Sep 16, 2021
1ca7488
allow saving of full model
mathemusician Sep 16, 2021
d3d77f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
e35884f
get rid of automatic addition of sparseml from deepspeed and sharded …
mathemusician Sep 17, 2021
2bdf62e
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 17, 2021
57d4c15
wandb to wab, LightningBoltsSparseMLCallback to TransformerSparseMLCa…
mathemusician Sep 20, 2021
c38be9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
c0ca012
get rid of wandblogger
mathemusician Sep 20, 2021
f49debf
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 20, 2021
e303dba
compute metrics only if labels are non -1
mathemusician Sep 20, 2021
afe2198
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
97b88c2
get rid of dependency on environment variables
mathemusician Sep 20, 2021
916e998
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 20, 2021
3863bef
add dependency to sparseml
mathemusician Sep 21, 2021
622b5ca
add logger to sparseml trainer
mathemusician Sep 21, 2021
b4bb8cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2021
c227247
Fix imports
Sep 22, 2021
aeb3b6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2021
0bc39ca
Fix
Sep 22, 2021
73a2331
Indent
Sep 22, 2021
5de4dd1
Merge branch 'master' into sparseml
Sep 22, 2021
461a6ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2021
f9080b3
Add requirement for bolts
Sep 22, 2021
8eee3e3
Update torch
Sep 22, 2021
b162d4a
add license
mathemusician Sep 23, 2021
f8d53cd
made WABlogger test
mathemusician Sep 23, 2021
7be9464
delete unnecessary folders that were made during testing phase
mathemusician Sep 23, 2021
6b8a398
added callback unit test
mathemusician Sep 23, 2021
f18648a
handling of pure tensors from datamodule into model during onnx expor…
mathemusician Sep 23, 2021
d3fa2f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
6805e55
typo
mathemusician Sep 23, 2021
60a9ca2
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 23, 2021
cff9a44
PEP8 type checking to isinstance
mathemusician Sep 23, 2021
d872b0e
get rid of unused import, comparison using unittest.assertequals
mathemusician Sep 23, 2021
220cab6
skip callbacks test if the wrong torch version is downloaded
mathemusician Sep 23, 2021
93b2014
correct way to check for instance of collections.ordereddict
mathemusician Sep 23, 2021
275398c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
372dd5c
skip test if wandb does not exist
mathemusician Sep 23, 2021
d334d61
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 23, 2021
0efcb29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
e05235b
skip loggers test if wandb exists
mathemusician Sep 23, 2021
9abde72
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 23, 2021
635bfd9
PEP8
mathemusician Sep 23, 2021
3a8595b
Add WANDB to available list and test install
Sep 27, 2021
88708cf
Add version
Sep 28, 2021
dcc4f49
Skip tests for windows
Sep 29, 2021
f1af3a0
Revert "Skip tests for windows"
Sep 29, 2021
31c900e
Skip for windows
Sep 29, 2021
dd3d8dd
fixed skipifs
mathemusician Sep 30, 2021
556d4b0
sparseml documentation
mathemusician Sep 30, 2021
f10bc34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
1f0fe55
add sparseml documentation to rst tree
mathemusician Sep 30, 2021
92244f1
Merge branch 'master' of https://github.com/mathemusician/lightning-t…
mathemusician Sep 30, 2021
2788bea
Merge branch 'master' into sparseml
Sep 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions conf/trainer/callbacks/sparseml.yaml
@@ -0,0 +1,3 @@
_target_: lightning_transformers.core.callback.LightningBoltsSparseMLCallback
output_dir: ${env:MODELS_PATH}
recipe_path: ${env:RECIPE_PATH}
1 change: 1 addition & 0 deletions conf/trainer/logger/sparsewandb.yaml
@@ -0,0 +1 @@
_target_: lightning_transformers.core.loggers.WANDBLogger
3 changes: 3 additions & 0 deletions conf/trainer/sparseml.yaml
@@ -0,0 +1,3 @@
defaults:
- default # inherit from default trainer conf
- callbacks: sparseml
112 changes: 112 additions & 0 deletions lightning_transformers/core/callback.py
Expand Up @@ -11,11 +11,123 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import inspect
import os
import time
from typing import Any, Dict, List, Optional, Union

import numpy
import onnxruntime
import torch
from pl_bolts.callbacks import SparseMLCallback
from pytorch_lightning import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from sparseml.pytorch.utils import ModuleExporter
from sparseml.pytorch.utils.logger import WANDBLogger
from torch import Tensor


class LightningBoltsSparseMLCallback(SparseMLCallback):
mathemusician marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, output_dir, recipe_path):
self.output_dir = output_dir
super().__init__(recipe_path=recipe_path)

def on_init_end(self, trainer: "pl.Trainer") -> None:
if isinstance(trainer.logger, WANDBLogger):
trainer.logger.__init__(init_kwargs={"project": "lightning-transformers"})

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
optimizer = trainer.optimizers

if len(optimizer) > 1:
raise MisconfigurationException("SparseML only supports training with one optimizer.")
optimizer = optimizer[0]

loggers = trainer.logger

if not isinstance(loggers, list):
loggers = [loggers]

self.manager.initialize(pl_module, epoch=0.0, logger=loggers)
self.manager.initialize_loggers(loggers)

optimizer = self.manager.modify(
pl_module, optimizer, steps_per_epoch=self._num_training_steps_per_epoch(trainer), epoch=0
)

trainer.optimizers = [optimizer]

@staticmethod
def export_to_sparse_onnx(
model: "LightningModule", output_dir: str, sample_batch: Optional[Tensor] = None, **kwargs
) -> None:
"""Exports the model to ONNX format."""
with model._prevent_trainer_and_dataloaders_deepcopy():
exporter = ModuleExporter(model.model, output_dir=output_dir)
sample_batch = sample_batch if sample_batch is not None else model.example_input_array
if sample_batch is None:
raise MisconfigurationException(
"To export the model, a sample batch must be passed via "
"``SparseMLCallback.export_to_sparse_onnx(model, output_dir, sample_batch=sample_batch)`` "
"or an ``example_input_array`` property within the LightningModule"
)

# the following is adapted from @natuan and @spacemanidol
sess = None
num_samples = 0

sample_inputs = os.path.join(output_dir, "sample-inputs")
sample_outputs = os.path.join(output_dir, "sample-outputs")
os.makedirs(sample_inputs, exist_ok=True)
os.makedirs(sample_outputs, exist_ok=True)

if sess is None:
forward_args_spec = inspect.getfullargspec(exporter._module.__class__.forward)
one_sample_input = collections.OrderedDict([(f, sample_batch[f][0].long().reshape(1, -1))
for f in forward_args_spec.args if f in sample_batch])

try:
exporter.export_onnx(sample_batch=one_sample_input, convert_qat=True, **kwargs)
exporter.export_onnx(
sample_batch=one_sample_input,
name="small_model.onnx",
convert_qat=True,
export_params=False,
**kwargs,
)
onnx_file = os.path.join(output_dir, "model.onnx")

except Exception:
mathemusician marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Error exporting ONNX models and/or inputs/outputs")

sess = onnxruntime.InferenceSession(onnx_file)

# add additional files for testing since this feature is very new
input_names = list(sample_batch.keys())
output_names = [o.name for o in sess.get_outputs()]
for input_vals in zip(*sample_batch.values()):
input_feed = {k: v.long().numpy() for k, v in zip(input_names, input_vals)}
output_vals = sess.run(output_names, {k: input_feed[k].reshape(1, -1) for k in input_feed})
output_dict = {name: numpy.squeeze(val) for name, val in zip(output_names, output_vals)}
file_idx = f"{num_samples}".zfill(4)
numpy.savez(f"{sample_inputs}/inp-{file_idx}.npz", **input_feed)
numpy.savez(f"{sample_outputs}/out-{file_idx}.npz", **output_dict)
num_samples += 1

def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
sample_batch = next(iter(trainer.train_dataloader))
# if asked for output names, bert's ModelOutput gives two names
# but when run, this the model only gives one output
# workaround is just to force onnx to realize there is only one output
output_names = ["logits"]
self.export_to_sparse_onnx(
output_dir=self.output_dir, model=pl_module, sample_batch=sample_batch, output_names=output_names
)


class CUDACallback(Callback):
Expand Down
92 changes: 92 additions & 0 deletions lightning_transformers/core/loggers.py
@@ -0,0 +1,92 @@
import time
from typing import Dict, Optional, Union

from pytorch_lightning.loggers import WandbLogger
from sparseml.pytorch.utils.logger import LambdaLogger


class WANDBLogger(WandbLogger):
mathemusician marked this conversation as resolved.
Show resolved Hide resolved
"""
Modifier logger that handles outputting values to Weights and Biases.

:param init_kwargs: the args to call into wandb.init with;
ex: wandb.init(**init_kwargs). If not supplied, then init will not be called
:param name: name given to the logger, used for identification;
defaults to wandb
:param enabled: True to log, False otherwise
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enabled = True

def _lambda_func(
self,
tag: Optional[str],
value: Optional[float],
values: Optional[Dict[str, float]],
step: Optional[int],
wall_time: Optional[float],
) -> bool:
params = {}

if value is not None:
params[tag] = value

if values:
if tag:
values = {f"{tag}/{key}": val for key, val in values.items()}
params.update(values)

try:
self.log_metrics(params, step=step)
except Exception as e:
print(params, e)

return True

def log_scalar(
self,
tag: str,
value: float,
step: Union[None, int] = None,
wall_time: Union[None, float] = None,
):
"""
:param tag: identifying tag to log the value with
:param value: value to save
:param step: global step for when the value was taken
:param wall_time: global wall time for when the value was taken,
defaults to time.time()
:return: True if logged, False otherwise.
"""
if not self.enabled:
return False

if not wall_time:
wall_time = time.time()

return self._lambda_func(tag, value, None, step, wall_time)

def log_scalars(
self,
tag: str,
values: Dict[str, float],
step: Union[None, int] = None,
wall_time: Union[None, float] = None,
):
"""
:param tag: identifying tag to log the values with
:param values: values to save
:param step: global step for when the values were taken
:param wall_time: global wall time for when the values were taken,
defaults to time.time()
:return: True if logged, False otherwise.
"""
if not self.enabled:
return False

if not wall_time:
wall_time = time.time()

return self._lambda_func(tag, None, values, step, wall_time)