diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72ac56a656..7f77c2d7f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -97,4 +97,4 @@ repos: name: validate library versions language: script entry: utils/validate_versions.py - files: ".*/setup.py" + files: ".*/__init__.py" diff --git a/utils/validate_versions.py b/utils/validate_versions.py index c1e6d8acd3..77787d60e5 100755 --- a/utils/validate_versions.py +++ b/utils/validate_versions.py @@ -3,6 +3,7 @@ import os import sys from typing import Dict +import argparse VERSION_LINE_START = "__version__ = " @@ -37,7 +38,25 @@ def check_versions() -> bool: return True +def set_version(new_version: str) -> None: + new_contents = f'{VERSION_LINE_START}"{new_version}"\n' + for directory in DIRECTORIES: + path = os.path.join(directory, "__init__.py") + print(f"Setting {path} to version {new_version}") + with open(path, "w") as f: + f.write(new_contents) + + if __name__ == "__main__": - ok = check_versions() - return_code = 0 if ok else 1 - sys.exit(return_code) + parser = argparse.ArgumentParser() + parser.add_argument("--new-version", default=None) + # unused, but allows precommit to pass filenames + parser.add_argument("files", nargs="*") + args = parser.parse_args() + if args.new_version: + print(f"Updating to verison {args.new_version}") + set_version(args.new_version) + else: + ok = check_versions() + return_code = 0 if ok else 1 + sys.exit(return_code)