diff --git a/pyproject.toml b/pyproject.toml index fed528d4a7..945b0cba32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,12 @@ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + +[tool.mypy] +files = ["src/rez/"] +exclude = [ + '.*/rez/data/.*', + '.*/rez/vendor/.*', + '.*/rez/tests/.*', +] +disable_error_code = ["var-annotated", "import-not-found"] \ No newline at end of file diff --git a/src/rez/build_process.py b/src/rez/build_process.py index a7d3e280d0..bc515980cf 100644 --- a/src/rez/build_process.py +++ b/src/rez/build_process.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from rez.packages import iter_packages from rez.exceptions import BuildProcessError, BuildContextResolveError, \ @@ -18,6 +18,12 @@ import getpass import os.path import sys +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from rez.build_system import BuildSystem + from rez.packages import Package, Variant + from rez.release_vcs import ReleaseVCS + from rez.developer_package import DeveloperPackage debug_print = config.debug_printer("package_release") @@ -29,9 +35,16 @@ def get_build_process_types(): return plugin_manager.get_plugins('build_process') -def create_build_process(process_type, working_dir, build_system, package=None, - vcs=None, ensure_latest=True, skip_repo_errors=False, - ignore_existing_tag=False, verbose=False, quiet=False): +def create_build_process(process_type: str, + working_dir: str, + build_system: BuildSystem, + package=None, + vcs: ReleaseVCS | None = None, + ensure_latest: bool = True, + skip_repo_errors: bool = False, + ignore_existing_tag: bool = False, + verbose: bool = False, + quiet: bool = False) -> BuildProcess: """Create a :class:`BuildProcess` instance. .. warning:: @@ -77,7 +90,8 @@ class BuildProcess(object): def name(cls): raise NotImplementedError - def __init__(self, working_dir, build_system, package=None, vcs=None, + def __init__(self, working_dir: str, build_system: BuildSystem, package=None, + vcs: ReleaseVCS | None = None, ensure_latest=True, skip_repo_errors=False, ignore_existing_tag=False, verbose=False, quiet=False): """Create a BuildProcess. @@ -119,14 +133,15 @@ def __init__(self, working_dir, build_system, package=None, vcs=None, self.package.config.build_directory) @property - def package(self): + def package(self) -> DeveloperPackage: return self.build_system.package @property - def working_dir(self): + def working_dir(self) -> str: return self.build_system.working_dir - def build(self, install_path=None, clean=False, install=False, variants=None): + def build(self, install_path: str | None = None, clean: bool = False, + install=False, variants: list[int] | None = None) -> int: """Perform the build process. Iterates over the package's variants, resolves the environment for @@ -149,7 +164,8 @@ def build(self, install_path=None, clean=False, install=False, variants=None): """ raise NotImplementedError - def release(self, release_message=None, variants=None): + def release(self, release_message: str | None = None, + variants: list[int] | None = None) -> int: """Perform the release process. Iterates over the package's variants, building and installing each into @@ -167,7 +183,7 @@ def release(self, release_message=None, variants=None): """ raise NotImplementedError - def get_changelog(self): + def get_changelog(self) -> str | None: """Get the changelog since last package release. Returns: @@ -215,7 +231,7 @@ def visit_variants(self, func, variants=None, **kwargs): return num_visited, results - def get_package_install_path(self, path): + def get_package_install_path(self, path: str) -> str: """Return the installation path for a package (where its payload goes). Args: @@ -230,7 +246,8 @@ def get_package_install_path(self, path): package_version=self.package.version ) - def create_build_context(self, variant, build_type, build_path): + def create_build_context(self, variant: Variant, build_type: BuildType, + build_path: str) -> tuple[ResolvedContext, str]: """Create a context to build the variant within.""" request = variant.get_requires(build_requires=True, private_build_requires=True) @@ -274,7 +291,7 @@ def create_build_context(self, variant, build_type, build_path): raise BuildContextResolveError(context) return context, rxt_filepath - def pre_release(self): + def pre_release(self) -> None: release_settings = self.package.config.plugins.release_vcs # test that the release path exists @@ -322,7 +339,7 @@ def pre_release(self): else: break - def post_release(self, release_message=None): + def post_release(self, release_message=None) -> None: tag_name = self.get_current_tag_name() if self.vcs is None: @@ -332,7 +349,7 @@ def post_release(self, release_message=None): with self.repo_operation(): self.vcs.create_release_tag(tag_name=tag_name, message=release_message) - def get_current_tag_name(self): + def get_current_tag_name(self) -> str: release_settings = self.package.config.plugins.release_vcs try: tag_name = self.package.format(release_settings.tag_name) @@ -342,7 +359,7 @@ def get_current_tag_name(self): tag_name = "unversioned" return tag_name - def run_hooks(self, hook_event, **kwargs): + def run_hooks(self, hook_event, **kwargs) -> None: hook_names = self.package.config.release_hooks or [] hooks = create_release_hooks(hook_names, self.working_dir) @@ -362,7 +379,7 @@ def run_hooks(self, hook_event, **kwargs): % (hook_event.label, hook.name(), e.__class__.__name__, str(e))) - def get_previous_release(self): + def get_previous_release(self) -> Package | None: release_path = self.package.config.release_packages_path it = iter_packages(self.package.name, paths=[release_path]) packages = sorted(it, key=lambda x: x.version, reverse=True) @@ -372,7 +389,7 @@ def get_previous_release(self): return package return None - def get_changelog(self): + def get_changelog(self) -> str | None: previous_package = self.get_previous_release() if previous_package: previous_revision = previous_package.revision @@ -380,10 +397,11 @@ def get_changelog(self): previous_revision = None changelog = None - with self.repo_operation(): - changelog = self.vcs.get_changelog( - previous_revision, - max_revisions=config.max_package_changelog_revisions) + if self.vcs: + with self.repo_operation(): + changelog = self.vcs.get_changelog( + previous_revision, + max_revisions=config.max_package_changelog_revisions) return changelog diff --git a/src/rez/build_system.py b/src/rez/build_system.py index 9f27b8e0b4..80aaee8e83 100644 --- a/src/rez/build_system.py +++ b/src/rez/build_system.py @@ -1,14 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project +from __future__ import annotations - +import argparse import os.path +from typing import TYPE_CHECKING, TypedDict + from rez.build_process import BuildType from rez.exceptions import BuildSystemError from rez.packages import get_developer_package from rez.rex_bindings import VariantBinding +if TYPE_CHECKING: + from rez.developer_package import DeveloperPackage + from rez.resolved_context import ResolvedContext + from rez.packages import Package, Variant + + +class BuildResult(TypedDict): + success: bool + extra_files: list[str] + build_env_script: str + def get_buildsys_types(): """Returns the available build system implementations - cmake, make etc.""" @@ -16,7 +30,8 @@ def get_buildsys_types(): return plugin_manager.get_plugins('build_system') -def get_valid_build_systems(working_dir, package=None): +def get_valid_build_systems(working_dir: str, + package: Package | None = None) -> list[type[BuildSystem]]: """Returns the build system classes that could build the source in given dir. Args: @@ -67,9 +82,9 @@ def get_valid_build_systems(working_dir, package=None): return clss -def create_build_system(working_dir, buildsys_type=None, package=None, opts=None, +def create_build_system(working_dir: str, buildsys_type=None, package=None, opts=None, write_build_scripts=False, verbose=False, - build_args=[], child_build_args=[]): + build_args=[], child_build_args=[]) -> BuildSystem: """Return a new build system that can build the source in working_dir.""" from rez.plugin_managers import plugin_manager @@ -104,11 +119,12 @@ class BuildSystem(object): """A build system, such as cmake, make, Scons etc. """ @classmethod - def name(cls): + def name(cls) -> str: """Return the name of the build system, eg 'make'.""" raise NotImplementedError - def __init__(self, working_dir, opts=None, package=None, + def __init__(self, working_dir: str, opts=None, + package: DeveloperPackage | None = None, write_build_scripts=False, verbose=False, build_args=[], child_build_args=[]): """Create a build system instance. @@ -143,12 +159,12 @@ def __init__(self, working_dir, opts=None, package=None, self.opts = opts @classmethod - def is_valid_root(cls, path): + def is_valid_root(cls, path: str) -> bool: """Return True if this build system can build the source in path.""" raise NotImplementedError @classmethod - def child_build_system(cls): + def child_build_system(cls) -> str | None: """Returns the child build system. Some build systems, such as cmake, don't build the source directly. @@ -163,7 +179,7 @@ def child_build_system(cls): return None @classmethod - def bind_cli(cls, parser, group): + def bind_cli(cls, parser: argparse.ArgumentParser, group): """Expose parameters to an argparse.ArgumentParser that are specific to this build system. @@ -174,8 +190,13 @@ def bind_cli(cls, parser, group): """ pass - def build(self, context, variant, build_path, install_path, install=False, - build_type=BuildType.local): + def build(self, + context: ResolvedContext, + variant: Variant, + build_path: str, + install_path: str, + install: bool = False, + build_type=BuildType.local) -> BuildResult: """Implement this method to perform the actual build. Args: diff --git a/src/rez/cli/build.py b/src/rez/cli/build.py index 208bab9a53..6ddf71cfa2 100644 --- a/src/rez/cli/build.py +++ b/src/rez/cli/build.py @@ -5,17 +5,23 @@ ''' Build a package from source. ''' +from __future__ import annotations + import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rez.developer_package import DeveloperPackage # Cache the developer package loaded from cwd. This is so the package is only # loaded once, even though it's required once at arg parsing time (to determine # valid build system types), and once at command run time. # -_package = None +_package: DeveloperPackage | None = None -def get_current_developer_package(): +def get_current_developer_package() -> DeveloperPackage: from rez.packages import get_developer_package global _package diff --git a/src/rez/config.py b/src/rez/config.py index e8542d1ede..0f5f0956e3 100644 --- a/src/rez/config.py +++ b/src/rez/config.py @@ -23,6 +23,13 @@ import re import copy +from typing_extensions import Protocol + + +class Validatable(Protocol): + def validate(self, data): + pass + class _Deprecation(object): def __init__(self, removed_in, extra=None): @@ -54,7 +61,7 @@ class Setting(object): Note that lazy setting validation only happens on main configuration settings - plugin settings are validated on load only. """ - schema = Schema(object) + schema: Validatable = Schema(object) def __init__(self, config, key): self.config = config @@ -135,7 +142,7 @@ def _validate(self, data): class Str(Setting): - schema = Schema(str) + schema: Validatable = Schema(str) def _parse_env_var(self, value): return value @@ -153,7 +160,7 @@ class OptionalStr(Str): class StrList(Setting): - schema = Schema([str]) + schema: Validatable = Schema([str]) sep = ',' def _parse_env_var(self, value): @@ -184,8 +191,7 @@ def validate(self, data): class OptionalStrList(StrList): - schema = Or(And(None, Use(lambda x: [])), - [str]) + schema = Or(And(None, Use(lambda x: [])), [str]) class PathList(StrList): @@ -219,7 +225,7 @@ def _parse_env_var(self, value): class Bool(Setting): - schema = Schema(bool) + schema: Validatable = Schema(bool) true_words = frozenset(["1", "true", "t", "yes", "y", "on"]) false_words = frozenset(["0", "false", "f", "no", "n", "off"]) all_words = true_words | false_words @@ -255,7 +261,7 @@ def _parse_env_var(self, value): class Dict(Setting): - schema = Schema(dict) + schema: Validatable = Schema(dict) def _parse_env_var(self, value): items = value.split(",") diff --git a/src/rez/package_filter.py b/src/rez/package_filter.py index 70014e4009..c371cf2bfc 100644 --- a/src/rez/package_filter.py +++ b/src/rez/package_filter.py @@ -327,7 +327,7 @@ class Rule(object): """Base package filter rule""" #: Rule name - name = None + name: str def match(self, package): """Apply the rule to the package. diff --git a/src/rez/package_maker.py b/src/rez/package_maker.py index e547e0287f..80c20ffdba 100644 --- a/src/rez/package_maker.py +++ b/src/rez/package_maker.py @@ -92,7 +92,7 @@ class PackageMaker(AttrDictWrapper): """Utility class for creating packages.""" - def __init__(self, name, data=None, package_cls=None): + def __init__(self, name: str, data=None, package_cls=None): """Create a package maker. Args: @@ -106,7 +106,7 @@ def __init__(self, name, data=None, package_cls=None): self.installed_variants = [] self.skipped_variants = [] - def get_package(self): + def get_package(self) -> Package: """Create the analogous package. Returns: diff --git a/src/rez/package_order.py b/src/rez/package_order.py index e60662d77a..b93a402986 100644 --- a/src/rez/package_order.py +++ b/src/rez/package_order.py @@ -1,36 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from inspect import isclass from hashlib import sha1 -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, Self +from typing_extensions import Protocol from rez.config import config from rez.utils.data_utils import cached_class_property from rez.version import Version, VersionRange from rez.version._version import _Comparable, _ReversedComparable, _LowerBound, _UpperBound, _Bound -from rez.packages import iter_packages +from rez.packages import iter_packages, Package ALL_PACKAGES = "*" +class SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: ... + + class FallbackComparable(_Comparable): """First tries to compare objects using the main_comparable, but if that fails, compares using the fallback_comparable object. """ - def __init__(self, main_comparable, fallback_comparable): + def __init__(self, main_comparable: SupportsLessThan, + fallback_comparable: SupportsLessThan): self.main_comparable = main_comparable self.fallback_comparable = fallback_comparable - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, FallbackComparable): + return NotImplemented try: return self.main_comparable == other.main_comparable except Exception: return self.fallback_comparable == other.fallback_comparable - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, FallbackComparable): + return NotImplemented try: return self.main_comparable < other.main_comparable except Exception: @@ -44,9 +54,10 @@ class PackageOrder(object): """Package reorderer base class.""" #: Orderer name - name = None + name: str + _packages: list[str] - def __init__(self, packages: Optional[Iterable[str]] = None): + def __init__(self, packages: Iterable[str] | None = None): """ Args: packages: If not provided, PackageOrder applies to all packages. @@ -54,7 +65,7 @@ def __init__(self, packages: Optional[Iterable[str]] = None): self.packages = packages @property - def packages(self) -> List[str]: + def packages(self) -> list[str]: """Returns an iterable over the list of package family names that this order applies to @@ -64,7 +75,7 @@ def packages(self) -> List[str]: return self._packages @packages.setter - def packages(self, packages: Union[str, Iterable[str]]): + def packages(self, packages: str | Iterable[str] | None): if packages is None: # Apply to all packages self._packages = [ALL_PACKAGES] @@ -73,7 +84,8 @@ def packages(self, packages: Union[str, Iterable[str]]): else: self._packages = sorted(packages) - def reorder(self, iterable, key=None): + def reorder(self, iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None) -> list[Package] | None: """Put packages into some order for consumption. You can safely assume that the packages referred to by `iterable` are @@ -101,7 +113,9 @@ def reorder(self, iterable, key=None): reverse=True) @staticmethod - def _get_package_name_from_iterable(iterable, key=None): + def _get_package_name_from_iterable(iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None + ) -> str | None: """Utility method for getting a package from an iterable""" try: item = next(iter(iterable)) @@ -111,7 +125,7 @@ def _get_package_name_from_iterable(iterable, key=None): key = key or (lambda x: x) return key(item).name - def sort_key(self, package_name, version_like): + def sort_key(self, package_name: str, version_like) -> SupportsLessThan: """Returns a sort key usable for sorting packages within the same family Args: @@ -148,7 +162,7 @@ def sort_key(self, package_name, version_like): return 0 raise TypeError(version_like) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: """Returns a sort key usable for sorting these packages within the same family Args: @@ -170,10 +184,10 @@ def from_pod(cls, data): raise NotImplementedError @property - def sha1(self): + def sha1(self) -> str: return sha1(repr(self).encode('utf-8')).hexdigest() - def __str__(self): + def __str__(self) -> str: raise NotImplementedError def __eq__(self, other): @@ -182,7 +196,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - def __repr__(self): + def __repr__(self) -> str: return "%s(%s)" % (self.__class__.__name__, str(self)) @@ -195,12 +209,12 @@ class NullPackageOrder(PackageOrder): """ name = "no_order" - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: # python's sort will preserve the order of items that compare equal, so # to not change anything, we just return the same object for all... return 0 - def __str__(self): + def __str__(self) -> str: return "{}" def __eq__(self, other): @@ -233,7 +247,7 @@ def __init__(self, descending, packages=None): super().__init__(packages) self.descending = descending - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: # Note that the name "descending" can be slightly confusing - it # indicates that the final ordering this Order gives should be # version descending (ie, the default) - however, the sort_key itself @@ -246,7 +260,7 @@ def sort_key_implementation(self, package_name, version): else: return _ReversedComparable(version) - def __str__(self): + def __str__(self) -> str: return str(self.descending) def __eq__(self, other): @@ -283,7 +297,7 @@ class PerFamilyOrder(PackageOrder): """ name = "per_family" - def __init__(self, order_dict, default_order=None): + def __init__(self, order_dict: dict[str, PackageOrder], default_order=None): """Create a reorderer. Args: @@ -296,7 +310,8 @@ def __init__(self, order_dict, default_order=None): self.order_dict = order_dict.copy() self.default_order = default_order - def reorder(self, iterable, key=None): + def reorder(self, iterable: Iterable[Package], + key: Callable[[Any], Package] | None = None) -> list[Package] | None: package_name = self._get_package_name_from_iterable(iterable, key) if package_name is None: return None @@ -309,7 +324,7 @@ def reorder(self, iterable, key=None): return orderer.reorder(iterable, key) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: orderer = self.order_dict.get(package_name) if orderer is None: if self.default_order is None: @@ -322,7 +337,7 @@ def sort_key_implementation(self, package_name, version): return orderer.sort_key_implementation(package_name, version) - def __str__(self): + def __str__(self) -> str: items = sorted((x[0], str(x[1])) for x in self.order_dict.items()) return str((items, str(self.default_order))) @@ -402,7 +417,7 @@ class VersionSplitPackageOrder(PackageOrder): """ name = "version_split" - def __init__(self, first_version, packages=None): + def __init__(self, first_version: Version, packages=None): """Create a reorderer. Args: @@ -411,7 +426,7 @@ def __init__(self, first_version, packages=None): super().__init__(packages) self.first_version = first_version - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: priority_key = 1 if version <= self.first_version else 0 return priority_key, version @@ -490,7 +505,7 @@ class TimestampPackageOrder(PackageOrder): """ name = "soft_timestamp" - def __init__(self, timestamp, rank=0, packages=None): + def __init__(self, timestamp: int, rank: int = 0, packages=None): """Create a reorderer. Args: @@ -569,7 +584,7 @@ def _calc_sort_key(self, package_name, version): return is_before, _ReversedComparable(version) - def sort_key_implementation(self, package_name, version): + def sort_key_implementation(self, package_name: str, version: Version) -> SupportsLessThan: cache_key = (package_name, str(version)) result = self._cached_sort_key.get(cache_key) if result is None: @@ -578,7 +593,7 @@ def sort_key_implementation(self, package_name, version): return result - def __str__(self): + def __str__(self) -> str: return str((self.timestamp, self.rank)) def __eq__(self, other): @@ -635,12 +650,12 @@ def from_pod(cls, data): return flist @cached_class_property - def singleton(cls): + def singleton(cls) -> Self: """Filter list as configured by rezconfig.package_filter.""" return cls.from_pod(config.package_orderers) @staticmethod - def _to_orderer(orderer: Union[dict, PackageOrder]) -> PackageOrder: + def _to_orderer(orderer: dict | PackageOrder) -> PackageOrder: if isinstance(orderer, dict): orderer = from_pod(orderer) return orderer @@ -681,7 +696,7 @@ def insert(self, *args, **kwargs): self.dirty = True return super().insert(*args, **kwargs) - def get(self, key: str, default: Optional[PackageOrder] = None) -> PackageOrder: + def get(self, key: str, default: PackageOrder | None = None) -> PackageOrder | None: """ Get an orderer that sorts a package by name. """ @@ -698,7 +713,7 @@ def to_pod(orderer): return data -def from_pod(data): +def from_pod(data) -> PackageOrder: if isinstance(data, dict): cls_name = data["type"] data = data.copy() diff --git a/src/rez/package_repository.py b/src/rez/package_repository.py index 2c7be5ab21..df450c023a 100644 --- a/src/rez/package_repository.py +++ b/src/rez/package_repository.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from rez.utils.resources import ResourcePool, ResourceHandle from rez.utils.data_utils import cached_property @@ -11,6 +11,14 @@ import threading import os.path import time +from typing import Hashable, Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.package_resources import PackageFamilyResource, PackageRepositoryResource, PackageResource, VariantResource + from rez.packages import Package + from rez.utils.resources import Resource + from rez.version import Version + from rezplugins.package_repository.memory import MemoryPackageRepository def get_package_repository_types(): @@ -18,7 +26,7 @@ def get_package_repository_types(): return plugin_manager.get_plugins('package_repository') -def create_memory_package_repository(repository_data): +def create_memory_package_repository(repository_data: dict) -> PackageRepository: """Create a standalone in-memory package repository from the data given. See rezplugins/package_repository/memory.py for more details. @@ -29,7 +37,7 @@ def create_memory_package_repository(repository_data): Returns: `PackageRepository` object. """ - cls_ = plugin_manager.get_plugin_class("package_repository", "memory") + cls_: type[MemoryPackageRepository] = plugin_manager.get_plugin_class("package_repository", "memory") return cls_.create_repository(repository_data) @@ -69,11 +77,11 @@ class PackageRepository(object): remove = object() @classmethod - def name(cls): + def name(cls) -> str: """Return the name of the package repository type.""" raise NotImplementedError - def __init__(self, location, resource_pool): + def __init__(self, location: str, resource_pool: ResourcePool): """Create a package repository. Args: @@ -85,10 +93,10 @@ def __init__(self, location, resource_pool): self.location = location self.pool = resource_pool - def __str__(self): + def __str__(self) -> str: return "%s@%s" % (self.name(), self.location) - def register_resource(self, resource_class): + def register_resource(self, resource_class: type[Resource]) -> None: """Register a resource with the repository. Your derived repository class should call this method in its __init__ to @@ -96,12 +104,12 @@ def register_resource(self, resource_class): """ self.pool.register_resource(resource_class) - def clear_caches(self): + def clear_caches(self) -> None: """Clear any cached resources in the pool.""" self.pool.clear_caches() @cached_property - def uid(self): + def uid(self) -> tuple[str, str]: """Returns a unique identifier for this repository. This must be a persistent identifier, for example a filepath, or @@ -119,7 +127,7 @@ def __eq__(self, other): and other.uid == self.uid ) - def is_empty(self): + def is_empty(self) -> bool: """Determine if the repository contains any packages. Returns: @@ -131,7 +139,7 @@ def is_empty(self): return True - def get_package_family(self, name): + def get_package_family(self, name) -> PackageFamilyResource | None: """Get a package family. Args: @@ -142,7 +150,7 @@ def get_package_family(self, name): """ raise NotImplementedError - def iter_package_families(self): + def iter_package_families(self) -> Iterator[PackageFamilyResource]: """Iterate over the package families in the repository, in no particular order. @@ -151,7 +159,7 @@ def iter_package_families(self): """ raise NotImplementedError - def iter_packages(self, package_family_resource): + def iter_packages(self, package_family_resource) -> Iterator[PackageResource]: """Iterate over the packages within the given family, in no particular order. @@ -163,7 +171,7 @@ def iter_packages(self, package_family_resource): """ raise NotImplementedError - def iter_variants(self, package_resource): + def iter_variants(self, package_resource: PackageResource) -> Iterator[VariantResource]: """Iterate over the variants within the given package. Args: @@ -174,7 +182,7 @@ def iter_variants(self, package_resource): """ raise NotImplementedError - def get_package(self, name, version): + def get_package(self, name: str, version: Version) -> Package | None: """Get a package. Args: @@ -182,7 +190,7 @@ def get_package(self, name, version): version (`Version`): Package version. Returns: - `PackageResource` or None: Matching package, or None if not found. + `Package` or None: Matching package, or None if not found. """ fam = self.get_package_family(name) if fam is None: @@ -194,7 +202,7 @@ def get_package(self, name, version): return None - def get_package_from_uri(self, uri): + def get_package_from_uri(self, uri: str) -> PackageResource | None: """Get a package given its URI. Args: @@ -206,7 +214,7 @@ def get_package_from_uri(self, uri): """ return None - def get_variant_from_uri(self, uri): + def get_variant_from_uri(self, uri: str) -> VariantResource | None: """Get a variant given its URI. Args: @@ -218,7 +226,7 @@ def get_variant_from_uri(self, uri): """ return None - def ignore_package(self, pkg_name, pkg_version, allow_missing=False): + def ignore_package(self, pkg_name: str, pkg_version: Version, allow_missing=False) -> int: """Ignore the given package. Ignoring a package makes it invisible to further resolves. @@ -239,7 +247,7 @@ def ignore_package(self, pkg_name, pkg_version, allow_missing=False): """ raise NotImplementedError - def unignore_package(self, pkg_name, pkg_version): + def unignore_package(self, pkg_name: str, pkg_version: Version) -> int: """Unignore the given package. Args: @@ -254,7 +262,7 @@ def unignore_package(self, pkg_name, pkg_version): """ raise NotImplementedError - def remove_package(self, pkg_name, pkg_version): + def remove_package(self, pkg_name: str, pkg_version: Version) -> bool: """Remove a package. Note that this should work even if the specified package is currently @@ -269,7 +277,7 @@ def remove_package(self, pkg_name, pkg_version): """ raise NotImplementedError - def remove_package_family(self, pkg_name, force=False): + def remove_package_family(self, pkg_name: str, force: bool = False) -> bool: """Remove an empty package family. Args: @@ -281,7 +289,8 @@ def remove_package_family(self, pkg_name, force=False): """ raise NotImplementedError - def remove_ignored_since(self, days, dry_run=False, verbose=False): + def remove_ignored_since(self, days: int, dry_run: bool = False, + verbose: bool = False) -> int: """Remove packages ignored for >= specified number of days. Args: @@ -295,7 +304,7 @@ def remove_ignored_since(self, days, dry_run=False, verbose=False): """ raise NotImplementedError - def pre_variant_install(self, variant_resource): + def pre_variant_install(self, variant_resource: VariantResource): """Called before a variant is installed. If any directories are created on disk for the variant to install into, @@ -306,7 +315,7 @@ def pre_variant_install(self, variant_resource): """ pass - def on_variant_install_cancelled(self, variant_resource): + def on_variant_install_cancelled(self, variant_resource: VariantResource): """Called when a variant installation is cancelled. This is called after `pre_variant_install`, but before `install_variant`, @@ -321,7 +330,8 @@ def on_variant_install_cancelled(self, variant_resource): """ pass - def install_variant(self, variant_resource, dry_run=False, overrides=None): + def install_variant(self, variant_resource: VariantResource, + dry_run=False, overrides=None) -> VariantResource: """Install a variant into this repository. Use this function to install a variant from some other package repository @@ -343,7 +353,7 @@ def install_variant(self, variant_resource, dry_run=False, overrides=None): """ raise NotImplementedError - def get_equivalent_variant(self, variant_resource): + def get_equivalent_variant(self, variant_resource: VariantResource) -> VariantResource: """Find a variant in this repository that is equivalent to that given. A variant is equivalent to another if it belongs to a package of the @@ -362,7 +372,8 @@ def get_equivalent_variant(self, variant_resource): """ return self.install_variant(variant_resource, dry_run=True) - def get_parent_package_family(self, package_resource): + def get_parent_package_family(self, package_resource: VariantResource + ) -> PackageFamilyResource: """Get the parent package family of the given package. Args: @@ -373,7 +384,8 @@ def get_parent_package_family(self, package_resource): """ raise NotImplementedError - def get_parent_package(self, variant_resource): + def get_parent_package(self, variant_resource: PackageFamilyResource + ) -> PackageResource: """Get the parent package of the given variant. Args: @@ -384,7 +396,8 @@ def get_parent_package(self, variant_resource): """ raise NotImplementedError - def get_variant_state_handle(self, variant_resource): + def get_variant_state_handle(self, variant_resource: PackageResource + ) -> Hashable | None: """Get a value that indicates the state of the variant. This is used for resolve caching. For example, in the 'filesystem' @@ -400,7 +413,8 @@ def get_variant_state_handle(self, variant_resource): """ return None - def get_last_release_time(self, package_family_resource): + def get_last_release_time(self, package_family_resource: PackageFamilyResource + ) -> int: """Get the last time a package was added to the given family. This information is used to cache resolves via memcached. It can be left @@ -414,7 +428,7 @@ def get_last_release_time(self, package_family_resource): """ return 0 - def make_resource_handle(self, resource_key, **variables): + def make_resource_handle(self, resource_key, **variables) -> ResourceHandle: """Create a `ResourceHandle` Nearly all `ResourceHandle` creation should go through here, because it @@ -438,7 +452,7 @@ def make_resource_handle(self, resource_key, **variables): variables = resource_cls.normalize_variables(variables) return ResourceHandle(resource_key, variables) - def get_resource(self, resource_key, **variables): + def get_resource(self, resource_key, **variables) -> PackageRepositoryResource: """Get a resource. Attempts to get and return a cached version of the resource if @@ -454,7 +468,9 @@ def get_resource(self, resource_key, **variables): handle = self.make_resource_handle(resource_key, **variables) return self.get_resource_from_handle(handle, verify_repo=False) - def get_resource_from_handle(self, resource_handle, verify_repo=True): + def get_resource_from_handle(self, resource_handle: ResourceHandle, + verify_repo: bool = True + ) -> PackageRepositoryResource: """Get a resource. Args: @@ -484,7 +500,7 @@ def get_resource_from_handle(self, resource_handle, verify_repo=True): resource._repository = self return resource - def get_package_payload_path(self, package_name, package_version=None): + def get_package_payload_path(self, package_name: str, package_version=None) -> str: """Defines where a package's payload should be installed to. Args: @@ -496,7 +512,7 @@ def get_package_payload_path(self, package_name, package_version=None): """ raise NotImplementedError - def _uid(self): + def _uid(self) -> tuple[str, str]: """Unique identifier implementation. You may need to provide your own implementation. For example, consider @@ -517,7 +533,7 @@ class PackageRepositoryManager(object): Manages retrieval of resources (packages and variants) from `PackageRepository` instances, and caches these resources in a resource pool. """ - def __init__(self, resource_pool=None): + def __init__(self, resource_pool: ResourcePool | None = None): """Create a package repo manager. Args: @@ -532,9 +548,9 @@ def __init__(self, resource_pool=None): resource_pool = ResourcePool(cache_size=cache_size) self.pool = resource_pool - self.repositories = {} + self.repositories: dict[str, PackageRepository] = {} - def get_repository(self, path): + def get_repository(self, path: str) -> PackageRepository: """Get a package repository. Args: @@ -550,9 +566,10 @@ def get_repository(self, path): # normalise repo path parts = path.split('@', 1) if len(parts) == 1: - parts = ("filesystem", parts[0]) + repo_type, location = ("filesystem", parts[0]) + else: + repo_type, location = parts - repo_type, location = parts if repo_type == "filesystem": # choice of abspath here vs realpath is deliberate. Realpath gives # canonical path, which can be a problem if two studios are sharing @@ -573,7 +590,7 @@ def get_repository(self, path): return repository - def are_same(self, path_1, path_2): + def are_same(self, path_1, path_2) -> bool: """Test that `path_1` and `path_2` refer to the same repository. This is more reliable than testing that the strings match, since slightly @@ -590,8 +607,8 @@ def are_same(self, path_1, path_2): repo_2 = self.get_repository(path_2) return (repo_1.uid == repo_2.uid) - def get_resource(self, resource_key, repository_type, location, - **variables): + def get_resource(self, resource_key: str, repository_type: str, + location: str, **variables) -> PackageRepositoryResource: """Get a resource. Attempts to get and return a cached version of the resource if @@ -612,7 +629,8 @@ def get_resource(self, resource_key, repository_type, location, resource = repo.get_resource(**variables) return resource - def get_resource_from_handle(self, resource_handle): + def get_resource_from_handle(self, resource_handle: ResourceHandle + ) -> PackageRepositoryResource: """Get a resource. Args: @@ -632,12 +650,12 @@ def get_resource_from_handle(self, resource_handle): resource = repo.get_resource_from_handle(resource_handle) return resource - def clear_caches(self): + def clear_caches(self) -> None: """Clear all cached data.""" self.repositories.clear() self.pool.clear_caches() - def _get_repository(self, path, **repo_args): + def _get_repository(self, path: str, **repo_args) -> PackageRepository: repo_type, location = path.split('@', 1) cls = plugin_manager.get_plugin_class('package_repository', repo_type) repo = cls(location, self.pool, **repo_args) diff --git a/src/rez/package_resources.py b/src/rez/package_resources.py index f33c4270ad..8a0010f6c0 100644 --- a/src/rez/package_resources.py +++ b/src/rez/package_resources.py @@ -12,12 +12,17 @@ from rez.utils.formatting import PackageRequest from rez.exceptions import PackageMetadataError, ResourceError from rez.config import config, Config, create_config -from rez.version import Version +from rez.version import Requirement, Version from rez.vendor.schema.schema import Schema, SchemaError, Optional, Or, And, Use from textwrap import dedent import os.path +from abc import abstractmethod from hashlib import sha1 +from typing import Any, Iterable, Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.packages import Package, Variant # package attributes created at release time @@ -76,7 +81,7 @@ def late_bound(schema): # requirements of all package-related resources # -base_resource_schema_dict = { +base_resource_schema_dict: dict[Schema, Any] = { Required("name"): str } @@ -269,7 +274,7 @@ class PackageRepositoryResource(Resource): """ schema_error = PackageMetadataError #: Type of package repository associated with this resource type. - repository_type = None + repository_type: str @classmethod def normalize_variables(cls, variables): @@ -284,18 +289,18 @@ def __init__(self, variables=None): super(PackageRepositoryResource, self).__init__(variables) @cached_property - def uri(self): + def uri(self) -> str: return self._uri() @property - def location(self): + def location(self) -> str | None: return self.get("location") @property - def name(self): + def name(self) -> str | None: return self.get("name") - def _uri(self): + def _uri(self) -> str: """Return a URI. Implement this function to return a short, readable string that @@ -310,7 +315,9 @@ class PackageFamilyResource(PackageRepositoryResource): A repository implementation's package family resource(s) must derive from this class. It must satisfy the schema `package_family_schema`. """ - pass + + def iter_packages(self) -> Iterator[Package]: + raise NotImplementedError class PackageResource(PackageRepositoryResource): @@ -330,7 +337,7 @@ def normalize_variables(cls, variables): return super(PackageResource, cls).normalize_variables(variables) @cached_property - def version(self): + def version(self) -> Version: ver_str = self.get("version", "") return Version(ver_str) @@ -345,17 +352,23 @@ class VariantResource(PackageResource): this case it is the 'None' variant (the value of `index` is None). This provides some internal consistency and simplifies the implementation. """ + + @property + @abstractmethod + def parent(self) -> PackageRepositoryResource: + raise NotImplementedError + @property def index(self): return self.get("index", None) @cached_property - def root(self): + def root(self) -> str: """Return the 'root' path of the variant.""" return self._root() @cached_property - def subpath(self): + def subpath(self) -> str: """Return the variant's 'subpath' The subpath is the relative path the variant's payload should be stored @@ -383,23 +396,28 @@ class PackageResourceHelper(PackageResource): """ variant_key = None + @property + @abstractmethod + def parent(self) -> PackageRepositoryResource: + raise NotImplementedError + @cached_property - def commands(self): + def commands(self) -> SourceCode: return self._convert_to_rex(self._commands) @cached_property - def pre_commands(self): + def pre_commands(self) -> SourceCode: return self._convert_to_rex(self._pre_commands) @cached_property - def post_commands(self): + def post_commands(self) -> SourceCode: return self._convert_to_rex(self._post_commands) - def iter_variants(self): + def iter_variants(self) -> Iterator[Variant]: num_variants = len(self.variants or []) if num_variants == 0: - indexes = [None] + indexes: Iterable[int | None] = [None] else: indexes = range(num_variants) @@ -412,7 +430,7 @@ def iter_variants(self): index=index) yield variant - def _convert_to_rex(self, commands): + def _convert_to_rex(self, commands) -> SourceCode: if isinstance(commands, list): from rez.utils.backcompat import convert_old_commands @@ -453,12 +471,12 @@ class VariantResourceHelper(VariantResource, metaclass=_Metas): # forward Package attributes onto ourself keys = schema_keys(package_schema) - set(["variants"]) - def _uri(self): + def _uri(self) -> str: index = self.index idxstr = '' if index is None else str(index) return "%s[%s]" % (self.parent.uri, idxstr) - def _subpath(self, ignore_shortlinks=False): + def _subpath(self, ignore_shortlinks=False) -> str | None: if self.index is None: return None @@ -488,7 +506,7 @@ def _subpath(self, ignore_shortlinks=False): subpath = os.path.join(*dirs) return subpath - def _root(self, ignore_shortlinks=False): + def _root(self, ignore_shortlinks: bool = False) -> str | None: if self.base is None: return None elif self.index is None: @@ -499,7 +517,7 @@ def _root(self, ignore_shortlinks=False): return root @cached_property - def variant_requires(self): + def variant_requires(self) -> list[Requirement]: index = self.index if index is None: return [] diff --git a/src/rez/packages.py b/src/rez/packages.py index dc816303bf..4a77fc316b 100644 --- a/src/rez/packages.py +++ b/src/rez/packages.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from rez.package_repository import package_repository_manager from rez.package_resources import PackageFamilyResource, PackageResource, \ @@ -21,6 +21,13 @@ import os import sys +from typing import Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.developer_package import DeveloperPackage + from rez.version import Requirement + from rez.package_repository import PackageRepository + # ------------------------------------------------------------------------------ # package-related classes @@ -36,7 +43,7 @@ def validated_data(self): return data @property - def repository(self): + def repository(self) -> PackageRepository: """The package repository this resource comes from. Returns: @@ -58,7 +65,7 @@ def __init__(self, resource): _check_class(resource, PackageFamilyResource) super(PackageFamily, self).__init__(resource) - def iter_packages(self): + def iter_packages(self) -> Iterator[Package]: """Iterate over the packages within this family, in no particular order. Returns: @@ -102,7 +109,7 @@ def config(self): return self.resource.config or config @cached_property - def is_local(self): + def is_local(self) -> bool: """Returns True if the package is in the local package repository""" local_repo = package_repository_manager.get_repository( self.config.local_packages_path) @@ -223,7 +230,7 @@ def arbitrary_keys(self): return set(self.data.keys()) - set(self.keys) @cached_property - def qualified_name(self): + def qualified_name(self) -> str: """Get the qualified name of the package. Returns: @@ -232,7 +239,7 @@ def qualified_name(self): o = VersionedObject.construct(self.name, self.version) return str(o) - def as_exact_requirement(self): + def as_exact_requirement(self) -> str: """Get the package, as an exact requirement string. Returns: @@ -242,7 +249,7 @@ def as_exact_requirement(self): return o.as_exact_requirement() @cached_property - def parent(self): + def parent(self) -> PackageFamily | None: """Get the parent package family. Returns: @@ -252,11 +259,11 @@ def parent(self): return PackageFamily(family) if family else None @cached_property - def num_variants(self): + def num_variants(self) -> int: return len(self.data.get("variants", [])) @property - def is_relocatable(self): + def is_relocatable(self) -> bool: """True if the package and its payload is safe to copy. """ if self.relocatable is not None: @@ -276,7 +283,7 @@ def is_relocatable(self): return config.default_relocatable @property - def is_cachable(self): + def is_cachable(self) -> bool: """True if the package and its payload is safe to cache locally. """ if self.cachable is not None: @@ -301,7 +308,7 @@ def is_cachable(self): return self.is_relocatable - def iter_variants(self): + def iter_variants(self) -> Iterator[Variant]: """Iterate over the variants within this package, in index order. Returns: @@ -310,7 +317,7 @@ def iter_variants(self): for variant in self.repository.iter_variants(self.resource): yield Variant(variant, context=self.context, parent=self) - def get_variant(self, index=None): + def get_variant(self, index=None) -> Variant | None: """Get the variant with the associated index. Returns: @@ -319,6 +326,7 @@ def get_variant(self, index=None): for variant in self.iter_variants(): if variant.index == index: return variant + return None class Variant(PackageBaseResourceWrapper): @@ -353,12 +361,12 @@ def arbitrary_keys(self): return self.parent.arbitrary_keys() @cached_property - def qualified_package_name(self): + def qualified_package_name(self) -> str: o = VersionedObject.construct(self.name, self.version) return str(o) @cached_property - def qualified_name(self): + def qualified_name(self) -> str: """Get the qualified name of the variant. Returns: @@ -368,7 +376,7 @@ def qualified_name(self): return "%s[%s]" % (self.qualified_package_name, idxstr) @cached_property - def parent(self): + def parent(self) -> Package: """Get the parent package. Returns: @@ -386,7 +394,7 @@ def parent(self): return self._parent @property - def variant_requires(self): + def variant_requires(self) -> list[Requirement]: """Get the subset of requirements specific to this variant. Returns: @@ -398,7 +406,7 @@ def variant_requires(self): return self.parent.variants[self.index] or [] @property - def requires(self): + def requires(self) -> list[Requirement]: """Get variant requirements. This is a concatenation of the package requirements and those of this @@ -411,7 +419,8 @@ def requires(self): (self.parent.requires or []) + self.variant_requires ) - def get_requires(self, build_requires=False, private_build_requires=False): + def get_requires(self, build_requires=False, private_build_requires=False + ) -> list[Requirement]: """Get the requirements of the variant. Args: @@ -431,7 +440,7 @@ def get_requires(self, build_requires=False, private_build_requires=False): return requires - def install(self, path, dry_run=False, overrides=None): + def install(self, path, dry_run=False, overrides=None) -> Variant: """Install this variant into another package repository. If the package already exists, this variant will be correctly merged @@ -518,7 +527,7 @@ def _repository_uids(self): # resource acquisition functions # ------------------------------------------------------------------------------ -def iter_package_families(paths=None): +def iter_package_families(paths: list[str] | None = None): """Iterate over package families, in no particular order. Note that multiple package families with the same name can be returned. @@ -538,7 +547,8 @@ def iter_package_families(paths=None): yield PackageFamily(resource) -def iter_packages(name, range_=None, paths=None): +def iter_packages(name: str, range_: VersionRange | str | None = None, + paths: list[str] | None = None) -> Iterator[Package]: """Iterate over `Package` instances, in no particular order. Packages of the same name and version earlier in the search path take @@ -574,7 +584,7 @@ def iter_packages(name, range_=None, paths=None): yield Package(package_resource) -def get_package(name, version, paths=None): +def get_package(name: str, version: Version | str, paths: list[str] | None = None) -> Package | None: """Get a package by searching a list of repositories. Args: @@ -598,7 +608,7 @@ def get_package(name, version, paths=None): return None -def get_package_family_from_repository(name, path): +def get_package_family_from_repository(name: str, path: str): """Get a package family from a repository. Args: @@ -616,7 +626,7 @@ def get_package_family_from_repository(name, path): return PackageFamily(family_resource) -def get_package_from_repository(name, version, path): +def get_package_from_repository(name: str, version, path: str): """Get a package from a repository. Args: @@ -656,7 +666,7 @@ def get_package_from_handle(package_handle): return package -def get_package_from_string(txt, paths=None): +def get_package_from_string(txt: str, paths: list[str] | None = None): """Get a package given a string. Args: @@ -671,7 +681,7 @@ def get_package_from_string(txt, paths=None): return get_package(o.name, o.version, paths=paths) -def get_developer_package(path, format=None): +def get_developer_package(path: str, format: str | None = None) -> DeveloperPackage: """Create a developer package. Args: @@ -685,7 +695,7 @@ def get_developer_package(path, format=None): return DeveloperPackage.from_path(path, format=format) -def create_package(name, data, package_cls=None): +def create_package(name: str, data, package_cls=None): """Create a package given package data. Args: @@ -700,7 +710,7 @@ def create_package(name, data, package_cls=None): return maker.get_package() -def get_variant(variant_handle, context=None): +def get_variant(variant_handle: ResourceHandle | dict, context=None): """Create a variant given its handle (or serialized dict equivalent) Args: @@ -721,7 +731,7 @@ def get_variant(variant_handle, context=None): return variant -def get_package_from_uri(uri, paths=None): +def get_package_from_uri(uri: str, paths: list[str] | None = None) -> Package | None: """Get a package given its URI. Args: @@ -768,7 +778,7 @@ def _find_in_path(path): return _find_in_path(path) -def get_variant_from_uri(uri, paths=None): +def get_variant_from_uri(uri: str, paths: list[str] | None = None) -> Variant | None: """Get a variant given its URI. Args: @@ -822,7 +832,7 @@ def _find_in_path(path): return _find_in_path(path) -def get_last_release_time(name, paths=None): +def get_last_release_time(name: str, paths: list[str] | None = None) -> int: """Returns the most recent time this package was released. Note that releasing a variant into an already-released package is also @@ -848,7 +858,7 @@ def get_last_release_time(name, paths=None): return max_time -def get_completions(prefix, paths=None, family_only=False): +def get_completions(prefix: str, paths: list[str] | None = None, family_only=False): """Get autocompletion options given a prefix string. Example: @@ -904,7 +914,7 @@ def get_completions(prefix, paths=None, family_only=False): return words -def get_latest_package(name, range_=None, paths=None, error=False): +def get_latest_package(name: str, range_=None, paths: list[str] | None = None, error=False): """Get the latest package for a given package name. Args: @@ -928,7 +938,7 @@ def get_latest_package(name, range_=None, paths=None, error=False): return None -def get_latest_package_from_string(txt, paths=None, error=False): +def get_latest_package_from_string(txt: str, paths: list[str] | None = None, error=False): """Get the latest package found within the given request string. Args: @@ -949,7 +959,8 @@ def get_latest_package_from_string(txt, paths=None, error=False): error=error) -def _get_families(name, paths=None): +def _get_families(name: str, paths: list[str] | None = None + ) -> list[tuple[PackageRepository, PackageFamilyResource]]: entries = [] for path in (paths or config.packages_path): repo = package_repository_manager.get_repository(path) diff --git a/src/rez/plugin_managers.py b/src/rez/plugin_managers.py index 64b1d8f090..efabd40435 100644 --- a/src/rez/plugin_managers.py +++ b/src/rez/plugin_managers.py @@ -85,7 +85,7 @@ class RezPluginType(object): 'type_name' must correspond with one of the source directories found under the 'plugins' directory. """ - type_name = None + type_name: str def __init__(self): if self.type_name is None: diff --git a/src/rez/release_vcs.py b/src/rez/release_vcs.py index dc232c4028..d854e7729b 100644 --- a/src/rez/release_vcs.py +++ b/src/rez/release_vcs.py @@ -70,7 +70,7 @@ def create_release_vcs(path, vcs_name=None): class ReleaseVCS(object): """A version control system (VCS) used to release Rez packages. """ - def __init__(self, pkg_root, vcs_root=None): + def __init__(self, pkg_root: str, vcs_root=None): if vcs_root is None: result = self.find_vcs_root(pkg_root) if not result: @@ -92,7 +92,7 @@ def name(cls): raise NotImplementedError @classmethod - def find_executable(cls, name): + def find_executable(cls, name: str): exe = which(name) if not exe: raise ReleaseVCSError("Couldn't find executable '%s' for VCS '%s'" @@ -100,7 +100,7 @@ def find_executable(cls, name): return exe @classmethod - def is_valid_root(cls, path): + def is_valid_root(cls, path: str): """Return True if the given path is a valid root directory for this version control system. @@ -118,7 +118,7 @@ def search_parents_for_root(cls): raise NotImplementedError @classmethod - def find_vcs_root(cls, path): + def find_vcs_root(cls, path: str): """Try to find a version control root directory of this type for the given path. @@ -141,7 +141,7 @@ def validate_repostate(self): """Ensure that the VCS working copy is up-to-date.""" raise NotImplementedError - def get_current_revision(self): + def get_current_revision(self) -> object: """Get the current revision, this can be any type (str, dict etc) appropriate to your VCS implementation. @@ -152,7 +152,7 @@ def get_current_revision(self): """ raise NotImplementedError - def get_changelog(self, previous_revision=None, max_revisions=None): + def get_changelog(self, previous_revision=None, max_revisions=None) -> str: """Get the changelog text since the given revision. If previous_revision is not an ancestor (for example, the last release @@ -169,7 +169,7 @@ def get_changelog(self, previous_revision=None, max_revisions=None): """ raise NotImplementedError - def tag_exists(self, tag_name): + def tag_exists(self, tag_name: str) -> bool: """Test if a tag exists in the repo. Args: @@ -180,7 +180,7 @@ def tag_exists(self, tag_name): """ raise NotImplementedError - def create_release_tag(self, tag_name, message=None): + def create_release_tag(self, tag_name: str, message=None): """Create a tag in the repo. Create a tag in the repository representing the release of the @@ -193,7 +193,7 @@ def create_release_tag(self, tag_name, message=None): raise NotImplementedError @classmethod - def export(cls, revision, path): + def export(cls, revision, path: str): """Export the repository to the given path at the given revision. Note: diff --git a/src/rez/resolved_context.py b/src/rez/resolved_context.py index 9ef3ae09db..68360b9500 100644 --- a/src/rez/resolved_context.py +++ b/src/rez/resolved_context.py @@ -123,6 +123,17 @@ def get_lock_request(name, version, patch_lock, weak=True): return PackageRequest(s) +def _on_success(fn): + @wraps(fn) + def _check(self, *nargs, **kwargs): + if self.status_ == ResolverStatus.solved: + return fn(self, *nargs, **kwargs) + else: + raise ResolvedContextError( + "Cannot perform operation in a failed context") + return _check + + class ResolvedContext(object): """A class that resolves, stores and spawns Rez environments. @@ -1038,16 +1049,6 @@ def print_resolve_diff(self, other, heading=None): print('\n'.join(columnise(rows))) - def _on_success(fn): - @wraps(fn) - def _check(self, *nargs, **kwargs): - if self.status_ == ResolverStatus.solved: - return fn(self, *nargs, **kwargs) - else: - raise ResolvedContextError( - "Cannot perform operation in a failed context") - return _check - @_on_success def get_dependency_graph(self, as_dot=False): """Generate the dependency graph. diff --git a/src/rez/resolver.py b/src/rez/resolver.py index 921871ce61..1d1a4fe610 100644 --- a/src/rez/resolver.py +++ b/src/rez/resolver.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project +from __future__ import annotations - -from rez.solver import Solver, SolverStatus +from rez.solver import PackageVariant, Solver, SolverStatus from rez.package_repository import package_repository_manager from rez.packages import get_variant, get_last_release_time from rez.package_filter import PackageFilterList, TimestampRule @@ -109,7 +109,7 @@ def __init__(self, context, package_requests, package_paths, package_filter=None self._print = config.debug_printer("resolve_memcache") @pool_memcached_connections - def solve(self): + def solve(self) -> None: """Perform the solve. """ with log_duration(self._print, "memcache get (resolve) took %s"): @@ -137,7 +137,7 @@ def status(self): return self.status_ @property - def resolved_packages(self): + def resolved_packages(self) -> list[PackageVariant]: """Get the list of resolved packages. Returns: @@ -147,7 +147,7 @@ def resolved_packages(self): return self.resolved_packages_ @property - def resolved_ephemerals(self): + def resolved_ephemerals(self) -> list[Requirement]: """Get the list of resolved ewphemerals. Returns: @@ -391,7 +391,7 @@ def _memcache_key(self, timestamped=False): return str(tuple(t)) - def _solve(self): + def _solve(self) -> Solver: solver = Solver(package_requests=self.package_requests, package_paths=self.package_paths, context=self.context, diff --git a/src/rez/utils/__init__.py b/src/rez/utils/__init__.py index 58a0da7ecb..9d0a4ee802 100644 --- a/src/rez/utils/__init__.py +++ b/src/rez/utils/__init__.py @@ -4,18 +4,18 @@ import sys from contextlib import contextmanager - +from typing import NoReturn @contextmanager def with_noop(): yield -def reraise(exc, new_exc_cls): +def reraise(exc, new_exc_cls) -> NoReturn: traceback = sys.exc_info()[2] # TODO test this. - def reraise_(tp, value, tb=None): + def reraise_(tp, value, tb=None) -> NoReturn: try: if value is None: value = tp() diff --git a/src/rez/utils/data_utils.py b/src/rez/utils/data_utils.py index e5f4c009e9..5a42e32612 100644 --- a/src/rez/utils/data_utils.py +++ b/src/rez/utils/data_utils.py @@ -12,6 +12,7 @@ from rez.vendor.schema.schema import Schema, Optional from threading import Lock +from typing import TYPE_CHECKING class ModifyList(object): @@ -213,50 +214,53 @@ def get_dict_diff_str(d1, d2, title): return '\n'.join(lines) -class cached_property(object): - """Simple property caching descriptor. - - Example: - - >>> class Foo(object): - >>> @cached_property - >>> def bah(self): - >>> print('bah') - >>> return 1 - >>> - >>> f = Foo() - >>> f.bah - bah - 1 - >>> f.bah - 1 - """ - def __init__(self, func, name=None): - self.func = func - # Make sure that Sphinx autodoc can follow and get the docstring from our wrapped function. - functools.update_wrapper(self, func) - self.name = name or func.__name__ - - def __get__(self, instance, owner=None): - if instance is None: - return self - - result = self.func(instance) - try: - setattr(instance, self.name, result) - except AttributeError: - raise AttributeError("can't set attribute %r on %r" - % (self.name, instance)) - return result - - # This is to silence Sphinx that complains that cached_property is not a callable. - def __call__(self): - raise RuntimeError("@cached_property should not be called.") - - @classmethod - def uncache(cls, instance, name): - if hasattr(instance, name): - delattr(instance, name) +if TYPE_CHECKING: + cached_property = property +else: + class cached_property(object): + """Simple property caching descriptor. + + Example: + + >>> class Foo(object): + >>> @cached_property + >>> def bah(self): + >>> print('bah') + >>> return 1 + >>> + >>> f = Foo() + >>> f.bah + bah + 1 + >>> f.bah + 1 + """ + def __init__(self, func, name=None): + self.func = func + # Make sure that Sphinx autodoc can follow and get the docstring from our wrapped function. + functools.update_wrapper(self, func) + self.name = name or func.__name__ + + def __get__(self, instance, owner=None): + if instance is None: + return self + + result = self.func(instance) + try: + setattr(instance, self.name, result) + except AttributeError: + raise AttributeError("can't set attribute %r on %r" + % (self.name, instance)) + return result + + # This is to silence Sphinx that complains that cached_property is not a callable. + def __call__(self): + raise RuntimeError("@cached_property should not be called.") + + @classmethod + def uncache(cls, instance, name): + if hasattr(instance, name): + delattr(instance, name) class cached_class_property(object): diff --git a/src/rez/utils/platform_.py b/src/rez/utils/platform_.py index f0901a1caf..0ba51315e1 100644 --- a/src/rez/utils/platform_.py +++ b/src/rez/utils/platform_.py @@ -18,7 +18,7 @@ class Platform(object): """Abstraction of a platform. """ - name = None + name: str def __init__(self): pass diff --git a/src/rez/utils/resources.py b/src/rez/utils/resources.py index 372d1455f0..3f51d967be 100644 --- a/src/rez/utils/resources.py +++ b/src/rez/utils/resources.py @@ -33,6 +33,8 @@ See the 'pets' unit test in tests/test_resources.py for a complete example. """ +from __future__ import annotations + from functools import lru_cache from rez.utils.data_utils import cached_property, AttributeForwardMeta, \ @@ -41,6 +43,10 @@ from rez.exceptions import ResourceError from rez.utils.logging_ import print_debug +from typing import Self, TYPE_CHECKING +if TYPE_CHECKING: + from rez.vendor.schema.schema import Schema + class Resource(object, metaclass=LazyAttributeMeta): """Abstract base class for a data resource. @@ -69,11 +75,11 @@ class Resource(object, metaclass=LazyAttributeMeta): `validated_data` function, and test full validation using `validate_data`. """ #: Unique identifier of the resource type. - key = None + key: str = None #: Schema for the resource data. #: Must validate a dict. Can be None, in which case the resource does #: not load any data. - schema = None + schema: Schema | None = None #: The exception type to raise on key validation failure. schema_error = Exception @@ -87,7 +93,7 @@ def __init__(self, variables=None): self.variables = self.normalize_variables(variables or {}) @cached_property - def handle(self): + def handle(self) -> ResourceHandle: """Get the resource handle.""" return ResourceHandle(self.key, self.variables) @@ -105,10 +111,10 @@ def get(self, key, default=None): """Get the value of a resource variable.""" return self.variables.get(key, default) - def __str__(self): + def __str__(self) -> str: return "%s%r" % (self.key, self.variables) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r)" % (self.__class__.__name__, self.variables) def __hash__(self): @@ -139,7 +145,7 @@ class ResourceHandle(object): A handle uniquely identifies a resource. A handle can be stored and used with a `ResourcePool` to retrieve the same resource at a later date. """ - def __init__(self, key, variables=None): + def __init__(self, key: str, variables=None): self.key = key self.variables = variables or {} @@ -154,7 +160,7 @@ def to_dict(self): return dict(key=self.key, variables=self.variables) @classmethod - def from_dict(cls, d): + def from_dict(cls, d) -> Self: """Return a `ResourceHandle` instance from a serialized dict This should ONLY be used with dicts created with ResourceHandle.to_dict; @@ -169,10 +175,10 @@ def _hashable_repr(self): tuple(sorted(self.variables.items())) ) - def __str__(self): + def __str__(self) -> str: return str(self.to_dict()) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r, %r)" % (self.__class__.__name__, self.key, self.variables) def __eq__(self, other): @@ -198,7 +204,7 @@ def __init__(self, cache_size=None): cache = lru_cache(maxsize=cache_size) self.cached_get_resource = cache(self._get_resource) - def register_resource(self, resource_class): + def register_resource(self, resource_class: type[Resource]) -> None: resource_key = resource_class.key assert issubclass(resource_class, Resource) assert resource_key is not None @@ -216,13 +222,13 @@ def register_resource(self, resource_class): self.resource_classes[resource_key] = resource_class - def get_resource_from_handle(self, resource_handle): + def get_resource_from_handle(self, resource_handle: ResourceHandle): return self.cached_get_resource(resource_handle) - def clear_caches(self): + def clear_caches(self) -> None: self.cached_get_resource.cache_clear() - def get_resource_class(self, resource_key): + def get_resource_class(self, resource_key) -> type[Resource]: resource_class = self.resource_classes.get(resource_key) if resource_class is None: raise ResourceError("Error getting resource from pool: Unknown " diff --git a/src/rez/version/_requirement.py b/src/rez/version/_requirement.py index 9e72a1133c..cfe85c2a1f 100644 --- a/src/rez/version/_requirement.py +++ b/src/rez/version/_requirement.py @@ -65,7 +65,7 @@ def name(self): return self.name_ @property - def version(self): + def version(self) -> Version: """Version of the object. Returns: @@ -73,7 +73,7 @@ def version(self): """ return self.version_ - def as_exact_requirement(self): + def as_exact_requirement(self) -> str: """Get the versioned object, as an exact requirement string. Returns: diff --git a/src/rez/version/_version.py b/src/rez/version/_version.py index 42c2caa4a2..7c3da5b02a 100644 --- a/src/rez/version/_version.py +++ b/src/rez/version/_version.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright Contributors to the Rez Project - +from __future__ import annotations from rez.version._util import VersionError, ParseException, _Common, \ dedup @@ -8,7 +8,9 @@ import copy import string import re +from typing import cast, Callable, Generic, Iterable, TypeVar +T = TypeVar("T") re_token = re.compile(r"[a-zA-Z0-9_]+") @@ -272,16 +274,16 @@ class Version(_Comparable): The empty version ``''`` is the smallest possible version, and can be used to represent an unversioned resource. """ - inf = None + inf: Version - def __init__(self, ver_str='', make_token=AlphanumericVersionToken): + def __init__(self, ver_str: str | None = '', make_token=AlphanumericVersionToken): """ Args: ver_str (str): Version string. make_token (typing.Callable[[str], None]): Callable that creates a VersionToken subclass from a string. """ - self.tokens = [] + self.tokens: list[VersionToken] | None = [] self.seps = [] self._str = None self._hash = None @@ -304,7 +306,7 @@ def __init__(self, ver_str='', make_token=AlphanumericVersionToken): self.seps = seps[1:-1] - def copy(self): + def copy(self) -> Version: """ Returns a copy of the version. @@ -316,7 +318,7 @@ def copy(self): other.seps = self.seps[:] return other - def trim(self, len_): + def trim(self, len_: int) -> Version: """Return a copy of the version, possibly with less tokens. Args: @@ -331,7 +333,7 @@ def trim(self, len_): other.seps = self.seps[:len_ - 1] return other - def __next__(self): + def __next__(self) -> Version: """Return :meth:`next` version. Eg, ``next(1.2)`` is ``1.2_``""" if self.tokens: other = self.copy() @@ -341,11 +343,11 @@ def __next__(self): else: return Version.inf - def next(self): + def next(self) -> Version: return self.__next__() @property - def major(self): + def major(self) -> VersionToken: """Semantic versioning major version. Returns: @@ -354,7 +356,7 @@ def major(self): return self[0] @property - def minor(self): + def minor(self) -> VersionToken: """Semantic versioning minor version. Returns: @@ -363,7 +365,7 @@ def minor(self): return self[1] @property - def patch(self): + def patch(self) -> VersionToken: """Semantic versioning patch version. Returns: @@ -371,7 +373,7 @@ def patch(self): """ return self[2] - def as_tuple(self): + def as_tuple(self) -> tuple[str, ...]: """Convert to a tuple of strings. Example: @@ -384,23 +386,26 @@ def as_tuple(self): """ return tuple(map(str, self.tokens)) - def __len__(self): + def __len__(self) -> int: return len(self.tokens or []) - def __getitem__(self, index): + def __getitem__(self, index: int) -> VersionToken: try: return (self.tokens or [])[index] except IndexError: raise IndexError("version token index out of range") - def __bool__(self): + def __bool__(self) -> bool: """The empty version equates to False.""" return bool(self.tokens) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Version) and self.tokens == other.tokens - def __lt__(self, other): + def __lt__(self, other: object) -> bool: + if not isinstance(other, Version): + return NotImplemented + if self.tokens is None: return False elif other.tokens is None: @@ -414,7 +419,7 @@ def __hash__(self): else hash(tuple(map(str, self.tokens))) return self._hash - def __str__(self): + def __str__(self) -> str: if self._str is None: self._str = "[INF]" if self.tokens is None \ else ''.join(str(x) + y for x, y in zip(self.tokens, self.seps + [''])) @@ -427,9 +432,9 @@ def __str__(self): class _LowerBound(_Comparable): - min = None + min: _LowerBound - def __init__(self, version, inclusive): + def __init__(self, version: Version, inclusive: bool): self.version = version self.inclusive = inclusive @@ -452,7 +457,7 @@ def __lt__(self, other): def __hash__(self): return hash((self.version, self.inclusive)) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (version > self.version) \ or (self.inclusive and (version == self.version)) @@ -461,9 +466,9 @@ def contains_version(self, version): class _UpperBound(_Comparable): - inf = None + inf: _UpperBound - def __init__(self, version, inclusive): + def __init__(self, version: Version, inclusive: bool): self.version = version self.inclusive = inclusive if not version and not inclusive: @@ -485,7 +490,7 @@ def __lt__(self, other): def __hash__(self): return hash((self.version, self.inclusive)) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (version < self.version) \ or (self.inclusive and (version == self.version)) @@ -494,9 +499,11 @@ def contains_version(self, version): class _Bound(_Comparable): - any = None + any: _Bound - def __init__(self, lower=None, upper=None, invalid_bound_error=True): + def __init__(self, lower: _LowerBound | None = None, + upper: _UpperBound | None = None, + invalid_bound_error: bool = True): self.lower = lower or _LowerBound.min self.upper = upper or _UpperBound.inf @@ -509,7 +516,7 @@ def __init__(self, lower=None, upper=None, invalid_bound_error=True): ): raise VersionError("Invalid bound") - def __str__(self): + def __str__(self) -> str: if self.upper.version == Version.inf: return str(self.lower) elif self.lower.version == self.upper.version: @@ -534,26 +541,26 @@ def __lt__(self, other): def __hash__(self): return hash((self.lower, self.upper)) - def lower_bounded(self): + def lower_bounded(self) -> bool: return (self.lower != _LowerBound.min) - def upper_bounded(self): + def upper_bounded(self) -> bool: return (self.upper != _UpperBound.inf) - def contains_version(self, version): + def contains_version(self, version: Version) -> bool: return (self.version_containment(version) == 0) - def version_containment(self, version): + def version_containment(self, version: Version) -> int: if not self.lower.contains_version(version): return -1 if not self.upper.contains_version(version): return 1 return 0 - def contains_bound(self, bound): + def contains_bound(self, bound: _Bound) -> bool: return (self.lower <= bound.lower) and (self.upper >= bound.upper) - def intersects(self, other): + def intersects(self, other: _Bound) -> bool: lower = max(self.lower, other.lower) upper = min(self.upper, other.upper) @@ -561,7 +568,7 @@ def intersects(self, other): (lower.version == upper.version) and (lower.inclusive and upper.inclusive) ) - def intersection(self, other): + def intersection(self, other: _Bound) -> _Bound | None: lower = max(self.lower, other.lower) upper = min(self.upper, other.upper) @@ -576,6 +583,20 @@ def intersection(self, other): _Bound.any = _Bound() +def action(fn): + def fn_(self): + result = fn(self) + if self.debug: + label = fn.__name__.replace("_act_", "") + print("%-21s: %s" % (label, self._input_string)) + for key, value in self._groups.items(): + print(" %-17s= %s" % (key, value)) + print(" %-17s= %s" % ("bounds", self.bounds)) + return result + + return fn_ + + class _VersionRangeParser(object): debug = False # set to True to enable parser debugging @@ -659,7 +680,7 @@ class _VersionRangeParser(object): regex = re.compile(version_range_regex, re_flags) - def __init__(self, input_string, make_token, invalid_bound_error=True): + def __init__(self, input_string: str, make_token, invalid_bound_error=True): self.make_token = make_token self._groups = {} self._input_string = input_string @@ -721,18 +742,6 @@ def _is_upper_bound_exclusive(self, token): def _create_version_from_token(self, token): return Version(token, make_token=self.make_token) - def action(fn): - def fn_(self): - result = fn(self) - if self.debug: - label = fn.__name__.replace("_act_", "") - print("%-21s: %s" % (label, self._input_string)) - for key, value in self._groups.items(): - print(" %-17s= %s" % (key, value)) - print(" %-17s= %s" % ("bounds", self.bounds)) - return result - return fn_ - @action def _act_version(self): version = self._create_version_from_token(self._groups['version']) @@ -867,7 +876,8 @@ class VersionRange(_Comparable): valid version range syntax. For example, ``>`` is a valid range - read like ``>''``, it means ``any version greater than the empty version``. """ - def __init__(self, range_str='', make_token=AlphanumericVersionToken, + def __init__(self, range_str: str | None = '', + make_token=AlphanumericVersionToken, invalid_bound_error=True): """ Args: @@ -899,7 +909,7 @@ def __init__(self, range_str='', make_token=AlphanumericVersionToken, else: self.bounds.append(_Bound.any) - def is_any(self): + def is_any(self) -> bool: """ Returns: bool: True if this is the "any" range, ie the empty string range @@ -907,7 +917,7 @@ def is_any(self): """ return (len(self.bounds) == 1) and (self.bounds[0] == _Bound.any) - def lower_bounded(self): + def lower_bounded(self) -> bool: """ Returns: bool: True if the range has a lower bound (that is not the empty @@ -915,35 +925,35 @@ def lower_bounded(self): """ return self.bounds[0].lower_bounded() - def upper_bounded(self): + def upper_bounded(self) -> bool: """ Returns: bool: True if the range has an upper bound. """ return self.bounds[-1].upper_bounded() - def bounded(self): + def bounded(self) -> bool: """ Returns: bool: True if the range has a lower and upper bound. """ return (self.lower_bounded() and self.upper_bounded()) - def issuperset(self, range): + def issuperset(self, range) -> bool: """ Returns: bool: True if the VersionRange is contained within this range. """ return self._issuperset(self.bounds, range.bounds) - def issubset(self, range): + def issubset(self, range) -> bool: """ Returns: bool: True if we are contained within the version range. """ return range.issuperset(self) - def union(self, other): + def union(self, other: VersionRange | Iterable[VersionRange]) -> VersionRange: """OR together version ranges. Calculates the union of this range with one or more other ranges. @@ -965,7 +975,7 @@ def union(self, other): range.bounds = bounds return range - def intersection(self, other): + def intersection(self, other: VersionRange | Iterable[VersionRange]) -> VersionRange | None: """AND together version ranges. Calculates the intersection of this range with one or more other ranges. @@ -990,7 +1000,7 @@ def intersection(self, other): range.bounds = bounds return range - def inverse(self): + def inverse(self) -> VersionRange | None: """Calculate the inverse of the range. Returns: @@ -1093,7 +1103,7 @@ def from_version(cls, version, op=None): return range @classmethod - def from_versions(cls, versions): + def from_versions(cls, versions: Iterable[Version]) -> VersionRange: """Create a range from a list of versions. This method creates a range that contains only the given versions and @@ -1114,7 +1124,7 @@ def from_versions(cls, versions): range.bounds.append(bound) return range - def to_versions(self): + def to_versions(self) -> list[Version] | None: """Returns exact version ranges as Version objects, or None if there are no exact version ranges present. @@ -1129,7 +1139,7 @@ def to_versions(self): return versions or None - def contains_version(self, version): + def contains_version(self, version: Version): """Returns True if version is contained in this range. Returns: @@ -1149,7 +1159,9 @@ def contains_version(self, version): return False - def iter_intersect_test(self, iterable, key=None, descending=False): + def iter_intersect_test(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Performs containment tests on a sorted list of versions. This is more optimal than performing separate containment tests on a @@ -1170,7 +1182,9 @@ def iter_intersect_test(self, iterable, key=None, descending=False): """ return _ContainsVersionIterator(self, iterable, key, descending) - def iter_intersecting(self, iterable, key=None, descending=False): + def iter_intersecting(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Like :meth:iter_intersect_test`, but returns intersections only. Returns: @@ -1180,7 +1194,9 @@ def iter_intersecting(self, iterable, key=None, descending=False): self, iterable, key, descending, mode=_ContainsVersionIterator.MODE_INTERSECTING ) - def iter_non_intersecting(self, iterable, key=None, descending=False): + def iter_non_intersecting(self, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False) -> _ContainsVersionIterator[T]: """Like :meth:`iter_intersect_test`, but returns non-intersections only. Returns: @@ -1190,7 +1206,7 @@ def iter_non_intersecting(self, iterable, key=None, descending=False): self, iterable, key, descending, mode=_ContainsVersionIterator.MODE_NON_INTERSECTING ) - def span(self): + def span(self) -> VersionRange: """Return a contiguous range that is a superset of this range. Returns: @@ -1275,7 +1291,7 @@ def __lt__(self, other): def __hash__(self): return hash(tuple(self.bounds)) - def _contains_version(self, version): + def _contains_version(self, version: Version) -> tuple[int, bool]: vbound = _Bound(_LowerBound(version, True)) i = bisect_left(self.bounds, vbound) if i and self.bounds[i - 1].contains_version(version): @@ -1364,7 +1380,7 @@ def _issuperset(cls, bounds1, bounds2): return True @classmethod - def _intersects(cls, bounds1, bounds2): + def _intersects(cls, bounds1, bounds2) -> bool: # sort so bounds1 is the shorter list bounds1, bounds2 = sorted((bounds1, bounds2), key=lambda x: len(x)) @@ -1388,23 +1404,27 @@ def _intersects(cls, bounds1, bounds2): return False -class _ContainsVersionIterator(object): +class _ContainsVersionIterator(Generic[T]): MODE_INTERSECTING = 0 MODE_NON_INTERSECTING = 2 MODE_ALL = 3 - def __init__(self, range_, iterable, key=None, descending=False, mode=MODE_ALL): + def __init__(self, range_: VersionRange, iterable: Iterable[T], + key: Callable[[T], Version] | None = None, + descending: bool = False, mode=MODE_ALL): self.mode = mode self.range_ = range_ - self.index = None + self.index: int | None = None self.nbounds = len(self.range_.bounds) self._constant = True if range_.is_any() else None self.fn = self._descending if descending else self._ascending self.it = iter(iterable) if key is None: - key = lambda x: x # noqa: E731 + # FIXME: this case seems to assume that iterable is Iterable[Version] + key = cast(Callable[[T], Version], lambda x: x) # noqa: E731 self.keyfunc = key + self.next_fn: Callable[[], tuple[bool, T]] | Callable[[], T] if mode == self.MODE_ALL: self.next_fn = self._next elif mode == self.MODE_INTERSECTING: @@ -1412,16 +1432,16 @@ def __init__(self, range_, iterable, key=None, descending=False, mode=MODE_ALL): else: self.next_fn = self._next_non_intersecting - def __iter__(self): + def __iter__(self) -> _ContainsVersionIterator[T]: return self - def __next__(self): + def __next__(self) -> T | tuple[bool, T]: return self.next_fn() - def next(self): + def next(self) -> T | tuple[bool, T]: return self.next_fn() - def _next(self): + def _next(self) -> tuple[bool, T]: value = next(self.it) if self._constant is not None: return self._constant, value @@ -1430,7 +1450,7 @@ def _next(self): intersects = self.fn(version) return intersects, value - def _next_intersecting(self): + def _next_intersecting(self) -> T: while True: value = next(self.it) @@ -1444,7 +1464,7 @@ def _next_intersecting(self): if intersects: return value - def _next_non_intersecting(self): + def _next_non_intersecting(self) -> T: while True: value = next(self.it) @@ -1465,7 +1485,7 @@ def _bound(self): else: return None - def _ascending(self, version): + def _ascending(self, version: Version) -> bool: if self.index is None: self.index, contains = self.range_._contains_version(version) bound = self._bound @@ -1501,7 +1521,7 @@ def _ascending(self, version): elif j == -1: return False - def _descending(self, version): + def _descending(self, version: Version) -> bool: if self.index is None: self.index, contains = self.range_._contains_version(version) bound = self._bound diff --git a/src/rezplugins/package_repository/filesystem.py b/src/rezplugins/package_repository/filesystem.py index d0b793b9d5..f24f49681b 100644 --- a/src/rezplugins/package_repository/filesystem.py +++ b/src/rezplugins/package_repository/filesystem.py @@ -5,6 +5,8 @@ """ Filesystem-based package repository """ +from __future__ import annotations + from contextlib import contextmanager from functools import lru_cache import os.path @@ -36,6 +38,11 @@ from rez.vendor.schema.schema import Schema, Optional, And, Use, Or from rez.version import Version, VersionRange +from typing import Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.packages import Package, Variant + from rez.package_resources import PackageRepositoryResource, VariantResource debug_print = config.debug_printer("resources") @@ -88,14 +95,14 @@ class FileSystemPackageFamilyResource(PackageFamilyResource): key = "filesystem.family" repository_type = "filesystem" - def _uri(self): + def _uri(self) -> str: return self.path @cached_property - def path(self): + def path(self) -> str: return os.path.join(self.location, self.name) - def get_last_release_time(self): + def get_last_release_time(self) -> float: # this repository makes sure to update path mtime every time a # variant is added to the repository try: @@ -103,7 +110,7 @@ def get_last_release_time(self): except OSError: return 0 - def iter_packages(self): + def iter_packages(self) -> Iterator[Package]: # check for unversioned package if config.allow_unversioned_packages: filepath, _ = self._repository._get_file(self.path) @@ -137,11 +144,11 @@ class FileSystemPackageResource(PackageResourceHelper): repository_type = "filesystem" schema = package_pod_schema - def _uri(self): + def _uri(self) -> str: return self.filepath @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: family = self._repository.get_resource( FileSystemPackageFamilyResource.key, location=self.location, @@ -149,7 +156,7 @@ def parent(self): return family @cached_property - def state_handle(self): + def state_handle(self) -> float | None: if self.filepath: return os.path.getmtime(self.filepath) return None @@ -173,7 +180,7 @@ def path(self): return path @cached_property - def filepath(self): + def filepath(self) -> str: return self._filepath_and_format[0] @cached_property @@ -268,7 +275,7 @@ class FileSystemVariantResource(VariantResourceHelper): repository_type = "filesystem" @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: package = self._repository.get_resource( FileSystemPackageResource.key, location=self.location, @@ -350,12 +357,12 @@ class FileSystemCombinedPackageResource(PackageResourceHelper): repository_type = "filesystem" schema = package_pod_schema - def _uri(self): + def _uri(self) -> str: ver_str = self.get("version", "") return "%s<%s>" % (self.parent.filepath, ver_str) @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: family = self._repository.get_resource( FileSystemCombinedPackageFamilyResource.key, location=self.location, @@ -368,7 +375,7 @@ def base(self): return None # combined resource types do not have 'base' @cached_property - def state_handle(self): + def state_handle(self) -> float: return os.path.getmtime(self.parent.filepath) def iter_variants(self): @@ -412,7 +419,7 @@ class FileSystemCombinedVariantResource(VariantResourceHelper): repository_type = "filesystem" @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: package = self._repository.get_resource( FileSystemCombinedPackageResource.key, location=self.location, @@ -557,37 +564,40 @@ def _uid(self): t.append(int(st.st_ino)) return tuple(t) - def get_package_family(self, name): + def get_package_family(self, name: str) -> PackageFamilyResource: return self.get_family(name) @pool_memcached_connections - def iter_package_families(self): + def iter_package_families(self) -> Iterator[PackageFamilyResource]: for family in self.get_families(): yield family @pool_memcached_connections - def iter_packages(self, package_family_resource): + def iter_packages(self, package_family_resource: PackageFamilyResource + ) -> Iterator[Package]: for package in self.get_packages(package_family_resource): yield package - def iter_variants(self, package_resource): + def iter_variants(self, package_resource: PackageResourceHelper + ) -> Iterator[Variant]: for variant in self.get_variants(package_resource): yield variant - def get_parent_package_family(self, package_resource): + def get_parent_package_family(self, package_resource: PackageResourceHelper + ) -> PackageRepositoryResource: return package_resource.parent - def get_parent_package(self, variant_resource): + def get_parent_package(self, variant_resource: VariantResource) -> PackageRepositoryResource: return variant_resource.parent - def get_variant_state_handle(self, variant_resource): + def get_variant_state_handle(self, variant_resource: VariantResource): package_resource = variant_resource.parent return package_resource.state_handle - def get_last_release_time(self, package_family_resource): + def get_last_release_time(self, package_family_resource: PackageFamilyResource): return package_family_resource.get_last_release_time() - def get_package_from_uri(self, uri): + def get_package_from_uri(self, uri: str): """ Example URIs: - /svr/packages/mypkg/1.0.0/package.py @@ -621,7 +631,7 @@ def get_package_from_uri(self, uri): pkg_ver = Version(pkg_ver_str) return self.get_package(pkg_name, pkg_ver) - def get_variant_from_uri(self, uri): + def get_variant_from_uri(self, uri: str): """ Example URIs: - /svr/packages/mypkg/1.0.0/package.py[1] @@ -657,7 +667,7 @@ def get_variant_from_uri(self, uri): return None - def ignore_package(self, pkg_name, pkg_version, allow_missing=False): + def ignore_package(self, pkg_name: str, pkg_version, allow_missing=False): # find package, even if already ignored if not allow_missing: repo_copy = self._copy( @@ -688,7 +698,7 @@ def ignore_package(self, pkg_name, pkg_version, allow_missing=False): self._on_changed(pkg_name) return 1 - def unignore_package(self, pkg_name, pkg_version): + def unignore_package(self, pkg_name: str, pkg_version): # find and remove .ignore{ver} file if it exists ignore_file_was_removed = False filename = self.ignore_prefix + str(pkg_version) @@ -707,7 +717,7 @@ def unignore_package(self, pkg_name, pkg_version): else: return -1 - def remove_package(self, pkg_name, pkg_version): + def remove_package(self, pkg_name: str, pkg_version): # ignore it first, so a partially deleted pkg is not visible i = self.ignore_package(pkg_name, pkg_version) if i == -1: @@ -735,7 +745,7 @@ def remove_package(self, pkg_name, pkg_version): return True - def remove_package_family(self, pkg_name, force=False): + def remove_package_family(self, pkg_name: str, force=False): # get a non-cached copy and see if fam exists repo_copy = self._copy( disable_pkg_ignore=True, @@ -852,7 +862,7 @@ def file_lock_dir(self): return dirname - def pre_variant_install(self, variant_resource): + def pre_variant_install(self, variant_resource: VariantResourceHelper): if not variant_resource.version: return @@ -1003,7 +1013,7 @@ def _lock_package(self, package_name, package_version=None): except NotLocked: pass - def clear_caches(self): + def clear_caches(self) -> None: super(FileSystemPackageRepository, self).clear_caches() self.get_families.cache_clear() self.get_family.cache_clear() @@ -1018,7 +1028,7 @@ def clear_caches(self): # unfortunately we need to clear file cache across the board clear_file_caches() - def get_package_payload_path(self, package_name, package_version=None): + def get_package_payload_path(self, package_name: str, package_version=None) -> str: path = os.path.join(self.location, package_name) if package_version: @@ -1123,7 +1133,7 @@ def ignore_dir(name): def _is_valid_package_directory(self, path): return bool(self._get_file(path, "package")[0]) - def _get_families(self): + def _get_families(self) -> list[PackageFamilyResource]: families = [] for name, ext in self._get_family_dirs(): if ext is None: # is a directory @@ -1141,7 +1151,7 @@ def _get_families(self): return families - def _get_family(self, name): + def _get_family(self, name: str) -> PackageFamilyResource: is_valid_package_name(name, raise_error=True) if os.path.isdir(os.path.join(self.location, name)): # force case-sensitive match on pkg family dir, on case-insensitive platforms @@ -1171,13 +1181,13 @@ def _get_family(self, name): ) return None - def _get_packages(self, package_family_resource): + def _get_packages(self, package_family_resource: PackageFamilyResource) -> list[Package]: return [x for x in package_family_resource.iter_packages()] - def _get_variants(self, package_resource): + def _get_variants(self, package_resource: PackageResourceHelper) -> list[Variant]: return [x for x in package_resource.iter_variants()] - def _get_file(self, path, package_filename=None): + def _get_file(self, path, package_filename=None) -> tuple[str, FileFormat] | tuple[None, None]: if package_filename: package_filenames = [package_filename] else: @@ -1192,7 +1202,7 @@ def _get_file(self, path, package_filename=None): return filepath, format_ return None, None - def _create_family(self, name): + def _create_family(self, name: str): path = os.path.join(self.location, name) if not os.path.exists(path): os.makedirs(path) @@ -1200,7 +1210,7 @@ def _create_family(self, name): self._on_changed(name) return self.get_package_family(name) - def _create_variant(self, variant, dry_run=False, overrides=None): + def _create_variant(self, variant: Variant, dry_run=False, overrides=None) -> dict | None: # special case overrides variant_name = overrides.get("name") or variant.name variant_version = overrides.get("version") or variant.version @@ -1489,7 +1499,7 @@ def _remove_build_keys(obj): return new_variant - def _on_changed(self, pkg_name): + def _on_changed(self, pkg_name: str): """Called when a package is added/removed/changed. """ @@ -1504,7 +1514,7 @@ def _on_changed(self, pkg_name): # clear internal caches, otherwise change may not be visible self.clear_caches() - def _delete_stale_build_tagfiles(self, family_path): + def _delete_stale_build_tagfiles(self, family_path: str): now = time.time() for name in os.listdir(family_path): diff --git a/src/rezplugins/package_repository/memory.py b/src/rezplugins/package_repository/memory.py index 74d50731e9..e7a7a32544 100644 --- a/src/rezplugins/package_repository/memory.py +++ b/src/rezplugins/package_repository/memory.py @@ -5,6 +5,8 @@ """ In-memory package repository """ +from __future__ import annotations + from rez.package_repository import PackageRepository from rez.package_resources import PackageFamilyResource, VariantResourceHelper, \ PackageResourceHelper, package_pod_schema @@ -12,6 +14,12 @@ from rez.utils.resources import ResourcePool, cached_property from rez.version import VersionedObject +from typing import Iterator, TYPE_CHECKING + +if TYPE_CHECKING: + from rez.packages import Package + from rez.package_resources import PackageRepositoryResource + # This repository type is used when loading 'developer' packages (a package.yaml # or package.py in a developer's working directory), and when programmatically @@ -29,7 +37,7 @@ class MemoryPackageFamilyResource(PackageFamilyResource): def _uri(self): return "%s:%s" % (self.location, self.name) - def iter_packages(self): + def iter_packages(self) -> Iterator[Package]: data = self._repository.data.get(self.name, {}) # check for unversioned package @@ -62,11 +70,11 @@ def _uri(self): return "%s:%s" % (self.location, str(obj)) @property - def base(self): + def base(self) -> None: return None # memory types do not have 'base' @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: family = self._repository.get_resource( MemoryPackageFamilyResource.key, location=self.location, @@ -86,11 +94,11 @@ class MemoryVariantResource(VariantResourceHelper): key = "memory.variant" repository_type = "memory" - def _root(self): + def _root(self, ignore_shortlinks: bool = False) -> str | None: return None # memory types do not have 'root' @cached_property - def parent(self): + def parent(self) -> PackageRepositoryResource: package = self._repository.get_resource( MemoryPackageResource.key, location=self.location, @@ -135,7 +143,7 @@ def name(cls): return "memory" @classmethod - def create_repository(cls, repository_data): + def create_repository(cls, repository_data) -> MemoryPackageRepository: """Create a standalone, in-memory repository. Using this function bypasses the `package_repository_manager` singleton.