Skip to content

Commit

Permalink
Merge pull request #814 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.20.15 release
  • Loading branch information
ablaom authored Jul 25, 2022
2 parents cd57ac3 + 3dc0ebc commit 2b90d4a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.20.14"
version = "0.20.15"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
9 changes: 5 additions & 4 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
## predict(::Machine, ) and transform(::Machine, )

_err_rows_not_allowed() =
throw(ArgumentError("Calling `transform(mach, rows=...)` when "*
throw(ArgumentError("Calling `transform(mach, rows=...)` or "*
"`predict(mach, rows=...)` when "*
"`mach.model isa Static` is not allowed, as no data "*
"is bound to `mach` in this case. Specify a explicit "*
"data or node, as in `transform(mach, X)`, or "*
Expand Down Expand Up @@ -80,14 +81,14 @@ for operation in OPERATIONS
)
return get!(ret, $quoted_operation, mach)
end

# special case of Static models (no training arguments):
$operation(mach::Machine{<:Static}; rows=:) = _err_rows_not_allowed()
end
eval(ex)

end

# special case of Static models (no training arguments):
transform(mach::Machine{<:Static}; rows=:) = _err_rows_not_allowed()

inverse_transform(mach::Machine; rows=:) =
throw(ArgumentError("`inverse_transform(mach)` and "*
"`inverse_transform(mach, rows=...)` are "*
Expand Down
6 changes: 5 additions & 1 deletion test/composition/models/static_transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
MLJBase.transform(transf::PlainTransformer, verbosity, X) =
selectcols(X, transf.ftr)

@testset "nodal machine constructor for static transformers" begin
@testset "machine constructor for static transformers" begin
X = (x1=rand(3), x2=[1, 2, 3]);
mach = machine(PlainTransformer(:x2))
@test transform(mach, X) == [1, 2, 3]
Expand All @@ -31,10 +31,14 @@ MLJBase.reporting_operations(::Type{<:YourTransformer}) = (:transform,)
MLJBase.transform(transf::YourTransformer, verbosity, X) =
(selectcols(X, transf.ftr), (; nrows=nrows(X)))

MLJBase.predict(transf::YourTransformer, verbosity, X) =
collect(1:nrows(X)) |> reverse

@testset "nodal machine constructor for static transformers" begin
X = (x1=rand(3), x2=[1, 2, 3]);
mach = machine(YourTransformer(:x2))
@test transform(mach, X) == [1, 2, 3]
@test predict(mach, X) == [3, 2, 1]
@test report(mach).nrows == 3
transform(mach, (x2=["a", "b"],))
@test report(mach).nrows == 2
Expand Down

0 comments on commit 2b90d4a

Please sign in to comment.