Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <fayearthur@gmail.com>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.4.3
* version: 1.4.4
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -4236,13 +4236,11 @@ var RNN = function () {
}, {
key: 'mapModel',
value: function mapModel() {
var _this = this;

var model = this.model;
var hiddenLayers = model.hiddenLayers;
var allMatrices = model.allMatrices;
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
return new _matrix2.default(_this.hiddenLayers[0], 1);
return new _matrix2.default(size, 1);
});

this.createInputMatrix();
Expand Down Expand Up @@ -4572,8 +4570,6 @@ var RNN = function () {
}, {
key: 'fromJSON',
value: function fromJSON(json) {
var _this2 = this;

var defaults = this.constructor.defaults;
var options = json.options;
this.model = null;
Expand Down Expand Up @@ -4619,7 +4615,7 @@ var RNN = function () {
equationConnections: []
};
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
return new _matrix2.default(_this2.hiddenLayers[0], 1);
return new _matrix2.default(size, 1);
});
this.bindEquation();
}
Expand Down
14 changes: 7 additions & 7 deletions browser.min.js

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions dist/recurrent/rnn.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/recurrent/rnn.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "brain.js",
"description": "Neural network library",
"version": "1.4.3",
"version": "1.4.4",
"author": "Heather Arthur <fayearthur@gmail.com>",
"repository": {
"type": "git",
Expand Down
4 changes: 2 additions & 2 deletions src/recurrent/rnn.js
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export default class RNN {
let model = this.model;
let hiddenLayers = model.hiddenLayers;
let allMatrices = model.allMatrices;
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));

this.createInputMatrix();
if (!model.input) throw new Error('net.model.input not set');
Expand Down Expand Up @@ -506,7 +506,7 @@ export default class RNN {
equations: [],
equationConnections: [],
};
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));
this.bindEquation();
}

Expand Down
37 changes: 34 additions & 3 deletions test/recurrent/rnn.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ describe('rnn', () => {
net.initialize();
assert.notEqual(net.model, null);
});
it('can setup different size hiddenLayers', () => {
const inputSize = 2;
const hiddenLayers = [5,4,3];
const networkOptions = {
learningRate: 0.001,
decayRate: 0.75,
inputSize: inputSize,
hiddenLayers,
outputSize: inputSize
};

const net = new RNN(networkOptions);
net.initialize();
net.bindEquation();
assert.equal(net.model.hiddenLayers.length, 3);
assert.equal(net.model.hiddenLayers[0].weight.columns, inputSize);
assert.equal(net.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
assert.equal(net.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
assert.equal(net.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
assert.equal(net.model.hiddenLayers[2].weight.columns, hiddenLayers[1]);
assert.equal(net.model.hiddenLayers[2].weight.rows, hiddenLayers[2]);
});
});
describe('basic operations', () => {
it('starts with zeros in input.deltas', () => {
Expand Down Expand Up @@ -354,9 +376,12 @@ describe('rnn', () => {

describe('.fromJSON', () => {
it('can import model from json', () => {
let dataFormatter = new DataFormatter('abcdef'.split(''));
let jsonString = JSON.stringify(new RNN({
inputSize: 6, //<- length
const inputSize = 6;
const hiddenLayers = [10, 20];
const dataFormatter = new DataFormatter('abcdef'.split(''));
const jsonString = JSON.stringify(new RNN({
inputSize, //<- length
hiddenLayers,
inputRange: dataFormatter.characters.length,
outputSize: dataFormatter.characters.length //<- length
}).toJSON(), null, 2);
Expand All @@ -368,6 +393,12 @@ describe('rnn', () => {
assert.equal(clone.inputSize, 6);
assert.equal(clone.inputRange, dataFormatter.characters.length);
assert.equal(clone.outputSize, dataFormatter.characters.length);

assert.equal(clone.model.hiddenLayers.length, 2);
assert.equal(clone.model.hiddenLayers[0].weight.columns, inputSize);
assert.equal(clone.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
assert.equal(clone.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
assert.equal(clone.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
});

it('can import model from json using .fromJSON()', () => {
Expand Down