-
Notifications
You must be signed in to change notification settings - Fork 50
Open
Description
When running this CLI command to run Delphi on a Gemma Scope SAE:
python -m delphi google/gemma-2-2b google/gemma-scope-2b-pt-res --explainer_model Qwen/Qwen3-4B-Instruct-2507 --n_tokens 10_000_000 --max_latents 10 --explainer_provider offline --hookpoints layer_13/width_16k/average_l0_84 --pipeline_num_proc 10 --explainer_model_max_len 5600 --max_memory 0.9 --name gemma-2b --hf_token <my_hf_token>
I ran into the following error:
Caching latents: 0%| | 0/1220 [00:00<?, ?it/s]
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 476, in <module>
asyncio.run(run(args.run_cfg))
File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 413, in run
populate_cache(
File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 340, in populate_cache
cache.run(cache_cfg.n_tokens, tokens)
File "/workspace/delphi_max-memory_vllm_fixes/delphi/latents/cache.py", line 281, in run
sae_latents = self.hookpoint_to_sparse_encode[hookpoint](
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/delphi_max-memory_vllm_fixes/delphi/sparse_coders/custom/gemmascope.py", line 63, in _forward
encoded = sae.encode(x)
^^^^^^^^^^^^^
File "/workspace/delphi_max-memory_vllm_fixes/delphi/sparse_coders/custom/gemmascope.py", line 83, in encode
pre_acts = input_acts @ self.W_enc + self.b_enc
~~~~~~~~~~~^~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)
(workspace) root@65cf2bd2fd30:/workspace# Connection to 195.26.233.61 closed by remote host.
Seemed like the SAE weights were somehow defaulting to the CPU, so I took a look at the delphi/sparse_coders/custom/gemmascope.py script to see where the model is assigned a device. I found that device assignment is handled in a conditional on lines 107-108:
if device == "cuda":
model.cuda()yet, according to the load_gemma_autoencoders type hints on line 16, device can be either a string (as expected in the conditional above) or a torch.device object:
device: str | torch.device = torch.device("cuda"),Because it defaults to a torch.device object, the device check on line 107 is False, and the model is never moved to GPU, yielding the original error.
I've modified the conditional and created a PR. Let me know what you think!
Metadata
Metadata
Assignees
Labels
No labels