-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: allow to load fitting models from external Python files
- Loading branch information
1 parent
66ec9ba
commit dd7f92d
Showing
8 changed files
with
186 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import importlib | ||
import pathlib | ||
import sys | ||
import warnings | ||
|
||
from .core import ModelImportError, NaniteFitModel | ||
|
||
#: currently available models | ||
models_available = {} | ||
|
||
|
||
def load_model_from_file(path, register=False): | ||
"""Import a fit model file and return the module | ||
This is intended for loading custom models or for model | ||
development. | ||
Parameters | ||
---------- | ||
path: str or Path | ||
pathname to a Python script conaining a fit model | ||
register: bool | ||
whether to register the model after import | ||
Returns | ||
------- | ||
model: NaniteFitModel | ||
nanite fit model object | ||
Raises | ||
------ | ||
ModelImportError | ||
If the model cannot be imported | ||
""" | ||
path = pathlib.Path(path) | ||
try: | ||
# insert the plugin directory to sys.path so we can import it | ||
sys.path.insert(-1, str(path.parent)) | ||
sys.dont_write_bytecode = True | ||
module = importlib.import_module(path.stem) | ||
except ModuleNotFoundError: | ||
raise ModelImportError(f"Could not import '{path}'!") | ||
finally: | ||
# undo our path insertion | ||
sys.path.pop(0) | ||
sys.dont_write_bytecode = False | ||
|
||
mod = NaniteFitModel(module) | ||
|
||
if register: | ||
register_model(module) | ||
|
||
return mod | ||
|
||
|
||
def register_model(module, *args): | ||
"""Register a fitting model | ||
Parameters | ||
---------- | ||
module: Python module or NaniteFitModel | ||
the model to register | ||
Returns | ||
------- | ||
model: NaniteFitModel | ||
the corresponding NaniteFitModel instance | ||
""" | ||
if args: | ||
warnings.warn("Please only pass the module to `register_model`!", | ||
DeprecationWarning) | ||
global models_available # this is not necessary, but clarifies things | ||
# add model | ||
if isinstance(module, NaniteFitModel): | ||
# we already have a fit model | ||
md = module | ||
else: | ||
md = NaniteFitModel(module) | ||
# the actual registration | ||
models_available[module.model_key] = md | ||
return md | ||
|
||
|
||
def deregister_model(model): | ||
"""Deregister a NaniteFitModel""" | ||
global models_available # this is not necessary, but clarifies things | ||
models_available.pop(model.model_key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import lmfit | ||
import numpy as np | ||
|
||
|
||
def get_parameter_defaults(): | ||
"""Return the default model parameters""" | ||
# The order of the parameters must match the order | ||
# of ´parameter_names´ and ´parameter_keys´. | ||
params = lmfit.Parameters() | ||
params.add("E", value=3e3, min=0) | ||
params.add("R", value=10e-6, min=0, vary=False) | ||
params.add("nu", value=.5, min=0, max=0.5, vary=False) | ||
params.add("contact_point", value=0) | ||
params.add("baseline", value=0) | ||
return params | ||
|
||
|
||
def hertz_paraboloidal(delta, E, R, nu, contact_point=0, baseline=0): | ||
"""This is identical to the Hertz parabolic indenter model""" | ||
aa = 4/3 * E/(1-nu**2)*np.sqrt(R) | ||
root = contact_point-delta | ||
pos = root > 0 | ||
bb = np.zeros_like(delta) | ||
bb[pos] = (root[pos])**(3/2) | ||
return aa*bb + baseline | ||
|
||
|
||
model_doc = hertz_paraboloidal.__doc__ | ||
model_func = hertz_paraboloidal | ||
model_key = "hans_peter" | ||
model_name = "Hans Peter's model" | ||
parameter_keys = ["E", "R", "nu", "contact_point", "baseline"] | ||
parameter_names = ["Young's Modulus", "Tip Radius", | ||
"Poisson's Ratio", "Contact Point", "Force Baseline"] | ||
parameter_units = ["Pa", "m", "", "m", "N"] | ||
valid_axes_x = ["tip position"] | ||
valid_axes_y = ["force"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pathlib | ||
|
||
import nanite | ||
|
||
|
||
data_dir = pathlib.Path(__file__).parent / "data" | ||
|
||
|
||
def test_load_model_from_file(): | ||
mpath = data_dir / "model_external_basic.py" | ||
md = nanite.model.load_model_from_file(mpath, register=True) | ||
assert md.model_key == "hans_peter" | ||
assert md.model_key in nanite.model.models_available | ||
nanite.model.deregister_model(md) | ||
assert md.model_key not in nanite.model.models_available | ||
|
||
|
||
def test_load_model_from_model(): | ||
mpath = data_dir / "model_external_basic.py" | ||
md = nanite.model.load_model_from_file(mpath, register=False) | ||
assert md.model_key == "hans_peter" | ||
assert md.model_key not in nanite.model.models_available | ||
md2 = nanite.model.register_model(md) | ||
assert md is md2 | ||
assert md.model_key in nanite.model.models_available | ||
nanite.model.deregister_model(md) | ||
assert md.model_key not in nanite.model.models_available |