Skip to content

RuntimeError when converting and saving Flax ViT model to PyTorch #37999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
4 tasks
nobodyPerfecZ opened this issue May 7, 2025 · 3 comments · May be fixed by #38114
Open
4 tasks

RuntimeError when converting and saving Flax ViT model to PyTorch #37999

nobodyPerfecZ opened this issue May 7, 2025 · 3 comments · May be fixed by #38114

Comments

@nobodyPerfecZ
Copy link

nobodyPerfecZ commented May 7, 2025

System Info

Env:

  • transformers version: 4.51.3
  • Platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.39
  • Python version: 3.10.16
  • Huggingface_hub version: 0.30.2
  • Safetensors version: 0.5.3
  • Accelerate version: not found
  • Accelerate config: not found
  • DeepSpeed version: not found
  • PyTorch version (GPU?): 2.7.0+cu128 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.10.5 (True)
  • Jax version: 0.5.1
  • JaxLib version: 0.5.1
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX 3090

Who can help?

@gante @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Example script to reproduce the bug:

from transformers import FlaxViTForImageClassification, ViTForImageClassification

# Simulate fine-tuning by loading the base model and saving it
flax_model = FlaxViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224",
    num_labels=5,
    id2label={
        0: "bicycle",
        1: "bus",
        2: "car",
        3: "crosswalk",
        4: "hydrant",
    },
    label2id={
        "bicycle": 0,
        "bus": 1,
        "car": 2,
        "crosswalk": 3,
        "hydrant": 4,
    },
    ignore_mismatched_sizes=True,
)
flax_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2")

# Load the fine-tuned model and convert it to PyTorch
pt_model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="./test-vit-finetuned-patch16-224-recaptchav2",
    from_flax=True,
)
pt_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2") # RuntimeError

The complete traceback is shown below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 30
     25 # Load the fine-tuned model and convert it to PyTorch
     26 pt_model = ViTForImageClassification.from_pretrained(
     27     pretrained_model_name_or_path="./test-vit-finetuned-patch16-224-recaptchav2",
     28     from_flax=True,
     29 )
---> 30 pt_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2") # RuntimeError

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/transformers/modeling_utils.py:3486, in PreTrainedModel.save_pretrained(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)
   3483         error_names.append(set(shared_names))
   3485     if len(error_names) > 0:
-> 3486         raise RuntimeError(
   3487             f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
   3488         )
   3490 # Shard the model if it is too big.
   3491 if not _hf_peft_config_loaded:

RuntimeError: The weights trying to be saved contained shared tensors [{'vit.encoder.layer.6.output.dense.bias', 'vit.encoder.layer.10.layernorm_after.weight', 'vit.encoder.layer.7.attention.attention.query.bias', 'vit.encoder.layer.7.attention.attention.key.bias', 'vit.encoder.layer.7.layernorm_after.bias', 'vit.encoder.layer.0.output.dense.bias', 'vit.encoder.layer.8.attention.attention.query.bias', 'vit.encoder.layer.1.attention.output.dense.bias', 'vit.encoder.layer.3.output.dense.bias', 'vit.encoder.layer.10.layernorm_before.weight', 'vit.encoder.layer.3.layernorm_after.bias', 'vit.encoder.layer.0.layernorm_after.bias', 'vit.encoder.layer.2.attention.attention.key.bias', 'vit.encoder.layer.0.layernorm_before.weight', 'vit.encoder.layer.3.attention.attention.value.bias', 'vit.encoder.layer.11.attention.attention.key.bias', 'vit.encoder.layer.5.output.dense.bias', 'vit.layernorm.weight', 'vit.encoder.layer.8.output.dense.bias', 'vit.encoder.layer.7.attention.output.dense.bias', 'vit.encoder.layer.6.attention.attention.value.bias', 'vit.encoder.layer.5.layernorm_before.bias', 'vit.encoder.layer.6.layernorm_before.bias', 'vit.encoder.layer.4.layernorm_after.bias', 'vit.encoder.layer.11.layernorm_after.weight', 'vit.encoder.layer.4.attention.attention.query.bias', 'vit.encoder.layer.9.attention.attention.query.bias', 'vit.encoder.layer.5.attention.output.dense.bias', 'vit.encoder.layer.3.attention.attention.key.bias', 'vit.encoder.layer.8.attention.attention.key.bias', 'vit.encoder.layer.8.attention.output.dense.bias', 'vit.encoder.layer.11.layernorm_after.bias', 'vit.encoder.layer.11.layernorm_before.weight', 'vit.encoder.layer.7.output.dense.bias', 'vit.encoder.layer.3.layernorm_before.weight', 'vit.encoder.layer.4.layernorm_before.bias', 'vit.encoder.layer.1.attention.attention.query.bias', 'vit.encoder.layer.2.layernorm_before.weight', 'vit.encoder.layer.0.attention.output.dense.bias', 'vit.encoder.layer.11.attention.output.dense.bias', 'vit.encoder.layer.0.layernorm_before.bias', 'vit.encoder.layer.10.attention.output.dense.bias', 'vit.encoder.layer.11.attention.attention.query.bias', 'vit.encoder.layer.8.attention.attention.value.bias', 'vit.encoder.layer.6.attention.output.dense.bias', 'vit.layernorm.bias', 'vit.encoder.layer.1.layernorm_after.weight', 'vit.encoder.layer.10.attention.attention.query.bias', 'vit.encoder.layer.11.layernorm_before.bias', 'vit.encoder.layer.1.layernorm_before.bias', 'vit.encoder.layer.4.output.dense.bias', 'vit.embeddings.cls_token', 'vit.encoder.layer.6.attention.attention.query.bias', 'vit.encoder.layer.3.attention.attention.query.bias', 'vit.encoder.layer.9.output.dense.bias', 'vit.encoder.layer.9.attention.attention.key.bias', 'vit.embeddings.patch_embeddings.projection.bias', 'vit.encoder.layer.5.layernorm_after.weight', 'vit.encoder.layer.11.output.dense.bias', 'vit.encoder.layer.8.layernorm_after.weight', 'vit.encoder.layer.5.attention.attention.value.bias', 'vit.encoder.layer.2.attention.output.dense.bias', 'vit.encoder.layer.6.layernorm_after.bias', 'vit.encoder.layer.5.layernorm_after.bias', 'vit.encoder.layer.1.layernorm_before.weight', 'vit.encoder.layer.2.layernorm_before.bias', 'vit.encoder.layer.10.attention.attention.value.bias', 'vit.encoder.layer.9.attention.output.dense.bias', 'vit.encoder.layer.9.attention.attention.value.bias', 'vit.encoder.layer.9.layernorm_before.weight', 'vit.encoder.layer.2.output.dense.bias', 'vit.encoder.layer.0.attention.attention.query.bias', 'vit.encoder.layer.9.layernorm_after.weight', 'vit.encoder.layer.10.output.dense.bias', 'vit.encoder.layer.7.layernorm_after.weight', 'vit.encoder.layer.8.layernorm_after.bias', 'vit.encoder.layer.5.attention.attention.key.bias', 'vit.encoder.layer.7.attention.attention.value.bias', 'vit.encoder.layer.4.layernorm_after.weight', 'vit.encoder.layer.3.attention.output.dense.bias', 'vit.encoder.layer.6.layernorm_after.weight', 'vit.encoder.layer.5.attention.attention.query.bias', 'vit.encoder.layer.5.layernorm_before.weight', 'vit.encoder.layer.0.layernorm_after.weight', 'vit.encoder.layer.10.layernorm_after.bias', 'vit.encoder.layer.3.layernorm_before.bias', 'vit.encoder.layer.4.layernorm_before.weight', 'vit.encoder.layer.11.attention.attention.value.bias', 'vit.encoder.layer.0.attention.attention.key.bias', 'vit.encoder.layer.4.attention.output.dense.bias', 'vit.encoder.layer.10.attention.attention.key.bias', 'vit.encoder.layer.4.attention.attention.key.bias', 'vit.encoder.layer.1.layernorm_after.bias', 'vit.encoder.layer.7.layernorm_before.bias', 'vit.encoder.layer.4.attention.attention.value.bias', 'vit.encoder.layer.8.layernorm_before.weight', 'vit.encoder.layer.1.attention.attention.key.bias', 'vit.encoder.layer.2.attention.attention.value.bias', 'vit.encoder.layer.0.attention.attention.value.bias', 'vit.encoder.layer.10.layernorm_before.bias', 'vit.encoder.layer.9.layernorm_after.bias', 'vit.encoder.layer.2.layernorm_after.bias', 'vit.encoder.layer.1.attention.attention.value.bias', 'vit.encoder.layer.6.layernorm_before.weight', 'vit.encoder.layer.9.layernorm_before.bias', 'vit.encoder.layer.1.output.dense.bias', 'vit.encoder.layer.3.layernorm_after.weight', 'vit.encoder.layer.8.layernorm_before.bias', 'vit.encoder.layer.2.attention.attention.query.bias', 'vit.encoder.layer.2.layernorm_after.weight', 'vit.encoder.layer.7.layernorm_before.weight', 'vit.encoder.layer.6.attention.attention.key.bias'}, {'vit.encoder.layer.5.attention.attention.value.weight', 'vit.encoder.layer.6.attention.attention.query.weight', 'vit.encoder.layer.10.attention.attention.key.weight', 'vit.encoder.layer.1.attention.attention.value.weight', 'vit.encoder.layer.4.attention.attention.query.weight', 'vit.encoder.layer.11.attention.attention.value.weight', 'vit.encoder.layer.11.attention.output.dense.weight', 'vit.encoder.layer.8.attention.attention.query.weight', 'vit.encoder.layer.4.attention.attention.key.weight', 'vit.encoder.layer.10.attention.attention.query.weight', 'vit.encoder.layer.3.attention.output.dense.weight', 'vit.encoder.layer.7.attention.attention.value.weight', 'vit.encoder.layer.6.attention.attention.key.weight', 'vit.encoder.layer.3.attention.attention.key.weight', 'vit.encoder.layer.0.attention.attention.key.weight', 'vit.encoder.layer.5.attention.attention.key.weight', 'vit.encoder.layer.9.attention.output.dense.weight', 'vit.encoder.layer.4.attention.attention.value.weight', 'vit.encoder.layer.11.attention.attention.key.weight', 'vit.encoder.layer.5.attention.output.dense.weight', 'vit.encoder.layer.7.attention.attention.query.weight', 'vit.encoder.layer.10.attention.output.dense.weight', 'vit.encoder.layer.3.attention.attention.query.weight', 'vit.encoder.layer.8.attention.attention.key.weight', 'vit.encoder.layer.4.attention.output.dense.weight', 'vit.encoder.layer.8.attention.output.dense.weight', 'vit.encoder.layer.5.attention.attention.query.weight', 'vit.encoder.layer.0.attention.attention.value.weight', 'vit.encoder.layer.6.attention.output.dense.weight', 'vit.encoder.layer.9.attention.attention.query.weight', 'vit.encoder.layer.0.attention.output.dense.weight', 'vit.encoder.layer.9.attention.attention.key.weight', 'vit.encoder.layer.7.attention.attention.key.weight', 'vit.encoder.layer.2.attention.output.dense.weight', 'vit.encoder.layer.1.attention.output.dense.weight', 'vit.encoder.layer.6.attention.attention.value.weight', 'vit.encoder.layer.2.attention.attention.key.weight', 'vit.encoder.layer.2.attention.attention.value.weight', 'vit.encoder.layer.7.attention.output.dense.weight', 'vit.encoder.layer.3.attention.attention.value.weight', 'vit.encoder.layer.1.attention.attention.key.weight', 'vit.encoder.layer.8.attention.attention.value.weight', 'vit.encoder.layer.9.attention.attention.value.weight', 'vit.encoder.layer.10.attention.attention.value.weight', 'vit.encoder.layer.2.attention.attention.query.weight', 'vit.encoder.layer.11.attention.attention.query.weight', 'vit.encoder.layer.1.attention.attention.query.weight', 'vit.embeddings.patch_embeddings.projection.weight', 'vit.encoder.layer.0.attention.attention.query.weight'}, {'vit.encoder.layer.11.output.dense.weight', 'vit.encoder.layer.7.intermediate.dense.weight', 'vit.encoder.layer.7.output.dense.weight', 'vit.encoder.layer.11.intermediate.dense.weight', 'vit.encoder.layer.0.output.dense.weight', 'vit.encoder.layer.4.output.dense.weight', 'vit.encoder.layer.3.intermediate.dense.weight', 'vit.encoder.layer.0.intermediate.dense.weight', 'vit.encoder.layer.2.intermediate.dense.weight', 'vit.encoder.layer.10.intermediate.dense.weight', 'vit.encoder.layer.9.output.dense.weight', 'vit.encoder.layer.8.output.dense.weight', 'vit.encoder.layer.1.output.dense.weight', 'vit.encoder.layer.4.intermediate.dense.weight', 'vit.encoder.layer.9.intermediate.dense.weight', 'vit.encoder.layer.2.output.dense.weight', 'vit.encoder.layer.1.intermediate.dense.weight', 'vit.encoder.layer.5.output.dense.weight', 'vit.encoder.layer.6.output.dense.weight', 'vit.encoder.layer.5.intermediate.dense.weight', 'vit.encoder.layer.3.output.dense.weight', 'vit.encoder.layer.10.output.dense.weight', 'vit.encoder.layer.8.intermediate.dense.weight', 'vit.encoder.layer.6.intermediate.dense.weight'}, {'vit.encoder.layer.10.intermediate.dense.bias', 'vit.encoder.layer.3.intermediate.dense.bias', 'vit.encoder.layer.11.intermediate.dense.bias', 'vit.encoder.layer.0.intermediate.dense.bias', 'vit.encoder.layer.2.intermediate.dense.bias', 'vit.encoder.layer.8.intermediate.dense.bias', 'vit.encoder.layer.1.intermediate.dense.bias', 'vit.encoder.layer.6.intermediate.dense.bias', 'vit.encoder.layer.4.intermediate.dense.bias', 'vit.encoder.layer.9.intermediate.dense.bias', 'vit.encoder.layer.7.intermediate.dense.bias', 'vit.encoder.layer.5.intermediate.dense.bias'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.

Expected behavior

The converted PyTorch ViT model should be saved/loaded successfully without requiring manual intervention given a Flax ViT model.

@xmarva
Copy link
Contributor

xmarva commented May 9, 2025

@nobodyPerfecZ Hey, just wondering, does using the safe_serialization=False parameter in pt_model.save_pretrained resolve the issue in your configuration? I reproduce the error without this parameter, but with it, the model successfully converts and saves.

@nobodyPerfecZ
Copy link
Author

nobodyPerfecZ commented May 9, 2025

I tried using the safe_serialization=False parameter in pt_model.safe_pretrained. The model is saved. However when i reload the model without setting from_flax=True, I encounter another error.

Code:

from transformers import FlaxViTForImageClassification, ViTForImageClassification

# Simulate fine-tuning by loading the base model and saving it
flax_model = FlaxViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="google/vit-base-patch16-224",
    num_labels=5,
    id2label={
        0: "bicycle",
        1: "bus",
        2: "car",
        3: "crosswalk",
        4: "hydrant",
    },
    label2id={
        "bicycle": 0,
        "bus": 1,
        "car": 2,
        "crosswalk": 3,
        "hydrant": 4,
    },
    ignore_mismatched_sizes=True,
)
flax_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2")

# Load the fine-tuned model and convert it to PyTorch
pt_model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="./test-vit-finetuned-patch16-224-recaptchav2",
    from_flax=True,
)
pt_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2", safe_serialization=False)

# Loading the model again
pt_model = ViTForImageClassification.from_pretrained(
    pretrained_model_name_or_path="./test-vit-finetuned-patch16-224-recaptchav2",
) # NotImplementedError

Traceback:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 33
     30 pt_model.save_pretrained("./test-vit-finetuned-patch16-224-recaptchav2", safe_serialization=False)
     32 # Loading the model
---> 33 pt_model = ViTForImageClassification.from_pretrained(
     34     pretrained_model_name_or_path="./test-vit-finetuned-patch16-224-recaptchav2",
     35 )

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/transformers/modeling_utils.py:279, in restore_default_torch_dtype.<locals>._wrapper(*args, **kwargs)
    277 old_dtype = torch.get_default_dtype()
    278 try:
--> 279     return func(*args, **kwargs)
    280 finally:
    281     torch.set_default_dtype(old_dtype)

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/transformers/modeling_utils.py:4399, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   4389     if dtype_orig is not None:
   4390         torch.set_default_dtype(dtype_orig)
   4392     (
   4393         model,
   4394         missing_keys,
   4395         unexpected_keys,
   4396         mismatched_keys,
   4397         offload_index,
   4398         error_msgs,
-> 4399     ) = cls._load_pretrained_model(
   4400         model,
   4401         state_dict,
   4402         checkpoint_files,
   4403         pretrained_model_name_or_path,
   4404         ignore_mismatched_sizes=ignore_mismatched_sizes,
   4405         sharded_metadata=sharded_metadata,
   4406         device_map=device_map,
   4407         disk_offload_folder=offload_folder,
   4408         offload_state_dict=offload_state_dict,
   4409         dtype=torch_dtype,
   4410         hf_quantizer=hf_quantizer,
   4411         keep_in_fp32_regex=keep_in_fp32_regex,
   4412         device_mesh=device_mesh,
   4413         key_mapping=key_mapping,
   4414         weights_only=weights_only,
   4415     )
   4417 # make sure token embedding weights are still tied if needed
   4418 model.tie_weights()

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/transformers/modeling_utils.py:4833, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
   4831 # Skip it with fsdp on ranks other than 0
   4832 elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
-> 4833     disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
   4834         model_to_load,
   4835         state_dict,
   4836         shard_file,
   4837         expected_keys,
   4838         reverse_key_renaming_mapping,
   4839         device_map=device_map,
   4840         disk_offload_folder=disk_offload_folder,
   4841         disk_offload_index=disk_offload_index,
   4842         cpu_offload_folder=cpu_offload_folder,
   4843         cpu_offload_index=cpu_offload_index,
   4844         hf_quantizer=hf_quantizer,
   4845         is_safetensors=is_offloaded_safetensors,
   4846         keep_in_fp32_regex=keep_in_fp32_regex,
   4847         unexpected_keys=unexpected_keys,
   4848         device_mesh=device_mesh,
   4849     )
   4851 # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
   4852 del state_dict

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/recaptchav2-solver/lib/python3.10/site-packages/transformers/modeling_utils.py:765, in _load_state_dict_into_meta_model(model, state_dict, shard_file, expected_keys, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, cpu_offload_folder, cpu_offload_index, hf_quantizer, is_safetensors, keep_in_fp32_regex, unexpected_keys, device_mesh)
    763     param = file_pointer.get_slice(serialized_param_name)
    764 else:
--> 765     param = empty_param.to(tensor_device)  # It is actually not empty!
    767 to_contiguous, casting_dtype = _infer_parameter_dtype(
    768     model,
    769     param_name,
   (...)
    772     hf_quantizer,
    773 )
    775 if device_mesh is not None:  # In this case, the param is already on the correct device!

NotImplementedError: Cannot copy out of meta tensor; no data!

@MutugiD
Copy link

MutugiD commented May 12, 2025

My workaround would be:
Patch modeling_utils.py so that from_pretrained(..., from_flax=True) defaults to low_cpu_mem_usage=False (forcing real tensors), and inject a preprocessing step in save_pretrained() that clones every parameter (p.detach().cpu().clone()) before serialization. This eliminates meta‐tensor placeholders and shared‐tensor errors, most likely.
safe_serialization=False seems temporal fix.

Opening a PR and testing if it will be solid enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants