Skip to content

AutoModel.from_pretrained(...) (with explicit device_map unset) fails under with torch.device("meta") with PyTorch 2.6.0 and 2.7.0 #38066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vadimkantorov opened this issue May 10, 2025 · 4 comments

Comments

@vadimkantorov
Copy link

vadimkantorov commented May 10, 2025

# from torch.nn.attention.flex_attention import BlockMask, flex_attention
from transformers import AutoModel
import torch

with torch.device('meta'):
    AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)

I found this code in the wild in https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/blob/f6d1ec77ce2ce18f3d925a1014c9e4d6b4ad3072/orz/ppo/actors.py#L745-L746 (linked issue Open-Reasoner-Zero/Open-Reasoner-Zero#71)

fails with:

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
[<ipython-input-1-00ba4c43be18>](https://localhost:8080/#) in <cell line: 0>()
      4 
      5 with torch.device('meta'):
----> 6     AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)

6 frames
[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    569             if model_class.config_class == config.sub_configs.get("text_config", None):
    570                 config = config.get_text_config()
--> 571             return model_class.from_pretrained(
    572                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    573             )

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _wrapper(*args, **kwargs)
    277         old_dtype = torch.get_default_dtype()
    278         try:
--> 279             return func(*args, **kwargs)
    280         finally:
    281             torch.set_default_dtype(old_dtype)

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   4397                 offload_index,
   4398                 error_msgs,
-> 4399             ) = cls._load_pretrained_model(
   4400                 model,
   4401                 state_dict,

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
   4831             # Skip it with fsdp on ranks other than 0
   4832             elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
-> 4833                 disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
   4834                     model_to_load,
   4835                     state_dict,

[/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_state_dict_into_meta_model(model, state_dict, shard_file, expected_keys, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, cpu_offload_folder, cpu_offload_index, hf_quantizer, is_safetensors, keep_in_fp32_regex, unexpected_keys, device_mesh)
    822                     param_device = "cpu" if is_local_dist_rank_0() else "meta"
    823 
--> 824                 _load_parameter_into_model(model, param_name, param.to(param_device))
    825 
    826             else:

[/usr/local/lib/python3.11/dist-packages/torch/utils/_device.py](https://localhost:8080/#) in __torch_function__(self, func, types, args, kwargs)
    102         if func in _device_constructors() and kwargs.get('device') is None:
    103             kwargs['device'] = self.device
--> 104         return func(*args, **kwargs)
    105 
    106 # NB: This is directly called from C++ in torch/csrc/Device.cpp

NotImplementedError: Cannot copy out of meta tensor; no data!

Also, unless uncommenting the first line, it also fails on 2.6.0 with RuntimeError: Tensor.item() cannot be called on meta tensors:

@vadimkantorov vadimkantorov changed the title AutoModel.from_pretrained(...) fails under with torch.device("meta") with PyTorch 2.7.0 AutoModel.from_pretrained(...) fails under with torch.device("meta") with PyTorch 2.6.0 and 2.7.0 May 10, 2025
@vadimkantorov
Copy link
Author

vadimkantorov commented May 10, 2025

Probably in the code

if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"

there "cpu" should be replaced with "meta" even if it's rank0, if the default device type is changed to meta (because of pytorch/pytorch#148874, this can be checked as torch.empty(()).device == torch.device("meta"))

In all of modeling_utils.py there are 3-4 places where this needs to be fixed to not fail under with torch.device("meta") (which forces the loaded params to actually be of device type "meta" which leads to not being able to copy from them to param_device = "cpu")


I think the code I found in the wild does this as a smoke test, e.g. to force downloading the model weights, check that the model code works etc?

@Rocketknight1
Copy link
Member

Rocketknight1 commented May 12, 2025

Hi, I'm not sure with torch.device("meta") works cleanly with from_pretrained()! Try device_map="meta" in the call instead, but note that this doesn't actually load any weights, since tensors on the meta device store no data, which somewhat defeats the purpose of from_pretrained

@vadimkantorov
Copy link
Author

vadimkantorov commented May 12, 2025

It actually worked after my fixes to ~three lines. The only problems came from usage of default cpu. I can submit a draft PR if you'd like to take a look.

I think it's a valid way for smoke tests / getting HF to download everything and check that sizes in the checkpoint are matching the meta tensors in the model etc. And supporting colloquial PyTorch's with torch.device("meta"): (in addition to device_map="meta") would be good too, given that it's a small fix

@Rocketknight1
Copy link
Member

Sure - if you're willing to open a PR for the fix we can take a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants