Skip to content

Commit

Permalink
Only use 1 lock
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffLIrion committed Jul 5, 2021
1 parent 9d58e3f commit 6d78057
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 67 deletions.
59 changes: 14 additions & 45 deletions adb_shell/adb_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,45 +93,34 @@
class _AdbIOManager(object):
"""A class for handling all ADB I/O.
Notes
-----
The only places where the ``self._store_lock`` and ``self._transport_lock`` locks are held at the same time are :meth:`_AdbIOManager.close` and
:meth:`_AdbIOManager.connect`. In both places, the ``self._transport_lock`` is acquired first. Therefore, there is no potential for deadlock.
Parameters
----------
transport : BaseTransport
A transport for communicating with the device; must be an instance of a subclass of :class:`~adb_shell.transport.base_transport.BaseTransport`
Attributes
----------
_lock : Lock
A lock for protecting the packet_store and the transport
_packet_store : _AdbPacketStore
A store for holding packets that correspond to different ADB streams
_store_lock : Lock
A lock for protecting ``self._packet_store`` (this lock is never held for long)
_transport : BaseTransport
A transport for communicating with the device; must be an instance of a subclass of :class:`~adb_shell.transport.base_transport.BaseTransport`
_transport_lock : Lock
A lock for protecting ``self._transport``
"""

def __init__(self, transport):
self._lock = Lock()
self._packet_store = _AdbPacketStore()
self._transport = transport

self._store_lock = Lock()
self._transport_lock = Lock()

def close(self):
"""Close the connection via the provided transport's ``close()`` method and clear the packet store.
"""
with self._transport_lock:
with self._lock:
self._transport.close()

with self._store_lock:
self._packet_store.clear_all()
self._packet_store.clear_all()

def connect(self, banner, rsa_keys, auth_timeout_s, auth_callback, adb_info):
"""Establish an ADB connection to the device.
Expand Down Expand Up @@ -180,12 +169,10 @@ def connect(self, banner, rsa_keys, auth_timeout_s, auth_callback, adb_info):
Invalid auth response from the device
"""
with self._transport_lock:
with self._lock:
# 0. Close the connection and clear the store
self._transport.close()

with self._store_lock:
self._packet_store.clear_all()
self._packet_store.clear_all()

# 1. Use the transport to establish a connection
self._transport.connect(adb_info.transport_timeout_s)
Expand Down Expand Up @@ -281,10 +268,9 @@ def read(self, expected_cmds, adb_info, allow_zeros=False):
start = time.time()

while True:
# Should both locks be held here?
with self._lock:
# Read packets from the store until we find a match or there are no more entries

# Read packets from the store until we find a match or there are no more entries
with self._store_lock:
# Recall that `arg0` from the device corresponds to `adb_info.remote_id` and `arg1` from the device corresponds to `adb_info.local_id`
arg0_arg1 = self._packet_store.find(adb_info.remote_id, adb_info.local_id) if not allow_zeros else self._packet_store.find_allow_zeros(adb_info.remote_id, adb_info.local_id)
while arg0_arg1:
Expand All @@ -294,43 +280,26 @@ def read(self, expected_cmds, adb_info, allow_zeros=False):

arg0_arg1 = self._packet_store.find(adb_info.remote_id, adb_info.local_id) if not allow_zeros else self._packet_store.find_allow_zeros(adb_info.remote_id, adb_info.local_id)

# Read from the device
with self._transport_lock:
# Read from the device
cmd, arg0, arg1, data = self._read_packet_from_device(adb_info)

if not adb_info.args_match(arg0, arg1, allow_zeros):
# The packet is not a match -> put it in the store
with self._store_lock:
self._packet_store.put(arg0, arg1, cmd, data)
self._packet_store.put(arg0, arg1, cmd, data)

else:
# The packet is a match for this `(adb_info.local_id, adb_info.remote_id)` pair
if cmd == constants.CLSE:
# Clear the entry in the store
with self._store_lock:
self._packet_store.clear(arg0, arg1)
self._packet_store.clear(arg0, arg1)

# If `cmd` is a match, then we are done
if cmd in expected_cmds:
return cmd, arg0, arg1, data

# Check if time is up
if time.time() - start > adb_info.read_timeout_s:
break

# Try one last time to read packets from the store before throwing a timeout exception
with self._store_lock:
# Recall that `arg0` from the device corresponds to `adb_info.remote_id` and `arg1` from the device corresponds to `adb_info.local_id`
arg0_arg1 = self._packet_store.find(adb_info.remote_id, adb_info.local_id) if not allow_zeros else self._packet_store.find_allow_zeros(adb_info.remote_id, adb_info.local_id)
while arg0_arg1:
cmd, arg0, arg1, data = self._packet_store.get(arg0_arg1[0], arg0_arg1[1])
if cmd in expected_cmds:
return cmd, arg0, arg1, data

arg0_arg1 = self._packet_store.find(adb_info.remote_id, adb_info.local_id) if not allow_zeros else self._packet_store.find_allow_zeros(adb_info.remote_id, adb_info.local_id)

# Timeout
raise exceptions.AdbTimeoutError("Never got one of the expected responses: {} (transport_timeout_s = {}, read_timeout_s = {})".format(expected_cmds, adb_info.transport_timeout_s, adb_info.read_timeout_s))
raise exceptions.AdbTimeoutError("Never got one of the expected responses: {} (transport_timeout_s = {}, read_timeout_s = {})".format(expected_cmds, adb_info.transport_timeout_s, adb_info.read_timeout_s))

def send(self, msg, adb_info):
"""Send a message to the device.
Expand All @@ -347,7 +316,7 @@ def send(self, msg, adb_info):
Info and settings for this ADB transaction
"""
with self._transport_lock:
with self._lock:
self._send(msg, adb_info)

def _read_expected_packet_from_device(self, expected_cmds, adb_info):
Expand Down
22 changes: 0 additions & 22 deletions tests/test_adb_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,28 +344,6 @@ def test_shell_multiple_streams(self):
self.assertEqual(self.device.shell('TEST1'), 'PASS1')
self.assertEqual(self.device.shell('TEST2'), 'PASS2')

def test_shell_multiple_streams2(self):
self.assertTrue(self.device.connect())

def fake_read_packet_from_device(*args, **kwargs):
# Mimic the scenario that this stream's packets get read by another stream after checking the store and while waiting to acquire the transport lock
self.device._io_manager._packet_store.put(arg0=1, arg1=1, cmd=constants.WRTE, data=b'\x00')
self.device._io_manager._packet_store.put(arg0=1, arg1=1, cmd=constants.OKAY, data=b'\x00')
self.device._io_manager._packet_store.put(arg0=2, arg1=2, cmd=constants.OKAY, data=b'\x00')
self.device._io_manager._packet_store.put(arg0=1, arg1=1, cmd=constants.OKAY, data=b'\x00')
self.device._io_manager._packet_store.put(arg0=2, arg1=2, cmd=constants.WRTE, data=b'PASS2')
self.device._io_manager._packet_store.put(arg0=1, arg1=1, cmd=constants.WRTE, data=b"PASS1")
self.device._io_manager._packet_store.put(arg0=1, arg1=1, cmd=constants.CLSE, data=b"")
self.device._io_manager._packet_store.put(arg0=2, arg1=2, cmd=constants.CLSE, data=b"")

return constants.OKAY, 2, 2, b"\x00"

with patch.object(self.device._io_manager, "_read_packet_from_device", fake_read_packet_from_device):
# Use a negative timeout in order to only allow one attempt to read a packet from the device
# (All subsequent packets will be retrieved from the store)
self.assertEqual(self.device.shell('TEST1', read_timeout_s=-1), 'PASS1')
self.assertEqual(self.device.shell('TEST2', read_timeout_s=-1), 'PASS2')

# ======================================================================= #
# #
# `shell` error tests #
Expand Down

0 comments on commit 6d78057

Please sign in to comment.