Skip to content

Commit

Permalink
chore: Upgrade net to handle types
Browse files Browse the repository at this point in the history
  • Loading branch information
robertleeplummerjr committed Sep 21, 2021
1 parent f31dced commit 76c3ce4
Show file tree
Hide file tree
Showing 18 changed files with 270 additions and 283 deletions.
21 changes: 0 additions & 21 deletions .babelrc

This file was deleted.

3 changes: 2 additions & 1 deletion .eslintrc.json → .eslintrc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"no-underscore-dangle": "off",
"prettier/prettier": "error",
"semi": "off",
"standard/no-callback-literal": "off"
"standard/no-callback-literal": "off",
"no-implied-eval": "off"
}
}
6 changes: 3 additions & 3 deletions src/cross-validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { LSTMTimeStep } from './recurrent/lstm-time-step';

describe('CrossValidate', () => {
describe('.train()', () => {
class FakeNN extends NeuralNetwork {
class FakeNN extends NeuralNetwork<number[], number[]> {
constructor(
options: Partial<
INeuralNetworkOptions & INeuralNetworkTrainOptions
Expand Down Expand Up @@ -195,7 +195,7 @@ describe('CrossValidate', () => {
});
});
describe('.fromJSON()', () => {
class FakeNN extends NeuralNetwork {}
class FakeNN extends NeuralNetwork<number[], number[]> {}
it("creates a new instance of constructor from argument's sets.error", () => {
const cv = new CrossValidate(FakeNN);
const options = { inputSize: 1, hiddenLayers: [10], outputSize: 1 };
Expand Down Expand Up @@ -241,7 +241,7 @@ describe('CrossValidate', () => {
});
});
describe('.toNeuralNetwork()', () => {
class FakeNN extends NeuralNetwork {}
class FakeNN extends NeuralNetwork<number[], number[]> {}
it('creates a new instance of constructor from top .json sets.error', () => {
const cv = new CrossValidate(FakeNN);
const details = {
Expand Down
4 changes: 4 additions & 0 deletions src/feed-forward.unit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ describe('FeedForward Class: Unit', () => {
{
filterHeight: 3,
filterWidth: 3,
filterCount: 1,
padding: 2,
stride: 2,
},
Expand All @@ -100,6 +101,7 @@ describe('FeedForward Class: Unit', () => {
padding: 2,
filterWidth: 3,
filterHeight: 3,
filterCount: 1,
stride: 3,
},
inputLayer
Expand Down Expand Up @@ -171,6 +173,7 @@ describe('FeedForward Class: Unit', () => {
{
filterWidth: 3, // TODO: setting height, width should behave same
filterHeight: 3,
filterCount: 3,
padding: 2,
stride: 3,
},
Expand All @@ -187,6 +190,7 @@ describe('FeedForward Class: Unit', () => {
{
filterWidth: 3,
filterHeight: 3,
filterCount: 16,
padding: 2,
stride: 2,
},
Expand Down
5 changes: 4 additions & 1 deletion src/layer/target.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ describe('Target Layer', () => {
});

test('uses compare2D when width > 1', () => {
const target = new Target({}, mockLayer({ height: 10, width: 10 }));
const target = new Target(
{ height: 10, width: 10 },
mockLayer({ height: 10, width: 10 })
);
target.setupKernels();
expect(makeKernel).toHaveBeenCalledWith(compare2D, {
output: [10, 10],
Expand Down
4 changes: 0 additions & 4 deletions src/likely.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import { NeuralNetwork } from './neural-network';

/**
* Return 0 or 1 for '#'
* @param character
* @returns {number}
*/
function integer(character: string): number {
if (character === '#') return 1;
Expand All @@ -13,8 +11,6 @@ function integer(character: string): number {

/**
* Turn the # into 1s and . into 0s. for whole string
* @param string
* @returns {Array}
*/
function character(string: string): number[] {
return string.trim().split('').map(integer);
Expand Down
24 changes: 10 additions & 14 deletions src/likely.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import { INumberHash } from './lookup';
import { NeuralNetwork } from './neural-network';
import { INeuralNetworkData, NeuralNetwork } from './neural-network';

/**
*
* @param {*} input
* @param {brain.NeuralNetwork} net
* @returns {*}
*/
export function likely<T extends number[] | Float32Array | INumberHash>(
input: T,
net: NeuralNetwork
): T | null {
export function likely<
InputType extends INeuralNetworkData,
OutputType extends INeuralNetworkData
>(
input: InputType,
net: NeuralNetwork<InputType, OutputType>
): OutputType | null {
if (!net) {
throw new TypeError(
`Required parameter 'net' is of type ${typeof net}. Must be of type 'brain.NeuralNetwork'`
);
}

const output = net.run<T>(input);
const output = net.run(input);
let maxProp = null;
let maxValue = -1;

Object.entries(output).forEach(([key, value]) => {
if (value > maxValue) {
if (typeof value !== 'undefined' && value > maxValue) {
maxProp = key;
maxValue = value;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lookup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export interface INumberArray {
[index: number]: number;
}

export type InputOutputValue = INumberArray | INumberHash;
export type InputOutputValue = INumberArray | Partial<INumberHash>;

export interface ITrainingDatum {
input: InputOutputValue | InputOutputValue[] | KernelOutput;
Expand Down
11 changes: 6 additions & 5 deletions src/neural-network-gpu.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Texture } from 'gpu.js';

import { NeuralNetwork } from './neural-network';
import { NeuralNetworkGPU } from './neural-network-gpu';
import { Texture } from 'gpu.js';

describe('NeuralNetworkGPU', () => {
const xorTrainingData = [
Expand Down Expand Up @@ -33,11 +34,11 @@ describe('NeuralNetworkGPU', () => {
});

it('can serialize from NeuralNetworkGPU & deserialize to NeuralNetwork', () => {
const net = new NeuralNetworkGPU();
const net = new NeuralNetworkGPU<number[], number[]>();
net.train(xorTrainingData, { iterations: 1 });
const target = xorTrainingData.map((datum) => net.run(datum.input));
const json = net.toJSON();
const net2 = new NeuralNetwork();
const net2 = new NeuralNetwork<number[], number[]>();
net2.fromJSON(json);
for (let i = 0; i < xorTrainingData.length; i++) {
// there is a wee bit of loss going from GPU to CPU
Expand All @@ -49,11 +50,11 @@ describe('NeuralNetworkGPU', () => {
});

it('can serialize from NeuralNetwork & deserialize to NeuralNetworkGPU', () => {
const net = new NeuralNetwork();
const net = new NeuralNetwork<number[], number[]>();
net.train(xorTrainingData, { iterations: 1 });
const target = xorTrainingData.map((datum) => net.run(datum.input));
const json = net.toJSON();
const net2 = new NeuralNetworkGPU();
const net2 = new NeuralNetworkGPU<number[], number[]>();
net2.fromJSON(json);
for (let i = 0; i < xorTrainingData.length; i++) {
// there is a wee bit of loss going from CPU to GPU
Expand Down
Loading

0 comments on commit 76c3ce4

Please sign in to comment.