Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Apr 29, 2016
1 parent 46286c0 commit 1693687
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
10 changes: 7 additions & 3 deletions astropy/modeling/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ def unit_support(func):

print("ICI")

def wrapper(self, model, x, y, *args, **kwargs):
def wrapper(self, model, x, y, z=None, *args, **kwargs):
print("HERE")
if isinstance(x, u.Quantity) or isinstance(y, u.Quantity):
model_new = func(self, model, x.value, y.value, *args, **kwargs)
return model_new.with_units_from_data(x, y)
if z is None:
model_new = func(self, model, x.value, y.value, *args, **kwargs)
return model_new.with_units_from_data(x, y)
else:
model_new = func(self, model, x.value, y.value, z=z.value, *args, **kwargs)
return model_new.with_units_from_data(x, y, z)
else:
func(self, model, x.value, y, *args, **kwargs)

Expand Down
21 changes: 11 additions & 10 deletions astropy/modeling/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, degree, n_models=None, model_set_axis=None,
n_models=n_models, model_set_axis=model_set_axis, name=name,
meta=meta, **params)

def with_units_from_data(self, x, y):
def with_units_from_data(self, x, y, z=None):
"""
Return an instance of the model which has units for which the parameter
values are compatible with the data units.
Expand All @@ -111,9 +111,14 @@ def with_units_from_data(self, x, y):
for n in range(self._order):
name = 'c{0}'.format(n)
params[name] = quantity_with_unit(getattr(self, name),
y.unit / x.unit ** (n))
y.unit / x.unit ** n)
else:
raise NotImplementedError()
for i in range(self._degree + 1):
for j in range(self._degree + 1):
if i + j < self._degree + 1:
name = 'c{0}_{1}'.format(i, j)
params[name] = quantity_with_unit(getattr(self, name),
z.unit / x.unit ** i / y.unit ** j)

return self.__class__(**params)

Expand Down Expand Up @@ -160,13 +165,9 @@ def _generate_coeff_names(self, ndim):
for n in range(self._order):
names.append('c{0}'.format(n))
else:
for i in range(self.degree + 1):
names.append('c{0}_{1}'.format(i, 0))
for i in range(1, self.degree + 1):
names.append('c{0}_{1}'.format(0, i))
for i in range(1, self.degree):
for j in range(1, self.degree):
if i + j < self.degree + 1:
for i in range(self._degree + 1):
for j in range(self._degree + 1):
if i + j < self._degree + 1:
names.append('c{0}_{1}'.format(i, j))
return tuple(names)

Expand Down

0 comments on commit 1693687

Please sign in to comment.