Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 29 additions & 74 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ var NeuralNetwork = function () {
Object.assign(this, this.constructor.defaults, options);
this.hiddenSizes = options.hiddenLayers;
this.trainOpts = {};
this.updateTrainingOptions(Object.assign({}, this.constructor.trainDefaults, options));
this._updateTrainingOptions(Object.assign({}, this.constructor.trainDefaults, options));

this.sizes = null;
this.outputLayer = null;
Expand Down Expand Up @@ -1183,38 +1183,15 @@ var NeuralNetwork = function () {
*/

}, {
key: 'updateTrainingOptions',
value: function updateTrainingOptions(opts) {
if (opts.iterations) {
this.trainOpts.iterations = opts.iterations;
}
if (opts.errorThresh) {
this.trainOpts.errorThresh = opts.errorThresh;
}
if (opts.log) {
this._setLogMethod(opts.log);
}
if (opts.logPeriod) {
this.trainOpts.logPeriod = opts.logPeriod;
}
if (opts.learningRate) {
this.trainOpts.learningRate = opts.learningRate;
}
if (opts.momentum) {
this.trainOpts.momentum = opts.momentum;
}
if (opts.callback) {
this.trainOpts.callback = opts.callback;
}
if (opts.callbackPeriod) {
this.trainOpts.callbackPeriod = opts.callbackPeriod;
}
if (opts.timeout) {
this.trainOpts.timeout = opts.timeout;
}
if (opts.activation) {
this.activation = opts.activation;
}
key: '_updateTrainingOptions',
value: function _updateTrainingOptions(opts) {
var _this2 = this;

Object.keys(NeuralNetwork.trainDefaults).forEach(function (opt) {
return _this2.trainOpts[opt] = opts[opt] || _this2.trainOpts[opt];
});
this._setLogMethod(opts.log || this.trainOpts.log);
this.activation = opts.activation || this.activation;
}

/**
Expand All @@ -1226,35 +1203,13 @@ var NeuralNetwork = function () {
}, {
key: '_getTrainOptsJSON',
value: function _getTrainOptsJSON() {
var results = {};
if (this.trainOpts.iterations) {
results.iterations = this.trainOpts.iterations;
}
if (this.trainOpts.errorThresh) {
results.errorThresh = this.trainOpts.errorThresh;
}
if (this.trainOpts.logPeriod) {
results.logPeriod = this.trainOpts.logPeriod;
}
if (this.trainOpts.learningRate) {
results.learningRate = this.trainOpts.learningRate;
}
if (this.trainOpts.momentum) {
results.momentum = this.trainOpts.momentum;
}
if (this.trainOpts.callback) {
results.callback = this.trainOpts.callback;
}
if (this.trainOpts.callbackPeriod) {
results.callbackPeriod = this.trainOpts.callbackPeriod;
}
if (this.trainOpts.timeout) {
results.timeout = this.trainOpts.timeout;
}
if (this.trainOpts.log) {
results.log = true;
}
return results;
var _this3 = this;

return Object.keys(NeuralNetwork.trainDefaults).reduce(function (opts, opt) {
if (_this3.trainOpts[opt]) opts[opt] = _this3.trainOpts[opt];
if (opt === 'log') opts.log = typeof opts.log === 'function';
return opts;
}, {});
}

/**
Expand Down Expand Up @@ -1387,7 +1342,7 @@ var NeuralNetwork = function () {
}, {
key: 'trainAsync',
value: function trainAsync(data) {
var _this2 = this;
var _this4 = this;

var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};

Expand All @@ -1403,10 +1358,10 @@ var NeuralNetwork = function () {

return new Promise(function (resolve, reject) {
try {
var thawedTrain = new _thaw2.default(new Array(_this2.trainOpts.iterations), {
var thawedTrain = new _thaw2.default(new Array(_this4.trainOpts.iterations), {
delay: true,
each: function each() {
return _this2._trainingTick(data, status, endTime) || thawedTrain.stop();
return _this4._trainingTick(data, status, endTime) || thawedTrain.stop();
},
done: function done() {
return resolve(status);
Expand Down Expand Up @@ -1584,7 +1539,7 @@ var NeuralNetwork = function () {
}, {
key: '_formatData',
value: function _formatData(data) {
var _this3 = this;
var _this5 = this;

if (!Array.isArray(data)) {
// turn stream datum into array
Expand All @@ -1601,7 +1556,7 @@ var NeuralNetwork = function () {
}));
}
data = data.map(function (datum) {
var array = _lookup2.default.toArray(_this3.inputLookup, datum.input);
var array = _lookup2.default.toArray(_this5.inputLookup, datum.input);
return Object.assign({}, datum, { input: array });
}, this);
}
Expand All @@ -1613,7 +1568,7 @@ var NeuralNetwork = function () {
}));
}
data = data.map(function (datum) {
var array = _lookup2.default.toArray(_this3.outputLookup, datum.output);
var array = _lookup2.default.toArray(_this5.outputLookup, datum.output);
return Object.assign({}, datum, { output: array });
}, this);
}
Expand All @@ -1634,7 +1589,7 @@ var NeuralNetwork = function () {
}, {
key: 'test',
value: function test(data) {
var _this4 = this;
var _this6 = this;

data = this._formatData(data);

Expand All @@ -1653,13 +1608,13 @@ var NeuralNetwork = function () {
var sum = 0;

var _loop = function _loop(i) {
var output = _this4.runInput(data[i].input);
var output = _this6.runInput(data[i].input);
var target = data[i].output;

var actual = void 0,
expected = void 0;
if (isBinary) {
actual = output[0] > _this4.binaryThresh ? 1 : 0;
actual = output[0] > _this6.binaryThresh ? 1 : 0;
expected = target[0];
} else {
actual = output.indexOf((0, _max2.default)(output));
Expand Down Expand Up @@ -1828,7 +1783,7 @@ var NeuralNetwork = function () {
}
}
}
this.updateTrainingOptions(json.trainOpts);
this._updateTrainingOptions(json.trainOpts);
this.setActivation();
return this;
}
Expand Down Expand Up @@ -1906,15 +1861,15 @@ var NeuralNetwork = function () {
}, {
key: 'isRunnable',
get: function get() {
var _this5 = this;
var _this7 = this;

if (!this.runInput) {
console.error('Activation function has not been initialized, did you run train()?');
return false;
}

var checkFns = ['sizes', 'outputLayer', 'biases', 'weights', 'outputs', 'deltas', 'changes', 'errors'].filter(function (c) {
return _this5[c] === null;
return _this7[c] === null;
});

if (checkFns.length > 0) {
Expand Down
Loading