Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnivariateTimeTypeToContinuous transformer to builtins #245

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MLJModels.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJModels
module MLJModels

import MLJModelInterface
import MLJModelInterface: MODEL_TRAITS
Expand All @@ -9,7 +9,7 @@ import MLJBase: @load
import MLJBase: Table, Continuous, Count, Finite, OrderedFactor, Multiclass

using Requires, Pkg, Pkg.TOML, OrderedCollections, Parameters
using Tables, CategoricalArrays, StatsBase, Statistics
using Tables, CategoricalArrays, StatsBase, Statistics, Dates
import Distributions

# for administrators to update Metadata.toml:
Expand Down
103 changes: 101 additions & 2 deletions src/builtins/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,106 @@ end
MLJBase.inverse_transform(transformer::UnivariateStandardizer, fitresult, w) =
[inverse_transform(transformer, fitresult, y) for y in w]

##########################################################################################
## CONTINUOUS TRANSFORM OF TIME TYPE FEATURES

"""
UnivariateTimeTypeToContinuous(zero_time=nothing, step=Hour(24))

Convert a `Date`, `DateTime`, and `Time` vector to `Float64` by assuming `0.0` corresponds
to the `zero_time` parameter and the time increment to reach `1.0` is given by the `step`
parameter. The type of `zero_time` should match the type of the column if provided. If not
provided, then `zero_time` is inferred as the minimum time found in the data when `fit` is
called.

"""
mutable struct UnivariateTimeTypeToContinuous <: Unsupervised
zero_time::Union{Nothing, TimeType}
step::Period
end

function UnivariateTimeTypeToContinuous(;
zero_time=nothing, step=Dates.Hour(24))
model = UnivariateTimeTypeToContinuous(zero_time, step)
MLJBase.clean!(model)
return return model
end
ablaom marked this conversation as resolved.
Show resolved Hide resolved

function MLJBase.clean!(model::UnivariateTimeTypeToContinuous)
# Step must be able to be added to zero_time if provided.
if model.zero_time !== nothing
try
tmp = model.zero_time + model.step
catch err
if err isa MethodError
# Cannot add time parts to dates nor date parts to times.
# If a mismatch is encountered. Conversion from date parts to time parts
# is possible, but not from time parts to date parts because we cannot
# represent fractional date parts.
if model.zero_time isa Dates.Date && model.step isa Dates.TimePeriod
# Convert zero_time to a DateTime to resolve conflict.
@warn "Cannot add TimePeriod step to Date zero_time. Converting zero_time to DateTime."
model.zero_time = convert(DateTime, model.zero_time)
elseif model.zero_time isa Dates.Time && model.step isa Dates.DatePeriod
# Convert step to Hour if possible. This will fail for
# isa(step, Month)
@warn "Cannot add DatePeriod step to Time zero_time. Converting step to Hour."
model.step = convert(Hour, model.step)
else
# Unable to resolve, rethrow original error.
throw(err)
end
else
throw(err)
end
end
end
ablaom marked this conversation as resolved.
Show resolved Hide resolved
end

function MLJBase.fit(model::UnivariateTimeTypeToContinuous, verbosity::Int, X)
if model.zero_time !== nothing
ablaom marked this conversation as resolved.
Show resolved Hide resolved
fitresult = model.zero_time
# Check zero_time is compatible with X
example = first(X)
try
X - fitresult
catch err
if err isa MethodError
@warn "$(typeof(fitresult)) zero_time is not compatible with $(eltype(X)) vector X. Attempting to convert zero_time."
fitresult = convert(eltype(X), fitresult)
else
throw(err)
end
end
else
min_dt = minimum(X)
fitresult = min_dt
end
cache = nothing
report = nothing
return fitresult, cache, report
end

function MLJBase.transform(model::UnivariateTimeTypeToContinuous, fitresult, X)
if typeof(fitresult) ≠ eltype(X)
# Cannot run if eltype in transform differs from zero_time from fit.
throw(ArgumentError("Different TimeType encountered during transform than expected from fit. Found $(eltype(X)), expected $(typeof(fitresult))"))
end
# Set the size of a single step.
ablaom marked this conversation as resolved.
Show resolved Hide resolved
next_time = fitresult + model.step
if next_time == fitresult
# Time type loops if model.step is a multiple of Hour(24), so calculate the
# number of multiples, then re-scale to Hour(12) and adjust delta to match original.
m = model.step / Dates.Hour(12)
delta = m * (
Float64(Dates.value(fitresult + Dates.Hour(12)) - Dates.value(fitresult)))
else
delta = Float64(Dates.value(fitresult + model.step) - Dates.value(fitresult))
end
return @. Float64(Dates.value(X - fitresult)) / delta
end


## STANDARDIZATION OF ORDINAL FEATURES OF TABULAR DATA

"""
Expand Down Expand Up @@ -862,7 +962,7 @@ the last class indicator column.
`Multiclass` or `OrderedFactor` column is the same in new data being
transformed as it is in the data used to fit the transformer.

### Example
### Example

```julia
X = (name=categorical(["Danesh", "Lee", "Mary", "John"]),
Expand Down Expand Up @@ -1036,4 +1136,3 @@ metadata_model(ContinuousEncoder,
weights = false,
descr = CONTINUOUS_ENCODER_DESCR,
path = "MLJModels.ContinuousEncoder")