From cc99fe8aa483b686d16d204f3e6281fb0dd86056 Mon Sep 17 00:00:00 2001 From: Christof Stocker Date: Fri, 2 Dec 2016 03:36:10 +0100 Subject: [PATCH] improve plots recipes --- src/supervised/io.jl | 45 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/src/supervised/io.jl b/src/supervised/io.jl index 476dd92..8292fb4 100644 --- a/src/supervised/io.jl +++ b/src/supervised/io.jl @@ -2,12 +2,13 @@ Base.print(io::IO, loss::SupervisedLoss, args...) = print(io, typeof(loss).name.name, args...) Base.print(io::IO, loss::L1DistLoss, args...) = print(io, "L1DistLoss", args...) Base.print(io::IO, loss::L2DistLoss, args...) = print(io, "L2DistLoss", args...) -Base.print{P}(io::IO, loss::LPDistLoss{P}, args...) = print(io, typeof(loss).name.name, " with P=$(P)", args...) -Base.print(io::IO, loss::L1EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with ɛ=$(loss.ε)", args...) -Base.print(io::IO, loss::L2EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with ɛ=$(loss.ε)", args...) -Base.print(io::IO, loss::QuantileLoss, args...) = print(io, typeof(loss).name.name, " with τ=$(loss.τ)", args...) -Base.print(io::IO, loss::SmoothedL1HingeLoss, args...) = print(io, typeof(loss).name.name, " with γ = $(loss.gamma)", args...) -Base.print(io::IO, loss::PeriodicLoss, args...) = print(io, typeof(loss).name.name, " with circumf=$(round(loss.k / 2π,1))", args...) +Base.print{P}(io::IO, loss::LPDistLoss{P}, args...) = print(io, typeof(loss).name.name, " with P = $(P)", args...) +Base.print(io::IO, loss::L1EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with \$\\varepsilon\$ = $(loss.ε)", args...) +Base.print(io::IO, loss::L2EpsilonInsLoss, args...) = print(io, typeof(loss).name.name, " with \$\\varepsilon\$ = $(loss.ε)", args...) +Base.print(io::IO, loss::QuantileLoss, args...) = print(io, typeof(loss).name.name, " with \$\\tau\$ = $(loss.τ)", args...) +Base.print(io::IO, loss::SmoothedL1HingeLoss, args...) = print(io, typeof(loss).name.name, " with \$\\gamma\$ = $(loss.gamma)", args...) +Base.print(io::IO, loss::DWDMarginLoss, args...) = print(io, typeof(loss).name.name, " with q = $(loss.q)", args...) +Base.print(io::IO, loss::PeriodicLoss, args...) = print(io, typeof(loss).name.name, " with circumf = $(round(loss.k / 2π,1))", args...) Base.print(io::IO, loss::ScaledLoss, args...) = print(io, typeof(loss).name.name, " $(loss.factor) * [ $(loss.loss) ]", args...) # ------------------------------------------------------------- @@ -16,18 +17,46 @@ Base.print(io::IO, loss::ScaledLoss, args...) = print(io, typeof(loss).name.name _loss_xguide(loss::MarginLoss) = "y ⋅ h(x)" _loss_xguide(loss::DistanceLoss) = "h(x) - y" -@recipe function plot(loss::SupervisedLoss, xmin = -2, xmax = 2) +@recipe function plot(drv::Deriv, xmin, xmax) + xguide --> _loss_xguide(drv.loss) + yguide --> "L'(y, h(x))" + label --> string(drv.loss) + deriv_fun(drv.loss), xmin, xmax +end + +@recipe function plot(drv::Deriv, rng = -2:0.05:2) + xguide --> _loss_xguide(drv.loss) + yguide --> "L'(y, h(x))" + label --> string(drv.loss) + deriv_fun(drv.loss), rng +end + +@recipe function plot(loss::SupervisedLoss, xmin, xmax) xguide --> _loss_xguide(loss) yguide --> "L(y, h(x))" label --> string(loss) value_fun(loss), xmin, xmax end -@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, xmin = -2, xmax = 2) +@recipe function plot(loss::SupervisedLoss, rng = -2:0.05:2) + xguide --> _loss_xguide(loss) + yguide --> "L(y, h(x))" + label --> string(loss) + value_fun(loss), rng +end + +@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, xmin, xmax) for loss in losses @series begin loss, xmin, xmax end end end +@recipe function plot{T<:SupervisedLoss}(losses::AbstractVector{T}, rng = -2:0.05:2) + for loss in losses + @series begin + loss, rng + end + end +end