Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion packages/tfjs-node-helpers-example/src/app/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { GenderFeatureExtractor } from './feature-extractors/gender';
import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car';
import { join } from 'node:path';
import { TrainingDataService } from './services/training-data';
import { AgeMinMaxFeatureNormalizer } from './feature-normalizers/age';
import { AnnualSalaryMinMaxFeatureNormalizer } from './feature-normalizers/annual-salary';

export async function startApplication(): Promise<void> {
await train();
Expand All @@ -23,7 +25,11 @@ async function train(): Promise<void> {
new AnnualSalaryFeatureExtractor(),
new GenderFeatureExtractor()
],
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
outputFeatureExtractor: new OwnsTheCarFeatureExtractor(),
inputFeatureNormalizers: [
new AgeMinMaxFeatureNormalizer(),
new AnnualSalaryMinMaxFeatureNormalizer()
]
});

const trainingDataService = new TrainingDataService({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ export class AgeFeatureExtractor extends FeatureExtractor<DatasetItem, FeatureTy
public featureType = FeatureType.AGE;

public extract(item: DatasetItem): Feature<FeatureType> {
const minAge = 18;
const maxAge = 63;

return new Feature({
type: this.featureType,
label: `${item.age} years`,
value: (item.age - minAge) / (maxAge - minAge)
value: item.age
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@ export class AnnualSalaryFeatureExtractor extends FeatureExtractor<DatasetItem,
public featureType = FeatureType.ANNUAL_SALARY;

public extract(item: DatasetItem): Feature<FeatureType> {
const minAnnualSalary = 15000;
const maxAnnualSalary = 152500;

return new Feature({
type: this.featureType,
label: item.annual_salary.toString(),
value: (item.annual_salary - minAnnualSalary) / (maxAnnualSalary - minAnnualSalary)
value: item.annual_salary
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers';
import { FeatureType } from '../enums/feature-type';

export class AgeMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer<FeatureType> {
public featureType = FeatureType.AGE;

constructor() {
super({ min: 18, max: 63 });
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers';
import { FeatureType } from '../enums/feature-type';

export class AnnualSalaryMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer<FeatureType> {
public featureType = FeatureType.ANNUAL_SALARY;

constructor() {
super({ min: 15000, max: 152500 });
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import { extractFeatures, Sample, splitSamplesIntoTrainingValidationTestForBinaryClassification } from '@ronas-it/tfjs-node-helpers';
import {
extractFeatures,
normalizeFeatures,
Sample,
splitSamplesIntoTrainingValidationTestForBinaryClassification
} from '@ronas-it/tfjs-node-helpers';
import { AgeFeatureExtractor } from '../feature-extractors/age';
import { AnnualSalaryFeatureExtractor } from '../feature-extractors/annual-salary';
import { GenderFeatureExtractor } from '../feature-extractors/gender';
import { OwnsTheCarFeatureExtractor } from '../feature-extractors/owns-the-car';
import dataset from '../../assets/data.json';
import { AgeMinMaxFeatureNormalizer } from '../feature-normalizers/age';
import { AnnualSalaryMinMaxFeatureNormalizer } from '../feature-normalizers/annual-salary';

export class TrainingDataService {
private simulatedDelayMs: number;
Expand All @@ -16,7 +23,7 @@ export class TrainingDataService {
}

public async initialize(): Promise<void> {
const samples = await extractFeatures({
const extracts = await extractFeatures({
data: dataset,
inputFeatureExtractors: [
new AgeFeatureExtractor(),
Expand All @@ -26,6 +33,14 @@ export class TrainingDataService {
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
});

const samples = await normalizeFeatures({
extracts,
inputFeatureNormalizers: [
new AgeMinMaxFeatureNormalizer(),
new AnnualSalaryMinMaxFeatureNormalizer()
]
});

const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(samples);

this.trainingSamples = trainingSamples;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ import { prepareDatasetsForBinaryClassification } from '../feature-engineering/p
import { ConfusionMatrix } from '../testing/confusion-matrix';
import { Metrics } from '../testing/metrics';
import { binarize } from '../utils/binarize';
import { FeatureNormalizer } from '../feature-engineering/feature-normalizer';

export type BinaryClassificationTrainerOptions = {
batchSize?: number;
epochs?: number;
patience?: number;
inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
outputFeatureExtractor?: FeatureExtractor<any, any>;
inputFeatureNormalizers?: Array<FeatureNormalizer<any>>;
model?: LayersModel;
hiddenLayers?: Array<layers.Layer>;
optimizer?: string | Optimizer;
Expand All @@ -39,6 +41,7 @@ export class BinaryClassificationTrainer {
protected tensorBoardLogsDirectory?: string;
protected inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
protected outputFeatureExtractor?: FeatureExtractor<any, any>;
protected inputFeatureNormalizers?: Array<FeatureNormalizer<any>>;
protected model!: LayersModel;

protected static DEFAULT_BATCH_SIZE: number = 32;
Expand All @@ -52,6 +55,7 @@ export class BinaryClassificationTrainer {
this.tensorBoardLogsDirectory = options.tensorBoardLogsDirectory;
this.inputFeatureExtractors = options.inputFeatureExtractors;
this.outputFeatureExtractor = options.outputFeatureExtractor;
this.inputFeatureNormalizers = options.inputFeatureNormalizers;

this.initializeModel(options);
}
Expand All @@ -63,7 +67,7 @@ export class BinaryClassificationTrainer {
testingDataset,
printTestingResults
}: {
data?: Array<any>,
data?: Array<any>;
trainingDataset?: data.Dataset<TensorContainer>;
validationDataset?: data.Dataset<TensorContainer>;
testingDataset?: data.Dataset<TensorContainer>;
Expand All @@ -90,14 +94,19 @@ export class BinaryClassificationTrainer {
validationDataset === undefined ||
testingDataset === undefined
) {
if (this.inputFeatureExtractors === undefined || this.outputFeatureExtractor === undefined) {
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!');
if (
this.inputFeatureExtractors === undefined ||
this.outputFeatureExtractor === undefined ||
this.inputFeatureNormalizers === undefined
) {
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors, outputFeatureExtractor and inputFeatureNormalizers are not provided!');
}

const datasets = await prepareDatasetsForBinaryClassification({
data: data as Array<any>,
inputFeatureExtractors: this.inputFeatureExtractors,
outputFeatureExtractor: this.outputFeatureExtractor,
inputFeatureNormalizers: this.inputFeatureNormalizers,
batchSize: this.batchSize
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Sample } from '../training/sample';
import { FeatureExtractor } from './feature-extractor';
import { Feature } from './feature';

export type DataItemExtract<T> = {
inputFeatures: Array<Feature<T>>;
outputFeature: Feature<T>;
};

export const extractFeatures = async <D, T>({
data,
Expand All @@ -9,8 +14,8 @@ export const extractFeatures = async <D, T>({
data: Array<D>;
inputFeatureExtractors: Array<FeatureExtractor<D, T>>;
outputFeatureExtractor: FeatureExtractor<D, T>;
}): Promise<Array<Sample>> => {
const samples = [];
}): Promise<Array<DataItemExtract<T>>> => {
const extracts = [];

for (const dataItem of data) {
const [inputFeatures, outputFeature] = await Promise.all([
Expand All @@ -22,11 +27,8 @@ export const extractFeatures = async <D, T>({
outputFeatureExtractor.extract(dataItem)
]);

const input = inputFeatures.map((feature) => feature.value);
const output = [outputFeature.value];

samples.push({ input, output });
extracts.push({ inputFeatures, outputFeature });
}

return samples;
return extracts;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { Feature } from './feature';

export abstract class FeatureNormalizer<T> {
public abstract featureType: T;

public abstract normalize(feature: Feature<T>): Feature<T> | Promise<Feature<T>>;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { FeatureNormalizer } from './feature-normalizer';
import { Feature } from './feature';

export abstract class MinMaxFeatureNormalizer<T> extends FeatureNormalizer<T> {
private min: number;
private max: number;

constructor({ min, max }: { min: number; max: number }) {
super();

this.min = min;
this.max = max;
}

public normalize(feature: Feature<T>): Feature<T> | Promise<Feature<T>> {
return new Feature({
...feature,
value: (feature.value - this.min) / (this.max - this.min)
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { FeatureNormalizer } from './feature-normalizer';
import { Sample } from '../training/sample';
import { DataItemExtract } from './extract-features';

export const normalizeFeatures = async <T>({
extracts,
inputFeatureNormalizers
}: {
extracts: Array<DataItemExtract<T>>;
inputFeatureNormalizers: Array<FeatureNormalizer<T>>;
}): Promise<Array<Sample>> => {
const samples = [];

for (const extractItem of extracts) {
const inputNormalizedFeatures = await Promise.all(
extractItem.inputFeatures.map((feature) => {
const desiredNormalizer = inputFeatureNormalizers.find((normalizer) => normalizer.featureType === feature.type);

return (desiredNormalizer !== undefined) ? desiredNormalizer.normalize(feature) : feature;
})
);

const input = inputNormalizedFeatures.map((feature) => feature.value);
const output = [extractItem.outputFeature.value];

samples.push({ input, output });
}

return samples;
};
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ import { splitSamplesIntoTrainingValidationTestForBinaryClassification } from '.
import { makeDataset } from '../utils/make-dataset';
import { extractFeatures } from './extract-features';
import { FeatureExtractor } from './feature-extractor';
import { FeatureNormalizer } from './feature-normalizer';
import { normalizeFeatures } from './normalize-features';

export const prepareDatasetsForBinaryClassification = async <D, T>({
data,
inputFeatureExtractors,
outputFeatureExtractor,
inputFeatureNormalizers,
batchSize,
trainingPercentage,
validationPercentage,
Expand All @@ -16,6 +19,7 @@ export const prepareDatasetsForBinaryClassification = async <D, T>({
data: Array<D>;
inputFeatureExtractors: Array<FeatureExtractor<D, T>>;
outputFeatureExtractor: FeatureExtractor<D, T>;
inputFeatureNormalizers: Array<FeatureNormalizer<T>>;
batchSize: number;
trainingPercentage?: number;
validationPercentage?: number;
Expand All @@ -25,12 +29,17 @@ export const prepareDatasetsForBinaryClassification = async <D, T>({
validationDataset: data.Dataset<TensorContainer>;
testingDataset: data.Dataset<TensorContainer>;
}> => {
const samples = await extractFeatures({
const extracts = await extractFeatures({
data,
inputFeatureExtractors,
outputFeatureExtractor
});

const samples = await normalizeFeatures({
extracts,
inputFeatureNormalizers
});

const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(
samples,
trainingPercentage,
Expand Down
3 changes: 3 additions & 0 deletions packages/tfjs-node-helpers/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ export * from './classification/binary-classifier';
export * from './data-splitting/training-validation-test-for-binary-classification';
export * from './feature-engineering/extract-features';
export * from './feature-engineering/feature-extractor';
export * from './feature-engineering/feature-normalizer';
export * from './feature-engineering/feature';
export * from './feature-engineering/min-max-feature-normalizer';
export * from './feature-engineering/normalize-features';
export * from './feature-engineering/prepare-datasets-for-binary-classification';
export * from './training/sample';
export * from './utils/make-chunked-dataset';