/
regressor.jl
167 lines (138 loc) · 6.38 KB
/
regressor.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
mutable struct NeuralNetworkRegressor{B<:Builder,O,L} <: MLJModelInterface.Deterministic
builder::B
optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl
loss::L # can be called as in `loss(yhat, y)`
epochs::Int # number of epochs
batch_size::Int # size of a batch
lambda::Float64 # regularization strength
alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1)
optimiser_changes_trigger_retraining::Bool
end
NeuralNetworkRegressor(; builder::B = Linear()
, optimiser::O = Flux.Optimise.ADAM()
, loss::L = Flux.mse
, epochs = 10
, batch_size = 1
, lambda = 0
, alpha = 0
, optimiser_changes_trigger_retraining=false
) where {B,O,L} =
NeuralNetworkRegressor{B,O,L}(builder
, optimiser
, loss
, epochs
, batch_size
, lambda
, alpha
, optimiser_changes_trigger_retraining)
mutable struct MultitargetNeuralNetworkRegressor{B<:Builder,O,L} <: MLJModelInterface.Deterministic
builder::B
optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl
loss::L # can be called as in `loss(yhat, y)`
epochs::Int # number of epochs
batch_size::Int # size of a batch
lambda::Float64 # regularization strength
alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1)
optimiser_changes_trigger_retraining::Bool
end
MultitargetNeuralNetworkRegressor(; builder::B = Linear()
, optimiser::O = Flux.Optimise.ADAM()
, loss::L = Flux.mse
, epochs = 10
, batch_size = 1
, lambda = 0
, alpha = 0
, optimiser_changes_trigger_retraining=false
) where {B,O,L} =
MultitargetNeuralNetworkRegressor{B,O,L}(builder
, optimiser
, loss
, epochs
, batch_size
, lambda
, alpha
, optimiser_changes_trigger_retraining)
const Regressor =
Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor}
function MLJModelInterface.fit(model::Regressor, verbosity::Int, X, y)
# (assumes no categorical features)
n_input = Tables.schema(X).names |> length
data = collate(model, X, y)
target_is_multivariate = Tables.istable(y)
if target_is_multivariate
target_column_names = Tables.schema(y).names
else
target_column_names = [""] # We won't be using this
end
n_output = length(target_column_names)
chain = fit(model.builder, n_input, n_output)
optimiser = deepcopy(model.optimiser)
chain, history = fit!(chain, optimiser, model.loss,
model.epochs, model.lambda,
model.alpha, verbosity, data)
cache = (deepcopy(model), data, history, n_input, n_output)
fitresult = (chain, target_is_multivariate, target_column_names)
report = (training_losses=[loss.data for loss in history],)
return fitresult, cache, report
end
function MLJModelInterface.update(model::Regressor,
verbosity::Int,
old_fitresult,
old_cache,
X,
y)
old_model, data, old_history, n_input, n_output = old_cache
old_chain, target_is_multivariate, target_column_names = old_fitresult
optimiser_flag = model.optimiser_changes_trigger_retraining &&
model.optimiser != old_model.optimiser
keep_chain = !optimiser_flag && model.epochs >= old_model.epochs &&
MLJModelInterface.is_same_except(model, old_model, :optimiser, :epochs)
if keep_chain
chain = old_chain
epochs = model.epochs - old_model.epochs
else
chain = fit(model.builder, n_input, n_output)
data = collate(model, X, y)
epochs = model.epochs
end
optimiser = deepcopy(model.optimiser)
chain, history = fit!(chain, optimiser, model.loss, epochs,
model.lambda, model.alpha,
verbosity, data)
if keep_chain
history = vcat(old_history, history)
end
fitresult = (chain, target_is_multivariate, target_column_names)
cache = (deepcopy(model), data, history, n_input, n_output)
report = (training_losses=[loss.data for loss in history],)
return fitresult, cache, report
end
function MLJModelInterface.predict(model::Regressor, fitresult, Xnew_)
chain , target_is_multivariate, target_column_names = fitresult
Xnew_ = MLJModelInterface.matrix(Xnew_)
if target_is_multivariate
ypred = [map(x->x.data, chain(values.(Xnew_[i, :])))
for i in 1:size(Xnew_, 1)]
return MLJModelInterface.table(reduce(hcat, y for y in ypred)',
names=target_column_names)
else
return [chain(values.(Xnew_[i, :]))[1].data for i in 1:size(Xnew_, 1)]
end
end
MLJModelInterface.metadata_model(NeuralNetworkRegressor,
input=Table(Continuous),
target=AbstractVector{<:Continuous},
path="MLJFlux.NeuralNetworkRegressor",
descr="A neural network model for making "*
"deterministic predictions of a "*
"`Continuous` target, given a table of "*
"`Continuous` features. ")
MLJModelInterface.metadata_model(MultitargetNeuralNetworkRegressor,
input=Table(Continuous),
target=Table(Continuous),
path="MLJFlux.NeuralNetworkRegressor",
descr = "A neural network model for making "*
"deterministic predictions of a "*
"`Continuous` multi-target, presented "*
"as a table, given a table of "*
"`Continuous` features. ")