Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ repos:
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$

- repo: https://github.com/netromdk/vermin
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
19 changes: 19 additions & 0 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,31 @@
import shutil
import subprocess
import sys
import platform
from pathlib import Path
from importlib.metadata import version as get_version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union


# Needs to stay consistent with .pre-commit-config.yaml config.
def min_python_version() -> Tuple[int]:
"""Minimum supported Python version."""
return (3, 10, 0)


def min_python_version_str() -> str:
"""String representing minimum supported Python version."""
return ".".join(map(str, min_python_version()))


if sys.version_info < min_python_version():
raise RuntimeError(
f"Transformer Engine requires Python {min_python_version_str()} or newer, "
f"but found Python {platform.python_version()}."
)


@functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
cuda_version,
get_frameworks,
remove_dups,
min_python_version_str,
)

frameworks = get_frameworks()
Expand Down Expand Up @@ -190,7 +191,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8",
python_requires=f">={min_python_version_str()}",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.utils import copy_common_headers, min_python_version_str
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements

Expand Down Expand Up @@ -100,6 +100,7 @@
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.utils import copy_common_headers, min_python_version_str
from build_tools.te_version import te_version
from build_tools.pytorch import (
setup_pytorch_extension,
Expand Down Expand Up @@ -152,6 +152,7 @@ def run(self):
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down
Loading