From ad5d5e8a2e31e61d087133ea5d85c0d341b65b59 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Wed, 16 Feb 2022 13:28:00 -0600 Subject: [PATCH 1/3] cherry-picked 0200b3af085570010f47797879eae050d946dcc3 commit --- tools/amd_build/build_amd.py | 4 + torch/utils/cpp_extension.py | 32 +++-- torch/utils/hipify/hipify_python.py | 206 ++++++++++++++-------------- 3 files changed, 132 insertions(+), 110 deletions(-) mode change 100644 => 100755 torch/utils/hipify/hipify_python.py diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 38698631c03c..785f63085c2e 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -89,6 +89,8 @@ "tools/autograd/templates/python_variable_methods.cpp", ] +includes = [os.path.join(proj_dir, include) for include in includes] + for new_dir in args.extra_include_dir: abs_new_dir = os.path.join(proj_dir, new_dir) if os.path.exists(abs_new_dir): @@ -112,6 +114,8 @@ "torch/include/*", ] +ignores = [os.path.join(proj_dir, ignore) for ignore in ignores] + # Check if the compiler is hip-clang. def is_hip_clang() -> bool: try: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 00e6d5d45e29..64a174c2b3b2 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -16,8 +16,9 @@ from .file_baton import FileBaton from ._cpp_extension_versioner import ExtensionVersioner from .hipify import hipify_python -from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner -from typing import List, Optional, Union +from .hipify.hipify_python import GeneratedFileCleaner, get_hip_file_path +from typing import List, Optional, Union, Tuple +from torch.torch_version import TorchVersion from setuptools.command.build_ext import build_ext from pkg_resources import packaging # type: ignore[attr-defined] @@ -959,16 +960,19 @@ def CUDAExtension(name, sources, *args, **kwargs): hipify_result = hipify_python.hipify( project_directory=build_dir, output_directory=build_dir, - includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'], + header_include_dirs=include_dirs, + includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only extra_files=[os.path.abspath(s) for s in sources], show_detailed=True, is_pytorch_extension=True, + hipify_extra_files_only=True, # don't hipify everything in includes path ) hipified_sources = set() for source in sources: s_abs = os.path.abspath(source) - hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs) + hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and + hipify_result[s_abs]["hipified_path"] is not None) else s_abs) sources = list(hipified_sources) @@ -1345,15 +1349,25 @@ def _jit_compile(name, try: with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx: if IS_HIP_EXTENSION and (with_cuda or with_cudnn): - hipify_python.hipify( + hipify_result = hipify_python.hipify( project_directory=build_directory, output_directory=build_directory, - includes=os.path.join(build_directory, '*'), + header_include_dirs=extra_include_paths, extra_files=[os.path.abspath(s) for s in sources], + ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers show_detailed=verbose, + show_progress=verbose, is_pytorch_extension=True, clean_ctx=clean_ctx ) + + hipified_sources = set() + for source in sources: + s_abs = os.path.abspath(source) + hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs) + + sources = list(hipified_sources) + _write_ninja_file_and_build_library( name=name, sources=sources, @@ -1849,10 +1863,6 @@ def _write_ninja_file_to_build_library(path, cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS cuda_flags += extra_cuda_cflags cuda_flags += _get_rocm_arch_flags(cuda_flags) - sources = [s if not _is_cuda_file(s) else - os.path.abspath(os.path.join( - path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True))) - for s in sources] elif with_cuda: cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() if IS_WINDOWS: @@ -1963,6 +1973,8 @@ def sanitize_flags(flags): nvcc = _join_cuda_home('bin', 'nvcc') config.append(f'nvcc = {nvcc}') + if IS_HIP_EXTENSION: + post_cflags = COMMON_HIP_FLAGS + post_cflags flags = [f'cflags = {" ".join(cflags)}'] flags.append(f'post_cflags = {" ".join(post_cflags)}') if with_cuda: diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py old mode 100644 new mode 100755 index ab541d07375e..25014dd7b4c7 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -117,15 +117,16 @@ def match_extensions(filename: str, extensions: Iterable) -> bool: """Helper method to see if filename ends with certain extension""" return any(filename.endswith(e) for e in extensions) +def _fnmatch(filepath, patterns): + return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) + def matched_files_iter( root_path: str, - includes: Iterable = ('*',), + includes: Iterable = (), ignores: Iterable = (), extensions: Iterable = (), out_of_place_only: bool = False, is_pytorch_extension: bool = False) -> Iterator[str]: - def _fnmatch(filepath, patterns): - return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) exact_matches = set(includes) @@ -145,7 +146,8 @@ def _fnmatch(filepath, patterns): if "third_party" in dirs: dirs.remove("third_party") for filename in filenames: - filepath = os.path.join(rel_dirpath, filename) + filepath = os.path.join(abs_dirpath, filename) + rel_filepath = os.path.join(rel_dirpath, filename) # We respect extensions, UNLESS you wrote the entire # filename verbatim, in which case we always accept it if ( @@ -154,9 +156,9 @@ def _fnmatch(filepath, patterns): and (match_extensions(filepath, extensions) or filepath in exact_matches) ): if not is_pytorch_extension: # for pytorch extensions, consider all files - if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath): + if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath): continue - if out_of_place_only and not is_out_of_place(filepath): + if out_of_place_only and not is_out_of_place(rel_filepath): continue yield filepath @@ -165,59 +167,23 @@ def preprocess_file_and_save_result( output_directory: str, filepath: str, all_files: Iterable, - includes: Iterable, + header_include_dirs: Iterable, stats: Dict[str, List], hip_clang_launch: bool, is_pytorch_extension: bool, clean_ctx: GeneratedFileCleaner, show_progress: bool) -> None: - result = preprocessor(output_directory, filepath, all_files, includes, stats, + result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) fin_path = os.path.abspath(os.path.join(output_directory, filepath)) # Show what happened - if show_progress: + if show_progress and "ignored" not in result["status"]: print( fin_path, "->", - result["hipified_path"], result["status"]) - - if result["hipified_path"] is not None: - HIPIFY_FINAL_RESULT[fin_path] = result - - -def preprocess( - output_directory: str, - all_files: Iterable, - includes: Iterable, - show_detailed: bool = False, - show_progress: bool = True, - hip_clang_launch: bool = False, - is_pytorch_extension: bool = False, - clean_ctx: Optional[GeneratedFileCleaner] = None) -> HipifyFinalResult: - """ - Call preprocessor on selected files. - - Arguments) - show_detailed - Show a detailed summary of the transpilation process. - """ - - if clean_ctx is None: - clean_ctx = GeneratedFileCleaner(keep_intermediates=True) - - # Preprocessing statistics. - stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []} - - for filepath in all_files: - preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats, - hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + result["hipified_path"], result["status"], flush=True) - print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) - - # Show detailed summary - if show_detailed: - compute_stats(stats) - - return HIPIFY_FINAL_RESULT + HIPIFY_FINAL_RESULT[fin_path] = result def compute_stats(stats): @@ -544,16 +510,17 @@ def replace_extern_shared(input_string): return output_string -def get_hip_file_path(filepath, is_pytorch_extension=False): +def get_hip_file_path(rel_filepath, is_pytorch_extension=False): """ Returns the new name of the hipified file """ # At the moment, some PyTorch source files are HIPified in place. The predicate # is_out_of_place tells us if this is the case or not. - if not is_pytorch_extension and not is_out_of_place(filepath): - return filepath + assert(not os.path.isabs(rel_filepath)) + if not is_pytorch_extension and not is_out_of_place(rel_filepath): + return rel_filepath - dirpath, filename = os.path.split(filepath) + dirpath, filename = os.path.split(rel_filepath) root, ext = os.path.splitext(filename) # Here's the plan: @@ -597,6 +564,7 @@ def get_hip_file_path(filepath, is_pytorch_extension=False): orig_dirpath = dirpath dirpath = dirpath.replace('cuda', 'hip') + dirpath = dirpath.replace('CUDA', 'HIP') dirpath = dirpath.replace('THC', 'THH') root = root.replace('cuda', 'hip') @@ -614,36 +582,39 @@ def get_hip_file_path(filepath, is_pytorch_extension=False): return os.path.join(dirpath, root + ext) -def is_out_of_place(filepath): - if filepath.startswith("torch/"): +def is_out_of_place(rel_filepath): + assert(not os.path.isabs(rel_filepath)) + if rel_filepath.startswith("torch/"): return False - if filepath.startswith("tools/autograd/templates/"): + if rel_filepath.startswith("tools/autograd/templates/"): return False return True # Keep this synchronized with includes/ignores in build_amd.py -def is_pytorch_file(filepath): - if filepath.startswith("aten/"): - if filepath.startswith("aten/src/ATen/core/"): +def is_pytorch_file(rel_filepath): + assert(not os.path.isabs(rel_filepath)) + if rel_filepath.startswith("aten/"): + if rel_filepath.startswith("aten/src/ATen/core/"): return False return True - if filepath.startswith("torch/"): + if rel_filepath.startswith("torch/"): return True - if filepath.startswith("tools/autograd/templates/"): + if rel_filepath.startswith("tools/autograd/templates/"): return True return False -def is_cusparse_file(filepath): - if is_pytorch_file(filepath): - return "sparse" in filepath.lower() +def is_cusparse_file(rel_filepath): + if is_pytorch_file(rel_filepath): + return "sparse" in rel_filepath.lower() return False -def is_caffe2_gpu_file(filepath): - if filepath.startswith("c10/cuda"): +def is_caffe2_gpu_file(rel_filepath): + assert(not os.path.isabs(rel_filepath)) + if rel_filepath.startswith("c10/cuda"): return True - filename = os.path.basename(filepath) + filename = os.path.basename(rel_filepath) _, ext = os.path.splitext(filename) return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) @@ -752,31 +723,36 @@ def pattern(self): Returns a dict with the following keys: "hipified_path" : absolute path of hipified source file "status" : "ok" if hipified file was written out - "skipped" if an identical hipified file already existed - "ignored" if the source file was a hipified file itself + "skipped" if an identical hipified file already existed or hipified file couldn't be written out + "ignored" if the source file was a hipified file itself or not meant to be hipified """ def preprocessor( output_directory: str, filepath: str, all_files: Iterable, - includes: Iterable, + header_include_dirs: Iterable, stats: Dict[str, List], hip_clang_launch: bool, is_pytorch_extension: bool, clean_ctx: GeneratedFileCleaner, show_progress: bool) -> HipifyResult: """ Executes the CUDA -> HIP conversion on the specified file. """ + if filepath not in all_files: + return {"hipified_path": None, "status": "[ignored, not to be hipified]"} + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + rel_filepath = os.path.relpath(filepath, output_directory) with open(fin_path, 'r', encoding='utf-8') as fin: if fin.readline() == HIPIFY_C_BREADCRUMB: - return {"hipified_path": None, "status": "ignored"} + return {"hipified_path": None, "status": "[ignored, input is hipified output]"} fin.seek(0) output_source = fin.read() orig_output_source = output_source - fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension))) + # get_hip_file_path needs a relative path to work correctly + fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension))) if not os.path.exists(os.path.dirname(fout_path)): clean_ctx.makedirs(os.path.dirname(fout_path)) @@ -791,9 +767,9 @@ def pt_sparse_repl(m): if is_pytorch_extension: output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) else: - if is_cusparse_file(filepath): + if is_cusparse_file(rel_filepath): output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source) - elif is_pytorch_file(filepath): + elif is_pytorch_file(rel_filepath): output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) else: def c2_repl(m): @@ -827,8 +803,8 @@ def repl(m): header_filepath = header_path_to_check # If not found, look in include dirs one by one and first match wins if header_filepath is None: - for include in includes: - header_dir_to_check = os.path.join(output_directory, os.path.dirname(include)) + for header_include_dir in header_include_dirs: + header_dir_to_check = os.path.join(output_directory, header_include_dir) header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) if os.path.exists(header_path_to_check): header_dir = header_dir_to_check @@ -839,12 +815,12 @@ def repl(m): # Hipify header file first if needed if header_filepath not in HIPIFY_FINAL_RESULT: preprocess_file_and_save_result(output_directory, - os.path.relpath(header_filepath, output_directory), - all_files, includes, stats, hip_clang_launch, is_pytorch_extension, - clean_ctx, show_progress) - value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"] - assert value is not None - return templ.format(os.path.relpath(value, header_dir)) + header_filepath, + all_files, header_include_dirs, stats, hip_clang_launch, + is_pytorch_extension, clean_ctx, show_progress) + hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"] + return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None + else header_filepath, header_dir)) return m.group(0) return repl @@ -878,7 +854,7 @@ def repl(m): and orig_output_source == output_source and os.path.dirname(fin_path) == os.path.dirname(fout_path) ): - return {"hipified_path": fin_path, "status": "ok"} + return {"hipified_path": fin_path, "status": "[skipped, no changes]"} # Add hipify breadcrumb for C-style files to avoid re-hipification if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): @@ -892,13 +868,13 @@ def repl(m): try: with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: fout.write(output_source) - return {"hipified_path": fout_path, "status": "ok"} + return {"hipified_path": fout_path, "status": "[ok]"} except PermissionError as e: print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}", file=sys.stderr) - return {"hipified_path": fin_path, "status": "skipped"} + return {"hipified_path": fin_path, "status": "[skipped, no permissions]"} else: - return {"hipified_path": fout_path, "status": "skipped"} + return {"hipified_path": fout_path, "status": "[skipped, already hipified]"} def file_specific_replacement(filepath, search_string, replace_string, strict=False): with openf(filepath, "r+") as f: @@ -993,14 +969,17 @@ def hipify( project_directory: str, show_detailed: bool = False, extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), + header_extensions: Iterable = (".cuh", ".h", ".hpp"), output_directory: str = "", - includes: Iterable = (), + header_include_dirs: Iterable = (), + includes: Iterable = ('*',), extra_files: Iterable = (), out_of_place_only: bool = False, ignores: Iterable = (), show_progress: bool = True, hip_clang_launch: bool = False, is_pytorch_extension: bool = False, + hipify_extra_files_only: bool = False, clean_ctx: Optional[GeneratedFileCleaner] = None ) -> HipifyFinalResult: if project_directory == "": @@ -1016,6 +995,10 @@ def hipify( project_directory.rstrip("/") output_directory = project_directory + "_amd" + if project_directory != output_directory: + includes = [include.replace(project_directory, output_directory) for include in includes] + ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores] + # Copy from project directory to output directory if not done already. if not os.path.exists(output_directory): shutil.copytree(project_directory, output_directory) @@ -1025,19 +1008,42 @@ def hipify( out_of_place_only=out_of_place_only, is_pytorch_extension=is_pytorch_extension)) all_files_set = set(all_files) - # Convert extra_files to relative paths since all_files has all relative paths for f in extra_files: - f_rel = os.path.relpath(f, output_directory) - if f_rel not in all_files_set: - all_files.append(f_rel) - - # Start Preprocessor - return preprocess( - output_directory, - all_files, - includes, - show_detailed=show_detailed, - show_progress=show_progress, - hip_clang_launch=hip_clang_launch, - is_pytorch_extension=is_pytorch_extension, - clean_ctx=clean_ctx) + if not os.path.isabs(f): + f = os.path.join(output_directory, f) + if f not in all_files_set: + all_files.append(f) + + # List all files in header_include_paths to ensure they are hipified + from pathlib import Path + for header_include_dir in header_include_dirs: + if os.path.isabs(header_include_dir): + header_include_dir_path = Path(header_include_dir) + else: + header_include_dir_path = Path(os.path.join(output_directory, header_include_dir)) + for path in header_include_dir_path.rglob('*'): + if ( + path.is_file() + and _fnmatch(str(path), includes) + and (not _fnmatch(str(path), ignores)) + and match_extensions(path.name, header_extensions) + ): + all_files.append(str(path)) + + if clean_ctx is None: + clean_ctx = GeneratedFileCleaner(keep_intermediates=True) + + # Preprocessing statistics. + stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []} + + for filepath in (all_files if not hipify_extra_files_only else extra_files): + preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs, + stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) + + # Show detailed summary + if show_detailed: + compute_stats(stats) + + return HIPIFY_FINAL_RESULT From c1192d9ab24ba0fd7dd4d92b9bc8118ae82f1c5d Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Sat, 19 Feb 2022 13:52:29 +0000 Subject: [PATCH 2/3] Hipify bug fix for header_include_paths being passed in as None from JIT path --- torch/utils/cpp_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 64a174c2b3b2..955494361a9b 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1352,7 +1352,7 @@ def _jit_compile(name, hipify_result = hipify_python.hipify( project_directory=build_directory, output_directory=build_directory, - header_include_dirs=extra_include_paths, + header_include_dirs=(extra_include_paths if extra_include_paths is not None else []), extra_files=[os.path.abspath(s) for s in sources], ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers show_detailed=verbose, From 90454d3275796404a650af5563ffe0a3d2e65f9e Mon Sep 17 00:00:00 2001 From: rraminen Date: Thu, 21 Apr 2022 16:21:36 -0400 Subject: [PATCH 3/3] Fixed lint errors --- torch/utils/cpp_extension.py | 4 ++-- torch/utils/hipify/hipify_python.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 955494361a9b..043d8834b09a 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -16,7 +16,7 @@ from .file_baton import FileBaton from ._cpp_extension_versioner import ExtensionVersioner from .hipify import hipify_python -from .hipify.hipify_python import GeneratedFileCleaner, get_hip_file_path +from .hipify.hipify_python import GeneratedFileCleaner from typing import List, Optional, Union, Tuple from torch.torch_version import TorchVersion @@ -1354,7 +1354,7 @@ def _jit_compile(name, output_directory=build_directory, header_include_dirs=(extra_include_paths if extra_include_paths is not None else []), extra_files=[os.path.abspath(s) for s in sources], - ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers + ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers show_detailed=verbose, show_progress=verbose, is_pytorch_extension=True, diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 25014dd7b4c7..19834696827a 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -178,7 +178,7 @@ def preprocess_file_and_save_result( fin_path = os.path.abspath(os.path.join(output_directory, filepath)) # Show what happened - if show_progress and "ignored" not in result["status"]: + if show_progress and "ignored" not in str(result["status"]): print( fin_path, "->", result["hipified_path"], result["status"], flush=True) @@ -819,7 +819,7 @@ def repl(m): all_files, header_include_dirs, stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"] - return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None + return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None else header_filepath, header_dir)) return m.group(0)