Skip to content

Commit

Permalink
fix: Fix CrossValidate to have tests for when data too small
Browse files Browse the repository at this point in the history
Also upgrade to use newer es6 syntax for defining optional properties.
Upgrade examples to be more straightforward when using CrossValidate.
Fix values from being set to NaN when training with smaller data in CrossValidate.
  • Loading branch information
robertleeplummerjr committed Sep 28, 2018
1 parent f0a1a56 commit ca437f3
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 53 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ With multiple networks you can train in parallel like this:
### Cross Validation
[Cross Validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
```js
const crossValidate = new CrossValidate(brain.NeuralNetwork, networkOptions);
const stats = crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, networkOptions);
crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
const json = crossValidate.toJSON(); // all stats in json as well as neural networks
const net = crossValidate.toNeuralNetwork();


Expand Down
2 changes: 1 addition & 1 deletion bower.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"node_modules",
"test"
],
"version": "1.4.1"
"version": "1.4.2"
}
15 changes: 10 additions & 5 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <fayearthur@gmail.com>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.4.1
* version: 1.4.2
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -214,8 +214,13 @@ var CrossValidate = function () {

}, {
key: "train",
value: function train(data, trainOpts, k) {
k = k || 4;
value: function train(data) {
var trainOpts = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
var k = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 4;

if (data.length <= k) {
throw new Error("Training set size is too small for " + data.length + " k folds of " + k);
}
var size = data.length / k;

if (data.constructor === Array) {
Expand Down Expand Up @@ -1946,8 +1951,8 @@ var NeuralNetwork = function () {
falseNeg: falseNeg,
falsePos: falsePos,
total: data.length,
precision: truePos / (truePos + falsePos),
recall: truePos / (truePos + falseNeg),
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
accuracy: (trueNeg + truePos) / data.length
});
}
Expand Down
13 changes: 7 additions & 6 deletions browser.min.js

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions dist/cross-validate.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/cross-validate.js.map

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dist/neural-network.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/neural-network.js.map

Large diffs are not rendered by default.

16 changes: 2 additions & 14 deletions examples-typescript/cross-validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,11 @@ const trainingData = [
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
{ input: [1, 0], output: [1] }
];

const netOptions = {
Expand Down
16 changes: 2 additions & 14 deletions examples/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,11 @@ const trainingData = [
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
{ input: [1, 0], output: [1] }
];

const netOptions = {
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "brain.js",
"description": "Neural network library",
"version": "1.4.1",
"version": "1.4.2",
"author": "Heather Arthur <fayearthur@gmail.com>",
"repository": {
"type": "git",
Expand Down
6 changes: 4 additions & 2 deletions src/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ export default class CrossValidate {
* }
* }
*/
train(data, trainOpts, k) {
k = k || 4;
train(data, trainOpts = {}, k = 4) {
if (data.length <= k) {
throw new Error(`Training set size is too small for ${ data.length } k folds of ${ k }`);
}
let size = data.length / k;

if (data.constructor === Array) {
Expand Down
4 changes: 2 additions & 2 deletions src/neural-network.js
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ export default class NeuralNetwork {
falseNeg: falseNeg,
falsePos: falsePos,
total: data.length,
precision: truePos / (truePos + falsePos),
recall: truePos / (truePos + falseNeg),
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
accuracy: (trueNeg + truePos) / data.length
});
}
Expand Down
42 changes: 42 additions & 0 deletions test/base/cross-validation.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import assert from 'assert';
import brain from '../../src';
import CrossValidate from '../../src/cross-validate';

describe('CrossValidation', () => {
describe('simple xor example', () => {
it('throws exception when training set is too small', () => {
const xorTrainingData = [
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] }
];
const net = new CrossValidate(brain.NeuralNetwork);
assert.throws(() => {
net.train(xorTrainingData);
});
});
it('handles training and outputs values that are all numbers', () => {
const xorTrainingData = [
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] }
];
const net = new CrossValidate(brain.NeuralNetwork);
net.train(xorTrainingData);
const json = net.toJSON();
for (let p in json.avgs) {
assert(json.avgs[p] >= 0);
}
for (let p in json.stats) {
assert(json.stats[p] >= 0);
}
});
});
});

0 comments on commit ca437f3

Please sign in to comment.