Skip to content

Commit

Permalink
Use a context manager to acquire and release the ADB lock (#135)
Browse files Browse the repository at this point in the history
* Use a context manager to acquire and release the ADB lock

* Add another check

* Add test to make sure the lock gets released

* Test that the lock is not released if it is not acquired
  • Loading branch information
JeffLIrion committed Dec 14, 2019
1 parent 106415c commit f0c2653
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 50 deletions.
93 changes: 51 additions & 42 deletions androidtv/adb_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""


from contextlib import contextmanager
import logging
import sys
import threading
Expand All @@ -25,6 +26,30 @@
FileNotFoundError = IOError # pylint: disable=redefined-builtin


@contextmanager
def _acquire(lock):
"""Handle acquisition and release of a ``threading.Lock`` object with ``LOCK_KWARGS`` keyword arguments.
Parameters
----------
lock : threading.Lock
The lock that we will try to acquire
Yields
------
acquired : bool
Whether or not the lock was acquired
"""
try:
acquired = lock.acquire(**LOCK_KWARGS)
yield acquired

finally:
if acquired:
lock.release()


class ADBPython(object):
"""A manager for ADB connections that uses a Python implementation of the ADB protocol.
Expand Down Expand Up @@ -84,9 +109,8 @@ def connect(self, always_log_errors=True, auth_timeout_s=DEFAULT_AUTH_TIMEOUT_S)
Whether or not the connection was successfully established and the device is available
"""
if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
# Make sure that we release the lock
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
# Catch exceptions
try:
# Connect with authentication
Expand Down Expand Up @@ -135,9 +159,6 @@ def connect(self, always_log_errors=True, auth_timeout_s=DEFAULT_AUTH_TIMEOUT_S)
self._available = False
return False

finally:
self._adb_lock.release()

# Lock could not be acquired
_LOGGER.warning("Couldn't connect to %s:%d because adb-shell lock not acquired.", self.host, self.port)
self.close()
Expand All @@ -159,12 +180,11 @@ def pull(self, local_path, device_path):
_LOGGER.debug("ADB command not sent to %s:%d because adb-shell connection is not established: pull(%s, %s)", self.host, self.port, local_path, device_path)
return

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via adb-shell: pull(%s, %s)", self.host, self.port, local_path, device_path)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via adb-shell: pull(%s, %s)", self.host, self.port, local_path, device_path)
self._adb.pull(device_path, local_path)
finally:
self._adb_lock.release()
return

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d because adb-shell lock not acquired: pull(%s, %s)", self.host, self.port, local_path, device_path)
Expand All @@ -185,12 +205,11 @@ def push(self, local_path, device_path):
_LOGGER.debug("ADB command not sent to %s:%d because adb-shell connection is not established: push(%s, %s)", self.host, self.port, local_path, device_path)
return

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via adb-shell: push(%s, %s)", self.host, self.port, local_path, device_path)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via adb-shell: push(%s, %s)", self.host, self.port, local_path, device_path)
self._adb.push(local_path, device_path)
finally:
self._adb_lock.release()
return

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d because adb-shell lock not acquired: push(%s, %s)", self.host, self.port, local_path, device_path)
Expand All @@ -214,12 +233,10 @@ def shell(self, cmd):
_LOGGER.debug("ADB command not sent to %s:%d because adb-shell connection is not established: %s", self.host, self.port, cmd)
return None

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via adb-shell: %s", self.host, self.port, cmd)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via adb-shell: %s", self.host, self.port, cmd)
return self._adb.shell(cmd)
finally:
self._adb_lock.release()

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d because adb-shell lock not acquired: %s", self.host, self.port, cmd)
Expand Down Expand Up @@ -326,9 +343,8 @@ def connect(self, always_log_errors=True):
Whether or not the connection was successfully established and the device is available
"""
if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
# Make sure that we release the lock
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
# Catch exceptions
try:
self._adb_client = Client(host=self.adb_server_ip, port=self.adb_server_port)
Expand Down Expand Up @@ -357,9 +373,6 @@ def connect(self, always_log_errors=True):
self._available = False
return False

finally:
self._adb_lock.release()

# Lock could not be acquired
_LOGGER.warning("Couldn't connect to %s:%d via ADB server %s:%d because pure-python-adb lock not acquired.", self.host, self.port, self.adb_server_ip, self.adb_server_port)
self.close()
Expand All @@ -381,12 +394,11 @@ def pull(self, local_path, device_path):
_LOGGER.debug("ADB command not sent to %s:%d via ADB server %s:%d because pure-python-adb connection is not established: pull(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
return

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: pull(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: pull(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
self._adb_device.pull(device_path, local_path)
finally:
self._adb_lock.release()
return

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d via ADB server %s:%d: pull(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
Expand All @@ -407,12 +419,11 @@ def push(self, local_path, device_path):
_LOGGER.debug("ADB command not sent to %s:%d via ADB server %s:%d because pure-python-adb connection is not established: push(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
return

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: push(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: push(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
self._adb_device.push(local_path, device_path)
finally:
self._adb_lock.release()
return

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d via ADB server %s:%d: push(%s, %s)", self.host, self.port, self.adb_server_ip, self.adb_server_port, local_path, device_path)
Expand All @@ -436,12 +447,10 @@ def shell(self, cmd):
_LOGGER.debug("ADB command not sent to %s:%d via ADB server %s:%d because pure-python-adb connection is not established: %s", self.host, self.port, self.adb_server_ip, self.adb_server_port, cmd)
return None

if self._adb_lock.acquire(**LOCK_KWARGS): # pylint: disable=unexpected-keyword-arg
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: %s", self.host, self.port, self.adb_server_ip, self.adb_server_port, cmd)
try:
with _acquire(self._adb_lock) as acquired:
if acquired:
_LOGGER.debug("Sending command to %s:%d via ADB server %s:%d: %s", self.host, self.port, self.adb_server_ip, self.adb_server_port, cmd)
return self._adb_device.shell(cmd)
finally:
self._adb_lock.release()

# Lock could not be acquired
_LOGGER.warning("ADB command not sent to %s:%d via ADB server %s:%d because pure-python-adb lock not acquired: %s", self.host, self.port, self.adb_server_ip, self.adb_server_port, cmd)
Expand Down
53 changes: 45 additions & 8 deletions tests/test_adb_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,20 @@ def open_priv_pub(infile):
pass


class LockedLock(object):
@staticmethod
def acquire(*args, **kwargs):
return False
class FakeLock(object):
def __init__(self, *args, **kwargs):
self._acquired = True

def acquire(self, *args, **kwargs):
return self._acquired

def release(self, *args, **kwargs):
self._acquired = True


class LockedLock(FakeLock):
def __init__(self, *args, **kwargs):
self._acquired = False


def return_empty_list(*args, **kwargs):
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_connect_fail_lock(self):
"""
with patchers.patch_connect(True)[self.PATCH_KEY]:
with patch.object(self.adb, '_adb_lock', LockedLock):
with patch.object(self.adb, '_adb_lock', LockedLock()):
self.assertFalse(self.adb.connect())
self.assertFalse(self.adb.available)
self.assertFalse(self.adb._available)
Expand All @@ -119,8 +129,9 @@ def test_adb_shell_fail(self):

with patchers.patch_connect(True)[self.PATCH_KEY], patchers.patch_shell("TEST")[self.PATCH_KEY]:
self.assertTrue(self.adb.connect())
with patch.object(self.adb, '_adb_lock', LockedLock):
with patch.object(self.adb, '_adb_lock', LockedLock()):
self.assertIsNone(self.adb.shell("TEST"))
self.assertIsNone(self.adb.shell("TEST2"))

def test_adb_shell_success(self):
"""Test when an ADB shell command is successfully sent.
Expand All @@ -130,6 +141,32 @@ def test_adb_shell_success(self):
self.assertTrue(self.adb.connect())
self.assertEqual(self.adb.shell("TEST"), "TEST")

def test_adb_shell_fail_lock_released(self):
"""Test that the ADB lock gets released when an exception is raised.
"""
with patchers.patch_connect(True)[self.PATCH_KEY], patchers.patch_shell("TEST")[self.PATCH_KEY]:
self.assertTrue(self.adb.connect())

with patchers.patch_shell("TEST", error=True)[self.PATCH_KEY], patch.object(self.adb, '_adb_lock', FakeLock()):
with patch('{}.FakeLock.release'.format(__name__)) as release:
with self.assertRaises(Exception):
self.adb.shell("TEST")
assert release.called

def test_adb_shell_lock_not_acquired_not_released(self):
"""Test that the lock does not get released if it is not acquired.
"""
with patchers.patch_connect(True)[self.PATCH_KEY], patchers.patch_shell("TEST")[self.PATCH_KEY]:
self.assertTrue(self.adb.connect())
self.assertEqual(self.adb.shell("TEST"), "TEST")

with patchers.patch_shell("TEST")[self.PATCH_KEY], patch.object(self.adb, '_adb_lock', LockedLock()):
with patch('{}.LockedLock.release'.format(__name__)) as release:
self.assertIsNone(self.adb.shell("TEST"))
release.assert_not_called()

def test_adb_push_fail(self):
"""Test when an ADB push command is not executed because the device is unavailable.
Expand All @@ -143,7 +180,7 @@ def test_adb_push_fail(self):
with patchers.patch_connect(True)[self.PATCH_KEY]:
with patchers.PATCH_PUSH[self.PATCH_KEY] as patch_push:
self.assertTrue(self.adb.connect())
with patch.object(self.adb, '_adb_lock', LockedLock):
with patch.object(self.adb, '_adb_lock', LockedLock()):
self.adb.push("TEST_LOCAL_PATH", "TEST_DEVICE_PATH")
patch_push.assert_not_called()

Expand All @@ -170,7 +207,7 @@ def test_adb_pull_fail(self):
with patchers.patch_connect(True)[self.PATCH_KEY]:
with patchers.PATCH_PULL[self.PATCH_KEY] as patch_pull:
self.assertTrue(self.adb.connect())
with patch.object(self.adb, '_adb_lock', LockedLock):
with patch.object(self.adb, '_adb_lock', LockedLock()):
self.adb.pull("TEST_LOCAL_PATH", "TEST_DEVICE_PATH")
patch_pull.assert_not_called()

Expand Down

0 comments on commit f0c2653

Please sign in to comment.