Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
99 lines (84 sloc)
2.24 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?php | |
namespace Rubix\Engine; | |
use Rubix\Engine\Preprocessors\Preprocessor; | |
use RuntimeException; | |
class Pipeline implements Estimator | |
{ | |
/** | |
* The estimator. | |
* | |
* @var \Rubix\Engine\Estimator | |
*/ | |
protected $estimator; | |
/** | |
* The transformers that process the sample data before they are fed to the | |
* estimator for training, testing, and prediction. | |
* | |
* @var array | |
*/ | |
protected $preprocessors = [ | |
// | |
]; | |
/** | |
* @param \Rubix\Engine\Estimator $estimator | |
* @param array $preprocessors | |
* @return void | |
*/ | |
public function __construct(Estimator $estimator, array $preprocessors = []) | |
{ | |
foreach ($preprocessors as $preprocessor) { | |
$this->addPreprocessor($preprocessor); | |
} | |
$this->estimator = $estimator; | |
} | |
/** | |
* Return the instance of the estimator. | |
* | |
* @return \Rubix\Engine\Estimator | |
*/ | |
public function estimator() : Estimator | |
{ | |
return $this->estimator; | |
} | |
/** | |
* Run the training dataset through all preprocessors in order and use the | |
* transformed dataset to train the estimator. | |
* | |
* @param \Rubix\Engine\Dataset $data | |
* @throws \RuntimeException | |
* @return void | |
*/ | |
public function train(Dataset $data) : void | |
{ | |
foreach ($this->preprocessors as $preprocessor) { | |
$preprocessor->fit($data); | |
$data->transform($preprocessor); | |
} | |
$this->estimator->train($data); | |
} | |
/** | |
* Preprocess the sample and make a prediction. | |
* | |
* @param array $sample | |
* @return \Rubix\Engine\Prediction | |
*/ | |
public function predict(array $sample) : Prediction | |
{ | |
$samples = [$sample]; | |
foreach ($this->preprocessors as $preprocessor) { | |
$preprocessor->transform($samples); | |
} | |
return $this->estimator->predict($samples[0]); | |
} | |
/** | |
* Add a preprocessor middleware to the pipeline. | |
* | |
* @param \Rubix\Engine\Preprocessors\Preprocessor $preprocessor | |
* @return self | |
*/ | |
public function addPreprocessor(Preprocessor $preprocessor) : self | |
{ | |
$this->preprocessors[] = $preprocessor; | |
return $this; | |
} | |
} |