Skip to content

Commit

Permalink
Merge pull request #35 from licode/expressModel
Browse files Browse the repository at this point in the history
Express model
  • Loading branch information
Eric Dill committed Dec 9, 2014
2 parents 4bff183 + fcbc464 commit bfe7524
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 8 deletions.
31 changes: 30 additions & 1 deletion vttools/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from skxray.testing.decorators import known_fail_if
from vttools.to_wrap.fitting import (gaussian_model, lorentzian_model,
lorentzian2_model, quadratic_model,
fit_engine)
fit_engine, fit_engine_list, expression_model)
from nose.tools import (assert_equal, assert_true, raises)


Expand Down Expand Up @@ -114,6 +114,7 @@ def test_lorentzian2_fit():
fitted_val = (out['area'], out['center'], out['sigma'])
assert_array_almost_equal(true_val, fitted_val)


def test_quadratic_fit():
x = np.arange(-1, 1, .01)
a = 1
Expand All @@ -133,3 +134,31 @@ def test_quadratic_fit():
out = result.values
fitted_val = (out['a'], out['b'], out['c'])
assert_array_almost_equal(true_val, fitted_val)


def test_fit_engine_list():
a = 1
b = 2
c = 3
m = quadratic_model('',
a, 'free', [-1, 1],
b, 'free', [-1, 1],
c, 'free', [-1, 1])
x = np.arange(-1, 1, 0.01)
y = x**2 + 1

datav = [(x, y), (x, y+2)]
out = fit_engine_list(m, datav)
assert_equal(len(out), 2)


def test_expression_model():

inputv = 'exp(-a*x)'

x = np.arange(-1, 1, 0.01)
y = np.exp(-x)

mod = expression_model(inputv)
out = mod.fit(y, x=x, a=0.1)
assert_equal(1, out.values['a'])
27 changes: 20 additions & 7 deletions vttools/to_wrap/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
########################################################################
import sys
from skxray.fitting.api import (QuadraticModel, GaussianModel,
LorentzianModel, Lorentzian2Model)
LorentzianModel, Lorentzian2Model,
ExpressionModel)

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,6 +151,24 @@ def fit_engine_list(g, data):
return result_list


def expression_model(model_exp):
"""
This function creates a Model from a user-supplied expression.
Parameters
----------
model_exp : str
expression of the model
Returns
-------
mod : array_like
object of fitting results
"""
mod = ExpressionModel(model_exp)
return mod


def set_range(model_name,
parameter_name, parameter_value,
parameter_vary, parameter_range):
Expand Down Expand Up @@ -265,9 +284,3 @@ def inner(prefix, amplitude, amplitude_vary, amplitude_range,
func_name.amplitude_vary = ['fixed', 'free', 'bounded']
func_name.center_vary = ['fixed', 'free', 'bounded']
func_name.sigma_vary = ['fixed', 'free', 'bounded']


function_list = [fit_engine, fit_engine_list, quadratic_model]

for func_name in function_list:
setattr(mod, func_name.__name__, func_name)
3 changes: 3 additions & 0 deletions vttools/vtmods/import_lists/modules.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ autowrap_func:
- func_name: fit_engine_list
module_path: vttools.to_wrap.fitting
namespace: fitting
- func_name: expression_model
module_path: vttools.to_wrap.fitting
namespace: fitting
# CALIBRATION
- func_name: refine_center
module_path: skxray.calibration
Expand Down

0 comments on commit bfe7524

Please sign in to comment.