Skip to content

Commit

Permalink
PyTorch Python fork fix (#291)
Browse files Browse the repository at this point in the history
* PyTorch Python fork fix

- fixes issue where forking process in PyTorch causes omnitrace/__main__.py to fail due to missing script argument

* Update source/python/omnitrace/__main__.py

Remove debugging "print" LOC
  • Loading branch information
jrmadsen committed Jun 22, 2023
1 parent 693f753 commit a85f141
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions source/python/omnitrace/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import traceback

PY3 = sys.version_info[0] == 3

_OMNITRACE_PYTHON_SCRIPT_FILE = None

# Python 3.x compatibility utils: execfile
try:
Expand Down Expand Up @@ -270,27 +270,27 @@ def run(prof, cmd):
prof.runctx(code, globs, None)


def main():
def main(main_args=sys.argv):
"""Main function"""

opts = None
argv = None
if "--" in sys.argv:
_idx = sys.argv.index("--")
_argv = sys.argv[(_idx + 1) :]
opts = parse_args(sys.argv[1:_idx])
if "--" in main_args:
_idx = main_args.index("--")
_argv = main_args[(_idx + 1) :]
opts = parse_args(main_args[1:_idx])
argv = _argv
else:
if "-h" in sys.argv or "--help" in sys.argv:
if "-h" in main_args or "--help" in main_args:
opts = parse_args()
else:
argv = sys.argv[1:]
argv = main_args[1:]
opts = parse_args([])
if len(argv) == 0 or not os.path.isfile(argv[0]):
raise RuntimeError(
"Could not determine input script. Use '--' before "
"Could not determine input script in '{}'. Use '--' before "
"the script and its arguments to ensure correct parsing. \nE.g. "
"python -m omnitrace -- ./script.py"
"python -m omnitrace -- ./script.py".format(" ".join(argv))
)

if len(argv) > 1:
Expand Down Expand Up @@ -337,7 +337,7 @@ def main():

print("[omnitrace]> profiling: {}".format(argv))

sys.argv[:] = argv
main_args[:] = argv
if opts.setup is not None:
# Run some setup code outside of the profiler. This is good for large
# imports.
Expand All @@ -351,12 +351,14 @@ def main():

from . import Profiler, FakeProfiler

script_file = find_script(sys.argv[0])
script_file = find_script(main_args[0])
__file__ = script_file
__name__ = "__main__"
# Make sure the script's directory is on sys.path
sys.path.insert(0, os.path.dirname(script_file))

_OMNITRACE_PYTHON_SCRIPT_FILE = script_file

prof = Profiler()
fake = FakeProfiler()

Expand Down Expand Up @@ -394,7 +396,12 @@ def main():


if __name__ == "__main__":
main()
args = sys.argv
if "--" not in args and _OMNITRACE_PYTHON_SCRIPT_FILE is not None:
args = [args[0]] + ["--", _OMNITRACE_PYTHON_SCRIPT_FILE] + args[1:]
os.environ["OMNITRACE_USE_PID"] = "ON"

main(args)
from .libpyomnitrace import finalize

finalize()

0 comments on commit a85f141

Please sign in to comment.