Skip to content

Commit

Permalink
Merge pull request #355 from alan-turing-institute/warning-removal
Browse files Browse the repository at this point in the history
Warning removal
  • Loading branch information
ablaom committed Nov 22, 2019
2 parents b35c4a2 + f226d1a commit 75e5a25
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
23 changes: 15 additions & 8 deletions src/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ struct Node{T<:Union{NodalMachine, Nothing}} <: AbstractNode
end

origins_ = unique(vcat([origins(arg) for arg in args]...))
length(origins_) == 1 ||
@warn "A node referencing multiple origins when called " *
"has been defined:\n$(origins_). "
# length(origins_) == 1 ||
# @warn "A node referencing multiple origins when called " *
# "has been defined:\n$(origins_). "

# initialize the list of upstream nodes:
nodes_ = AbstractNode[]
Expand Down Expand Up @@ -287,17 +287,24 @@ Node(operation, machine::M, args...) where M <: Union{NodalMachine,Nothing} =
Node(operation, args::AbstractNode...) = Node(operation, nothing, args...)

# make nodes callable:
(y::Node)(; rows=:) = (y.operation)(y.machine, [arg(rows=rows) for arg in y.args]...)
(y::Node)(; rows=:) =
(y.operation)(y.machine, [arg(rows=rows) for arg in y.args]...)
function (y::Node)(Xnew)
length(y.origins) == 1 ||
error("Nodes with multiple origins are not callable on new data. "*
"Use origins(node) to inspect. ")
error("Node $y has multiple origins and cannot be called "*
"on new data. ")
return (y.operation)(y.machine, [arg(Xnew) for arg in y.args]...)
end

# and for the special case of static operations:
(y::Node{Nothing})(; rows=:) = (y.operation)([arg(rows=rows) for arg in y.args]...)
(y::Node{Nothing})(Xnew) = (y.operation)([arg(Xnew) for arg in y.args]...)
(y::Node{Nothing})(; rows=:) =
(y.operation)([arg(rows=rows) for arg in y.args]...)
function (y::Node{Nothing})(Xnew)
length(y.origins) == 1 ||
error("Node $y has multiple origins and cannot be called "*
"on new data. ")
return (y.operation)([arg(Xnew) for arg in y.args]...)
end

"""
fit!(N::Node; rows=nothing, verbosity::Int=1, force::Bool=false)
Expand Down
8 changes: 6 additions & 2 deletions test/composites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,18 @@ ex = Meta.parse("Comp() <= z")
MLJ.from_network_preprocess(TestComposites,ex))

X2s = source(nothing)
z = @test_logs (:warn, r"^A node ref") vcat(Xs, X2s)
# z = @test_logs (:warn, r"^A node ref") vcat(Xs, X2s)
z = vcat(Xs, X2s)
ex = Meta.parse("Comp() <= z")
@test_throws(ArgumentError,
MLJ.from_network_preprocess(TestComposites, ex))


y2s = source(nothing, kind=:target)
z = @test_logs (:warn, r"^A node ref") vcat(ys, y2s, Xs)
# z = @test_logs (:warn, r"^A node ref") vcat(ys, y2s, Xs)
z = vcat(ys, y2s, Xs)
@test_throws Exception z(Xs())

ex = Meta.parse("Comp() <= z")
@test_throws(ArgumentError,
MLJ.from_network_preprocess(TestComposites, ex))
Expand Down

0 comments on commit 75e5a25

Please sign in to comment.