From a5b6477edc4db6d4f72c2b49bfef18b3fb29486b Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 1 Jul 2019 10:18:40 +0200 Subject: [PATCH] Respect the minimum polling interval for scheduler updates (#3096) The `aiida.engine.processes.calcjobs.manager.JobsList` container was introduced for two purposes, related to managing running calculation jobs in high-throughput mode: * It bundles the scheduler update calls for all active calculation jobs for given authentication info (the comination of a computer and user) * It ensures that consecutive scheduler update calls are separated at least by a time that is given by the `get_minimum_job_poll_interval` as defined by the computer of the authentication info. However, the final requirement was not being respected. The problem was twofold. The internal attribute `_last_updated` that records the timestamp of the last update was not being updated after the scheduler was queried for a status update. The use of the `RefObjectStore` to create and delete these `JobsLists` instances, however, was an even bigger problem. The reference store would delete the `JobsLists` instance as soon as no-one held a reference to it any more, since they are created through the `request_job_info_request` context manager of the `JobManager` that each runner has. As soon as no requests were active, the object was deleted and with it the record of the last time the scheduler was queried. The next time a request comes in a new `JobsList` would be created that straight away call the scheduler as it has no recollection the last time it was called for the given authinfo, if at all. This would result in the minimum poll interval essentially never being respected. Fundamentally, the problem lies in the fact that the data that needs to be persistent, the "last updated" timestamp, was being stored in the container `JobsList` that by implementation was made non-persistent. The solution is to simply make the `JobsList` instances live for as long as the python interpreter is alive. Each daemon runner creates a single `JobManager` instance on start up and this will now only create a new `JobsList` once for each authinfo and will keep returning the same instance for the rest of its lifetime. Co-Authored-By: Leopold Talirz --- aiida/backends/tests/__init__.py | 1 + aiida/backends/tests/engine/test_manager.py | 81 +++++++++ aiida/backends/tests/engine/test_utils.py | 68 +------- aiida/engine/daemon/execmanager.py | 2 + aiida/engine/processes/calcjobs/manager.py | 182 ++++++++++++-------- aiida/engine/processes/calcjobs/tasks.py | 12 +- aiida/engine/utils.py | 80 +-------- 7 files changed, 203 insertions(+), 223 deletions(-) create mode 100644 aiida/backends/tests/engine/test_manager.py diff --git a/aiida/backends/tests/__init__.py b/aiida/backends/tests/__init__.py index 22d7fd2299..e1468d932a 100644 --- a/aiida/backends/tests/__init__.py +++ b/aiida/backends/tests/__init__.py @@ -91,6 +91,7 @@ 'engine.daemon': ['aiida.backends.tests.engine.test_daemon'], 'engine.futures': ['aiida.backends.tests.engine.test_futures'], 'engine.launch': ['aiida.backends.tests.engine.test_launch'], + 'engine.manager': ['aiida.backends.tests.engine.test_manager'], 'engine.persistence': ['aiida.backends.tests.engine.test_persistence'], 'engine.ports': ['aiida.backends.tests.engine.test_ports'], 'engine.process': ['aiida.backends.tests.engine.test_process'], diff --git a/aiida/backends/tests/engine/test_manager.py b/aiida/backends/tests/engine/test_manager.py new file mode 100644 index 0000000000..e57b40e2aa --- /dev/null +++ b/aiida/backends/tests/engine/test_manager.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the classes in `aiida.engine.processes.calcjobs.manager`.""" +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import time + +import tornado + +from aiida.orm import AuthInfo, User +from aiida.backends.testbase import AiidaTestCase +from aiida.engine.processes.calcjobs.manager import JobManager, JobsList +from aiida.engine.transports import TransportQueue + + +class TestJobManager(AiidaTestCase): + """Test the `aiida.engine.processes.calcjobs.manager.JobManager` class.""" + + def setUp(self): + super(TestJobManager, self).setUp() + self.loop = tornado.ioloop.IOLoop() + self.transport_queue = TransportQueue(self.loop) + self.user = User.objects.get_default() + self.auth_info = AuthInfo(self.computer, self.user).store() + self.manager = JobManager(self.transport_queue) + + def tearDown(self): + super(TestJobManager, self).tearDown() + AuthInfo.objects.delete(self.auth_info.pk) + + def test_get_jobs_list(self): + """Test the `JobManager.get_jobs_list` method.""" + jobs_list = self.manager.get_jobs_list(self.auth_info) + self.assertIsInstance(jobs_list, JobsList) + + # Calling the method again, should return the exact same instance of `JobsList` + self.assertEqual(self.manager.get_jobs_list(self.auth_info), jobs_list) + + def test_request_job_info_update(self): + """Test the `JobManager.request_job_info_update` method.""" + with self.manager.request_job_info_update(self.auth_info, job_id=1) as request: + self.assertIsInstance(request, tornado.concurrent.Future) + + +class TestJobsList(AiidaTestCase): + """Test the `aiida.engine.processes.calcjobs.manager.JobsList` class.""" + + def setUp(self): + super(TestJobsList, self).setUp() + self.loop = tornado.ioloop.IOLoop() + self.transport_queue = TransportQueue(self.loop) + self.user = User.objects.get_default() + self.auth_info = AuthInfo(self.computer, self.user).store() + self.jobs_list = JobsList(self.auth_info, self.transport_queue) + + def tearDown(self): + super(TestJobsList, self).tearDown() + AuthInfo.objects.delete(self.auth_info.pk) + + def test_get_minimum_update_interval(self): + """Test the `JobsList.get_minimum_update_interval` method.""" + minimum_poll_interval = self.auth_info.computer.get_minimum_job_poll_interval() + self.assertEqual(self.jobs_list.get_minimum_update_interval(), minimum_poll_interval) + + def test_last_updated(self): + """Test the `JobsList.last_updated` method.""" + jobs_list = JobsList(self.auth_info, self.transport_queue) + self.assertEqual(jobs_list.last_updated, None) + + last_updated = time.time() + jobs_list = JobsList(self.auth_info, self.transport_queue, last_updated=last_updated) + self.assertEqual(jobs_list.last_updated, last_updated) diff --git a/aiida/backends/tests/engine/test_utils.py b/aiida/backends/tests/engine/test_utils.py index 8d950ee47d..1a9268acb0 100644 --- a/aiida/backends/tests/engine/test_utils.py +++ b/aiida/backends/tests/engine/test_utils.py @@ -11,13 +11,12 @@ from __future__ import absolute_import from __future__ import print_function -import unittest from tornado.ioloop import IOLoop from tornado.gen import coroutine from aiida import orm from aiida.backends.testbase import AiidaTestCase -from aiida.engine.utils import exponential_backoff_retry, RefObjectStore +from aiida.engine.utils import exponential_backoff_retry ITERATION = 0 MAX_ITERATIONS = 3 @@ -65,68 +64,3 @@ def coro(): max_attempts = MAX_ITERATIONS - 1 with self.assertRaises(RuntimeError): loop.run_sync(lambda: exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts)) - - -class RefObjectsStore(unittest.TestCase): - - def test_simple(self): - """ Test the reference counting works """ - IDENTIFIER = 'a' - OBJECT = 'my string' - obj_store = RefObjectStore() - - with obj_store.get(IDENTIFIER, lambda: OBJECT) as obj: - # Make sure we got back the same object - self.assertIs(OBJECT, obj) - - # Now check that the reference has the correct information - ref = obj_store._objects['a'] - self.assertEqual(OBJECT, ref._obj) - self.assertEqual(1, ref.count) - - # Now request the object again - with obj_store.get(IDENTIFIER) as obj2: - # ...and check the reference has had it's count upped - self.assertEqual(OBJECT, obj2) - self.assertEqual(2, ref.count) - - # Now it should have been reduced - self.assertEqual(1, ref.count) - - # Finally the store should be empty (there are no more references) - self.assertEqual(0, len(obj_store._objects)) - - def test_get_no_constructor(self): - """ - Test that trying to get an object that does exists and providing - no means to construct it fails - """ - obj_store = RefObjectStore() - with self.assertRaises(ValueError): - with obj_store.get('a'): - pass - - def test_construct(self): - """ Test that construction only gets called when used """ - IDENTIFIER = 'a' - OBJECT = 'my string' - - # Use a list for a single number so we can get references to it - times_constructed = [ - 0, - ] - - def construct(): - times_constructed[0] += 1 - return OBJECT - - obj_store = RefObjectStore() - with obj_store.get(IDENTIFIER, construct): - self.assertEqual(1, times_constructed[0]) - with obj_store.get(IDENTIFIER, construct): - self.assertEqual(1, times_constructed[0]) - - # Now the object should be removed and so another call to get - # should create - with obj_store.get(IDENTIFIER, construct): - self.assertEqual(2, times_constructed[0]) diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 143220cacb..dccf58a886 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -230,6 +230,7 @@ def submit_calculation(calculation, transport, calc_info, script_filename): :param transport: an already opened transport to use to submit the calculation. :param calc_info: the calculation info datastructure returned by `CalcJobNode._presubmit` :param script_filename: the job launch script returned by `CalcJobNode._presubmit` + :return: the job id as returned by the scheduler `submit_from_script` call """ scheduler = calculation.computer.get_scheduler() scheduler.set_transport(transport) @@ -237,6 +238,7 @@ def submit_calculation(calculation, transport, calc_info, script_filename): workdir = calculation.get_remote_workdir() job_id = scheduler.submit_from_script(workdir, script_filename) calculation.set_job_id(job_id) + return job_id def retrieve_calculation(calculation, transport, retrieved_temporary_folder): diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py index 3e2a3bef78..c3faf33e33 100644 --- a/aiida/engine/processes/calcjobs/manager.py +++ b/aiida/engine/processes/calcjobs/manager.py @@ -7,61 +7,82 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Module containing utilities and classes relating to job calculations running -on systems that require transport. -""" +"""Module containing utilities and classes relating to job calculations running on systems that require transport.""" from __future__ import division from __future__ import print_function from __future__ import absolute_import import contextlib -from functools import partial import time from six import iteritems, itervalues from tornado import concurrent, gen from aiida import schedulers -from aiida.common import exceptions -from ...utils import RefObjectStore +from aiida.common import exceptions, lang +from aiida.common.log import AIIDA_LOGGER __all__ = ('JobsList', 'JobManager') class JobsList(object): # pylint: disable=useless-object-inheritance - """ - A list of submitted jobs on a machine connected to by transport based on the - authorisation information. + """Manager of calculation jobs submitted with a specific ``AuthInfo``, i.e. computer configured for a specific user. + + This container of active calculation jobs is used to update their status periodically in batches, ensuring that + even when a lot of jobs are running, the scheduler update command is not triggered for each job individually. + + In addition, the :py:class:`~aiida.orm.computers.Computer` for which the :py:class:`~aiida.orm.authinfos.AuthInfo` + is configured, can define a minimum polling interval. This class will guarantee that the time between update calls + to the scheduler is larger or equal to that minimum interval. + + Note that since each instance operates on a specific authinfo, the guarantees of batching scheduler update calls + and the limiting of number of calls per unit time, through the minimum polling interval, is only applicable for jobs + launched with that particular authinfo. If multiple authinfo instances with the same computer, have active jobs + these limitations are not respected between them, since there is no communication between ``JobsList`` instances. + See the :py:class:`~aiida.engine.processes.calcjobs.manager.JobManager` for example usage. """ - def __init__(self, authinfo, transport_queue): - """ + def __init__(self, authinfo, transport_queue, last_updated=None): + """Construct an instance for the given authinfo and transport queue. + :param authinfo: The authinfo used to check the jobs list :type authinfo: :class:`aiida.orm.AuthInfo` :param transport_queue: A transport queue :type: :class:`aiida.engine.transports.TransportQueue` + :param last_updated: initialize the last updated timestamp + :type: float """ + lang.type_check(last_updated, float, allow_none=True) + self._authinfo = authinfo self._transport_queue = transport_queue self._loop = transport_queue.loop() + self._logger = AIIDA_LOGGER.getChild('calcjobs') self._jobs_cache = {} - self._last_updated = None # type: float self._job_update_requests = {} # Mapping: {job_id: Future} + self._last_updated = last_updated self._update_handle = None - def get_minimum_update_interval(self): + @property + def logger(self): + """Return the logger configured for this instance. + + :return: the logger """ - Get the minimum interval that can be expected between updates of the list - :return: The minimum interval + return self._logger + + def get_minimum_update_interval(self): + """Get the minimum interval that should be respected between updates of the list. + + :return: the minimum interval :rtype: float """ return self._authinfo.computer.get_minimum_job_poll_interval() - def get_last_updated(self): - """ - Get the timestamp of when the list was last updated as produced by `time.time()` + @property + def last_updated(self): + """Get the timestamp of when the list was last updated as produced by `time.time()` :return: The last update point :rtype: float @@ -70,10 +91,9 @@ def get_last_updated(self): @gen.coroutine def _get_jobs_from_scheduler(self): - """ - Get the current jobs list from the scheduler + """Get the current jobs list from the scheduler. - :return: A dictionary of {job_id: job info} + :return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances :rtype: dict """ with self._transport_queue.request_transport(self._authinfo) as request: @@ -89,7 +109,11 @@ def _get_jobs_from_scheduler(self): kwargs['jobs'] = self._get_jobs_with_scheduler() scheduler_response = scheduler.get_jobs(**kwargs) + + # Update the last update time and clear the jobs cache + self._last_updated = time.time() jobs_cache = {} + self.logger.info('AuthInfo<{}>: successfully retrieved status of active jobs'.format(self._authinfo.pk)) for job_id, job_info in iteritems(scheduler_response): # If the job is done then get detailed job information @@ -107,12 +131,10 @@ def _get_jobs_from_scheduler(self): @gen.coroutine def _update_job_info(self): - """ - Update all of the job information objects for a given authinfo, that is to say for - all the jobs on a particular machine for a particular user. + """Update all of the job information objects. - This will set the futures for all pending update requests where the corresponding job - has a new status compared to the last update. + This will set the futures for all pending update requests where the corresponding job has a new status compared + to the last update. """ try: if not self._update_requests_outstanding(): @@ -135,16 +157,16 @@ def _update_job_info(self): @contextlib.contextmanager def request_job_info_update(self, job_id): - """ - Request job info about a job when it next changes it's job state. If the job is not - found in the jobs list at the update the future will resolve to None. + """Request job info about a job when the job next changes state. + + If the job is not found in the jobs list at the update, the future will resolve to `None`. - :param job_id: The job identifier - :return: A future that will resolve to a JobInfo object when the job changes state + :param job_id: job identifier + :return: future that will resolve to a `JobInfo` object when the job changes state """ # Get or create the future request = self._job_update_requests.setdefault(job_id, concurrent.Future()) - assert not request.done(), "The future should be no be in the done state" + assert not request.done(), 'Expected pending job info future, found in done state.' try: self._ensure_updating() @@ -153,14 +175,14 @@ def request_job_info_update(self, job_id): pass def _ensure_updating(self): - """ - Ensure that we are updating the job list from the remote resource. + """Ensure that we are updating the job list from the remote resource. + This will automatically stop if there are no outstanding requests. """ @gen.coroutine def updating(): - """ Do the actual update, stop if not requests left """ + """Do the actual update, stop if not requests left.""" yield self._update_job_info() # Any outstanding requests? if self._update_requests_outstanding(): @@ -174,9 +196,10 @@ def updating(): @staticmethod def _has_job_state_changed(old, new): - """ - :type old: :class:`aiida.schedulers.JobInfo` - :type new: :class:`aiida.schedulers.JobInfo` + """Return whether the states `old` and `new` are different. + + :type old: :class:`aiida.schedulers.JobInfo` or `None` + :type new: :class:`aiida.schedulers.JobInfo` or `None` :rtype: bool """ if old is None and new is None: @@ -189,31 +212,31 @@ def _has_job_state_changed(old, new): return old.job_state != new.job_state or old.job_substate != new.job_substate def _get_next_update_delay(self): - """ - Calculate when we are next allowed to call the scheduler get jobs command - based on when we last called it, how long has elapsed and the minimum given - update interval. + """Calculate when we are next allowed to poll the scheduler. - :return: The delay (in seconds) for when it's safe to call the get jobs command + This delay is calculated as the minimum polling interval defined by the authentication info for this instance, + minus time elapsed since the last update. + + :return: delay (in seconds) after which the scheduler may be polled again :rtype: float """ - if self._last_updated is None: + if self.last_updated is None: # Never updated, so do it straight away return 0. - # Make sure to actually 'get' it here, so that if the user changed it - # between times we use the current value - minimum_interval = self._authinfo.computer.get_minimum_job_poll_interval() - elapsed = time.time() - self._last_updated + # Make sure to actually 'get' the minimum interval here, in case the user changed since last time + minimum_interval = self.get_minimum_update_interval() + elapsed = time.time() - self.last_updated + + delay = max(minimum_interval - elapsed, 0.) - return max(minimum_interval - elapsed, 0.) + return delay def _update_requests_outstanding(self): return any(not request.done() for request in itervalues(self._job_update_requests)) def _get_jobs_with_scheduler(self): - """ - Get all the jobs that are currently with scheduler for this authinfo + """Get all the jobs that are currently with scheduler. :return: the list of jobs with the scheduler :rtype: list @@ -221,31 +244,50 @@ def _get_jobs_with_scheduler(self): return [str(job_id) for job_id, _ in self._job_update_requests.items()] -class JobManager(object): # pylint: disable=useless-object-inheritance - """ - A manager for jobs on a (usually) remote resource such as a supercomputer +class JobManager(object): + """A manager for :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` submitted to ``Computer`` instances. + + When a calculation job is submitted to a :py:class:`~aiida.orm.computers.Computer`, it actually uses a specific + :py:class:`~aiida.orm.authinfos.AuthInfo`, which is a computer configured for a :py:class:`~aiida.orm.users.User`. + The ``JobManager`` maintains a mapping of :py:class:`~aiida.engine.processes.calcjobs.manager.JobsList` instances + for each authinfo that has active calculation jobs. These jobslist instances are then responsible for bundling + scheduler updates for all the jobs they maintain (i.e. that all share the same authinfo) and update their status. + + As long as a :py:class:`~aiida.engine.runners.Runner` will create a single ``JobManager`` instance and use that for + its lifetime, the guarantees made by the ``JobsList`` about respecting the minimum polling interval of the scheduler + will be maintained. Note, however, that since each ``Runner`` will create its own job manager, these guarantees + only hold per runner. """ + # pylint: disable=useless-object-inheritance + def __init__(self, transport_queue): self._transport_queue = transport_queue - self._job_lists = RefObjectStore() + self._job_lists = {} + + def get_jobs_list(self, authinfo): + """Get or create a new `JobLists` instance for the given authinfo. + + :param authinfo: the `AuthInfo` + :return: a `JobsList` instance + """ + if authinfo.id not in self._job_lists: + self._job_lists[authinfo.id] = JobsList(authinfo, self._transport_queue) + + return self._job_lists[authinfo.id] @contextlib.contextmanager def request_job_info_update(self, authinfo, job_id): - """ - Get a future that will resolve to information about a given job. This is a context - manager so that if the user leaves the context the request is automatically cancelled. + """Get a future that will resolve to information about a given job. + + This is a context manager so that if the user leaves the context the request is automatically cancelled. - :return: A tuple containing the JobInfo object and detailed job info. Both can be None. + :return: A tuple containing the `JobInfo` object and detailed job info. Both can be None. :rtype: :class:`tornado.concurrent.Future` """ - # Define a way to create a JobsList if needed - create = partial(JobsList, authinfo, self._transport_queue) - - with self._job_lists.get(authinfo.id, create) as job_list: - with job_list.request_job_info_update(job_id) as request: - try: - yield request - finally: - if not request.done(): - request.cancel() + with self.get_jobs_list(authinfo).request_job_info_update(job_id) as request: + try: + yield request + finally: + if not request.done(): + request.cancel() diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 28d01cacb0..883afa9b7d 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -73,11 +73,10 @@ def task_upload_job(node, transport_queue, calc_info, script_filename, cancellab def do_upload(): with transport_queue.request_transport(authinfo) as request: transport = yield cancellable.with_interrupt(request) - - logger.info('uploading calculation<{}>'.format(node.pk)) raise Return(execmanager.upload_calculation(node, transport, calc_info, script_filename)) try: + logger.info('uploading calculation<{}>'.format(node.pk)) result = yield exponential_backoff_retry( do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.CancelledError) except plumpy.CancelledError: @@ -124,11 +123,10 @@ def task_submit_job(node, transport_queue, calc_info, script_filename, cancellab def do_submit(): with transport_queue.request_transport(authinfo) as request: transport = yield cancellable.with_interrupt(request) - - logger.info('submitting CalcJob<{}>'.format(node.pk)) raise Return(execmanager.submit_calculation(node, transport, calc_info, script_filename)) try: + logger.info('submitting CalcJob<{}>'.format(node.pk)) result = yield exponential_backoff_retry( do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) except plumpy.Interruption: @@ -188,6 +186,7 @@ def do_update(): raise Return(job_done) try: + logger.info('updating CalcJob<{}>'.format(node.pk)) job_done = yield exponential_backoff_retry( do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) except plumpy.Interruption: @@ -233,11 +232,10 @@ def task_retrieve_job(node, transport_queue, retrieved_temporary_folder, cancell def do_retrieve(): with transport_queue.request_transport(authinfo) as request: transport = yield cancellable.with_interrupt(request) - - logger.info('retrieving CalcJob<{}>'.format(node.pk)) raise Return(execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)) try: + logger.info('retrieving CalcJob<{}>'.format(node.pk)) result = yield exponential_backoff_retry( do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption) except plumpy.Interruption: @@ -281,10 +279,10 @@ def task_kill_job(node, transport_queue, cancellable): def do_kill(): with transport_queue.request_transport(authinfo) as request: transport = yield cancellable.with_interrupt(request) - logger.info('killing CalcJob<{}>'.format(node.pk)) raise Return(execmanager.kill_calculation(node, transport)) try: + logger.info('killing CalcJob<{}>'.format(node.pk)) result = yield exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) except plumpy.Interruption: raise diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index 6665526764..9f2eb57f8c 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -20,7 +20,7 @@ import tornado.ioloop from tornado import concurrent, gen -__all__ = ('RefObjectStore', 'interruptable_task', 'InterruptableFuture') +__all__ = ('interruptable_task', 'InterruptableFuture') LOGGER = logging.getLogger(__name__) PROCESS_STATE_CHANGE_KEY = 'process|state_change|{}' @@ -306,81 +306,3 @@ def get_process_state_change_timestamp(process_type=None): return None return max(timestamps) - - -class RefObjectStore(object): # pylint: disable=useless-object-inheritance - """ - An object store that has a reference count based on a context manager. - Basic usage:: - - store = RefObjectStore() - with store.get('Martin', lambda: 'martin.uhrin@epfl.ch') as email: - with store.get('Martin') as email2: - email is email2 # True - - The use case for this store is when you have an object can be used by - multiple parts of the code simultaneously (nested or async code) and - where there should be one instance that exists for the lifetime of these - contexts. Once noone is using the object, it should be removed from the - store (and therefore eventually garbage collected). - """ - - class Reference(object): # pylint: disable=useless-object-inheritance - """A reference to store the context reference count and the object itself.""" - - def __init__(self, obj): - self._count = 0 - self._obj = obj - - @property - def count(self): - """ - Get the reference count for the object - - :return: The reference count - :rtype: int - """ - return self._count - - @contextlib.contextmanager - def get(self): - """ - Get the object itself, which will up the reference count for the duration of the context. - - :return: The object - """ - self._count += 1 - try: - yield self._obj - finally: - self._count -= 1 - - def __init__(self): - self._objects = {} - - @contextlib.contextmanager - def get(self, identifier, constructor=None): - """ - Get or create an object. The internal reference count will be upped for - the duration of the context. If the reference count drops to 0 the object - will be automatically removed from the list. - - :param identifier: The key identifying the object - :param constructor: An optional constructor that is called with no arguments - if the object doesn't already exist in the store - :return: The object corresponding to the identifier - """ - try: - ref = self._objects[identifier] - except KeyError: - if constructor is None: - raise ValueError("Object not found and no constructor given") - ref = self.Reference(constructor()) - self._objects[identifier] = ref - - try: - with ref.get() as obj: - yield obj - finally: - if ref.count == 0: - self._objects.pop(identifier)