Permalink
Browse files

added SignalException and caused some signals to throw an exception, c…

…loses #91
  • Loading branch information...
1 parent bdb5335 commit 6f4be63d117da098a0b91bae411a7465cdbb0cce Andrew Moffat committed Jan 30, 2013
Showing with 73 additions and 15 deletions.
  1. +3 −0 CHANGELOG.md
  2. +42 −9 sh.py
  3. +28 −6 test.py
View
@@ -3,6 +3,9 @@
## 1.08 -
+* Added SignalException class and made all commands that end terminate by
+ a signal defined in SIGNALS_THAT_SHOULD_THROW_EXCEPTION raise it. [#91](https://github.com/amoffat/sh/issues/91)
+
* Bugfix where CommandNotFound was not being raised if Command was created
by instantiation. [#113](https://github.com/amoffat/sh/issues/113)
View
51 sh.py
@@ -120,23 +120,39 @@ def __init__(self, full_cmd, stdout, stderr):
msg = "\n\n RAN: %r\n\n STDOUT:\n%s\n\n STDERR:\n%s" %\
(full_cmd, tstdout.decode(DEFAULT_ENCODING), tstderr.decode(DEFAULT_ENCODING))
super(ErrorReturnCode, self).__init__(msg)
+
+
+class SignalException(ErrorReturnCode): pass
+
+SIGNALS_THAT_SHOULD_THROW_EXCEPTION = (
+ signal.SIGKILL,
+ signal.SIGSEGV,
+ signal.SIGTERM,
+ signal.SIGINT,
+ signal.SIGQUIT
+)
# we subclass AttributeError because:
# https://github.com/ipython/ipython/issues/2577
# https://github.com/amoffat/sh/issues/97#issuecomment-10610629
class CommandNotFound(AttributeError): pass
-rc_exc_regex = re.compile("ErrorReturnCode_(\d+)")
+rc_exc_regex = re.compile("(ErrorReturnCode|SignalException)_(\d+)")
rc_exc_cache = {}
def get_rc_exc(rc):
rc = int(rc)
try: return rc_exc_cache[rc]
except KeyError: pass
- name = "ErrorReturnCode_%d" % rc
- exc = type(name, (ErrorReturnCode,), {})
+ if rc > 0:
+ name = "ErrorReturnCode_%d" % rc
+ exc = type(name, (ErrorReturnCode,), {})
+ else:
+ name = "SignalException_%d" % abs(rc)
+ exc = type(name, (SignalException,), {})
+
rc_exc_cache[rc] = exc
return exc
@@ -222,6 +238,13 @@ def __init__(self, cmd, call_args, stdin, stdout, stderr):
self.cmd = cmd
self.ran = " ".join(cmd)
self.process = None
+
+ # this flag is for whether or not we've handled the exit code (like
+ # by raising an exception). this is necessary because .wait() is called
+ # from multiple places, and wait() triggers the exit code to be
+ # processed. but we don't want to raise multiple exceptions, only
+ # one (if any at all)
+ self._handled_exit_code = False
self.should_wait = True
spawn_process = True
@@ -275,11 +298,18 @@ def wait(self):
# here we determine if we had an exception, or an error code that we weren't
# expecting to see. if we did, we create and raise an exception
def _handle_exit_code(self, code):
- if code not in self.call_args["ok_code"] and code >= 0: raise get_rc_exc(code)(
- " ".join(self.cmd),
- self.process.stdout,
- self.process.stderr
- )
+ if self._handled_exit_code: return
+ self._handled_exit_code = True
+
+ if code not in self.call_args["ok_code"] and \
+ (code > 0 or -code in SIGNALS_THAT_SHOULD_THROW_EXCEPTION):
+ raise get_rc_exc(code)(
+ " ".join(self.cmd),
+ self.process.stdout,
+ self.process.stderr
+ )
+
+
@property
def stdout(self):
@@ -1532,7 +1562,10 @@ def __getitem__(self, k):
try: return rc_exc_cache[k]
except KeyError:
m = rc_exc_regex.match(k)
- if m: return get_rc_exc(int(m.group(1)))
+ if m:
+ exit_code = int(m.group(2))
+ if m.group(1) == "SignalException": exit_code = -exit_code
+ return get_rc_exc(exit_code)
# is it a builtin?
try: return getattr(self["__builtins__"], k)
View
34 test.py
@@ -825,8 +825,11 @@ def agg(line, stdin, process):
process.terminate()
return True
- p = python(py.name, _out=agg, u=True)
- p.wait()
+ try:
+ p = python(py.name, _out=agg, u=True)
+ p.wait()
+ except sh.SignalException_15:
+ pass
self.assertEqual(p.process.exit_code, -signal.SIGTERM)
self.assertTrue("4" not in p)
@@ -836,6 +839,7 @@ def agg(line, stdin, process):
def test_stdout_callback_kill(self):
import signal
+ import sh
py = create_tmp_test("""
import sys
@@ -855,8 +859,11 @@ def agg(line, stdin, process):
process.kill()
return True
- p = python(py.name, _out=agg, u=True)
- p.wait()
+ try:
+ p = python(py.name, _out=agg, u=True)
+ p.wait()
+ except sh.SignalException_9:
+ pass
self.assertEqual(p.process.exit_code, -signal.SIGKILL)
self.assertTrue("4" not in p)
@@ -957,7 +964,7 @@ def test_piped_generator(self):
import time
for letter in "andrew":
- time.sleep(0.5)
+ time.sleep(0.6)
print(letter)
""")
@@ -1189,7 +1196,8 @@ def test_timeout(self):
sleep_for = 3
timeout = 1
started = time()
- sh.sleep(sleep_for, _timeout=timeout).wait()
+ try: sh.sleep(sleep_for, _timeout=timeout).wait()
+ except sh.SignalException_9: pass
elapsed = time() - started
self.assertTrue(abs(elapsed - timeout) < 0.1)
@@ -1356,6 +1364,20 @@ def test_shared_secial_args(self):
out2.close()
+ def test_signal_exception(self):
+ from sh import SignalException, get_rc_exc
+
+ def throw_terminate_signal():
+ py = create_tmp_test("""
+import time
+while True: time.sleep(1)
+""")
+ to_kill = python(py.name, _bg=True)
+ to_kill.terminate()
+ to_kill.wait()
+
+ self.assertRaises(get_rc_exc(-15), throw_terminate_signal)
+
if __name__ == "__main__":

0 comments on commit 6f4be63

Please sign in to comment.