From b82ff55cb9ffa1b347afa9eb5bb749308400e3ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 3 Oct 2023 19:03:30 +0200 Subject: [PATCH 1/4] Forbid init_module on-device instantiation with bnb ignored modules --- src/lightning/fabric/plugins/precision/bitsandbytes.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 2ffa3fe460724..5a46b2e493771 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -118,6 +118,13 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: def init_context(self) -> ContextManager: dtype_ctx = _DtypeContextManager(self.dtype) if self.ignore_modules: + if torch.tensor(0).device.type != "cpu": + raise RuntimeError( + "Instantiating your model under the `init_module` context manager is not supported when used with" + f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` and non-CPU device, as this" + " would initialize the layers on-device, defeating the purpose of quantization. You can remove" + " `ignore_modules` or remove the `init_module` context manager." + ) # cannot patch the Linear class if the user wants to skip some submodules return dtype_ctx stack = ExitStack() From 14a1403d5be4ea391e76236661ebc882ed3403f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 3 Oct 2023 19:06:14 +0200 Subject: [PATCH 2/4] Test --- tests/tests_fabric/plugins/precision/test_bitsandbytes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index 1c3c817b70163..e6073ce3c0d9c 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -134,10 +134,11 @@ def __init__(self): assert model.l.weight.dtype == expected fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(*args, ignore_modules={"foo"})) - with fabric.init_module(): - model = MyModel() + with pytest.raises(RuntimeError, match="not supported.*non-CPU device"), fabric.init_module(): + pass + model = MyModel() # When ignore_modules is set, we only quantize on `setup` - assert model.l.weight.device.type == "cuda" + assert model.l.weight.device.type == "cpu" assert model.l.weight.dtype == args[1] # this quantizes now model = fabric.setup(model) From d9fb589999765cb12476584bf66dd8782a728a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 3 Oct 2023 19:10:29 +0200 Subject: [PATCH 3/4] Test --- tests/tests_fabric/plugins/precision/test_bitsandbytes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index e6073ce3c0d9c..e77332b07ffed 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -139,7 +139,7 @@ def __init__(self): model = MyModel() # When ignore_modules is set, we only quantize on `setup` assert model.l.weight.device.type == "cpu" - assert model.l.weight.dtype == args[1] + assert model.l.weight.dtype == torch.float32 # this quantizes now model = fabric.setup(model) assert model.l.weight.device.type == "cuda" From 66f8d706947bf270dc483984d80bb72d399c5803 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 4 Oct 2023 19:33:00 +0200 Subject: [PATCH 4/4] Update --- .../fabric/plugins/precision/bitsandbytes.py | 16 +++++++--------- .../plugins/precision/test_bitsandbytes.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 7cdb1e4942c58..ff638eaae0cee 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -116,17 +116,15 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module def init_context(self) -> ContextManager: - dtype_ctx = _DtypeContextManager(self.dtype) if self.ignore_modules: - if torch.tensor(0).device.type != "cpu": - raise RuntimeError( - "Instantiating your model under the `init_module` context manager is not supported when used with" - f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` and non-CPU device, as this" - " would initialize the layers on-device, defeating the purpose of quantization. You can remove" - " `ignore_modules` or remove the `init_module` context manager." - ) # cannot patch the Linear class if the user wants to skip some submodules - return dtype_ctx + raise RuntimeError( + "Instantiating your model under the `init_module` context manager is not supported when used with" + f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this" + " may initialize the layers on-device, defeating the purpose of quantization. You can remove" + " `ignore_modules` or remove the `init_module` context manager." + ) + dtype_ctx = _DtypeContextManager(self.dtype) stack = ExitStack() stack.enter_context(dtype_ctx) # TODO: this could also support replacing `Embedding` and `Conv1D` diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index e77332b07ffed..a6d32c1d1708f 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -134,7 +134,7 @@ def __init__(self): assert model.l.weight.dtype == expected fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(*args, ignore_modules={"foo"})) - with pytest.raises(RuntimeError, match="not supported.*non-CPU device"), fabric.init_module(): + with pytest.raises(RuntimeError, match="not supported"), fabric.init_module(): pass model = MyModel() # When ignore_modules is set, we only quantize on `setup`