Skip to content

Commit

Permalink
Implement trait HasUnit following SimpleTraits.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
singularitti committed Sep 30, 2019
1 parent 44de5db commit e3dfd4a
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/NonlinearFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ using ..Collections

export lsqfit

abstract type UnitTrait end
struct NoUnit <: UnitTrait end
struct HasUnit <: UnitTrait end
# This idea is borrowed from [SimpleTraits.jl](https://github.com/mauro3/SimpleTraits.jl/blob/master/src/SimpleTraits.jl).
abstract type Trait end
abstract type Not{T<:Trait} <: Trait end
struct HasUnit <: Trait end

_traitfn(T::Type{<:Number}) = NoUnit
_traitfn(T::Type{<:AbstractQuantity}) = HasUnit
_unit_trait(T::Type{<:Real}) = Not{HasUnit}
_unit_trait(T::Type{<:AbstractQuantity}) = HasUnit

"""
lsqfit(form, eos, xdata, ydata; debug = false, kwargs...)
Expand All @@ -48,29 +49,29 @@ function lsqfit(
kwargs...,
)
T = eltype(eos)
return lsqfit(form, eos, xdata, ydata, _traitfn(T), kwargs...)
return lsqfit(_unit_trait(T), form, eos, xdata, ydata, kwargs...)
end # function lsqfit
function lsqfit(
::Type{Not{HasUnit}},
form::EquationForm,
eos::EquationOfState,
xdata::AbstractVector,
ydata::AbstractVector,
trait::Type{NoUnit};
ydata::AbstractVector;
debug = false,
kwargs...,
)
T = promote_type(eltype(eos), eltype(xdata), eltype(ydata), Float64)
E = typeof(eos).name.wrapper
model(x, p) = map(apply(form, E(p...)), x)
fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)); kwargs...)
fitted = curve_fit(model, T.(xdata), T.(ydata), T.(Collections.fieldvalues(eos)), kwargs...)
return debug ? fitted : E(fitted.param...)
end # function lsqfit
function lsqfit(
::Type{HasUnit},
form::EquationForm,
eos::EquationOfState,
xdata::AbstractVector,
ydata::AbstractVector,
trait::Type{HasUnit};
ydata::AbstractVector;
kwargs...,
)
E = typeof(eos).name.wrapper
Expand Down

0 comments on commit e3dfd4a

Please sign in to comment.