From 7bf9e76accdf6e9e92f8e27f9c9341e47ea8489d Mon Sep 17 00:00:00 2001 From: Elias Kahl Date: Sat, 13 May 2023 21:43:06 +0000 Subject: [PATCH] refactor proc tree iter and ps parsing Fixes #55 Tries to also address #21, #35 --- src/shellingham/posix/__init__.py | 28 ++++++---------- src/shellingham/posix/proc.py | 39 +++++++++++------------ src/shellingham/posix/ps.py | 53 +++++++++++++++++++++---------- 3 files changed, 65 insertions(+), 55 deletions(-) diff --git a/src/shellingham/posix/__init__.py b/src/shellingham/posix/__init__.py index 3dfde8c..15f3fb5 100644 --- a/src/shellingham/posix/__init__.py +++ b/src/shellingham/posix/__init__.py @@ -4,8 +4,12 @@ from .._core import SHELL_NAMES, ShellDetectionFailure from . import proc, ps +QEMU_BIN_REGEX = re.compile( + r"qemu-(alpha|armeb|arm|m68k|cris|i386|x86_64|microblaze|mips|mipsel|mips64|mips64el|mipsn32|mipsn32el|nios2|ppc64|ppc|sh4eb|sh4|sparc|sparc32plus|sparc64)" +) -def _get_process_mapping(): + +def _get_process_parents(pid, max_depth=10): """Select a way to obtain process information from the system. * `/proc` is used if supported. @@ -13,25 +17,13 @@ def _get_process_mapping(): """ for impl in (proc, ps): try: - mapping = impl.get_process_mapping() + mapping = impl.get_process_parents(pid, max_depth) except EnvironmentError: continue return mapping raise ShellDetectionFailure("compatible proc fs or ps utility is required") -def _iter_process_args(mapping, pid, max_depth): - """Traverse up the tree and yield each process's argument list.""" - for _ in range(max_depth): - try: - proc = mapping[pid] - except KeyError: # We've reached the root process. Give up. - break - if proc.args: # Presumably the process should always have a name? - yield proc.args - pid = proc.ppid # Go up one level. - - def _get_login_shell(proc_cmd): """Form shell information from SHELL environ if possible.""" login_shell = os.environ.get("SHELL", "") @@ -71,8 +63,8 @@ def _get_shell(cmd, *args): if cmd.startswith("-"): # Login shell! Let's use this. return _get_login_shell(cmd) name = os.path.basename(cmd).lower() - if name == "rosetta" or name.contains("qemu-"): - # Running (probably in docker) with rosetta or qemu, first arg is real command + if name == "rosetta" or QEMU_BIN_REGEX.fullmatch(name): + # Running (probably in docker) with rosetta or qemu, first arg is actual command cmd = args[0] args = args[1:] name = os.path.basename(cmd).lower() @@ -87,8 +79,8 @@ def _get_shell(cmd, *args): def get_shell(pid=None, max_depth=10): """Get the shell that the supplied pid or os.getpid() is running in.""" pid = str(pid or os.getpid()) - mapping = _get_process_mapping() - for proc_args in _iter_process_args(mapping, pid, max_depth): + processes = _get_process_parents(pid, max_depth) + for proc_args, _, _ in processes: shell = _get_shell(*proc_args) if shell: return shell diff --git a/src/shellingham/posix/proc.py b/src/shellingham/posix/proc.py index 4405731..05f7d31 100644 --- a/src/shellingham/posix/proc.py +++ b/src/shellingham/posix/proc.py @@ -9,11 +9,9 @@ # NetBSD: https://man.netbsd.org/NetBSD-9.3-STABLE/mount_procfs.8 # DragonFlyBSD: https://www.dragonflybsd.org/cgi/web-man?command=procfs BSD_STAT_PPID = 2 -BSD_STAT_TTY = 5 # See https://docs.kernel.org/filesystems/proc.html LINUX_STAT_PPID = 3 -LINUX_STAT_TTY = 6 STAT_PATTERN = re.compile(r"\(.+\)|\S+") @@ -41,14 +39,14 @@ def _use_bsd_stat_format(): return False -def _get_stat(pid, name): +def _get_ppid(pid, name): path = os.path.join("/proc", str(pid), name) with io.open(path, encoding="ascii", errors="replace") as f: parts = STAT_PATTERN.findall(f.read()) # We only care about TTY and PPID -- both are numbers. if _use_bsd_stat_format(): - return parts[BSD_STAT_TTY], parts[BSD_STAT_PPID] - return parts[LINUX_STAT_TTY], parts[LINUX_STAT_PPID] + return parts[BSD_STAT_PPID] + return parts[LINUX_STAT_PPID] def _get_cmdline(pid): @@ -66,21 +64,22 @@ class ProcFormatError(EnvironmentError): pass -def get_process_mapping(): +def get_process_parents(pid, max_depth=10): """Try to look up the process tree via the /proc interface.""" stat_name = detect_proc() - self_tty = _get_stat(os.getpid(), stat_name)[0] - processes = {} - for pid in os.listdir("/proc"): - if not pid.isdigit(): - continue - try: - tty, ppid = _get_stat(pid, stat_name) - if tty != self_tty: - continue - args = _get_cmdline(pid) - processes[pid] = Process(args=args, pid=pid, ppid=ppid) - except IOError: - # Process has disappeared - just ignore it. - continue + processes = [] + + depth = 0 + while depth < max_depth: + depth += 1 + ppid = _get_ppid(pid, stat_name) + args = _get_cmdline(pid) + processes.append(Process(args=args, pid=pid, ppid=ppid)) + + if ppid == "0": + break + + pid = ppid + + return processes diff --git a/src/shellingham/posix/ps.py b/src/shellingham/posix/ps.py index 3de6d25..a249b57 100644 --- a/src/shellingham/posix/ps.py +++ b/src/shellingham/posix/ps.py @@ -8,11 +8,9 @@ class PsNotAvailable(EnvironmentError): pass - -def get_process_mapping(): - """Try to look up the process tree via the output of `ps`.""" +def _get_stats(pid): try: - cmd = ["ps", "-ww", "-o", "pid=", "-o", "ppid=", "-o", "args="] + cmd = ["ps", "wwl", "-P", pid] output = subprocess.check_output(cmd) except OSError as e: # Python 2-compatible FileNotFoundError. if e.errno != errno.ENOENT: @@ -27,17 +25,38 @@ def get_process_mapping(): if not isinstance(output, str): encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() output = output.decode(encoding) - processes = {} - for line in output.split("\n"): - try: - pid, ppid, args = line.strip().split(None, 2) - # XXX: This is not right, but we are really out of options. - # ps does not offer a sane way to decode the argument display, - # and this is "Good Enough" for obtaining shell names. Hopefully - # people don't name their shell with a space, or have something - # like "/usr/bin/xonsh is uber". (sarugaku/shellingham#14) - args = tuple(a.strip() for a in args.split(" ")) - except ValueError: - continue - processes[pid] = Process(args=args, pid=pid, ppid=ppid) + + print(output) + + header, row = output.split("\n")[:2] + header = header.split() + row = row.split() + + pid_index = header.index("PID") + ppid_index = header.index("PPID") + + try: + cmd_index = header.index("COMMAND") + except ValueError: + # https://github.com/sarugaku/shellingham/pull/23#issuecomment-474005491 + cmd_index = header.index("CMD") + + + return row[cmd_index:], row[pid_index], row[ppid_index] + + + + +def get_process_parents(pid, max_depth=10): + """Try to look up the process tree via the output of `ps`.""" + processes = [] + + depth = 0 + while pid != "0" and depth < max_depth: + depth += 1 + cmd, pid, ppid = _get_stats(pid) + processes.append(Process(args=cmd, pid=pid, ppid=ppid)) + + pid = ppid + return processes