Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.81, black: 0.18 }

```javascript
net.train(data, {
iterations: 20000, // the maximum times to iterate the training data
errorThresh: 0.005, // the acceptable error percentage from training data
log: false, // true to use console.log, when a function is supplied it is used
logPeriod: 10, // iterations between logging out
learningRate: 0.3, // scales with delta to effect traiing rate
momentum: 0.1, // scales with next layer's change value
callback: null, // a periodic call back that can be triggered while training
callbackPeriod: 10, // the number of iterations through the training data between callback calls
timeout: Infinity // the max number of milliseconds to train for
// Defaults values --> expected validation
iterations: 20000, // the maximum times to iterate the training data --> number greater than 0
errorThresh: 0.005, // the acceptable error percentage from training data --> number between 0 and 1
log: false, // true to use console.log, when a function is supplied it is used --> Either true or a function
logPeriod: 10, // iterations between logging out --> number greater than 0
learningRate: 0.3, // scales with delta to effect traiing rate --> number between 0 and 1
momentum: 0.1, // scales with next layer's change value --> number between 0 and 1
callback: null, // a periodic call back that can be triggered while training --> null or function
callbackPeriod: 10, // the number of iterations through the training data between callback calls --> number greater than 0
timeout: Infinity // the max number of milliseconds to train for --> number greater than 0
});
```

Expand All @@ -151,6 +152,8 @@ The momentum is similar to learning rate, expecting a value from `0` to `1` as w

Any of these training options can be passed into the constructor or passed into the `updateTrainingOptions(opts)` method and they will be saved on the network and used any time you trian. If you save your network to json, these training options are saved and restored as well (except for callback and log, callback will be forgoten and log will be restored using console.log).

There is a boolean property called `invalidTrainOptsShouldThrow` that by default is set to true. While true if you enter a training option that is outside the normal range an error will be thrown with a message about the option you sent. When set to false no error is sent but a message is still sent to `console.warn` with the information.

### Async Training
`trainAsync()` takes the same arguments as train (data and options). Instead of returning the results object from training it returns a promise that when resolved will return the training results object.

Expand Down
22,589 changes: 11,367 additions & 11,222 deletions browser.js

Large diffs are not rendered by default.

122 changes: 62 additions & 60 deletions browser.min.js

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion dist/neural-network.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/neural-network.js.map

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion src/neural-network.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ export default class NeuralNetwork {
};
}

/**
*
* @param options
* @param boolean
* @private
*/
static _validateTrainingOptions(options) {
var validations = {
iterations: (val) => { return typeof val === 'number' && val > 0; },
errorThresh: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
log: (val) => { return typeof val === 'function' || typeof val === 'boolean'; },
logPeriod: (val) => { return typeof val === 'number' && val > 0; },
learningRate: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
momentum: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
callback: (val) => { return typeof val === 'function' || val === null },
callbackPeriod: (val) => { return typeof val === 'number' && val > 0; },
timeout: (val) => { return typeof val === 'number' && val > 0 }
};
Object.keys(NeuralNetwork.trainDefaults).forEach(key => {
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
throw new Error(`[${key}, ${options[key]}] is out of normal training range, your network will probably not train.`);
}
});
}

constructor(options = {}) {
Object.assign(this, this.constructor.defaults, options);
this.hiddenSizes = options.hiddenLayers;
Expand Down Expand Up @@ -293,7 +318,8 @@ export default class NeuralNetwork {
* activation: ['sigmoid', 'relu', 'leaky-relu', 'tanh']
*/
_updateTrainingOptions(opts) {
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = opts[opt] || this.trainOpts[opt]);
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = (opts.hasOwnProperty(opt)) ? opts[opt] : this.trainOpts[opt]);
NeuralNetwork._validateTrainingOptions(this.trainOpts);
this._setLogMethod(opts.log || this.trainOpts.log);
this.activation = opts.activation || this.activation;
}
Expand Down
92 changes: 91 additions & 1 deletion test/base/trainopts.js
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,94 @@ describe('train() and trainAsync() use the same private methods', () => {
done()
});
});
});
});

describe('training options validation', () => {
it('iterations validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ iterations: 'should be a string' }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: false }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: -1 }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ iterations: 5000 }) });
});

it('errorThresh validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 'no strings'}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 5}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: -1}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: false}) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ errorThresh: 0.008}) });
});

it('log validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ log: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ log: 4 }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: () => {} }) });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a "valid" log function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it would be user defined. (logging allows the user to either set true and it will console.log or set their own method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the log value can be user defined, that's the point of callback. Buuuuuut, I was a bit too hasty with this and the other "Valid callback" comment. This is fine with me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. The main idea is so that if the user has their own logging system (i.e. it writes to file or whatever) then they can keep the log to do just that, without having to use up their callback for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meh, I don't think it would "use up" their callback. But this is a moot point, we can continue discussion outside of this PR.

});

it('logPeriod validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ logPeriod: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ logPeriod: 40 }) });
});

it('learningRate validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ learningRate: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: 50 }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ learningRate: 0.5 }) });
});

it('momentum validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ momentum: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: 50 }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ momentum: 0.8 }) });
});

it('callback validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ callback: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ callback: 4 }) });
assert.throws(() => { net._updateTrainingOptions({ callback: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: null }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: () => {} }) });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a "valid" callback?

Copy link
Contributor Author

@freddyC freddyC Feb 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep (it just does nothing with the passed in object, I am not enforcing the params)

});

it('callbackPeriod validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callbackPeriod: 40 }) });
});

it('timeout validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ timeout: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ timeout: 40 }) });
});

it('should handle unsupported options', () => {
let net = new brain.NeuralNetwork();
assert.doesNotThrow(() => { net._updateTrainingOptions({ fakeProperty: 'should be handled fine' }) });
})
});