Skip to content

[minor] Protect against broken/missing torchvision installations and do not hard-fail at timm/torchvision import (many text models don't need any timm/torchvision as hard dependencies) #38065

Closed
@vadimkantorov

Description

@vadimkantorov

Also, transformers import fails when no torchvision is installed at all. I think it should be no-error, especially if I'm working with text models only. timm/friends should not be imported at all...

System Info

Google Colab, uses pytorch 2.6.0 for now

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Go to colab which currently still has pytorch 2.6.0, then uninstall pytorch and install new pytorch (without torchvision):

!pip uninstall torch -y
!pip install torch --index-url https://download.pytorch.org/whl/cpu

`from transformers import AutoModel; AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)`

Currently (nothing to do with vision) fails when the setup has broken installation of torchvision (can happen if the torchvision isn't actually used and has a version not matching torch)

import transformers.models.timm_wrapper.configuration_timm_wrapper because of the following error (look up to see its traceback):
operator torchvision::nms does not exist
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.11/dist-packages/transformers/utils/import_utils.py](https://localhost:8080/#) in _get_module(self, module_name)
   1966         try:
-> 1967             return importlib.import_module("." + module_name, self.__name__)
   1968         except Exception as e:

24 frames
[/usr/lib/python3.11/importlib/__init__.py](https://localhost:8080/#) in import_module(name, package)
    125             level += 1
--> 126     return _bootstrap._gcd_import(name[level:], package, level)
    127 

/usr/lib/python3.11/importlib/_bootstrap.py in _gcd_import(name, package, level)

/usr/lib/python3.11/importlib/_bootstrap.py in _find_and_load(name, import_)

/usr/lib/python3.11/importlib/_bootstrap.py in _find_and_load_unlocked(name, import_)

/usr/lib/python3.11/importlib/_bootstrap.py in _load_unlocked(spec)

/usr/lib/python3.11/importlib/_bootstrap_external.py in exec_module(self, module)

/usr/lib/python3.11/importlib/_bootstrap.py in _call_with_frames_removed(f, *args, **kwds)

[/usr/local/lib/python3.11/dist-packages/transformers/models/timm_wrapper/configuration_timm_wrapper.py](https://localhost:8080/#) in <module>
     24 if is_timm_available():
---> 25     from timm.data import ImageNetInfo, infer_imagenet_subset
     26 

[/usr/local/lib/python3.11/dist-packages/timm/__init__.py](https://localhost:8080/#) in <module>
      1 from .version import __version__ as __version__
----> 2 from .layers import (
      3     is_scriptable as is_scriptable,

[/usr/local/lib/python3.11/dist-packages/timm/layers/__init__.py](https://localhost:8080/#) in <module>
      7 from .blur_pool import BlurPool2d, create_aa
----> 8 from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
      9 from .cond_conv2d import CondConv2d, get_condconv_initializer

[/usr/local/lib/python3.11/dist-packages/timm/layers/classifier.py](https://localhost:8080/#) in <module>
     14 from .create_act import get_act_layer
---> 15 from .create_norm import get_norm_layer
     16 

[/usr/local/lib/python3.11/dist-packages/timm/layers/create_norm.py](https://localhost:8080/#) in <module>
     13 from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
---> 14 from torchvision.ops.misc import FrozenBatchNorm2d
     15 

[/usr/local/lib/python3.11/dist-packages/torchvision/__init__.py](https://localhost:8080/#) in <module>
      9 from .extension import _HAS_OPS  # usort:skip
---> 10 from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
     11 

[/usr/local/lib/python3.11/dist-packages/torchvision/_meta_registrations.py](https://localhost:8080/#) in <module>
    162 
--> 163 @torch.library.register_fake("torchvision::nms")
    164 def meta_nms(dets, scores, iou_threshold):

[/usr/local/lib/python3.11/dist-packages/torch/library.py](https://localhost:8080/#) in register(func)
   1022             use_lib = lib
-> 1023         use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)
   1024         return func

[/usr/local/lib/python3.11/dist-packages/torch/library.py](https://localhost:8080/#) in _register_fake(self, op_name, fn, _stacklevel)
    213 
--> 214         handle = entry.fake_impl.register(func_to_register, source)
    215         self._registration_handles.append(handle)

[/usr/local/lib/python3.11/dist-packages/torch/_library/fake_impl.py](https://localhost:8080/#) in register(self, func, source)
     30             )
---> 31         if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
     32             raise RuntimeError(

RuntimeError: operator torchvision::nms does not exist

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-6-8843fc8bd8f0>](https://localhost:8080/#) in <cell line: 0>()
      4 with torch.device('meta'):
      5     #from torch.nn.attention.flex_attention import BlockMask, flex_attention
----> 6     AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)

[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    546 
    547         has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
--> 548         has_local_code = type(config) in cls._model_mapping.keys()
    549         trust_remote_code = resolve_trust_remote_code(
    550             trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code

[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in keys(self)
    785 
    786     def keys(self):
--> 787         mapping_keys = [
    788             self._load_attr_from_module(key, name)
    789             for key, name in self._config_mapping.items()

[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in <listcomp>(.0)
    786     def keys(self):
    787         mapping_keys = [
--> 788             self._load_attr_from_module(key, name)
    789             for key, name in self._config_mapping.items()
    790             if key in self._model_mapping.keys()

[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in _load_attr_from_module(self, model_type, attr)
    782         if module_name not in self._modules:
    783             self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
--> 784         return getattribute_from_module(self._modules[module_name], attr)
    785 
    786     def keys(self):

[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in getattribute_from_module(module, attr)
    698     if isinstance(attr, tuple):
    699         return tuple(getattribute_from_module(module, a) for a in attr)
--> 700     if hasattr(module, attr):
    701         return getattr(module, attr)
    702     # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the

[/usr/local/lib/python3.11/dist-packages/transformers/utils/import_utils.py](https://localhost:8080/#) in __getattr__(self, name)
   1953             value = Placeholder
   1954         elif name in self._class_to_module.keys():
-> 1955             module = self._get_module(self._class_to_module[name])
   1956             value = getattr(module, name)
   1957         elif name in self._modules:

[/usr/local/lib/python3.11/dist-packages/transformers/utils/import_utils.py](https://localhost:8080/#) in _get_module(self, module_name)
   1967             return importlib.import_module("." + module_name, self.__name__)
   1968         except Exception as e:
-> 1969             raise RuntimeError(
   1970                 f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
   1971                 f" traceback):\n{e}"

RuntimeError: Failed to import transformers.models.timm_wrapper.configuration_timm_wrapper because of the following error (look up to see its traceback):
operator torchvision::nms does not exist

Expected behavior

no hard-fail at import time, maybe a hard-fail when these ops/modules from torchvision are actually needed at runtime

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions