diff --git a/.travis.yml b/.travis.yml index 24dec98..7002f6b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,11 @@ language: python python: - - "2.7" + - "3.6" - "3.7" + - "3.8" install: - pip install . - pip install flake8 pylint coveralls cryptography - - python --version 2>&1 | grep -q "Python 2" && pip install mock || true script: - flake8 adb_shell/ && pylint adb_shell/ && coverage run --source adb_shell setup.py test && coverage report -m after_success: diff --git a/README.rst b/README.rst index 9f50ca1..edc8f24 100644 --- a/README.rst +++ b/README.rst @@ -32,15 +32,15 @@ Example Usage # Connect (no authentication necessary) device1 = AdbDeviceTcp('192.168.0.111', 5555, default_timeout_s=9.) - device1.connect(auth_timeout_s=0.1) + await device1.connect(auth_timeout_s=0.1) # Connect (authentication required) with open('path/to/adbkey') as f: priv = f.read() signer = PythonRSASigner('', priv) device2 = AdbDeviceTcp('192.168.0.222', 5555, default_timeout_s=9.) - device2.connect(rsa_keys=[signer], auth_timeout_s=0.1) + await device2.connect(rsa_keys=[signer], auth_timeout_s=0.1) # Send a shell command - response1 = device1.shell('echo TEST1') - response2 = device2.shell('echo TEST2') + response1 = await device1.shell('echo TEST1') + response2 = await device2.shell('echo TEST2') diff --git a/adb_shell/adb_device.py b/adb_shell/adb_device.py index 6f8c3fd..771b367 100644 --- a/adb_shell/adb_device.py +++ b/adb_shell/adb_device.py @@ -237,6 +237,7 @@ def __init__(self, handle, banner=None): self._banner = banner else: try: + # TODO: make this async / don't do I/O self._banner = bytearray(socket.gethostname(), 'utf-8') except: # noqa pylint: disable=bare-except self._banner = bytearray('unknown', 'utf-8') @@ -260,14 +261,14 @@ def available(self): """ return self._available - def close(self): + async def close(self): """Close the connection via the provided handle's ``close()`` method. """ self._available = False - self._handle.close() + await self._handle.close() - def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAULT_AUTH_TIMEOUT_S, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, auth_callback=None): + async def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAULT_AUTH_TIMEOUT_S, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, auth_callback=None): """Establish an ADB connection to the device. 1. Use the handle to establish a connection @@ -314,16 +315,16 @@ def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAUL """ # 1. Use the handle to establish a connection - self._handle.close() - self._handle.connect(timeout_s) + await self._handle.close() + await self._handle.connect(timeout_s) # 2. Send a ``b'CNXN'`` message msg = AdbMessage(constants.CNXN, constants.VERSION, constants.MAX_ADB_DATA, b'host::%s\0' % self._banner) adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) - self._send(msg, adb_info) + await self._send(msg, adb_info) # 3. Unpack the ``cmd``, ``arg0``, ``arg1``, and ``banner`` fields from the response - cmd, arg0, arg1, banner = self._read([constants.AUTH, constants.CNXN], adb_info) + cmd, arg0, arg1, banner = await self._read([constants.AUTH, constants.CNXN], adb_info) # 4. If ``cmd`` is not ``b'AUTH'``, then authentication is not necesary and so we are done if cmd != constants.AUTH: @@ -332,23 +333,23 @@ def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAUL # 5. If no ``rsa_keys`` are provided, raise an exception if not rsa_keys: - self._handle.close() + await self._handle.close() raise exceptions.DeviceAuthError('Device authentication required, no keys available.') # 6. Loop through our keys, signing the last ``banner`` that we received for rsa_key in rsa_keys: # 6.1. If the last ``arg0`` was not :const:`adb_shell.constants.AUTH_TOKEN`, raise an exception if arg0 != constants.AUTH_TOKEN: - self._handle.close() + await self._handle.close() raise exceptions.InvalidResponseError('Unknown AUTH response: %s %s %s' % (arg0, arg1, banner)) # 6.2. Sign the last ``banner`` and send it in an ``b'AUTH'`` message signed_token = rsa_key.Sign(banner) msg = AdbMessage(constants.AUTH, constants.AUTH_SIGNATURE, 0, signed_token) - self._send(msg, adb_info) + await self._send(msg, adb_info) # 6.3. Unpack the ``cmd``, ``arg0``, and ``banner`` fields from the response via :func:`adb_shell.adb_message.unpack` - cmd, arg0, _, banner = self._read([constants.CNXN, constants.AUTH], adb_info) + cmd, arg0, _, banner = await self._read([constants.CNXN, constants.AUTH], adb_info) # 6.4. If ``cmd`` is ``b'CNXN'``, return ``banner`` if cmd == constants.CNXN: @@ -364,10 +365,10 @@ def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAUL auth_callback(self) msg = AdbMessage(constants.AUTH, constants.AUTH_RSAPUBLICKEY, 0, pubkey + b'\0') - self._send(msg, adb_info) + await self._send(msg, adb_info) adb_info.timeout_s = auth_timeout_s - cmd, arg0, _, banner = self._read([constants.CNXN], adb_info) + cmd, arg0, _, banner = await self._read([constants.CNXN], adb_info) self._available = True return True # return banner @@ -376,7 +377,7 @@ def connect(self, rsa_keys=None, timeout_s=None, auth_timeout_s=constants.DEFAUL # Services # # # # ======================================================================= # - def _service(self, service, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): + async def _service(self, service, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): """Send an ADB command to the device. Parameters @@ -401,10 +402,10 @@ def _service(self, service, command, timeout_s=None, total_timeout_s=constants.D """ adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) if decode: - return b''.join(self._streaming_command(service, command, adb_info)).decode('utf8') - return b''.join(self._streaming_command(service, command, adb_info)) + return b''.join([x async for x in self._streaming_command(service, command, adb_info)]).decode('utf8') + return b''.join([x async for x in self._streaming_command(service, command, adb_info)]) - def _streaming_service(self, service, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): + async def _streaming_service(self, service, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): """Send an ADB command to the device, yielding each line of output. Parameters @@ -430,13 +431,13 @@ def _streaming_service(self, service, command, timeout_s=None, total_timeout_s=c adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) stream = self._streaming_command(service, command, adb_info) if decode: - for line in (stream_line.decode('utf8') for stream_line in stream): + async for line in (stream_line.decode('utf8') async for stream_line in stream): yield line else: - for line in stream: + async for line in stream: yield line - def shell(self, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): + async def shell(self, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): """Send an ADB shell command to the device. Parameters @@ -460,9 +461,9 @@ def shell(self, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL if not self.available: raise exceptions.AdbConnectionError("ADB command not sent because a connection to the device has not been established. (Did you call `AdbDevice.connect()`?)") - return self._service(b'shell', command.encode('utf8'), timeout_s, total_timeout_s, decode) + return await self._service(b'shell', command.encode('utf8'), timeout_s, total_timeout_s, decode) - def streaming_shell(self, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): + async def streaming_shell(self, command, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S, decode=True): """Send an ADB shell command to the device, yielding each line of output. Parameters @@ -486,7 +487,7 @@ def streaming_shell(self, command, timeout_s=None, total_timeout_s=constants.DEF if not self.available: raise exceptions.AdbConnectionError("ADB command not sent because a connection to the device has not been established. (Did you call `AdbDevice.connect()`?)") - for line in self._streaming_service(b'shell', command.encode('utf8'), timeout_s, total_timeout_s, decode): + async for line in self._streaming_service(b'shell', command.encode('utf8'), timeout_s, total_timeout_s, decode): yield line # ======================================================================= # @@ -494,7 +495,7 @@ def streaming_shell(self, command, timeout_s=None, total_timeout_s=constants.DEF # FileSync # # # # ======================================================================= # - def list(self, device_path, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): + async def list(self, device_path, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): """Return a directory listing of the given path. Parameters @@ -517,23 +518,23 @@ def list(self, device_path, timeout_s=None, total_timeout_s=constants.DEFAULT_TO adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) filesync_info = _FileSyncTransactionInfo(constants.FILESYNC_LIST_FORMAT) - self._open(b'sync:', adb_info) + await self._open(b'sync:', adb_info) - self._filesync_send(constants.LIST, adb_info, filesync_info, data=device_path) + await self._filesync_send(constants.LIST, adb_info, filesync_info, data=device_path) files = [] - for cmd_id, header, filename in self._filesync_read_until([constants.DENT], [constants.DONE], adb_info, filesync_info): + async for cmd_id, header, filename in self._filesync_read_until([constants.DENT], [constants.DONE], adb_info, filesync_info): if cmd_id == constants.DONE: break mode, size, mtime = header files.append(DeviceFile(filename, mode, size, mtime)) - self._close(adb_info) + await self._close(adb_info) return files - def pull(self, device_filename, dest_file=None, progress_callback=None, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): + async def pull(self, device_filename, dest_file=None, progress_callback=None, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): """Pull a file from the device. Parameters @@ -574,9 +575,9 @@ def pull(self, device_filename, dest_file=None, progress_callback=None, timeout_ filesync_info = _FileSyncTransactionInfo(constants.FILESYNC_PULL_FORMAT) with _open(dest_file, 'wb') as dest: - self._open(b'sync:', adb_info) - self._pull(device_filename, dest, progress_callback, adb_info, filesync_info) - self._close(adb_info) + await self._open(b'sync:', adb_info) + await self._pull(device_filename, dest, progress_callback, adb_info, filesync_info) + await self._close(adb_info) if isinstance(dest, io.BytesIO): return dest.getvalue() @@ -587,7 +588,7 @@ def pull(self, device_filename, dest_file=None, progress_callback=None, timeout_ # We don't know what the path is, so we just assume it exists. return True - def _pull(self, filename, dest, progress_callback, adb_info, filesync_info): + async def _pull(self, filename, dest, progress_callback, adb_info, filesync_info): """Pull a file from the device into the file-like ``dest_file``. Parameters @@ -605,12 +606,12 @@ def _pull(self, filename, dest, progress_callback, adb_info, filesync_info): """ if progress_callback: - total_bytes = self.stat(filename)[1] + total_bytes = await self.stat(filename)[1] progress = self._handle_progress(lambda current: progress_callback(filename, current, total_bytes)) next(progress) - self._filesync_send(constants.RECV, adb_info, filesync_info, data=filename) - for cmd_id, _, data in self._filesync_read_until([constants.DATA], [constants.DONE], adb_info, filesync_info): + await self._filesync_send(constants.RECV, adb_info, filesync_info, data=filename) + async for cmd_id, _, data in self._filesync_read_until([constants.DATA], [constants.DONE], adb_info, filesync_info): if cmd_id == constants.DONE: break @@ -618,7 +619,7 @@ def _pull(self, filename, dest, progress_callback, adb_info, filesync_info): if progress_callback: progress.send(len(data)) - def push(self, source_file, device_filename, st_mode=constants.DEFAULT_PUSH_MODE, mtime=0, progress_callback=None, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): + async def push(self, source_file, device_filename, st_mode=constants.DEFAULT_PUSH_MODE, mtime=0, progress_callback=None, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): """Push a file or directory to the device. Parameters @@ -645,21 +646,21 @@ def push(self, source_file, device_filename, st_mode=constants.DEFAULT_PUSH_MODE if isinstance(source_file, str): if os.path.isdir(source_file): - self.shell("mkdir " + device_filename, timeout_s, total_timeout_s) + await self.shell("mkdir " + device_filename, timeout_s, total_timeout_s) for f in os.listdir(source_file): - self.push(os.path.join(source_file, f), device_filename + '/' + f, progress_callback=progress_callback) + await self.push(os.path.join(source_file, f), device_filename + '/' + f, progress_callback=progress_callback) return adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) filesync_info = _FileSyncTransactionInfo(constants.FILESYNC_PUSH_FORMAT) with _open(source_file, 'rb') as source: - self._open(b'sync:', adb_info) - self._push(source, device_filename, st_mode, mtime, progress_callback, adb_info, filesync_info) + await self._open(b'sync:', adb_info) + await self._push(source, device_filename, st_mode, mtime, progress_callback, adb_info, filesync_info) - self._close(adb_info) + await self._close(adb_info) - def _push(self, datafile, filename, st_mode, mtime, progress_callback, adb_info, filesync_info): + async def _push(self, datafile, filename, st_mode, mtime, progress_callback, adb_info, filesync_info): """Push a file-like object to the device. Parameters @@ -685,7 +686,7 @@ def _push(self, datafile, filename, st_mode, mtime, progress_callback, adb_info, """ fileinfo = ('{},{}'.format(filename, int(st_mode))).encode('utf-8') - self._filesync_send(constants.SEND, adb_info, filesync_info, data=fileinfo) + await self._filesync_send(constants.SEND, adb_info, filesync_info, data=fileinfo) if progress_callback: total_bytes = os.fstat(datafile.fileno()).st_size if isinstance(datafile, FILE_TYPES) else -1 @@ -695,7 +696,7 @@ def _push(self, datafile, filename, st_mode, mtime, progress_callback, adb_info, while True: data = datafile.read(constants.MAX_PUSH_DATA) if data: - self._filesync_send(constants.DATA, adb_info, filesync_info, data=data) + await self._filesync_send(constants.DATA, adb_info, filesync_info, data=data) if progress_callback: progress.send(len(data)) @@ -706,14 +707,14 @@ def _push(self, datafile, filename, st_mode, mtime, progress_callback, adb_info, mtime = int(time.time()) # DONE doesn't send data, but it hides the last bit of data in the size field. - self._filesync_send(constants.DONE, adb_info, filesync_info, size=mtime) - for cmd_id, _, data in self._filesync_read_until([], [constants.OKAY, constants.FAIL], adb_info, filesync_info): + await self._filesync_send(constants.DONE, adb_info, filesync_info, size=mtime) + async for cmd_id, _, data in self._filesync_read_until([], [constants.OKAY, constants.FAIL], adb_info, filesync_info): if cmd_id == constants.OKAY: return raise exceptions.PushFailedError(data) - def stat(self, device_filename, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): + async def stat(self, device_filename, timeout_s=None, total_timeout_s=constants.DEFAULT_TOTAL_TIMEOUT_S): """Get a file's ``stat()`` information. Parameters @@ -739,12 +740,12 @@ def stat(self, device_filename, timeout_s=None, total_timeout_s=constants.DEFAUL raise exceptions.AdbConnectionError("ADB command not sent because a connection to the device has not been established. (Did you call `AdbDevice.connect()`?)") adb_info = _AdbTransactionInfo(None, None, timeout_s, total_timeout_s) - self._open(b'sync:', adb_info) + await self._open(b'sync:', adb_info) filesync_info = _FileSyncTransactionInfo(constants.FILESYNC_STAT_FORMAT) - self._filesync_send(constants.STAT, adb_info, filesync_info, data=device_filename) - _, (mode, size, mtime), _ = self._filesync_read([constants.STAT], adb_info, filesync_info, read_data=False) - self._close(adb_info) + await self._filesync_send(constants.STAT, adb_info, filesync_info, data=device_filename) + _, (mode, size, mtime), _ = await self._filesync_read([constants.STAT], adb_info, filesync_info, read_data=False) + await self._close(adb_info) return mode, size, mtime @@ -753,7 +754,7 @@ def stat(self, device_filename, timeout_s=None, total_timeout_s=constants.DEFAUL # Hidden Methods # # # # ======================================================================= # - def _close(self, adb_info): + async def _close(self, adb_info): """Send a ``b'CLSE'`` message. .. warning:: @@ -768,10 +769,10 @@ def _close(self, adb_info): """ msg = AdbMessage(constants.CLSE, adb_info.local_id, adb_info.remote_id) - self._send(msg, adb_info) - self._read_until([constants.CLSE], adb_info) + await self._send(msg, adb_info) + await self._read_until([constants.CLSE], adb_info) - def _okay(self, adb_info): + async def _okay(self, adb_info): """Send an ``b'OKAY'`` mesage. Parameters @@ -781,9 +782,9 @@ def _okay(self, adb_info): """ msg = AdbMessage(constants.OKAY, adb_info.local_id, adb_info.remote_id) - self._send(msg, adb_info) + await self._send(msg, adb_info) - def _open(self, destination, adb_info): + async def _open(self, destination, adb_info): """Opens a new connection to the device via an ``b'OPEN'`` message. 1. :meth:`~AdbDevice._send` an ``b'OPEN'`` command to the device that specifies the ``local_id`` @@ -810,13 +811,13 @@ def _open(self, destination, adb_info): """ adb_info.local_id = 1 msg = AdbMessage(constants.OPEN, adb_info.local_id, 0, destination + b'\0') - self._send(msg, adb_info) - _, adb_info.remote_id, their_local_id, _ = self._read([constants.OKAY], adb_info) + await self._send(msg, adb_info) + _, adb_info.remote_id, their_local_id, _ = await self._read([constants.OKAY], adb_info) if adb_info.local_id != their_local_id: raise exceptions.InvalidResponseError('Expected the local_id to be {}, got {}'.format(adb_info.local_id, their_local_id)) - def _read(self, expected_cmds, adb_info): + async def _read(self, expected_cmds, adb_info): """Receive a response from the device. 1. Read a message from the device and unpack the ``cmd``, ``arg0``, ``arg1``, ``data_length``, and ``data_checksum`` fields @@ -856,7 +857,7 @@ def _read(self, expected_cmds, adb_info): start = time.time() while True: - msg = self._handle.bulk_read(constants.MESSAGE_SIZE, adb_info.timeout_s) + msg = await self._handle.bulk_read(constants.MESSAGE_SIZE, adb_info.timeout_s) _LOGGER.debug("bulk_read(%d): %s", constants.MESSAGE_SIZE, repr(msg)) cmd, arg0, arg1, data_length, data_checksum = unpack(msg) command = constants.WIRE_TO_ID.get(cmd) @@ -873,7 +874,7 @@ def _read(self, expected_cmds, adb_info): if data_length > 0: data = bytearray() while data_length > 0: - temp = self._handle.bulk_read(data_length, adb_info.timeout_s) + temp = await self._handle.bulk_read(data_length, adb_info.timeout_s) _LOGGER.debug("bulk_read(%d): %s", data_length, repr(temp)) data += temp @@ -888,7 +889,7 @@ def _read(self, expected_cmds, adb_info): return command, arg0, arg1, bytes(data) - def _read_until(self, expected_cmds, adb_info): + async def _read_until(self, expected_cmds, adb_info): """Read a packet, acknowledging any write packets. 1. Read data via :meth:`AdbDevice._read` @@ -923,7 +924,7 @@ def _read_until(self, expected_cmds, adb_info): start = time.time() while True: - cmd, remote_id2, local_id2, data = self._read(expected_cmds, adb_info) + cmd, remote_id2, local_id2, data = await self._read(expected_cmds, adb_info) if local_id2 not in (0, adb_info.local_id): raise exceptions.InterleavedDataError("We don't support multiple streams...") @@ -941,11 +942,11 @@ def _read_until(self, expected_cmds, adb_info): # Acknowledge write packets if cmd == constants.WRTE: - self._okay(adb_info) + await self._okay(adb_info) return cmd, data - def _read_until_close(self, adb_info): + async def _read_until_close(self, adb_info): """Yield packets until a ``b'CLSE'`` packet is received. 1. Read the ``cmd`` and ``data`` fields from a ``b'CLSE'`` or ``b'WRTE'`` packet via :meth:`AdbDevice._read_until` @@ -970,16 +971,16 @@ def _read_until_close(self, adb_info): """ while True: - cmd, data = self._read_until([constants.CLSE, constants.WRTE], adb_info) + cmd, data = await self._read_until([constants.CLSE, constants.WRTE], adb_info) if cmd == constants.CLSE: msg = AdbMessage(constants.CLSE, adb_info.local_id, adb_info.remote_id) - self._send(msg, adb_info) + await self._send(msg, adb_info) break yield data - def _send(self, msg, adb_info): + async def _send(self, msg, adb_info): """Send a message to the device. 1. Send the message header (:meth:`adb_shell.adb_message.AdbMessage.pack `) @@ -995,11 +996,11 @@ def _send(self, msg, adb_info): """ _LOGGER.debug("bulk_write: %s", repr(msg.pack())) - self._handle.bulk_write(msg.pack(), adb_info.timeout_s) + await self._handle.bulk_write(msg.pack(), adb_info.timeout_s) _LOGGER.debug("bulk_write: %s", repr(msg.data)) - self._handle.bulk_write(msg.data, adb_info.timeout_s) + await self._handle.bulk_write(msg.data, adb_info.timeout_s) - def _streaming_command(self, service, command, adb_info): + async def _streaming_command(self, service, command, adb_info): """One complete set of USB packets for a single command. 1. :meth:`~AdbDevice._open` a new connection to the device, where the ``destination`` parameter is ``service:command`` @@ -1026,12 +1027,12 @@ def _streaming_command(self, service, command, adb_info): The responses from the service. """ - self._open(b'%s:%s' % (service, command), adb_info) + await self._open(b'%s:%s' % (service, command), adb_info) - for data in self._read_until_close(adb_info): + async for data in self._read_until_close(adb_info): yield data - def _write(self, data, adb_info): + async def _write(self, data, adb_info): """Write a packet and expect an Ack. Parameters @@ -1043,17 +1044,17 @@ def _write(self, data, adb_info): """ msg = AdbMessage(constants.WRTE, adb_info.local_id, adb_info.remote_id, data) - self._send(msg, adb_info) + await self._send(msg, adb_info) # Expect an ack in response. - self._read_until([constants.OKAY], adb_info) + await self._read_until([constants.OKAY], adb_info) # ======================================================================= # # # # FileSync Hidden Methods # # # # ======================================================================= # - def _filesync_flush(self, adb_info, filesync_info): + async def _filesync_flush(self, adb_info, filesync_info): """Write the data in the buffer up to ``filesync_info.send_idx``, then set ``filesync_info.send_idx`` to 0. Parameters @@ -1064,10 +1065,10 @@ def _filesync_flush(self, adb_info, filesync_info): Data and storage for this FileSync transaction """ - self._write(filesync_info.send_buffer[:filesync_info.send_idx], adb_info) + await self._write(filesync_info.send_buffer[:filesync_info.send_idx], adb_info) filesync_info.send_idx = 0 - def _filesync_read(self, expected_ids, adb_info, filesync_info, read_data=True): + async def _filesync_read(self, expected_ids, adb_info, filesync_info, read_data=True): """Read ADB messages and return FileSync packets. Parameters @@ -1099,10 +1100,10 @@ def _filesync_read(self, expected_ids, adb_info, filesync_info, read_data=True): """ if filesync_info.send_idx: - self._filesync_flush(adb_info, filesync_info) + await self._filesync_flush(adb_info, filesync_info) # Read one filesync packet off the recv buffer. - header_data = self._filesync_read_buffered(filesync_info.recv_message_size, adb_info, filesync_info) + header_data = await self._filesync_read_buffered(filesync_info.recv_message_size, adb_info, filesync_info) header = struct.unpack(filesync_info.recv_message_format, header_data) # Header is (ID, ...). command_id = constants.FILESYNC_WIRE_TO_ID[header[0]] @@ -1122,11 +1123,11 @@ def _filesync_read(self, expected_ids, adb_info, filesync_info, read_data=True): # Header is (ID, ..., size). size = header[-1] - data = self._filesync_read_buffered(size, adb_info, filesync_info) + data = await self._filesync_read_buffered(size, adb_info, filesync_info) return command_id, header[1:-1], data - def _filesync_read_buffered(self, size, adb_info, filesync_info): + async def _filesync_read_buffered(self, size, adb_info, filesync_info): """Read ``size`` bytes of data from ``self.recv_buffer``. Parameters @@ -1146,14 +1147,14 @@ def _filesync_read_buffered(self, size, adb_info, filesync_info): """ # Ensure recv buffer has enough data. while len(filesync_info.recv_buffer) < size: - _, data = self._read_until([constants.WRTE], adb_info) + _, data = await self._read_until([constants.WRTE], adb_info) filesync_info.recv_buffer += data result = filesync_info.recv_buffer[:size] filesync_info.recv_buffer = filesync_info.recv_buffer[size:] return result - def _filesync_read_until(self, expected_ids, finish_ids, adb_info, filesync_info): + async def _filesync_read_until(self, expected_ids, finish_ids, adb_info, filesync_info): """Useful wrapper around :meth:`AdbDevice._filesync_read`. Parameters @@ -1178,7 +1179,7 @@ def _filesync_read_until(self, expected_ids, finish_ids, adb_info, filesync_info """ while True: - cmd_id, header, data = self._filesync_read(expected_ids + finish_ids, adb_info, filesync_info) + cmd_id, header, data = await self._filesync_read(expected_ids + finish_ids, adb_info, filesync_info) yield cmd_id, header, data # These lines are not reachable because whenever this method is called and `cmd_id` is in `finish_ids`, the code @@ -1186,7 +1187,7 @@ def _filesync_read_until(self, expected_ids, finish_ids, adb_info, filesync_info if cmd_id in finish_ids: # pragma: no cover break - def _filesync_send(self, command_id, adb_info, filesync_info, data=b'', size=0): + async def _filesync_send(self, command_id, adb_info, filesync_info, data=b'', size=0): """Send/buffer FileSync packets. Packets are buffered and only flushed when this connection is read from. All @@ -1212,7 +1213,7 @@ def _filesync_send(self, command_id, adb_info, filesync_info, data=b'', size=0): size = len(data) if not filesync_info.can_add_to_send_buffer(len(data)): - self._filesync_flush(adb_info, filesync_info) + await self._filesync_flush(adb_info, filesync_info) buf = struct.pack(b'<2I', constants.FILESYNC_ID_TO_WIRE[command_id], size) + data filesync_info.send_buffer[filesync_info.send_idx:filesync_info.send_idx + len(buf)] = buf diff --git a/adb_shell/handle/base_handle.py b/adb_shell/handle/base_handle.py index 57f3eb5..642d08f 100644 --- a/adb_shell/handle/base_handle.py +++ b/adb_shell/handle/base_handle.py @@ -32,13 +32,13 @@ class BaseHandle(ABC): """ @abstractmethod - def close(self): + async def close(self): """Close the connection. """ @abstractmethod - def connect(self, timeout_s=None): + async def connect(self, timeout_s=None): """Create a connection to the device. Parameters @@ -49,7 +49,7 @@ def connect(self, timeout_s=None): """ @abstractmethod - def bulk_read(self, numbytes, timeout_s=None): + async def bulk_read(self, numbytes, timeout_s=None): """Read data from the device. Parameters @@ -67,7 +67,7 @@ def bulk_read(self, numbytes, timeout_s=None): """ @abstractmethod - def bulk_write(self, data, timeout_s=None): + async def bulk_write(self, data, timeout_s=None): """Send data to the device. Parameters diff --git a/adb_shell/handle/tcp_handle.py b/adb_shell/handle/tcp_handle.py index 60080e1..c7bb946 100644 --- a/adb_shell/handle/tcp_handle.py +++ b/adb_shell/handle/tcp_handle.py @@ -30,8 +30,7 @@ """ -import select -import socket +import asyncio from .base_handle import BaseHandle from ..exceptions import TcpTimeoutException @@ -51,14 +50,16 @@ class TcpHandle(BaseHandle): Attributes ---------- - _connection : socket.socket, None - A socket connection to the device _default_timeout_s : float, None Default timeout in seconds for TCP packets, or ``None`` _host : str The address of the device; may be an IP address or a host name _port : int The device port to which we are connecting (default is 5555) + _reader : StreamReader, None + TODO + _writer : StreamWriter, None + TODO """ def __init__(self, host, port=5555, default_timeout_s=None): @@ -66,22 +67,24 @@ def __init__(self, host, port=5555, default_timeout_s=None): self._port = port self._default_timeout_s = default_timeout_s - self._connection = None + self._reader = None + self._writer = None - def close(self): + async def close(self): """Close the socket connection. """ - if self._connection: + if self._writer: try: - self._connection.shutdown(socket.SHUT_RDWR) + self._writer.close() + await self._writer.wait_closed() except OSError: pass - self._connection.close() - self._connection = None + self._reader = None + self._writer = None - def connect(self, timeout_s=None): + async def connect(self, timeout_s=None): """Create a socket connection to the device. Parameters @@ -91,13 +94,14 @@ def connect(self, timeout_s=None): """ timeout = self._default_timeout_s if timeout_s is None else timeout_s - self._connection = socket.create_connection((self._host, self._port), timeout=timeout) - if timeout: - # Put the socket in non-blocking mode - # https://docs.python.org/3/library/socket.html#socket.socket.settimeout - self._connection.setblocking(0) - def bulk_read(self, numbytes, timeout_s=None): + try: + self._reader, self._writer = await asyncio.wait_for(asyncio.open_connection(self._host, self._port), timeout) + except asyncio.TimeoutError: + msg = 'Connecting to {}:{} timed out ({} seconds)'.format(self._host, self._port, timeout) + raise TcpTimeoutException(msg) + + async def bulk_read(self, numbytes, timeout_s=None): """Receive data from the socket. Parameters @@ -119,14 +123,14 @@ def bulk_read(self, numbytes, timeout_s=None): """ timeout = self._default_timeout_s if timeout_s is None else timeout_s - readable, _, _ = select.select([self._connection], [], [], timeout) - if readable: - return self._connection.recv(numbytes) - msg = 'Reading from {}:{} timed out ({} seconds)'.format(self._host, self._port, timeout) - raise TcpTimeoutException(msg) + try: + return await asyncio.wait_for(self._reader.read(numbytes), timeout) + except asyncio.TimeoutError: + msg = 'Reading from {}:{} timed out ({} seconds)'.format(self._host, self._port, timeout) + raise TcpTimeoutException(msg) - def bulk_write(self, data, timeout_s=None): + async def bulk_write(self, data, timeout_s=None): """Send data to the socket. Parameters @@ -148,9 +152,11 @@ def bulk_write(self, data, timeout_s=None): """ timeout = self._default_timeout_s if timeout_s is None else timeout_s - _, writeable, _ = select.select([], [self._connection], [], timeout) - if writeable: - return self._connection.send(data) - msg = 'Sending data to {}:{} timed out after {} seconds. No data was sent.'.format(self._host, self._port, timeout) - raise TcpTimeoutException(msg) + try: + self._writer.write(data) + await asyncio.wait_for(self._writer.drain(), timeout) + return len(data) + except asyncio.TimeoutError: + msg = 'Sending data to {}:{} timed out after {} seconds. No data was sent.'.format(self._host, self._port, timeout) + raise TcpTimeoutException(msg) diff --git a/setup.py b/setup.py index 2f3588f..4577a87 100644 --- a/setup.py +++ b/setup.py @@ -15,9 +15,9 @@ packages=['adb_shell', 'adb_shell.auth', 'adb_shell.handle'], install_requires=['cryptography', 'pyasn1', 'rsa'], tests_require=['pycryptodome'], + python_requires='>=3.6', classifiers=['Operating System :: OS Independent', 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 2'], + 'Programming Language :: Python :: 3'], test_suite='tests' ) diff --git a/tests/async_wrapper.py b/tests/async_wrapper.py new file mode 100644 index 0000000..049feca --- /dev/null +++ b/tests/async_wrapper.py @@ -0,0 +1,12 @@ +import asyncio + + +def _await(coro): + return asyncio.get_event_loop().run_until_complete(coro) + + +def awaiter(func): + def sync_func(*args, **kwargs): + return _await(func(*args, **kwargs)) + + return sync_func diff --git a/tests/keygen_stub.py b/tests/keygen_stub.py index 3a8c16f..db5c3fa 100644 --- a/tests/keygen_stub.py +++ b/tests/keygen_stub.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from mock import patch +from unittest.mock import patch class FileReadWrite(object): diff --git a/tests/patchers.py b/tests/patchers.py index c41dfd6..06bc735 100644 --- a/tests/patchers.py +++ b/tests/patchers.py @@ -1,4 +1,4 @@ -from mock import patch +from unittest.mock import patch from adb_shell import constants from adb_shell.adb_message import AdbMessage, unpack @@ -46,19 +46,21 @@ def __init__(self, *args, **kwargs): self._bulk_read = b'' self._bulk_write = b'' - def close(self): - self._connection = None + async def close(self): + self._reader = None + self._writer = None - def connect(self, auth_timeout_s=None): - self._connection = True + async def connect(self, auth_timeout_s=None): + self._reader = True + self._writer = True - def bulk_read(self, numbytes, timeout_s=None): + async def bulk_read(self, numbytes, timeout_s=None): num = min(numbytes, constants.MAX_ADB_DATA) ret = self._bulk_read[:num] self._bulk_read = self._bulk_read[num:] return ret - def bulk_write(self, data, timeout_s=None): + async def bulk_write(self, data, timeout_s=None): self._bulk_write += data return len(data) diff --git a/tests/test_adb_device.py b/tests/test_adb_device.py index d7b7fcc..780fe19 100644 --- a/tests/test_adb_device.py +++ b/tests/test_adb_device.py @@ -1,9 +1,9 @@ +import asyncio import logging from io import BytesIO import sys import unittest - -from mock import mock_open, patch +from unittest.mock import mock_open, patch from adb_shell import constants, exceptions from adb_shell.adb_device import AdbDevice, AdbDeviceTcp, DeviceFile @@ -12,6 +12,7 @@ from adb_shell.auth.sign_pythonrsa import PythonRSASigner from . import patchers +from .async_wrapper import awaiter from .filesync_helpers import FileSyncMessage, FileSyncListMessage, FileSyncStatMessage from .keygen_stub import open_priv_pub @@ -46,41 +47,45 @@ def setUp(self): def tearDown(self): self.assertFalse(self.device._handle._bulk_read) - def test_adb_connection_error(self): + @awaiter + async def test_adb_connection_error(self): with self.assertRaises(exceptions.AdbConnectionError): - self.device.shell('FAIL') + await self.device.shell('FAIL') with self.assertRaises(exceptions.AdbConnectionError): - ''.join(self.device.streaming_shell('FAIL')) + async_generator = self.device.streaming_shell('FAIL') + await async_generator.__anext__() with self.assertRaises(exceptions.AdbConnectionError): - self.device.list('FAIL') + await self.device.list('FAIL') with self.assertRaises(exceptions.AdbConnectionError): - self.device.push('FAIL', 'FAIL') + await self.device.push('FAIL', 'FAIL') with self.assertRaises(exceptions.AdbConnectionError): - self.device.pull('FAIL', 'FAIL') + await self.device.pull('FAIL', 'FAIL') with self.assertRaises(exceptions.AdbConnectionError): - self.device.stat('FAIL') + await self.device.stat('FAIL') self.device._handle._bulk_read = b'' - def test_init_tcp(self): + @awaiter + async def test_init_tcp(self): with patchers.PATCH_TCP_HANDLE: tcp_device = AdbDeviceTcp('host') tcp_device._handle._bulk_read = self.device._handle._bulk_read # Make sure that the `connect()` method works - self.assertTrue(tcp_device.connect()) + self.assertTrue(await tcp_device.connect()) self.assertTrue(tcp_device.available) # Clear the `_bulk_read` buffer so that `self.tearDown()` passes self.device._handle._bulk_read = b'' - def test_init_banner(self): + @awaiter + async def test_init_banner(self): device_with_banner = AdbDevice(handle=patchers.FakeTcpHandle('host', 5555), banner='banner') self.assertEqual(device_with_banner._banner, b'banner') @@ -97,21 +102,24 @@ def test_init_banner(self): # Clear the `_bulk_read` buffer so that `self.tearDown()` passes self.device._handle._bulk_read = b'' - def test_init_invalid_handle(self): + @awaiter + async def test_init_invalid_handle(self): with self.assertRaises(exceptions.InvalidHandleError): device = AdbDevice(handle=123) # Clear the `_bulk_read` buffer so that `self.tearDown()` passes self.device._handle._bulk_read = b'' - def test_available(self): + @awaiter + async def test_available(self): self.assertFalse(self.device.available) # Clear the `_bulk_read` buffer so that `self.tearDown()` passes self.device._handle._bulk_read = b'' - def test_close(self): - self.assertFalse(self.device.close()) + @awaiter + async def test_close(self): + self.assertFalse(await self.device.close()) self.assertFalse(self.device.available) # Clear the `_bulk_read` buffer so that `self.tearDown()` passes @@ -122,18 +130,21 @@ def test_close(self): # `connect` tests # # # # ======================================================================= # - def test_connect(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_connect(self): + self.assertTrue(await self.device.connect()) self.assertTrue(self.device.available) - def test_connect_no_keys(self): + @awaiter + async def test_connect_no_keys(self): self.device._handle._bulk_read = b''.join(patchers.BULK_READ_LIST_WITH_AUTH[:2]) with self.assertRaises(exceptions.DeviceAuthError): - self.device.connect() + await self.device.connect() self.assertFalse(self.device.available) - def test_connect_with_key_invalid_response(self): + @awaiter + async def test_connect_with_key_invalid_response(self): with patch('adb_shell.auth.sign_pythonrsa.open', open_priv_pub), patch('adb_shell.auth.keygen.open', open_priv_pub): keygen('tests/adbkey') signer = PythonRSASigner.FromRSAKeyPath('tests/adbkey') @@ -141,20 +152,22 @@ def test_connect_with_key_invalid_response(self): self.device._handle._bulk_read = b''.join(patchers.BULK_READ_LIST_WITH_AUTH_INVALID) with self.assertRaises(exceptions.InvalidResponseError): - self.device.connect([signer]) + await self.device.connect([signer]) self.assertFalse(self.device.available) - def test_connect_with_key(self): + @awaiter + async def test_connect_with_key(self): with patch('adb_shell.auth.sign_pythonrsa.open', open_priv_pub), patch('adb_shell.auth.keygen.open', open_priv_pub): keygen('tests/adbkey') signer = PythonRSASigner.FromRSAKeyPath('tests/adbkey') self.device._handle._bulk_read = b''.join(patchers.BULK_READ_LIST_WITH_AUTH) - self.assertTrue(self.device.connect([signer])) + self.assertTrue(await self.device.connect([signer])) - def test_connect_with_new_key(self): + @awaiter + async def test_connect_with_new_key(self): with patch('adb_shell.auth.sign_pythonrsa.open', open_priv_pub), patch('adb_shell.auth.keygen.open', open_priv_pub): keygen('tests/adbkey') signer = PythonRSASigner.FromRSAKeyPath('tests/adbkey') @@ -162,9 +175,10 @@ def test_connect_with_new_key(self): self.device._handle._bulk_read = b''.join(patchers.BULK_READ_LIST_WITH_AUTH_NEW_KEY) - self.assertTrue(self.device.connect([signer])) + self.assertTrue(await self.device.connect([signer])) - def test_connect_with_new_key_and_callback(self): + @awaiter + async def test_connect_with_new_key_and_callback(self): with patch('adb_shell.auth.sign_pythonrsa.open', open_priv_pub), patch('adb_shell.auth.keygen.open', open_priv_pub): keygen('tests/adbkey') signer = PythonRSASigner.FromRSAKeyPath('tests/adbkey') @@ -176,7 +190,7 @@ def auth_callback(device): self.device._handle._bulk_read = b''.join(patchers.BULK_READ_LIST_WITH_AUTH_NEW_KEY) - self.assertTrue(self.device.connect([signer], auth_callback=auth_callback)) + self.assertTrue(await self.device.connect([signer], auth_callback=auth_callback)) self.assertTrue(self._callback_invoked) @@ -185,17 +199,19 @@ def auth_callback(device): # `shell` tests # # # # ======================================================================= # - def test_shell_no_return(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_no_return(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(self.device.shell('TEST'), '') + self.assertEqual(await self.device.shell('TEST'), '') - def test_shell_return_pass(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_return_pass(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), @@ -203,10 +219,11 @@ def test_shell_return_pass(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'SS'), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(self.device.shell('TEST'), 'PASS') + self.assertEqual(await self.device.shell('TEST'), 'PASS') - def test_shell_dont_decode(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_dont_decode(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), @@ -214,31 +231,34 @@ def test_shell_dont_decode(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'SS'), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(self.device.shell('TEST', decode=False), b'PASS') + self.assertEqual(await self.device.shell('TEST', decode=False), b'PASS') - def test_shell_data_length_exceeds_max(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_data_length_exceeds_max(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'0'*(constants.MAX_ADB_DATA+1)), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.device.shell('TEST') + await self.device.shell('TEST') self.assertTrue(True) - def test_shell_multibytes_sequence_exceeds_max(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_multibytes_sequence_exceeds_max(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'0'*(constants.MAX_ADB_DATA-1) + b'\xe3\x81\x82'), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(u'0'*(constants.MAX_ADB_DATA-1) + u'\u3042', self.device.shell('TEST')) + self.assertEqual(await self.device.shell('TEST'), u'0'*(constants.MAX_ADB_DATA-1) + u'\u3042') - def test_shell_with_multibytes_sequence_over_two_messages(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_with_multibytes_sequence_over_two_messages(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), @@ -246,11 +266,12 @@ def test_shell_with_multibytes_sequence_over_two_messages(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'\x81\x82'), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(u'\u3042', self.device.shell('TEST')) + self.assertEqual(await self.device.shell('TEST'), u'\u3042') - def test_shell_multiple_clse(self): + @awaiter + async def test_shell_multiple_clse(self): # https://github.com/JeffLIrion/adb_shell/issues/15#issuecomment-536795938 - self.assertTrue(self.device.connect()) + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values msg1 = AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'') @@ -268,53 +289,58 @@ def test_shell_multiple_clse(self): msg2.data, msg3.pack()]) - self.device.shell("dumpsys power | grep 'Display Power' | grep -q 'state=ON' && echo -e '1\\c' && dumpsys power | grep mWakefulness | grep -q Awake && echo -e '1\\c' && dumpsys audio | grep paused | grep -qv 'Buffer Queue' && echo -e '1\\c' || (dumpsys audio | grep started | grep -qv 'Buffer Queue' && echo '2\\c' || echo '0\\c') && dumpsys power | grep Locks | grep 'size=' && CURRENT_APP=$(dumpsys window windows | grep mCurrentFocus) && CURRENT_APP=${CURRENT_APP#*{* * } && CURRENT_APP=${CURRENT_APP%%/*} && echo $CURRENT_APP && (dumpsys media_session | grep -A 100 'Sessions Stack' | grep -A 100 $CURRENT_APP | grep -m 1 'state=PlaybackState {' || echo) && dumpsys audio | grep '\\- STREAM_MUSIC:' -A 12") - self.assertEqual(self.device.shell('TEST'), 'PASS') + await self.device.shell("dumpsys power | grep 'Display Power' | grep -q 'state=ON' && echo -e '1\\c' && dumpsys power | grep mWakefulness | grep -q Awake && echo -e '1\\c' && dumpsys audio | grep paused | grep -qv 'Buffer Queue' && echo -e '1\\c' || (dumpsys audio | grep started | grep -qv 'Buffer Queue' && echo '2\\c' || echo '0\\c') && dumpsys power | grep Locks | grep 'size=' && CURRENT_APP=$(dumpsys window windows | grep mCurrentFocus) && CURRENT_APP=${CURRENT_APP#*{* * } && CURRENT_APP=${CURRENT_APP%%/*} && echo $CURRENT_APP && (dumpsys media_session | grep -A 100 'Sessions Stack' | grep -A 100 $CURRENT_APP | grep -m 1 'state=PlaybackState {' || echo) && dumpsys audio | grep '\\- STREAM_MUSIC:' -A 12") + self.assertEqual(await self.device.shell('TEST'), 'PASS') # ======================================================================= # # # # `shell` error tests # # # # ======================================================================= # - def test_shell_error_local_id(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_local_id(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1234, data=b'\x00')) with self.assertRaises(exceptions.InvalidResponseError): - self.device.shell('TEST') + await self.device.shell('TEST') - def test_shell_error_unknown_command(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_unknown_command(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessageForTesting(command=constants.FAIL, arg0=1, arg1=1, data=b'')) with self.assertRaises(exceptions.InvalidCommandError): - self.assertEqual(self.device.shell('TEST'), '') + self.assertEqual(await self.device.shell('TEST'), '') - def test_shell_error_timeout(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_timeout(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'')) with self.assertRaises(exceptions.InvalidCommandError): - self.device.shell('TEST', total_timeout_s=-1) + await self.device.shell('TEST', total_timeout_s=-1) - def test_shell_error_timeout_multiple_clse(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_timeout_multiple_clse(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b''), AdbMessage(command=constants.CLSE, arg0=2, arg1=1, data=b'')) with self.assertRaises(exceptions.InvalidCommandError): - self.device.shell('TEST', total_timeout_s=-1) + await self.device.shell('TEST', total_timeout_s=-1) - def test_shell_error_checksum(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_checksum(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values msg1 = AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00') @@ -322,30 +348,33 @@ def test_shell_error_checksum(self): self.device._handle._bulk_read = b''.join([msg1.pack(), msg1.data, msg2.pack(), msg2.data[:-1] + b'0']) with self.assertRaises(exceptions.InvalidChecksumError): - self.device.shell('TEST') + await self.device.shell('TEST') - def test_shell_error_local_id2(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_local_id2(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), AdbMessage(command=constants.WRTE, arg0=1, arg1=2, data=b'PASS')) with self.assertRaises(exceptions.InterleavedDataError): - self.device.shell('TEST') - self.device.shell('TEST') + await self.device.shell('TEST') + await self.device.shell('TEST') - def test_shell_error_remote_id2(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_shell_error_remote_id2(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages(AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b'\x00'), AdbMessage(command=constants.WRTE, arg0=2, arg1=1, data=b'PASS')) with self.assertRaises(exceptions.InvalidResponseError): - self.device.shell('TEST') + await self.device.shell('TEST') - def test_issue29(self): + @awaiter + async def test_issue29(self): # https://github.com/JeffLIrion/adb_shell/issues/29 with patch('adb_shell.auth.sign_pythonrsa.open', open_priv_pub), patch('adb_shell.auth.keygen.open', open_priv_pub): keygen('tests/adbkey') @@ -404,27 +433,28 @@ def test_issue29(self): msg1.data, msg2.pack()]) - self.assertTrue(self.device.connect([signer])) + self.assertTrue(await self.device.connect([signer])) - self.device.shell('Android TV update command') + await self.device.shell('Android TV update command') - self.assertTrue(self.device.connect([signer])) - self.device.shell('Android TV update command') - self.device.shell('Android TV update command') - self.assertTrue(self.device.connect([signer])) - self.device.shell('Android TV update command') - self.device.shell('Android TV update command') - self.assertTrue(self.device.connect([signer])) - self.device.shell('Android TV update command') - self.device.shell('Android TV update command') + self.assertTrue(await self.device.connect([signer])) + await self.device.shell('Android TV update command') + await self.device.shell('Android TV update command') + self.assertTrue(await self.device.connect([signer])) + await self.device.shell('Android TV update command') + await self.device.shell('Android TV update command') + self.assertTrue(await self.device.connect([signer])) + await self.device.shell('Android TV update command') + await self.device.shell('Android TV update command') # ======================================================================= # # # # `streaming_shell` tests # # # # ======================================================================= # - def test_streaming_shell_decode(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_streaming_shell_decode(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages( @@ -433,12 +463,13 @@ def test_streaming_shell_decode(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'123'), ) - generator = self.device.streaming_shell('TEST', decode=True) - self.assertEqual('ABC', next(generator)) - self.assertEqual('123', next(generator)) + async_generator = self.device.streaming_shell('TEST', decode=True) + self.assertEqual(await async_generator.__anext__(), 'ABC') + self.assertEqual(await async_generator.__anext__(), '123') - def test_streaming_shell_dont_decode(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_streaming_shell_dont_decode(self): + self.assertTrue(await self.device.connect()) # Provide the `bulk_read` return values self.device._handle._bulk_read = join_messages( @@ -447,9 +478,9 @@ def test_streaming_shell_dont_decode(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=b'123'), ) - generator = self.device.streaming_shell('TEST', decode=False) - self.assertEqual(b'ABC', next(generator)) - self.assertEqual(b'123', next(generator)) + async_generator = self.device.streaming_shell('TEST', decode=False) + self.assertEqual(await async_generator.__anext__(), b'ABC') + self.assertEqual(await async_generator.__anext__(), b'123') # ======================================================================= # @@ -457,8 +488,9 @@ def test_streaming_shell_dont_decode(self): # `filesync` tests # # # # ======================================================================= # - def test_list(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_list(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' # Provide the `bulk_read` return values @@ -478,11 +510,11 @@ def test_list(self): expected_result = [DeviceFile(filename=bytearray(b'file1'), mode=1, size=2, mtime=3), DeviceFile(filename=bytearray(b'file2'), mode=4, size=5, mtime=6)] - self.assertEqual(expected_result, self.device.list('/dir')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertEqual(await self.device.list('/dir'), expected_result) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def _test_push(self, mtime): - self.assertTrue(self.device.connect()) + async def _test_push(self, mtime): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' filedata = b'Ohayou sekai.\nGood morning world!' @@ -502,19 +534,22 @@ def _test_push(self, mtime): AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) with patch('time.time', return_value=mtime): - self.device.push(BytesIO(filedata), '/data', mtime=mtime) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + await self.device.push(BytesIO(filedata), '/data', mtime=mtime) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) return True - def test_push(self): - self.assertTrue(self._test_push(100)) + @awaiter + async def test_push(self): + self.assertTrue(await self._test_push(100)) - def test_push_mtime0(self): - self.assertTrue(self._test_push(0)) + @awaiter + async def test_push_mtime0(self): + self.assertTrue(await self._test_push(0)) - def test_push_file(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_push_file(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' mtime = 100 @@ -535,11 +570,12 @@ def test_push_file(self): AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) with patch('adb_shell.adb_device.open', mock_open(read_data=filedata)): - self.device.push('TEST_FILE', '/data', mtime=mtime) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + await self.device.push('TEST_FILE', '/data', mtime=mtime) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_push_fail(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_push_fail(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' mtime = 100 @@ -551,10 +587,11 @@ def test_push_fail(self): AdbMessage(command=constants.WRTE, arg0=1, arg1=1, data=join_messages(FileSyncMessage(constants.FAIL, data=b'')))) with self.assertRaises(exceptions.PushFailedError), patch('adb_shell.adb_device.open', mock_open(read_data=filedata)): - self.device.push('TEST_FILE', '/data', mtime=mtime) + await self.device.push('TEST_FILE', '/data', mtime=mtime) - def test_push_big_file(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_push_big_file(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' mtime = 100 @@ -580,11 +617,12 @@ def test_push_big_file(self): AdbMessage(command=constants.OKAY, arg0=1, arg1=1), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.device.push(BytesIO(filedata), '/data', mtime=mtime) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + await self.device.push(BytesIO(filedata), '/data', mtime=mtime) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_push_dir(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_push_dir(self): + self.assertTrue(await self.device.connect()) mtime = 100 filedata = b'Ohayou sekai.\nGood morning world!' @@ -605,10 +643,11 @@ def test_push_dir(self): #TODO with patch('adb_shell.adb_device.open', mock_open(read_data=filedata)), patch('os.path.isdir', lambda x: x == 'TEST_DIR/'), patch('os.listdir', return_value=['TEST_FILE1', 'TEST_FILE2']): - self.device.push('TEST_DIR/', '/data', mtime=mtime) + await self.device.push('TEST_DIR/', '/data', mtime=mtime) - def test_pull(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_pull(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' filedata = b'Ohayou sekai.\nGood morning world!' @@ -626,11 +665,12 @@ def test_pull(self): AdbMessage(command=constants.OKAY, arg0=1, arg1=1), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual(filedata, self.device.pull('/data')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertEqual(await self.device.pull('/data'), filedata) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_pull_file(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_pull_file(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' filedata = b'Ohayou sekai.\nGood morning world!' @@ -649,11 +689,12 @@ def test_pull_file(self): AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) with patch('adb_shell.adb_device.open', mock_open()), patch('os.path.exists', return_value=True): - self.assertTrue(self.device.pull('/data', 'TEST_FILE')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertTrue(await self.device.pull('/data', 'TEST_FILE')) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_pull_file_return_true(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_pull_file_return_true(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' filedata = b'Ohayou sekai.\nGood morning world!' @@ -672,11 +713,12 @@ def test_pull_file_return_true(self): AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) with patch('adb_shell.adb_device.open', mock_open()), patch('adb_shell.adb_device.hasattr', return_value=False): - self.assertTrue(self.device.pull('/data', 'TEST_FILE')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertTrue(await self.device.pull('/data', 'TEST_FILE')) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_pull_big_file(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_pull_big_file(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' filedata = b'0' * int(1.5 * constants.MAX_ADB_DATA) @@ -696,11 +738,12 @@ def test_pull_big_file(self): AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) with patch('adb_shell.adb_device.open', mock_open()), patch('os.path.exists', return_value=True): - self.assertTrue(self.device.pull('/data', 'TEST_FILE')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertTrue(await self.device.pull('/data', 'TEST_FILE')) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) - def test_stat(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_stat(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' # Provide the `bulk_read` return values @@ -717,16 +760,17 @@ def test_stat(self): AdbMessage(command=constants.OKAY, arg0=1, arg1=1, data=b''), AdbMessage(command=constants.CLSE, arg0=1, arg1=1, data=b'')) - self.assertEqual((1, 2, 3), self.device.stat('/data')) - self.assertEqual(expected_bulk_write, self.device._handle._bulk_write) + self.assertEqual(await self.device.stat('/data'), (1, 2, 3)) + self.assertEqual(self.device._handle._bulk_write, expected_bulk_write) # ======================================================================= # # # # `filesync` hidden methods tests # # # # ======================================================================= # - def test_filesync_read_adb_command_failure_exceptions(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_filesync_read_adb_command_failure_exceptions(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' # Provide the `bulk_read` return values @@ -736,10 +780,11 @@ def test_filesync_read_adb_command_failure_exceptions(self): FileSyncStatMessage(constants.DONE, 0, 0, 0)))) with self.assertRaises(exceptions.AdbCommandFailureException): - self.device.stat('/data') + await self.device.stat('/data') - def test_filesync_read_invalid_response_error(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_filesync_read_invalid_response_error(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' # Provide the `bulk_read` return values @@ -749,16 +794,17 @@ def test_filesync_read_invalid_response_error(self): FileSyncStatMessage(constants.DONE, 0, 0, 0)))) with self.assertRaises(exceptions.InvalidResponseError): - self.device.stat('/data') + await self.device.stat('/data') # ======================================================================= # # # # `filesync` error tests # # # # ======================================================================= # - def test_pull_value_error(self): - self.assertTrue(self.device.connect()) + @awaiter + async def test_pull_value_error(self): + self.assertTrue(await self.device.connect()) self.device._handle._bulk_write = b'' with self.assertRaises(ValueError): - self.device.pull('device_filename', 123) + await self.device.pull('device_filename', 123) diff --git a/tests/test_adb_message.py b/tests/test_adb_message.py index 68dc21c..6184fdc 100644 --- a/tests/test_adb_message.py +++ b/tests/test_adb_message.py @@ -1,7 +1,7 @@ import os import unittest -from mock import patch +from unittest.mock import patch from adb_shell import constants from adb_shell.adb_device import AdbDevice diff --git a/tests/test_keygen.py b/tests/test_keygen.py index aacbf30..ddd1a1f 100644 --- a/tests/test_keygen.py +++ b/tests/test_keygen.py @@ -1,6 +1,6 @@ import unittest -from mock import patch +from unittest.mock import patch from adb_shell.auth.keygen import get_user_info diff --git a/tests/test_sign_cryptography.py b/tests/test_sign_cryptography.py index e1f9ac8..8e3a2ea 100644 --- a/tests/test_sign_cryptography.py +++ b/tests/test_sign_cryptography.py @@ -3,7 +3,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes -from mock import patch +from unittest.mock import patch from adb_shell.auth.keygen import keygen from adb_shell.auth.sign_cryptography import CryptographySigner diff --git a/tests/test_sign_pycryptodome.py b/tests/test_sign_pycryptodome.py index 54cb44e..eb2cb81 100644 --- a/tests/test_sign_pycryptodome.py +++ b/tests/test_sign_pycryptodome.py @@ -1,7 +1,7 @@ import os import unittest -from mock import patch +from unittest.mock import patch from adb_shell.auth.keygen import keygen from adb_shell.auth.sign_pycryptodome import PycryptodomeAuthSigner diff --git a/tests/test_sign_pythonrsa.py b/tests/test_sign_pythonrsa.py index 1e9d50e..b801a64 100644 --- a/tests/test_sign_pythonrsa.py +++ b/tests/test_sign_pythonrsa.py @@ -1,7 +1,7 @@ import os import unittest -from mock import patch +from unittest.mock import patch from adb_shell.auth.keygen import keygen from adb_shell.auth.sign_pythonrsa import PythonRSASigner diff --git a/tests/test_tcp_handle.py b/tests/test_tcp_handle.py index 395c13e..5d62901 100644 --- a/tests/test_tcp_handle.py +++ b/tests/test_tcp_handle.py @@ -1,64 +1,113 @@ +import asyncio import unittest - -from mock import patch +from unittest.mock import patch from adb_shell.exceptions import TcpTimeoutException from adb_shell.handle.tcp_handle import TcpHandle from . import patchers +from .async_wrapper import awaiter +try: + from unittest.mock import AsyncMock +except ImportError: + from unittest.mock import MagicMock -class TestTcpHandle(unittest.TestCase): - def setUp(self): - """Create a ``TcpHandle`` and connect to a TCP service. + class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) - """ - self.handle = TcpHandle('host', 5555) - with patchers.PATCH_CREATE_CONNECTION: - self.handle.connect() - def tearDown(self): - """Close the socket connection.""" - self.handle.close() +class FakeStreamWriter: + def close(self): + pass - def test_connect_with_timeout(self): - """TODO + async def wait_closed(self): + pass - """ - self.handle.close() - with patchers.PATCH_CREATE_CONNECTION: - self.handle.connect(timeout_s=1) - self.assertTrue(True) + def write(self, data): + pass - def test_bulk_read(self): - """TODO + async def drain(self): + pass - """ - # Provide the `recv` return values - self.handle._connection._recv = b'TEST1TEST2' - with patchers.PATCH_SELECT_SUCCESS: - self.assertEqual(self.handle.bulk_read(5), b'TEST1') - self.assertEqual(self.handle.bulk_read(5), b'TEST2') +class FakeStreamReader: + async def read(self, numbytes): + return b'TEST' - with patchers.PATCH_SELECT_FAIL: - with self.assertRaises(TcpTimeoutException): - self.handle.bulk_read(4) - def test_close_oserror(self): - """Test that an `OSError` exception is handled when closing the socket. - - """ - with patch('{}.patchers.FakeSocket.shutdown'.format(__name__), side_effect=OSError): - self.handle.close() - - def test_bulk_write(self): - """TODO +class TestTcpHandle(unittest.TestCase): + def setUp(self): + """Create a ``TcpHandle`` and connect to a TCP service. """ - with patchers.PATCH_SELECT_SUCCESS: - self.handle.bulk_write(b'TEST') + self.handle = TcpHandle('host', 5555) + #with patchers.PATCH_CREATE_CONNECTION: + # self.handle.connect() - with patchers.PATCH_SELECT_FAIL: - with self.assertRaises(TcpTimeoutException): - self.handle.bulk_write(b'FAIL') + '''def tearDown(self): + """Close the socket connection.""" + self.handle.close()''' + + @awaiter + async def test_close(self): + await self.handle.close() + + @awaiter + async def test_close2(self): + await self.handle.close() + + @awaiter + async def test_connect(self): + with patch('asyncio.open_connection', return_value=(True, True), new_callable=AsyncMock): + await self.handle.connect() + + @awaiter + async def test_connect_close(self): + with patch('asyncio.open_connection', return_value=(FakeStreamReader(), FakeStreamWriter()), new_callable=AsyncMock): + await self.handle.connect() + self.assertIsNotNone(self.handle._writer) + + await self.handle.close() + self.assertIsNone(self.handle._reader) + self.assertIsNone(self.handle._writer) + + @awaiter + async def test_connect_close_catch_oserror(self): + with patch('asyncio.open_connection', return_value=(FakeStreamReader(), FakeStreamWriter()), new_callable=AsyncMock): + await self.handle.connect() + self.assertIsNotNone(self.handle._writer) + + with patch('{}.FakeStreamWriter.close'.format(__name__), side_effect=OSError): + await self.handle.close() + self.assertIsNone(self.handle._reader) + self.assertIsNone(self.handle._writer) + + @awaiter + async def test_connect_with_timeout(self): + with self.assertRaises(TcpTimeoutException): + with patch('asyncio.open_connection', side_effect=asyncio.TimeoutError, new_callable=AsyncMock): + await self.handle.connect() + + @awaiter + async def test_bulk_read(self): + with patch('asyncio.open_connection', return_value=(FakeStreamReader(), FakeStreamWriter()), new_callable=AsyncMock): + await self.handle.connect() + + self.assertEqual(await self.handle.bulk_read(4), b'TEST') + + with self.assertRaises(TcpTimeoutException): + with patch('{}.FakeStreamReader.read'.format(__name__), side_effect=asyncio.TimeoutError): + await self.handle.bulk_read(4) + + @awaiter + async def test_bulk_write(self): + with patch('asyncio.open_connection', return_value=(FakeStreamReader(), FakeStreamWriter()), new_callable=AsyncMock): + await self.handle.connect() + + self.assertEqual(await self.handle.bulk_write(b'TEST'), 4) + + with self.assertRaises(TcpTimeoutException): + with patch('{}.FakeStreamWriter.write'.format(__name__), side_effect=asyncio.TimeoutError): + await self.handle.bulk_write(b'TEST')