Skip to content

Commit

Permalink
Made caching a stateful client property.
Browse files Browse the repository at this point in the history
* Added `caching()` method to `Client`. It enables or disables caching
of previously loaded models in subsequent calls to `load()`.
* Made caching opt-in, i.e. disabled at start-up.
* Renamed `Model.path()` to `Model.file()`.
* Renamed `Client.path()` to `Client.files()`.
* Added test coverage for caching behavior.
* Added "rich comparison" method for `Model` objects.
  • Loading branch information
john-hen committed Feb 24, 2021
1 parent fc0710b commit 2e903f1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 23 deletions.
54 changes: 36 additions & 18 deletions mph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,42 @@ def __init__(self, cores=None, version=None, port=None, host='localhost'):
self.port = port
self.java = java

def load(self, file, reload=False):
"""Returns the model loaded from the given `file`."""
file = Path(file)

# Check if model is already loaded
if file in self.paths() and not reload:
logger.info('Found model in memory, returning model from memory')
return self.models()[self.paths().index(file)]

# Deactivate caching of previously loaded models by default.
self._caching = False

def load(self, file):
"""Loads a model from the given `file` and returns it."""
file = Path(file).resolve()
if self.caching() and file in self.files():
logger.info('Returning previously loaded model from cache.')
return self.models()[self.files().index(file)]
tag = self.java.uniquetag('model')
logger.info(f'Loading model "{file.name}".')
model = Model(self.java.load(tag, str(file)))
logger.info('Finished loading model.')
return model

def caching(self, state=None):
"""
Enables or disables caching of previously loaded models.
Caching means that the `load()` method will check if a model
has been previously loaded from the same file-system path and,
if so, return the in-memory model object instead of reloading
it from disk. By default (at start-up) caching is disabled.
Pass `True` to enable caching, `False` to disable it. If no
argument is passed, the current state is returned.
"""
if state is None:
return self._caching
elif state in (True, False):
self._caching = state
else:
error = 'Caching state can only be set to either True or False.'
logger.error(error)
raise ValueError(error)

def create(self, name):
"""
Creates and returns a new, empty model with the given `name`.
Expand All @@ -217,20 +238,17 @@ def create(self, name):
return model

def models(self):
"""Returns all model objects currently held in memory."""
"""Returns all models currently held in memory."""
return [Model(self.java.model(tag)) for tag in self.java.tags()]

def paths(self):
"""
Returns the file paths of all models in memory (abspath). Models without
files will return empty path.
"""
return [Path(str(self.java.model(tag).getFilePath())) for tag in self.java.tags()]

def names(self):
"""Names all models that are currently held in memory."""
"""Returns the names of all loaded models."""
return [model.name() for model in self.models()]

def files(self):
"""Returns the file-system paths of all loaded models."""
return [model.file() for model in self.models()]

def remove(self, model):
"""Removes the given `model` from memory."""
name = model.name()
Expand Down
9 changes: 6 additions & 3 deletions mph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class is not intended to be instantiated directly. Rather, the
def __init__(self, java):
self.java = java

def __eq__(self, other):
return self.java.tag() == other.java.tag()

####################################
# Inspection #
####################################
Expand All @@ -65,9 +68,9 @@ def name(self):
name = name.rsplit('.', maxsplit=1)[0]
return name

def path(self):
"""Returns the abspath of the model's mpf file"""
return Path(str(self.java.getFilePath()))
def file(self):
"""Returns the absolute path to the file the model was loaded from."""
return Path(str(self.java.getFilePath())).resolve()

def parameters(self):
"""
Expand Down
23 changes: 21 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,33 @@ def test_load():
assert model


def test_caching():
assert not client.caching()
copy = client.load(file)
assert model != copy
client.remove(copy)
client.caching(True)
assert client.caching()
copy = client.load(file)
assert model == copy
client.caching(False)
assert not client.caching()


def test_create():
name = 'test'
client.create(name)
assert name in client.names()


def test_list():
def test_names():
assert model.name() in client.names()


def test_files():
assert file.resolve() in client.files()


def test_remove():
name = model.name()
client.remove(model)
Expand Down Expand Up @@ -95,8 +112,10 @@ def test_clear():
test_start()
test_cores()
test_load()
test_caching()
test_create()
test_list()
test_names()
test_files()
test_remove()
test_clear()
finally:
Expand Down

0 comments on commit 2e903f1

Please sign in to comment.