Skip to content

Commit

Permalink
Enable mnist tests (#20)
Browse files Browse the repository at this point in the history
* Enable mnist tests

Signed-off-by: Jerome <janand@habana.ai>

* Update right path

Signed-off-by: Jerome <janand@habana.ai>

* Enable fast device run

Signed-off-by: Jerome <janand@habana.ai>

---------

Signed-off-by: Jerome <janand@habana.ai>
  • Loading branch information
jerome-habana committed Apr 27, 2023
1 parent 86f5557 commit 219518f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
18 changes: 9 additions & 9 deletions .azure/hpu-tests-pl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ jobs:
workingDirectory: tests/
displayName: 'HPU precision test'
# ToDo
#- bash: pip install ".[examples]"
# displayName: 'Install extra for examples'
#
#- bash: |
# export PYTHONPATH="${PYTHONPATH}:$(pwd)"
# python mnist_sample.py
# workingDirectory: examples/
# displayName: 'Testing: HPU examples'
- bash: pip install ".[examples]"
displayName: 'Install extra for examples'

- bash: |
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
python mnist_sample.py
workingDirectory: examples/pytorch/
displayName: 'Testing: HPU examples'
- bash: |
python -m pytest -sv test_pytorch/test_profiler.py --forked --hpus 1 --junitxml=hpu1_profiler_test-results.xml
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
### Fixed

- Fixed mnist example test ([#20](https://github.com/Lightning-AI/lightning-Habana/pull/20))
-
### Removed

### Deprecated
31 changes: 9 additions & 22 deletions examples/pytorch/mnist_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@
# limitations under the License.

import torch
from jsonargparse import lazy_instance
from lightning_utilities import module_available
from torch.nn import functional as F # noqa: N812

if module_available("lightning"):
from lightning.pytorch import LightningModule
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
elif module_available("pytorch_lightning"):
from pytorch_lightning import LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule

from lightning_habana.pytorch.plugins.precision import HPUPrecisionPlugin
from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.strategies import SingleHPUStrategy


class LitClassifier(LightningModule):
Expand Down Expand Up @@ -62,20 +60,9 @@ def configure_optimizers(self):


if __name__ == "__main__":
cli = LightningCLI(
LitClassifier,
MNISTDataModule,
trainer_defaults={
"accelerator": "hpu",
"devices": 1,
"max_epochs": 1,
"plugins": lazy_instance(HPUPrecisionPlugin, precision="bf16-mixed"),
},
run=False,
save_config_kwargs={"overwrite": True},
)
dm = MNISTDataModule(batch_size=32)
model = LitClassifier()
trainer = Trainer(fast_dev_run=True, accelerator=HPUAccelerator(), devices=1, strategy=SingleHPUStrategy())

# Run the model ⚡
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.validate(cli.model, datamodule=cli.datamodule)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)

0 comments on commit 219518f

Please sign in to comment.