-
Notifications
You must be signed in to change notification settings - Fork 13
/
generic.jl
179 lines (145 loc) · 5.08 KB
/
generic.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
export Cost,
NoLoss, NoPenalty,
AtomicLoss, AtomicPenalty,
ScaledLoss, ScaledPenalty,
CompositeLoss, CompositePenalty
abstract type Cost end
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
# Loss: (x, y) -> L(x, y)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
abstract type Loss <: Cost end
struct NoLoss <: Loss end
abstract type AtomicLoss <: Loss end
mutable struct ScaledLoss{AL} <: Loss where AL <: AtomicLoss
loss::AL
scale::Float64
end
mutable struct CompositeLoss <: Loss
losses::Vector{ScaledLoss}
end
(sl::ScaledLoss)(x::AVR, y::AVR) = sl.scale * sl.loss(x, y)
(cl::CompositeLoss)(x::AVR, y::AVR) = sum(loss(x, y) for loss ∈ cl.losses)
getscale(n::NoLoss) = 0.0
getscale(l::AtomicLoss) = 1.0
getscale(l::ScaledLoss) = l.scale
# Convenient extension for classification
abstract type MultiClassLoss{c} <: AtomicLoss where c end
getc(m) = 0
getc(m, y) = 0
getc(m::MultiClassLoss{c}) where c = c
getc(m::MultiClassLoss{0}, y) = maximum(y)
getc(m::MultiClassLoss{c}, y) where c = c
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
# Penalty: θ -> P(θ)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
abstract type Penalty <: Cost end
struct NoPenalty <: Penalty end
abstract type AtomicPenalty <: Penalty end
mutable struct ScaledPenalty{AP} <: Penalty where AP <: AtomicPenalty
penalty::AP
scale::Float64
end
mutable struct CompositePenalty <: Penalty
penalties::Vector{ScaledPenalty}
end
(sp::ScaledPenalty)(θ::AVR) = sp.scale * sp.penalty(θ)
(cl::CompositePenalty)(θ::AVR) = sum(penalty(θ) for penalty ∈ cl.penalties)
getscale(n::NoPenalty) = 0.0
getscale(p::AtomicPenalty) = 1.0
getscale(p::ScaledPenalty) = p.scale
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
# Objective function: (x, y, θ) -> L(x, y) + P(θ)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
mutable struct ObjectiveFunction{L<:Loss,P<:Penalty} <: Cost
loss::L
penalty::P
end
(J::ObjectiveFunction)(y, ŷ, θ) = J.loss(y, ŷ) + J.penalty(θ)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
# Composition of Loss & Penalty functions
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
const AL = AtomicLoss
const AP = AtomicPenalty
const SL = ScaledLoss
const CL = CompositeLoss
const SP = ScaledPenalty
const CP = CompositePenalty
const NL = NoLoss
const NP = NoPenalty
const OF = ObjectiveFunction
scale1(a::AL) = ScaledLoss(a, 1.0)
scale1(a::AP) = ScaledPenalty(a, 1.0)
# Combinations with NoLoss (NL)
*(::NL, ::Real) = NoLoss()
+(::NL, l::Loss) = l
+(l::Loss, ::NL) = l
# Combinations with NoPenalty (NP)
*(::NP, ::Real) = NoPenalty()
+(::NP, p::Penalty) = p
+(p::Penalty, ::NP) = p
# Combinations with AtomicLoss (AL)
+(a::AL, b::AL) = scale1(a) + scale1(b)
+(a::AL, b::Union{SL,CL}) = scale1(a) + b
+(b::Union{SL,CL}, a::AL) = a + b
*(a::AL, c::Real) = ScaledLoss(a, float(c))
# Combinations with AtomicPenalty (AP)
+(a::AP, b::AP) = scale1(a) + scale1(b)
+(a::AP, b::Union{SP,CP}) = scale1(a) + b
+(b::Union{SP,CP}, a::AP) = a + b
*(a::AP, c::Real) = ScaledPenalty(a, float(c))
# Combinations with Scaled Losses and Combined Losses
+(a::SL{T}, b::SL{T}) where {T} = ScaledLoss(a.loss, a.scale + b.scale)
+(a::SL{T1}, b::SL{T2}) where {T1,T2} = CL([a, b])
+(a::CL, b::CL) = begin
a_ = a.losses
a_T = typeof.(a_)
c_ = copy(a_)
rem = ones(Bool, length(b.losses))
for (i, L) in enumerate(b.losses)
m = findfirst(typeof(L) .== a_T)
if m !== nothing
c_[m] = c_[m] + L # will be SL{T} + SL{T}
rem[i] = false
end
end
CL(vcat(c_, b.losses[rem]))
end
+(a::SL, c::CL) = CL([a]) + c
+(c::CL, a::SL) = a + c
*(a::SL, c::Real) = SL(a.loss, c * a.scale)
*(a::CL, c::Real) = CL(a.losses .* c)
# Combinations with Scaled Penalties and Combined Penalties
+(a::SP{T}, b::SP{T}) where {T} = ScaledPenalty(a.penalty, a.scale + b.scale)
+(a::SP{T1}, b::SP{T2}) where {T1,T2} = CP([a, b])
+(a::CP, b::CP) = begin
a_ = a.penalties
a_T = typeof.(a_)
c_ = copy(a_)
rem = ones(Bool, length(b.penalties))
for (i, P) in enumerate(b.penalties)
m = findfirst(typeof(P) .== a_T)
if m !== nothing
c_[m] = c_[m] + P # will be SP{T} + SP{T}
rem[i] = false
end
end
CP(vcat(c_, b.penalties[rem]))
end
+(a::SP, c::CP) = CP([a]) + c
+(c::CP, a::SP) = a + c
*(a::SP, c::Real) = SP(a.penalty, c * a.scale)
*(a::CP, c::Real) = CP(a.penalties .* c)
# higher combinations ==> OF
+(l::Loss, p::Penalty) = OF(l, p)
+(p::Penalty, l::Loss) = l + p
+(o::OF, l::Loss) = OF(o.loss + l, o.penalty)
+(l::Loss, o::OF) = o + l
+(o::OF, p::Penalty) = OF(o.loss, o.penalty + p)
+(p::Penalty, o::OF) = o + p
+(a::OF, b::OF) = OF(a.loss+b.loss, a.penalty+b.penalty)
*(o::OF, a::Real) = OF(a * o.loss, a * o.penalty)
# Symetric relations
*(a::Real, c::Cost) = c * a
# - and / operations with Objective Functions (just use + and *)
-(a::Cost, b::Cost) = a + (-1 * b)
/(a::Cost, c::Real) = a * (1 / c)