diff --git a/adb_shell/adb_device.py b/adb_shell/adb_device.py index 4072b37..570829d 100644 --- a/adb_shell/adb_device.py +++ b/adb_shell/adb_device.py @@ -257,7 +257,7 @@ def connect(self, rsa_keys=None, transport_timeout_s=None, auth_timeout_s=consta # Services # # # # ======================================================================= # - def _service(self, service, command, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, decode=True): + def _service(self, service, command, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, timeout_s=None, decode=True): """Send an ADB command to the device. Parameters @@ -271,6 +271,8 @@ def _service(self, service, command, transport_timeout_s=None, read_timeout_s=co and :meth:`BaseTransport.bulk_write() ` read_timeout_s : float The total time in seconds to wait for a ``b'CLSE'`` or ``b'OKAY'`` command in :meth:`AdbDevice._read` + timeout_s : float, None + The total time in seconds to wait for the ADB command to finish decode : bool Whether to decode the output to utf8 before returning @@ -280,7 +282,7 @@ def _service(self, service, command, transport_timeout_s=None, read_timeout_s=co The output of the ADB command as a string if ``decode`` is True, otherwise as bytes. """ - adb_info = _AdbTransactionInfo(None, None, transport_timeout_s, read_timeout_s) + adb_info = _AdbTransactionInfo(None, None, transport_timeout_s, read_timeout_s, timeout_s) if decode: return b''.join(self._streaming_command(service, command, adb_info)).decode('utf8') return b''.join(self._streaming_command(service, command, adb_info)) @@ -317,7 +319,7 @@ def _streaming_service(self, service, command, transport_timeout_s=None, read_ti for line in stream: yield line - def root(self, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S): + def root(self, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, timeout_s=None): """Gain root access. The device must be rooted in order for this to work. @@ -334,9 +336,9 @@ def root(self, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_T if not self.available: raise exceptions.AdbConnectionError("ADB command not sent because a connection to the device has not been established. (Did you call `AdbDevice.connect()`?)") - self._service(b'root', b'', transport_timeout_s, read_timeout_s, False) + self._service(b'root', b'', transport_timeout_s, read_timeout_s, timeout_s, False) - def shell(self, command, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, decode=True): + def shell(self, command, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, timeout_s=None, decode=True): """Send an ADB shell command to the device. Parameters @@ -348,6 +350,8 @@ def shell(self, command, transport_timeout_s=None, read_timeout_s=constants.DEFA and :meth:`BaseTransport.bulk_write() ` read_timeout_s : float The total time in seconds to wait for a ``b'CLSE'`` or ``b'OKAY'`` command in :meth:`AdbDevice._read` + timeout_s : float, None + The total time in seconds to wait for the ADB command to finish decode : bool Whether to decode the output to utf8 before returning @@ -360,7 +364,7 @@ def shell(self, command, transport_timeout_s=None, read_timeout_s=constants.DEFA if not self.available: raise exceptions.AdbConnectionError("ADB command not sent because a connection to the device has not been established. (Did you call `AdbDevice.connect()`?)") - return self._service(b'shell', command.encode('utf8'), transport_timeout_s, read_timeout_s, decode) + return self._service(b'shell', command.encode('utf8'), transport_timeout_s, read_timeout_s, timeout_s, decode) def streaming_shell(self, command, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, decode=True): """Send an ADB shell command to the device, yielding each line of output. @@ -869,6 +873,8 @@ def _read_until_close(self, adb_info): The data that was read by :meth:`AdbDevice._read_until` """ + start = time.time() + while True: cmd, data = self._read_until([constants.CLSE, constants.WRTE], adb_info) @@ -879,6 +885,10 @@ def _read_until_close(self, adb_info): yield data + # Make sure the ADB command has not timed out + if adb_info.timeout_s is not None and time.time() - start > adb_info.timeout_s: + raise exceptions.AdbTimeoutError("The command did not complete within {} seconds".format(adb_info.timeout_s)) + def _send(self, msg, adb_info): """Send a message to the device. diff --git a/adb_shell/exceptions.py b/adb_shell/exceptions.py index 655caf0..3f112d5 100644 --- a/adb_shell/exceptions.py +++ b/adb_shell/exceptions.py @@ -38,6 +38,12 @@ class AdbConnectionError(Exception): """ +class AdbTimeoutError(Exception): + """ADB command did not complete within the specified time. + + """ + + class DeviceAuthError(Exception): """Device authentication failed. diff --git a/adb_shell/hidden_helpers.py b/adb_shell/hidden_helpers.py index ca67f4c..f6395a1 100644 --- a/adb_shell/hidden_helpers.py +++ b/adb_shell/hidden_helpers.py @@ -81,6 +81,15 @@ def _open(name, mode='r'): class _AdbTransactionInfo(object): # pylint: disable=too-few-public-methods """A class for storing info and settings used during a single ADB "transaction." + Note that if ``timeout_s`` is not ``None``, then: + + :: + + self.transport_timeout_s <= self.read_timeout_s <= self.timeout_s + + If ``timeout_s`` is ``None``, the first inequality still applies. + + Parameters ---------- local_id : int @@ -94,6 +103,8 @@ class _AdbTransactionInfo(object): # pylint: disable=too-few-public-methods :meth:`BaseTransportAsync.bulk_write() ` read_timeout_s : float The total time in seconds to wait for a command in ``expected_cmds`` in :meth:`AdbDevice._read` and :meth:`AdbDeviceAsync._read` + timeout_s : float, None + The total time in seconds to wait for the ADB command to finish Attributes ---------- @@ -103,6 +114,8 @@ class _AdbTransactionInfo(object): # pylint: disable=too-few-public-methods The total time in seconds to wait for a command in ``expected_cmds`` in :meth:`AdbDevice._read` and :meth:`AdbDeviceAsync._read` remote_id : int The ID for the recipient + timeout_s : float, None + The total time in seconds to wait for the ADB command to finish transport_timeout_s : float, None Timeout in seconds for sending and receiving packets, or ``None``; see :meth:`BaseTransport.bulk_read() `, :meth:`BaseTransport.bulk_write() `, @@ -110,11 +123,12 @@ class _AdbTransactionInfo(object): # pylint: disable=too-few-public-methods :meth:`BaseTransportAsync.bulk_write() ` """ - def __init__(self, local_id, remote_id, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S): + def __init__(self, local_id, remote_id, transport_timeout_s=None, read_timeout_s=constants.DEFAULT_READ_TIMEOUT_S, timeout_s=None): self.local_id = local_id self.remote_id = remote_id - self.transport_timeout_s = transport_timeout_s - self.read_timeout_s = read_timeout_s + self.timeout_s = timeout_s + self.read_timeout_s = read_timeout_s if self.timeout_s is None else min(read_timeout_s, self.timeout_s) + self.transport_timeout_s = self.read_timeout_s if transport_timeout_s is None else min(transport_timeout_s, self.read_timeout_s) class _FileSyncTransactionInfo(object): # pylint: disable=too-few-public-methods diff --git a/tests/test_adb_device.py b/tests/test_adb_device.py index 2a681d5..c9356b6 100644 --- a/tests/test_adb_device.py +++ b/tests/test_adb_device.py @@ -1,6 +1,7 @@ import logging from io import BytesIO import sys +import time import unittest from mock import mock_open, patch @@ -421,6 +422,26 @@ def test_issue29(self): self.device.shell('Android TV update command') self.device.shell('Android TV update command') + def test_shell_error_timeout(self): + self.assertTrue(self.device.connect()) + + # Provide the `bulk_read` return values + self.device._transport._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), + AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'PA'), + AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'SS'), + AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) + + def fake_read_until(*args, **kwargs): + time.sleep(0.2) + return b'WRTE', b'PA' + + with patch('adb_shell.adb_device.AdbDevice._read_until', fake_read_until): + with self.assertRaises(exceptions.AdbTimeoutError): + self.device.shell('TEST', timeout_s=0.5) + + # Clear the `_bulk_read` buffer so that `self.tearDown()` passes + self.device._transport._bulk_read = b'' + # ======================================================================= # # # # `streaming_shell` tests #