Skip to content

Commit

Permalink
Merge branch 'ayush-1506-mvlr'
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 15, 2019
2 parents 8f0d65d + 6d6e0ca commit 504eb62
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 17 deletions.
12 changes: 6 additions & 6 deletions docs/src/simple_user_defined_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@ Here's a quick-and-dirty implementation of a ridge regressor with no intercept:
import MLJBase
using LinearAlgebra

mutable struct MyRegressor <: MLJBase.Deterministic
mutable struct SimpleRidgeRegressor <: MLJBase.Deterministic
lambda::Float64
end
MyRegressor(; lambda=0.1) = MyRegressor(lambda)
SimpleRidgeRegressor(; lambda=0.1) = SimpleRidgeRegressor(lambda)

# fit returns coefficients minimizing a penalized rms loss function:
function MLJBase.fit(model::MyRegressor, X, y)
function MLJBase.fit(model::SimpleRidgeRegressor, X, y)
x = MLJBase.matrix(X) # convert table to matrix
fitresult = (x'x - model.lambda*I)\(x'y) # the coefficients
return fitresult
end

# predict uses coefficients to make new prediction:
MLJBase.predict(model::MyRegressor, fitresult, Xnew) = MLJBase.matrix(Xnew)fitresult
MLJBase.predict(model::SimpleRidgeRegressor, fitresult, Xnew) = MLJBase.matrix(Xnew)fitresult
````

After loading this code, all MLJ's basic meta-algorithms can be applied to `MyRegressor`:
After loading this code, all MLJ's basic meta-algorithms can be applied to `SimpleRidgeRegressor`:

````julia
julia> using MLJ
julia> task = load_boston()
julia> model = MyRegressor(lambda=1.0)
julia> model = SimpleRidgeRegressor(lambda=1.0)
julia> regressor = machine(model, task)
julia> evaluate!(regressor, resampling=CV(), measure=rms) |> mean
7.434221318358656
Expand Down
4 changes: 2 additions & 2 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ include("tasks.jl") # enhancements to task interface defined in MLJBase
include("builtins/Transformers.jl")
include("builtins/Constant.jl")
include("builtins/KNN.jl")
include("builtins/LocalMultivariateStats.jl")

include("builtins/LocalMultivariateStats.jl")
include("builtins/ridge.jl")

## GET THE EXTERNAL MODEL METADATA AND MERGE WITH MLJ MODEL METADATA

Expand Down
57 changes: 57 additions & 0 deletions src/builtins/ridge.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Defines a simple deterministic regressor for MLJ testing purposes
# only. MLJ users should use RidgeRegressor from MultivariateStats.

import MLJBase
using LinearAlgebra

export SimpleRidgeRegressor

mutable struct SimpleRidgeRegressor <: MLJBase.Deterministic
lambda::Float64
end

function SimpleRidgeRegressor(; lambda=0.0)
simpleridgemodel = SimpleRidgeRegressor(lambda)
message = MLJBase.clean!(simpleridgemodel)
isempty(message) || @warn message
return simpleridgemodel
end

function MLJ.clean!(model::SimpleRidgeRegressor)
warning = ""
if model.lambda < 0
warning *= "Need lambda ≥ 0. Resetting lambda=0. "
model.lambda = 0
end
return warning
end

function MLJBase.fitted_params(::SimpleRidgeRegressor, fitresult)
return (coefficients=fitresult)
end

function MLJBase.fit(model::SimpleRidgeRegressor, verbosity::Int, X, y)
x = MLJBase.matrix(X)
fitresult = (x'x - model.lambda*I)\(x'y)
cache = nothing
report = NamedTuple()
return fitresult, cache, report
end


function MLJBase.predict(model::SimpleRidgeRegressor, fitresult, Xnew)
x = MLJBase.matrix(Xnew)
return x*fitresult
end

# to hide from models generated from calls to models()
MLJBase.is_wrapper(::Type{<:SimpleRidgeRegressor}) = true


# metadata:
MLJBase.load_path(::Type{<:SimpleRidgeRegressor}) = "MLJ.SimpleRidgeRegressor"
MLJBase.package_name(::Type{<:SimpleRidgeRegressor}) = "MLJ"
MLJBase.package_uuid(::Type{<:SimpleRidgeRegressor}) = ""
MLJBase.is_pure_julia(::Type{<:SimpleRidgeRegressor}) = true
MLJBase.input_scitype_union(::Type{<:SimpleRidgeRegressor}) = MLJBase.Continuous
MLJBase.target_scitype_union(::Type{<:SimpleRidgeRegressor}) = MLJBase.Continuous
10 changes: 5 additions & 5 deletions test/composites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ train, test = partition(eachindex(yin), 0.7);
Xtrain = Xin[train,:];
ytrain = yin[train];

ridge_model = RidgeRegressor(lambda=0.1)
ridge_model = SimpleRidgeRegressor(lambda=0.1)
selector_model = FeatureSelector()

composite = MLJ.SimpleDeterministicCompositeModel(model=ridge_model, transformer=selector_model)
Expand All @@ -25,21 +25,21 @@ selector_old = deepcopy(selector)

# this should trigger no retraining:
fitresult, cache, report = MLJ.update(composite, 3, fitresult, cache, Xtrain, ytrain);
@test ridge.fitresult.coefficients == ridge_old.fitresult.coefficients
@test ridge.fitresult == ridge_old.fitresult
@test selector.fitresult == selector_old.fitresult

# this should trigger retraining of selector and ridge:
selector_model.features = [:Crim, :Rm]
fitresult, cache, report = MLJ.update(composite, 2, fitresult, cache, Xtrain, ytrain)
@test ridge.fitresult.bias != ridge_old.fitresult.bias
@test ridge.fitresult != ridge_old.fitresult
@test selector.fitresult != selector_old.fitresult
ridge_old = deepcopy(ridge)
selector_old = deepcopy(selector)

# this should trigger retraining of ridge only:
ridge_model.lambda = 1.0
fitresult, cache, report = MLJ.update(composite, 2, fitresult, cache, Xtrain, ytrain)
@test ridge.fitresult.bias != ridge_old.fitresult.bias
@test ridge.fitresult != ridge_old.fitresult
@test selector.fitresult == selector_old.fitresult

predict(composite, fitresult, Xin[test,:]);
Expand Down Expand Up @@ -84,7 +84,7 @@ end

X, y = datanow()

ridge = RidgeRegressor(lambda=0.1)
ridge = SimpleRidgeRegressor(lambda=0.1)
model = WrappedRidge(ridge)
mach = machine(model, X, y)
fit!(mach)
Expand Down
2 changes: 1 addition & 1 deletion test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end

# holdout:
X, y = datanow()
ridge_model = RidgeRegressor(lambda=20.0)
ridge_model = SimpleRidgeRegressor(lambda=20.0)
resampler = Resampler(resampling=holdout, model=ridge_model)
resampling_machine = machine(resampler, X, y)
fit!(resampling_machine)
Expand Down
19 changes: 19 additions & 0 deletions test/ridge.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module TestSimpleRidgeRegressor

using Test
using MLJ

# These are calculated values, the coefficients (parmeters) should be [1, -1 , 2]
ip = [1 0 3 4 5; 2 1 3 -3 4; 0 1 -3 2 1]
op = [-1, 1, -6, 11, 3]

ip = MLJ.table(ip')

model = MLJ.SimpleRidgeRegressor(lambda = 0.0)

fitresult, report ,cache = MLJ.fit(model, 0, ip, op);

@test MLJ.predict(model, fitresult, Float64[-1,0,2]')[1] 3

end
true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ end
@test include("KNN.jl")
end

@testset "ridge" begin
@test include("ridge.jl")
end

@testset "Constant" begin
@test include("Constant.jl")
end
Expand Down
6 changes: 3 additions & 3 deletions test/tuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.2*rand(100);

sel = FeatureSelector()
stand = UnivariateStandardizer()
ridge = RidgeRegressor()
ridge = SimpleRidgeRegressor()
composite = MLJ.SimpleDeterministicCompositeModel(transformer=sel, model=ridge)

features_ = range(sel, :features, values=[[:x1], [:x1, :x2], [:x2, :x3], [:x1, :x2, :x3]])
Expand Down Expand Up @@ -55,7 +55,7 @@ e = rms(y, predict(tuned, X))
r = e/tuned.report.best_measurement
@test r < 10 && r > 0.1

ridge = RidgeRegressor()
ridge = SimpleRidgeRegressor()
tuned_model = TunedModel(model=ridge,
nested_ranges=(lambda = range(ridge, :lambda, lower=0.01, upper=1.0),))
tuned = machine(tuned_model, X, y)
Expand All @@ -65,7 +65,7 @@ fit!(tuned)
## LEARNING CURVE

X, y = datanow()
atom = RidgeRegressor()
atom = SimpleRidgeRegressor()
ensemble = EnsembleModel(atom=atom)
mach = machine(ensemble, X, y)
r_lambda = range(atom, :lambda, lower=0.1, upper=100, scale=:log10)
Expand Down

0 comments on commit 504eb62

Please sign in to comment.