Skip to content

Commit

Permalink
added SignalException and caused some signals to throw an exception, c…
Browse files Browse the repository at this point in the history
…loses #91
  • Loading branch information
Andrew Moffat committed Jan 30, 2013
1 parent bdb5335 commit 6f4be63
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@


## 1.08 - ## 1.08 -


* Added SignalException class and made all commands that end terminate by
a signal defined in SIGNALS_THAT_SHOULD_THROW_EXCEPTION raise it. [#91](https://github.com/amoffat/sh/issues/91)

* Bugfix where CommandNotFound was not being raised if Command was created * Bugfix where CommandNotFound was not being raised if Command was created
by instantiation. [#113](https://github.com/amoffat/sh/issues/113) by instantiation. [#113](https://github.com/amoffat/sh/issues/113)


Expand Down
51 changes: 42 additions & 9 deletions sh.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -120,23 +120,39 @@ def __init__(self, full_cmd, stdout, stderr):
msg = "\n\n RAN: %r\n\n STDOUT:\n%s\n\n STDERR:\n%s" %\ msg = "\n\n RAN: %r\n\n STDOUT:\n%s\n\n STDERR:\n%s" %\
(full_cmd, tstdout.decode(DEFAULT_ENCODING), tstderr.decode(DEFAULT_ENCODING)) (full_cmd, tstdout.decode(DEFAULT_ENCODING), tstderr.decode(DEFAULT_ENCODING))
super(ErrorReturnCode, self).__init__(msg) super(ErrorReturnCode, self).__init__(msg)


class SignalException(ErrorReturnCode): pass

SIGNALS_THAT_SHOULD_THROW_EXCEPTION = (
signal.SIGKILL,
signal.SIGSEGV,
signal.SIGTERM,
signal.SIGINT,
signal.SIGQUIT
)




# we subclass AttributeError because: # we subclass AttributeError because:
# https://github.com/ipython/ipython/issues/2577 # https://github.com/ipython/ipython/issues/2577
# https://github.com/amoffat/sh/issues/97#issuecomment-10610629 # https://github.com/amoffat/sh/issues/97#issuecomment-10610629
class CommandNotFound(AttributeError): pass class CommandNotFound(AttributeError): pass


rc_exc_regex = re.compile("ErrorReturnCode_(\d+)") rc_exc_regex = re.compile("(ErrorReturnCode|SignalException)_(\d+)")
rc_exc_cache = {} rc_exc_cache = {}


def get_rc_exc(rc): def get_rc_exc(rc):
rc = int(rc) rc = int(rc)
try: return rc_exc_cache[rc] try: return rc_exc_cache[rc]
except KeyError: pass except KeyError: pass


name = "ErrorReturnCode_%d" % rc if rc > 0:
exc = type(name, (ErrorReturnCode,), {}) name = "ErrorReturnCode_%d" % rc
exc = type(name, (ErrorReturnCode,), {})
else:
name = "SignalException_%d" % abs(rc)
exc = type(name, (SignalException,), {})

rc_exc_cache[rc] = exc rc_exc_cache[rc] = exc
return exc return exc


Expand Down Expand Up @@ -222,6 +238,13 @@ def __init__(self, cmd, call_args, stdin, stdout, stderr):
self.cmd = cmd self.cmd = cmd
self.ran = " ".join(cmd) self.ran = " ".join(cmd)
self.process = None self.process = None

# this flag is for whether or not we've handled the exit code (like
# by raising an exception). this is necessary because .wait() is called
# from multiple places, and wait() triggers the exit code to be
# processed. but we don't want to raise multiple exceptions, only
# one (if any at all)
self._handled_exit_code = False


self.should_wait = True self.should_wait = True
spawn_process = True spawn_process = True
Expand Down Expand Up @@ -275,11 +298,18 @@ def wait(self):
# here we determine if we had an exception, or an error code that we weren't # here we determine if we had an exception, or an error code that we weren't
# expecting to see. if we did, we create and raise an exception # expecting to see. if we did, we create and raise an exception
def _handle_exit_code(self, code): def _handle_exit_code(self, code):
if code not in self.call_args["ok_code"] and code >= 0: raise get_rc_exc(code)( if self._handled_exit_code: return
" ".join(self.cmd), self._handled_exit_code = True
self.process.stdout,
self.process.stderr if code not in self.call_args["ok_code"] and \
) (code > 0 or -code in SIGNALS_THAT_SHOULD_THROW_EXCEPTION):
raise get_rc_exc(code)(
" ".join(self.cmd),
self.process.stdout,
self.process.stderr
)




@property @property
def stdout(self): def stdout(self):
Expand Down Expand Up @@ -1532,7 +1562,10 @@ def __getitem__(self, k):
try: return rc_exc_cache[k] try: return rc_exc_cache[k]
except KeyError: except KeyError:
m = rc_exc_regex.match(k) m = rc_exc_regex.match(k)
if m: return get_rc_exc(int(m.group(1))) if m:
exit_code = int(m.group(2))
if m.group(1) == "SignalException": exit_code = -exit_code
return get_rc_exc(exit_code)


# is it a builtin? # is it a builtin?
try: return getattr(self["__builtins__"], k) try: return getattr(self["__builtins__"], k)
Expand Down
34 changes: 28 additions & 6 deletions test.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -825,8 +825,11 @@ def agg(line, stdin, process):
process.terminate() process.terminate()
return True return True


p = python(py.name, _out=agg, u=True) try:
p.wait() p = python(py.name, _out=agg, u=True)
p.wait()
except sh.SignalException_15:
pass


self.assertEqual(p.process.exit_code, -signal.SIGTERM) self.assertEqual(p.process.exit_code, -signal.SIGTERM)
self.assertTrue("4" not in p) self.assertTrue("4" not in p)
Expand All @@ -836,6 +839,7 @@ def agg(line, stdin, process):


def test_stdout_callback_kill(self): def test_stdout_callback_kill(self):
import signal import signal
import sh


py = create_tmp_test(""" py = create_tmp_test("""
import sys import sys
Expand All @@ -855,8 +859,11 @@ def agg(line, stdin, process):
process.kill() process.kill()
return True return True


p = python(py.name, _out=agg, u=True) try:
p.wait() p = python(py.name, _out=agg, u=True)
p.wait()
except sh.SignalException_9:
pass


self.assertEqual(p.process.exit_code, -signal.SIGKILL) self.assertEqual(p.process.exit_code, -signal.SIGKILL)
self.assertTrue("4" not in p) self.assertTrue("4" not in p)
Expand Down Expand Up @@ -957,7 +964,7 @@ def test_piped_generator(self):
import time import time
for letter in "andrew": for letter in "andrew":
time.sleep(0.5) time.sleep(0.6)
print(letter) print(letter)
""") """)


Expand Down Expand Up @@ -1189,7 +1196,8 @@ def test_timeout(self):
sleep_for = 3 sleep_for = 3
timeout = 1 timeout = 1
started = time() started = time()
sh.sleep(sleep_for, _timeout=timeout).wait() try: sh.sleep(sleep_for, _timeout=timeout).wait()
except sh.SignalException_9: pass
elapsed = time() - started elapsed = time() - started
self.assertTrue(abs(elapsed - timeout) < 0.1) self.assertTrue(abs(elapsed - timeout) < 0.1)


Expand Down Expand Up @@ -1356,6 +1364,20 @@ def test_shared_secial_args(self):
out2.close() out2.close()




def test_signal_exception(self):
from sh import SignalException, get_rc_exc

def throw_terminate_signal():
py = create_tmp_test("""
import time
while True: time.sleep(1)
""")
to_kill = python(py.name, _bg=True)
to_kill.terminate()
to_kill.wait()

self.assertRaises(get_rc_exc(-15), throw_terminate_signal)





if __name__ == "__main__": if __name__ == "__main__":
Expand Down

0 comments on commit 6f4be63

Please sign in to comment.