diff --git a/CHANGELOG.md b/CHANGELOG.md index 48274f3ab945b..0d11cf4bdaab1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) + + - Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886)) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 4b0b3f702cf6e..2ae1262eb25d9 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -26,7 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: @@ -34,6 +34,11 @@ else: from torch.quantization import QConfig +if _TORCH_GREATER_EQUAL_1_11: + from torch.ao.quantization import fuse_modules_qat as fuse_modules +else: + from torch.quantization import fuse_modules + def wrap_qat_forward_context( quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None @@ -252,7 +257,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: model.qconfig = self._qconfig if self._check_feasible_fuse(model): - torch.quantization.fuse_modules(model, self._modules_to_fuse, inplace=True) + fuse_modules(model, self._modules_to_fuse, inplace=True) # Prepare the model for QAT. This inserts observers and fake_quants in # the model that will observe weight and activation tensors during calibration. diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 289b7faa431e2..87947ac9a10f3 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -49,6 +49,7 @@ _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE,