From 67bc264b969a09111046cfa36b8e7f732813c093 Mon Sep 17 00:00:00 2001 From: ORippler Date: Tue, 26 Apr 2022 18:25:33 +0200 Subject: [PATCH 1/6] Fuse_modules in a qat-respecting way --- pytorch_lightning/callbacks/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 4b0b3f702cf6e..9f46f3a7f928e 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -252,7 +252,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: model.qconfig = self._qconfig if self._check_feasible_fuse(model): - torch.quantization.fuse_modules(model, self._modules_to_fuse, inplace=True) + torch.ao.quantization.fuse_modules_qat(model, self._modules_to_fuse, inplace=True) # Prepare the model for QAT. This inserts observers and fake_quants in # the model that will observe weight and activation tensors during calibration. From 923ae894485382ae43f1a617b83befb6aeb1b293 Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 27 Apr 2022 08:57:52 +0200 Subject: [PATCH 2/6] Add compatibility for PyTorch <1.11 In older pytorch versions, `fuse_modules` used the `Module.training` flag to determine wheter fusion should be QAT-compliant or not, refer https://github.com/pytorch/pytorch/releases/tag/v1.11.0 --- pytorch_lightning/callbacks/quantization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 9f46f3a7f928e..8cc044f62b4a1 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -31,8 +31,10 @@ if _TORCH_GREATER_EQUAL_1_10: from torch.ao.quantization.qconfig import QConfig + from torch.ao.quantization import fuse_modules_qat as fuse_modules else: from torch.quantization import QConfig + from torch.quantization import fuse_modules def wrap_qat_forward_context( @@ -252,7 +254,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: model.qconfig = self._qconfig if self._check_feasible_fuse(model): - torch.ao.quantization.fuse_modules_qat(model, self._modules_to_fuse, inplace=True) + fuse_modules(model, self._modules_to_fuse, inplace=True) # Prepare the model for QAT. This inserts observers and fake_quants in # the model that will observe weight and activation tensors during calibration. From 7f16d6adc56729121deb3ad0ec24bf804050e172 Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 27 Apr 2022 09:01:39 +0200 Subject: [PATCH 3/6] Add CHANGELOG for pull #12891 --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index daee5ae803144..316fba7ef47fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `fuse_modules` to be qat-aware for `pytorch>1.10`. ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) + - Fixed `ImportError` when `torch.distributed` is not available. ([#12794](https://github.com/PyTorchLightning/pytorch-lightning/pull/12794)) From 4ea51bed3cb819a02c919ff0634f58f7e089aa0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Apr 2022 07:03:38 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/quantization.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 8cc044f62b4a1..ba38e8a0d2c7e 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -30,11 +30,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: - from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization import fuse_modules_qat as fuse_modules + from torch.ao.quantization.qconfig import QConfig else: - from torch.quantization import QConfig - from torch.quantization import fuse_modules + from torch.quantization import fuse_modules, QConfig def wrap_qat_forward_context( From c122790c7540e0721f7eeabcde3d2d0ffc9c06e1 Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 27 Apr 2022 16:46:40 +0200 Subject: [PATCH 5/6] Fix conditional import of fuse_modules_qat `torch.ao.quantization.fuse_modules_qat` was actually added in torch 1.11. --- pytorch_lightning/callbacks/quantization.py | 10 +++++++--- pytorch_lightning/utilities/__init__.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index ba38e8a0d2c7e..2ae1262eb25d9 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -26,14 +26,18 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: - from torch.ao.quantization import fuse_modules_qat as fuse_modules from torch.ao.quantization.qconfig import QConfig else: - from torch.quantization import fuse_modules, QConfig + from torch.quantization import QConfig + +if _TORCH_GREATER_EQUAL_1_11: + from torch.ao.quantization import fuse_modules_qat as fuse_modules +else: + from torch.quantization import fuse_modules def wrap_qat_forward_context( diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 289b7faa431e2..87947ac9a10f3 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -49,6 +49,7 @@ _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, From 31907237d9cab73943c306e97ec56f0b84cb92c7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 28 Apr 2022 04:35:24 +0900 Subject: [PATCH 6/6] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 316fba7ef47fe..2b0867639b07d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,7 +141,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed `fuse_modules` to be qat-aware for `pytorch>1.10`. ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) +- Fixed `fuse_modules` to be qat-aware for `torch>=1.11`. ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) - Fixed `ImportError` when `torch.distributed` is not available. ([#12794](https://github.com/PyTorchLightning/pytorch-lightning/pull/12794))