Skip to content

Gemma Scope model device defaults to CPU #160

@kmaherx

Description

@kmaherx

When running this CLI command to run Delphi on a Gemma Scope SAE:

python -m delphi google/gemma-2-2b google/gemma-scope-2b-pt-res --explainer_model Qwen/Qwen3-4B-Instruct-2507 --n_tokens 10_000_000 --max_latents 10 --explainer_provider offline --hookpoints layer_13/width_16k/average_l0_84 --pipeline_num_proc 10 --explainer_model_max_len 5600 --max_memory 0.9 --name gemma-2b --hf_token <my_hf_token>

I ran into the following error:

Caching latents:   0%|                                                                                                                                                             | 0/1220 [00:00<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                                          
  File "<frozen runpy>", line 198, in _run_module_as_main                                                                                                                                                   
  File "<frozen runpy>", line 88, in _run_code                                                                                                                                                              
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 476, in <module>                                                                                                                  
    asyncio.run(run(args.run_cfg))                                                                                                                                                                          
  File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run                                                                                                                                           
    return runner.run(main)                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^                                                                                                                                                                                 
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run                                                                                                                                           
    return self._loop.run_until_complete(task)                                                                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                              
  File "/usr/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete                                                                                                                        
    return future.result()                                                                                                                                                                                  
           ^^^^^^^^^^^^^^^                                                                                                                                                                                  
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 413, in run                                                                                                                       
    populate_cache(                                                                                                                                                                                         
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/__main__.py", line 340, in populate_cache                                                                                                            
    cache.run(cache_cfg.n_tokens, tokens)                                                                                                                                                                   
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/latents/cache.py", line 281, in run                                                                                                                  
    sae_latents = self.hookpoint_to_sparse_encode[hookpoint](                                                                                                                                               
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                               
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/sparse_coders/custom/gemmascope.py", line 63, in _forward                                                                                            
    encoded = sae.encode(x)                                                                                                                                                                                 
              ^^^^^^^^^^^^^                                                                                                                                                                                 
  File "/workspace/delphi_max-memory_vllm_fixes/delphi/sparse_coders/custom/gemmascope.py", line 83, in encode                                                                                              
    pre_acts = input_acts @ self.W_enc + self.b_enc 
               ~~~~~~~~~~~^~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)
(workspace) root@65cf2bd2fd30:/workspace# Connection to 195.26.233.61 closed by remote host.

Seemed like the SAE weights were somehow defaulting to the CPU, so I took a look at the delphi/sparse_coders/custom/gemmascope.py script to see where the model is assigned a device. I found that device assignment is handled in a conditional on lines 107-108:

if device == "cuda":
    model.cuda()

yet, according to the load_gemma_autoencoders type hints on line 16, device can be either a string (as expected in the conditional above) or a torch.device object:

device: str | torch.device = torch.device("cuda"),

Because it defaults to a torch.device object, the device check on line 107 is False, and the model is never moved to GPU, yielding the original error.

I've modified the conditional and created a PR. Let me know what you think!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions