diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 490dee8b205d4..ff638eaae0cee 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -116,10 +116,15 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module def init_context(self) -> ContextManager: - dtype_ctx = _DtypeContextManager(self.dtype) if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules - return dtype_ctx + raise RuntimeError( + "Instantiating your model under the `init_module` context manager is not supported when used with" + f" `BitsandbytesPrecision(..., ignore_modules={self.ignore_modules})` as this" + " may initialize the layers on-device, defeating the purpose of quantization. You can remove" + " `ignore_modules` or remove the `init_module` context manager." + ) + dtype_ctx = _DtypeContextManager(self.dtype) stack = ExitStack() stack.enter_context(dtype_ctx) # TODO: this could also support replacing `Embedding` and `Conv1D` diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index 1c3c817b70163..a6d32c1d1708f 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -134,11 +134,12 @@ def __init__(self): assert model.l.weight.dtype == expected fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(*args, ignore_modules={"foo"})) - with fabric.init_module(): - model = MyModel() + with pytest.raises(RuntimeError, match="not supported"), fabric.init_module(): + pass + model = MyModel() # When ignore_modules is set, we only quantize on `setup` - assert model.l.weight.device.type == "cuda" - assert model.l.weight.dtype == args[1] + assert model.l.weight.device.type == "cpu" + assert model.l.weight.dtype == torch.float32 # this quantizes now model = fabric.setup(model) assert model.l.weight.device.type == "cuda"