Skip to content

Commit

Permalink
Added back ability to use etiher 'expected' or 'output' keys in datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
DanRuta committed Apr 17, 2018
1 parent f4208ee commit aa5ac5d
Show file tree
Hide file tree
Showing 14 changed files with 72 additions and 56 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Upcoming
---
#### Global
- Added back ability to use etiher 'expected' or 'output' keys in data sets

# 3.4.0 - Bug fixes and improvements
---
#### Global
Expand Down
16 changes: 8 additions & 8 deletions dev/js-WebAssembly/Network.js
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,13 @@ class Network {
}

if (this.state != "initialised") {
this.initLayers(data[0].input.length, data[0].expected.length)
this.initLayers(data[0].input.length, (data[0].expected || data[0].output).length)
}

const startTime = Date.now()

const dimension = this.layers[0].size
const itemSize = dimension + data[0].expected.length
const itemSize = dimension + (data[0].expected || data[0].output).length
const itemsCount = itemSize * data.length

if (log) {
Expand Down Expand Up @@ -558,8 +558,8 @@ class Network {
loadData (data, typedArray, itemSize, reject) {
for (let di=0; di<data.length; di++) {

if (!data[di].hasOwnProperty("input") || !data[di].hasOwnProperty("expected")) {
return void reject("Data set must be a list of objects with keys: 'input' and 'expected'")
if (!data[di].hasOwnProperty("input") || (!data[di].hasOwnProperty("expected") && !data[di].hasOwnProperty("output"))) {
return void reject("Data set must be a list of objects with keys: 'input' and 'expected' (or 'output')")
}

let index = itemSize * di
Expand All @@ -582,8 +582,8 @@ class Network {
}
}

for (let ei=0; ei<data[di].expected.length; ei++) {
typedArray[index] = data[di].expected[ei]
for (let ei=0; ei<(data[di].expected || data[di].output).length; ei++) {
typedArray[index] = (data[di].expected || data[di].output)[ei]
index++
}
}
Expand All @@ -602,7 +602,7 @@ class Network {

const startTime = Date.now()
const dimension = data[0].input.length
const itemSize = dimension + data[0].expected.length
const itemSize = dimension + (data[0].expected || data[0].output).length
const itemsCount = itemSize * data.length
const typedArray = new Float32Array(itemsCount)

Expand Down Expand Up @@ -743,7 +743,7 @@ class Network {
}

static get version () {
return "3.3.4"
return "3.4.1"
}
}

Expand Down
14 changes: 7 additions & 7 deletions dev/js/Network.js
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class Network {
}

if (this.state != "initialised") {
this.initLayers.bind(this, dataSet[0].input.length, dataSet[0].expected.length)()
this.initLayers.bind(this, dataSet[0].input.length, (dataSet[0].expected || dataSet[0].output).length)()
}

this.layers.forEach(layer => layer.state = "training")
Expand Down Expand Up @@ -335,16 +335,16 @@ class Network {

const doIteration = async () => {

if (!dataSet[iterationIndex].hasOwnProperty("input") || !dataSet[iterationIndex].hasOwnProperty("expected")) {
return void reject("Data set must be a list of objects with keys: 'input' and 'expected'")
if (!dataSet[iterationIndex].hasOwnProperty("input") || (!dataSet[iterationIndex].hasOwnProperty("expected") && !dataSet[iterationIndex].hasOwnProperty("output"))) {
return void reject("Data set must be a list of objects with keys: 'input' and 'expected' (or 'output')")
}

let trainingError
let validationError

const input = dataSet[iterationIndex].input
const output = this.forward(input)
const target = dataSet[iterationIndex].expected
const target = dataSet[iterationIndex].expected || dataSet[iterationIndex].output

let classification = -Infinity
const errors = []
Expand Down Expand Up @@ -447,7 +447,7 @@ class Network {
const validateItem = (item) => {

const output = this.forward(data[validationIndex].input)
const target = data[validationIndex].expected
const target = data[validationIndex].expected || data[validationIndex].output

let classification = -Infinity
for (let i=0; i<output.length; i++) {
Expand Down Expand Up @@ -539,7 +539,7 @@ class Network {

const input = testSet[iterationIndex].input
const output = this.forward(input)
const target = testSet[iterationIndex].expected
const target = testSet[iterationIndex].expected || testSet[iterationIndex].output
const elapsed = Date.now() - startTime

let classification = -Infinity
Expand Down Expand Up @@ -671,7 +671,7 @@ class Network {
}

static get version () {
return "3.3.4"
return "3.4.1"
}
}

Expand Down
14 changes: 7 additions & 7 deletions dist/jsNetJS.concat.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/jsNetJS.concat.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/jsNetJS.min.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/jsNetJS.min.js.map

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions dist/jsNetWebAssembly.concat.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/jsNetWebAssembly.concat.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/jsNetWebAssembly.min.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/jsNetWebAssembly.min.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "jsnet",
"version": "3.4.0",
"version": "3.4.1",
"description": "Javascript based deep learning framework for basic and convolutional neural networks.",
"scripts": {
"test": "npm run js-tests && npm run wa-tests",
Expand Down
20 changes: 10 additions & 10 deletions test/js-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe("Loading", () => {
})

it("Statically returns the Network version when accessing via .version", () => {
expect(Network.version).to.equal("3.3.4")
expect(Network.version).to.equal("3.4.1")
})
})

Expand Down Expand Up @@ -1093,16 +1093,16 @@ describe("Network", () => {
return expect(net.train()).to.be.rejectedWith("No data provided")
})

it("Rejects the promise if some data does not have the key 'input' and 'expected'", () => {
return expect(net.train(badTestData)).to.be.rejectedWith("Data set must be a list of objects with keys: 'input' and 'expected")
it("Rejects the promise if some data does not have the key 'input' and 'expected'/'output'", () => {
return expect(net.train(badTestData)).to.be.rejectedWith("Data set must be a list of objects with keys: 'input' and 'expected' (or 'output')")
})

it("Resolves the promise when you give it data", () => {
return expect(net.train(testData)).to.be.fulfilled
return expect(net.train(testDataWithMixedExpectedOutput)).to.be.fulfilled
})

it("Does not accept 'output' as an alternative name for expected values", () => {
return expect(net.train(testDataWithOutput)).to.not.be.fulfilled
it("Accepts 'output' as an alternative name for expected values", () => {
return expect(net.train(testDataWithOutput)).to.be.fulfilled
})

it("Does one iteration when not passing any config data", () => {
Expand Down Expand Up @@ -1191,15 +1191,15 @@ describe("Network", () => {
})
})

it("Calls the initLayers function with the length of the first input and length of first expected", () => {
it("Calls the initLayers function with the length of the first input and length of first expected, when using output key in the data", () => {
const network = new Network({updateFn: null})
network.trainingConfusionMatrix = [[0,0],[0,0]]
network.testConfusionMatrix = [[0,0],[0,0]]
network.validationConfusionMatrix = [[0,0],[0,0]]
sinon.stub(network, "forward").callsFake(() => [1,1])
sinon.spy(network, "initLayers")

return network.train(testData).then(() => {
return network.train(testDataWithOutput).then(() => {
expect(network.initLayers).to.have.been.calledWith(2, 2)
network.initLayers.restore()
})
Expand Down Expand Up @@ -1307,7 +1307,7 @@ describe("Network", () => {

it("Runs validation when validation data is given", () => {
sinon.spy(net, "validate")
return net.train(testData, {epochs: 10, validation: {data: testData, interval: 2}}).then(() => {
return net.train(testDataWithOutput, {epochs: 10, validation: {data: testDataWithOutput, interval: 2}}).then(() => {
expect(net.validate).to.be.called
net.validate.restore()
})
Expand Down Expand Up @@ -1632,7 +1632,7 @@ describe("Network", () => {
})

it("Resolves with a number, indicating error", () => {
return net.test(testData).then((result) => {
return net.test(testDataOutput).then((result) => {
expect(typeof result).to.equal("number")
})
})
Expand Down
29 changes: 20 additions & 9 deletions test/wa-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe("Loading", () => {
})

it("Statically returns the Network version when accessing via .version", () => {
expect(Network.version).to.equal("3.3.4")
expect(Network.version).to.equal("3.4.1")
})
})

Expand Down Expand Up @@ -918,16 +918,16 @@ describe("Network", () => {
return expect(net.train()).to.be.rejectedWith("No data provided")
})

it("Rejects the promise if some data does not have the key 'input' and 'expected'", () => {
return expect(net.train(badTestData)).to.be.rejectedWith("Data set must be a list of objects with keys: 'input' and 'expected'")
it("Rejects the promise if some data does not have the key 'input' and 'expected'/'output'", () => {
return expect(net.train(badTestData)).to.be.rejectedWith("Data set must be a list of objects with keys: 'input' and 'expected' (or 'output')")
})

it("Resolves the promise when you give it data", () => {
return expect(net.train(testData)).to.be.fulfilled
})

it("Does not accept 'output' as an alternative name for expected values", () => {
return expect(net.train(testDataWithOutput)).to.not.be.fulfilled
it("Accepts 'output' as an alternative name for expected values", () => {
return expect(net.train(testDataWithOutput)).to.be.fulfilled
})

it("CCalls the Module's set_miniBatchSize function with the given miniBatchSize value", () => {
Expand Down Expand Up @@ -975,6 +975,15 @@ describe("Network", () => {
})
})

it("Calls the initLayers function when the net state is not 'initialised' (When data uses 'output' keys)", () => {
const network = new Network({Module: fakeModule})
sinon.spy(network, "initLayers")

return network.train(testDataWithOutput).then(() => {
expect(network.initLayers).to.have.been.called
})
})

it("CCalls the WASM Module's loadTrainingData function", () => {
sinon.stub(fakeModule, "ccall")
const network = new Network({Module: fakeModule})
Expand Down Expand Up @@ -1003,7 +1012,7 @@ describe("Network", () => {
const network = new Network({Module: fakeModule})
const stub = sinon.stub(fakeModule, "ccall").callsFake(() => 0)

return network.train(testData, {epochs: 2, callback: cb, validation: {data: testData}}).then(() => {
return network.train(testData, {epochs: 2, callback: cb, validation: {data: testDataWithOutput}}).then(() => {
expect(counter).to.equal(8)
stub.restore()
})
Expand Down Expand Up @@ -1321,8 +1330,10 @@ describe("Network", () => {
})
})

it("Does not accept test data with output key instead of expected", () => {
return expect(net.test(testDataOutput)).to.not.be.fulfilled
it("Accepts test data with output key instead of expected", () => {
return net.test(testDataOutput).then(() => {
expect(fakeModule.ccall).to.be.called
})
})

it("Logs to the console twice", () => {
Expand Down Expand Up @@ -1513,7 +1524,7 @@ describe("Network", () => {
const data = {inputs: [[[1,2],[3,4]]], expected: [5]}
const typedArray = new Float32Array(5)
net.loadData([data], typedArray, 5, stub)
expect(stub).to.be.calledWith("Data set must be a list of objects with keys: 'input' and 'expected'")
expect(stub).to.be.calledWith("Data set must be a list of objects with keys: 'input' and 'expected' (or 'output')")
})
})

Expand Down

0 comments on commit aa5ac5d

Please sign in to comment.