Skip to content

Commit

Permalink
improve plots recipes
Browse files Browse the repository at this point in the history
  • Loading branch information
Evizero committed Dec 2, 2016
1 parent 7dd53a2 commit cc99fe8
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions src/supervised/io.jl
Expand Up @@ -2,12 +2,13 @@
Base.print(io::IO, loss::SupervisedLoss, args...) = print(io, typeof(loss).name.name, args...)
Base.print(io::IO, loss::L1DistLoss, args...) = print(io, "L1DistLoss", args...)
Base.print(io::IO, loss::L2DistLoss, args...) = print(io, "L2DistLoss", args...)
Base.print{P}(io::IO, loss::LPDistLoss{P}, args...) = print(io, typeof(loss).name.name, " with P=$(P)", args...)
Base.print(io::IO, loss::L1EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with ɛ=$(loss.ε)", args...)
Base.print(io::IO, loss::L2EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with ɛ=$(loss.ε)", args...)
Base.print(io::IO, loss::QuantileLoss, args...) = print(io, typeof(loss).name.name, " with τ=$(loss.τ)", args...)
Base.print(io::IO, loss::SmoothedL1HingeLoss, args...) = print(io, typeof(loss).name.name, " with γ = $(loss.gamma)", args...)
Base.print(io::IO, loss::PeriodicLoss, args...) = print(io, typeof(loss).name.name, " with circumf=$(round(loss.k / 2π,1))", args...)
Base.print{P}(io::IO, loss::LPDistLoss{P}, args...) = print(io, typeof(loss).name.name, " with P = $(P)", args...)
Base.print(io::IO, loss::L1EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with \$\\varepsilon\$ = $(loss.ε)", args...)
Base.print(io::IO, loss::L2EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with \$\\varepsilon\$ = $(loss.ε)", args...)
Base.print(io::IO, loss::QuantileLoss, args...) = print(io, typeof(loss).name.name, " with \$\\tau\$ = $(loss.τ)", args...)
Base.print(io::IO, loss::SmoothedL1HingeLoss, args...) = print(io, typeof(loss).name.name, " with \$\\gamma\$ = $(loss.gamma)", args...)
Base.print(io::IO, loss::DWDMarginLoss, args...) = print(io, typeof(loss).name.name, " with q = $(loss.q)", args...)
Base.print(io::IO, loss::PeriodicLoss, args...) = print(io, typeof(loss).name.name, " with circumf = $(round(loss.k / 2π,1))", args...)
Base.print(io::IO, loss::ScaledLoss, args...) = print(io, typeof(loss).name.name, " $(loss.factor) * [ $(loss.loss) ]", args...)

# -------------------------------------------------------------
Expand All @@ -16,18 +17,46 @@ Base.print(io::IO, loss::ScaledLoss, args...) = print(io, typeof(loss).name.name
_loss_xguide(loss::MarginLoss) = "y ⋅ h(x)"
_loss_xguide(loss::DistanceLoss) = "h(x) - y"

@recipe function plot(loss::SupervisedLoss, xmin = -2, xmax = 2)
@recipe function plot(drv::Deriv, xmin, xmax)
xguide --> _loss_xguide(drv.loss)
yguide --> "L'(y, h(x))"
label --> string(drv.loss)
deriv_fun(drv.loss), xmin, xmax
end

@recipe function plot(drv::Deriv, rng = -2:0.05:2)
xguide --> _loss_xguide(drv.loss)
yguide --> "L'(y, h(x))"
label --> string(drv.loss)
deriv_fun(drv.loss), rng
end

@recipe function plot(loss::SupervisedLoss, xmin, xmax)
xguide --> _loss_xguide(loss)
yguide --> "L(y, h(x))"
label --> string(loss)
value_fun(loss), xmin, xmax
end

@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, xmin = -2, xmax = 2)
@recipe function plot(loss::SupervisedLoss, rng = -2:0.05:2)
xguide --> _loss_xguide(loss)
yguide --> "L(y, h(x))"
label --> string(loss)
value_fun(loss), rng
end

@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, xmin, xmax)
for loss in losses
@series begin
loss, xmin, xmax
end
end
end
@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, rng = -2:0.05:2)
for loss in losses
@series begin
loss, rng
end
end
end

0 comments on commit cc99fe8

Please sign in to comment.