Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnivariateTimeTypeToContinuous transformer to builtins #245

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MLJModels.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJModels
module MLJModels

import MLJModelInterface
import MLJModelInterface: MODEL_TRAITS
Expand All @@ -9,7 +9,7 @@ import MLJBase: @load
import MLJBase: Table, Continuous, Count, Finite, OrderedFactor, Multiclass

using Requires, Pkg, Pkg.TOML, OrderedCollections, Parameters
using Tables, CategoricalArrays, StatsBase, Statistics
using Tables, CategoricalArrays, StatsBase, Statistics, Dates
import Distributions

# for administrators to update Metadata.toml:
Expand Down
53 changes: 51 additions & 2 deletions src/builtins/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,56 @@ end
MLJBase.inverse_transform(transformer::UnivariateStandardizer, fitresult, w) =
[inverse_transform(transformer, fitresult, y) for y in w]

##########################################################################################
## CONTINUOUS TRANSFORM OF TIME TYPE FEATURES

"""
UnivariateTimeTypeToContinuous(zero_time=nothing, step=Hour(24))

Convert `Date`, `DateTime`, and `Time` vectors to `Float64` by assuming `0.0` corresponds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I suggest "Convert a Date, DateTime, and Time vector to prevent confusion with multi-variate case.

to the `zero_time` parameter and the time increment to reach `1.0` is given by the `step`
parameter. The type of `zero_time` should match the type of the column if provided. If not
provided, then `zero_time` is inferred as the minimum time found in the data when `fit` is
called.

"""
mutable struct UnivariateTimeTypeToContinuous <: Unsupervised
zero_time::Union{Nothing, TimeType}
step::Period
function UnivariateTimeTypeToContinuous(;
zero_time=nothing, step=Dates.Hour(24))
return new(zero_time, step)
end
ablaom marked this conversation as resolved.
Show resolved Hide resolved
end

function MLJBase.fit(model::UnivariateTimeTypeToContinuous, verbosity::Int, X)
if model.zero_time !== nothing
ablaom marked this conversation as resolved.
Show resolved Hide resolved
fitresult = model.zero_time
else
min_dt = minimum(X)
fitresult = min_dt
end
cache = nothing
report = nothing
return fitresult, cache, report
end

function MLJBase.transform(model::UnivariateTimeTypeToContinuous, fitresult, X)
# Set the size of a single step.
ablaom marked this conversation as resolved.
Show resolved Hide resolved
next_time = fitresult + model.step
if next_time == fitresult
# Time type loops if model.step is a multiple of Hour(24), so calculate the
# number of multiples, then re-scale to Hour(12) and adjust delta to match original.
m = model.step / Dates.Hour(12)
delta = m * (
Float64(Dates.value(fitresult + Dates.Hour(12)) - Dates.value(fitresult)))
else
delta = Float64(Dates.value(fitresult + model.step) - Dates.value(fitresult))
end
return @. Float64(Dates.value(X - fitresult)) / delta
end


## STANDARDIZATION OF ORDINAL FEATURES OF TABULAR DATA

"""
Expand Down Expand Up @@ -862,7 +912,7 @@ the last class indicator column.
`Multiclass` or `OrderedFactor` column is the same in new data being
transformed as it is in the data used to fit the transformer.

### Example
### Example

```julia
X = (name=categorical(["Danesh", "Lee", "Mary", "John"]),
Expand Down Expand Up @@ -1036,4 +1086,3 @@ metadata_model(ContinuousEncoder,
weights = false,
descr = CONTINUOUS_ENCODER_DESCR,
path = "MLJModels.ContinuousEncoder")