-
-
Notifications
You must be signed in to change notification settings - Fork 182
/
GridSearch.php
358 lines (305 loc) · 9.04 KB
/
GridSearch.php
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
<?php
namespace Rubix\ML;
use Rubix\ML\Helpers\Params;
use Rubix\ML\Backends\Serial;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Traits\LoggerAware;
use Rubix\ML\CrossValidation\KFold;
use Rubix\ML\Traits\Multiprocessing;
use Rubix\ML\Traits\AutotrackRevisions;
use Rubix\ML\CrossValidation\Validator;
use Rubix\ML\Backends\Tasks\CrossValidate;
use Rubix\ML\CrossValidation\Metrics\RMSE;
use Rubix\ML\CrossValidation\Metrics\FBeta;
use Rubix\ML\CrossValidation\Metrics\Metric;
use Rubix\ML\Specifications\DatasetIsLabeled;
use Rubix\ML\CrossValidation\Metrics\Accuracy;
use Rubix\ML\CrossValidation\Metrics\VMeasure;
use Rubix\ML\Specifications\DatasetIsNotEmpty;
use Rubix\ML\Specifications\SpecificationChain;
use Rubix\ML\Specifications\LabelsAreCompatibleWithLearner;
use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
use Rubix\ML\Exceptions\InvalidArgumentException;
/**
* Grid Search
*
* Grid Search is an algorithm that optimizes hyper-parameter selection. From
* the user's perspective, the process of training and predicting is the same,
* however, under the hood, Grid Search trains one estimator per combination
* of parameters and the best model is selected as the base estimator.
*
* > **Note:** You can choose the hyper-parameters manually or you can generate
* them randomly or in a grid using the Params helper.
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
*/
class GridSearch implements EstimatorWrapper, Learner, Parallel, Verbose, Persistable
{
use AutotrackRevisions, Multiprocessing, LoggerAware;
/**
* The class name of the base estimator.
*
* @var string
*/
protected string $class;
/**
* An array of lists containing the possible values for each of the base learner's constructor parameters.
*
* @var list<list<mixed>>
*/
protected array $params;
/**
* The validation metric used to score the estimator.
*
* @var Metric
*/
protected Metric $metric;
/**
* The validator used to test the estimator.
*
* @var Validator
*/
protected Validator $validator;
/**
* The base estimator instance.
*
* @var Learner
*/
protected Learner $base;
/**
* The validation scores obtained from the last search.
*
* @var list<float>|null
*/
protected ?array $scores = null;
/**
* Return an array of all possible combinations of parameters. i.e their Cartesian product.
*
* @param list<list<mixed>> $params
* @return list<list<mixed>>
*/
protected static function combine(array $params) : array
{
$combinations = [[]];
/** @var int<0,max> $i */
foreach ($params as $i => $params) {
$append = [];
foreach ($combinations as $product) {
foreach ($params as $param) {
$product[$i] = $param;
$append[] = $product;
}
}
$combinations = $append;
}
return $combinations;
}
/**
* @param class-string $class
* @param array<mixed[]> $params
* @param Metric|null $metric
* @param Validator|null $validator
* @throws InvalidArgumentException
*/
public function __construct(
string $class,
array $params,
?Metric $metric = null,
?Validator $validator = null
) {
if (!class_exists($class)) {
throw new InvalidArgumentException("Class $class does not exist.");
}
$proxy = new $class(...array_map('current', $params));
if (!$proxy instanceof Learner) {
throw new InvalidArgumentException('Base class must'
. ' implement the Learner Interface.');
}
$params = array_values($params);
foreach ($params as &$tuple) {
$tuple = empty($tuple) ? [null] : array_unique($tuple, SORT_REGULAR);
}
if ($metric) {
EstimatorIsCompatibleWithMetric::with($proxy, $metric)->check();
} else {
switch ($proxy->type()) {
case EstimatorType::classifier():
$metric = new FBeta();
break;
case EstimatorType::regressor():
$metric = new RMSE();
break;
case EstimatorType::clusterer():
$metric = new VMeasure();
break;
case EstimatorType::anomalyDetector():
$metric = new FBeta();
break;
default:
$metric = new Accuracy();
}
}
$this->class = $class;
$this->params = $params;
$this->metric = $metric;
$this->validator = $validator ?? new KFold(3);
$this->base = $proxy;
$this->backend = new Serial();
}
/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}
/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->trained()
? $this->base->compatibility()
: DataType::all();
}
/**
* Return the settings of the hyper-parameters in an associative array.
*
* @internal
*
* @return mixed[]
*/
public function params() : array
{
return [
'class' => $this->class,
'params' => $this->params,
'metric' => $this->metric,
'validator' => $this->validator,
];
}
/**
* Has the learner been trained?
*
* @return bool
*/
public function trained() : bool
{
return $this->base->trained();
}
/**
* Return the base learner instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}
/**
* Train one estimator per combination of parameters given by the grid and
* assign the best one as the base estimator of this instance.
*
* @param Datasets\Labeled $dataset
*/
public function train(Dataset $dataset) : void
{
SpecificationChain::with([
new DatasetIsLabeled($dataset),
new DatasetIsNotEmpty($dataset),
new SamplesAreCompatibleWithEstimator($dataset, $this),
new LabelsAreCompatibleWithLearner($dataset, $this),
])->check();
if ($this->logger) {
$this->logger->info("Training $this");
}
$combinations = self::combine($this->params);
$this->backend->flush();
foreach ($combinations as $params) {
/** @var Learner $estimator */
$estimator = new $this->class(...$params);
$task = new CrossValidate(
$estimator,
$dataset,
$this->validator,
$this->metric
);
$this->backend->enqueue(
$task,
[$this, 'afterScore'],
$estimator->params()
);
}
$scores = $this->backend->process();
array_multisort($scores, SORT_DESC, $combinations);
$best = reset($combinations) ?: [];
$estimator = new $this->base(...array_values($best));
if ($this->logger) {
$this->logger->info('Training with best hyper-parameters');
}
$estimator->train($dataset);
$this->base = $estimator;
if ($this->logger) {
$this->logger->info('Training complete');
}
}
/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @throws Exceptions\RuntimeException
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}
/**
* The callback that executes after the cross validation task.
*
* @internal
*
* @param float $score
* @param mixed[] $params
*/
public function afterScore(float $score, array $params) : void
{
if ($this->logger) {
$this->logger->info("{$this->metric}: $score, "
. 'params: [' . Params::stringify($params) . ']');
}
}
/**
* Allow methods to be called on the estimator from the wrapper.
*
* @param string $name
* @param mixed[] $arguments
* @return mixed
*/
public function __call(string $name, array $arguments)
{
return $this->base->$name(...$arguments);
}
/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
{
return 'Grid Search (' . Params::stringify($this->params()) . ')';
}
}