-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.php
89 lines (65 loc) · 2.5 KB
/
train.php
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
<?php
include __DIR__ . '/vendor/autoload.php';
use Rubix\ML\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Extractors\CSV;
use Rubix\ML\Extractors\ColumnPicker;
use Rubix\ML\Transformers\LambdaFunction;
use Rubix\ML\PersistentModel;
use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Serializers\RBX;
use Rubix\ML\Transformers\NumericStringConverter;
use Rubix\ML\Transformers\MissingDataImputer;
use Rubix\ML\Transformers\MinMaxNormalizer;
use Rubix\ML\Transformers\OneHotEncoder;
use Rubix\ML\Classifiers\RandomForest;
use Rubix\ML\Classifiers\ClassificationTree;
use Rubix\ML\CrossValidation\Metrics\Accuracy;
ini_set('memory_limit', '-1');
$logger = new Screen();
$serializer = new RBX();
$logger->info('Loading data into memory');
$extractor = new ColumnPicker(new CSV('train.csv', true), [
'Pclass', 'Age', 'Fare', 'SibSp', 'Parch', 'Sex', 'Embarked', 'Survived',
]);
$logger->info('Processing features');
$toPlaceholder = function (&$sample, $offset, $types) {
foreach ($sample as $column => &$value) {
if (empty($value) && $types[$column]->isContinuous()) {
$value = NAN;
}
else if (empty($value) && $types[$column]->isCategorical()) {
$value = '?';
}
}
};
$transformLabel = function ($label) {
return $label == 0 ? 'Dead' : 'Survived';
};
$dataset = Labeled::fromIterator($extractor)
->apply(new NumericStringConverter())
->transformLabels($transformLabel);
$minMaxNormalizer = new MinMaxNormalizer();
$oneHotEncoder = new OneHotEncoder();
$imputer = new MissingDataImputer();
$dataset->apply(new LambdaFunction($toPlaceholder, $dataset->types()))
->apply($imputer)
->apply($minMaxNormalizer)
->apply($oneHotEncoder);
$serializer->serialize($imputer)->saveTo(new Filesystem('imputer.rbx'));
$serializer->serialize($minMaxNormalizer)->saveTo(new Filesystem('minmax.rbx'));
$serializer->serialize($oneHotEncoder)->saveTo(new Filesystem('onehot.rbx'));
$logger->info('Training and validating model');
$estimator = new RandomForest(new ClassificationTree(10), 500, 0.8, false);
$estimator->train($dataset);
$metric = new Accuracy();
$predictions = $estimator->predict($dataset);
$score = $metric->score($predictions, $dataset->labels());
$logger->info("Accuracy is $score");
if (strtolower(readline('Save this model? (y|[n]): ')) === 'y') {
$estimator = new PersistentModel($estimator, new Filesystem('model.rbx'));
$estimator->save();
$logger->info('Model saved as model.rbx');
}
?>