diff --git a/packages/tfjs-node-helpers-example/src/app/app.ts b/packages/tfjs-node-helpers-example/src/app/app.ts index 41400fb..bf58bc8 100644 --- a/packages/tfjs-node-helpers-example/src/app/app.ts +++ b/packages/tfjs-node-helpers-example/src/app/app.ts @@ -6,6 +6,8 @@ import { GenderFeatureExtractor } from './feature-extractors/gender'; import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car'; import { join } from 'node:path'; import { TrainingDataService } from './services/training-data'; +import { AgeMinMaxFeatureNormalizer } from './feature-normalizers/age'; +import { AnnualSalaryMinMaxFeatureNormalizer } from './feature-normalizers/annual-salary'; export async function startApplication(): Promise { await train(); @@ -23,7 +25,11 @@ async function train(): Promise { new AnnualSalaryFeatureExtractor(), new GenderFeatureExtractor() ], - outputFeatureExtractor: new OwnsTheCarFeatureExtractor() + outputFeatureExtractor: new OwnsTheCarFeatureExtractor(), + inputFeatureNormalizers: [ + new AgeMinMaxFeatureNormalizer(), + new AnnualSalaryMinMaxFeatureNormalizer() + ] }); const trainingDataService = new TrainingDataService({ diff --git a/packages/tfjs-node-helpers-example/src/app/feature-extractors/age.ts b/packages/tfjs-node-helpers-example/src/app/feature-extractors/age.ts index 8c9bbbd..34cbfd7 100644 --- a/packages/tfjs-node-helpers-example/src/app/feature-extractors/age.ts +++ b/packages/tfjs-node-helpers-example/src/app/feature-extractors/age.ts @@ -6,13 +6,10 @@ export class AgeFeatureExtractor extends FeatureExtractor { - const minAge = 18; - const maxAge = 63; - return new Feature({ type: this.featureType, label: `${item.age} years`, - value: (item.age - minAge) / (maxAge - minAge) + value: item.age }); } } diff --git a/packages/tfjs-node-helpers-example/src/app/feature-extractors/annual-salary.ts b/packages/tfjs-node-helpers-example/src/app/feature-extractors/annual-salary.ts index 338d6ee..2b88434 100644 --- a/packages/tfjs-node-helpers-example/src/app/feature-extractors/annual-salary.ts +++ b/packages/tfjs-node-helpers-example/src/app/feature-extractors/annual-salary.ts @@ -6,13 +6,10 @@ export class AnnualSalaryFeatureExtractor extends FeatureExtractor { - const minAnnualSalary = 15000; - const maxAnnualSalary = 152500; - return new Feature({ type: this.featureType, label: item.annual_salary.toString(), - value: (item.annual_salary - minAnnualSalary) / (maxAnnualSalary - minAnnualSalary) + value: item.annual_salary }); } } diff --git a/packages/tfjs-node-helpers-example/src/app/feature-normalizers/age.ts b/packages/tfjs-node-helpers-example/src/app/feature-normalizers/age.ts new file mode 100644 index 0000000..7f86205 --- /dev/null +++ b/packages/tfjs-node-helpers-example/src/app/feature-normalizers/age.ts @@ -0,0 +1,10 @@ +import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers'; +import { FeatureType } from '../enums/feature-type'; + +export class AgeMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer { + public featureType = FeatureType.AGE; + + constructor() { + super({ min: 18, max: 63 }); + } +} diff --git a/packages/tfjs-node-helpers-example/src/app/feature-normalizers/annual-salary.ts b/packages/tfjs-node-helpers-example/src/app/feature-normalizers/annual-salary.ts new file mode 100644 index 0000000..537623e --- /dev/null +++ b/packages/tfjs-node-helpers-example/src/app/feature-normalizers/annual-salary.ts @@ -0,0 +1,10 @@ +import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers'; +import { FeatureType } from '../enums/feature-type'; + +export class AnnualSalaryMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer { + public featureType = FeatureType.ANNUAL_SALARY; + + constructor() { + super({ min: 15000, max: 152500 }); + } +} diff --git a/packages/tfjs-node-helpers-example/src/app/services/training-data.ts b/packages/tfjs-node-helpers-example/src/app/services/training-data.ts index f63a719..2def15d 100644 --- a/packages/tfjs-node-helpers-example/src/app/services/training-data.ts +++ b/packages/tfjs-node-helpers-example/src/app/services/training-data.ts @@ -1,9 +1,16 @@ -import { extractFeatures, Sample, splitSamplesIntoTrainingValidationTestForBinaryClassification } from '@ronas-it/tfjs-node-helpers'; +import { + extractFeatures, + normalizeFeatures, + Sample, + splitSamplesIntoTrainingValidationTestForBinaryClassification +} from '@ronas-it/tfjs-node-helpers'; import { AgeFeatureExtractor } from '../feature-extractors/age'; import { AnnualSalaryFeatureExtractor } from '../feature-extractors/annual-salary'; import { GenderFeatureExtractor } from '../feature-extractors/gender'; import { OwnsTheCarFeatureExtractor } from '../feature-extractors/owns-the-car'; import dataset from '../../assets/data.json'; +import { AgeMinMaxFeatureNormalizer } from '../feature-normalizers/age'; +import { AnnualSalaryMinMaxFeatureNormalizer } from '../feature-normalizers/annual-salary'; export class TrainingDataService { private simulatedDelayMs: number; @@ -16,7 +23,7 @@ export class TrainingDataService { } public async initialize(): Promise { - const samples = await extractFeatures({ + const extracts = await extractFeatures({ data: dataset, inputFeatureExtractors: [ new AgeFeatureExtractor(), @@ -26,6 +33,14 @@ export class TrainingDataService { outputFeatureExtractor: new OwnsTheCarFeatureExtractor() }); + const samples = await normalizeFeatures({ + extracts, + inputFeatureNormalizers: [ + new AgeMinMaxFeatureNormalizer(), + new AnnualSalaryMinMaxFeatureNormalizer() + ] + }); + const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(samples); this.trainingSamples = trainingSamples; diff --git a/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts b/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts index 1762b81..0de2d7e 100644 --- a/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts +++ b/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts @@ -19,6 +19,7 @@ import { prepareDatasetsForBinaryClassification } from '../feature-engineering/p import { ConfusionMatrix } from '../testing/confusion-matrix'; import { Metrics } from '../testing/metrics'; import { binarize } from '../utils/binarize'; +import { FeatureNormalizer } from '../feature-engineering/feature-normalizer'; export type BinaryClassificationTrainerOptions = { batchSize?: number; @@ -26,6 +27,7 @@ export type BinaryClassificationTrainerOptions = { patience?: number; inputFeatureExtractors?: Array>; outputFeatureExtractor?: FeatureExtractor; + inputFeatureNormalizers?: Array>; model?: LayersModel; hiddenLayers?: Array; optimizer?: string | Optimizer; @@ -39,6 +41,7 @@ export class BinaryClassificationTrainer { protected tensorBoardLogsDirectory?: string; protected inputFeatureExtractors?: Array>; protected outputFeatureExtractor?: FeatureExtractor; + protected inputFeatureNormalizers?: Array>; protected model!: LayersModel; protected static DEFAULT_BATCH_SIZE: number = 32; @@ -52,6 +55,7 @@ export class BinaryClassificationTrainer { this.tensorBoardLogsDirectory = options.tensorBoardLogsDirectory; this.inputFeatureExtractors = options.inputFeatureExtractors; this.outputFeatureExtractor = options.outputFeatureExtractor; + this.inputFeatureNormalizers = options.inputFeatureNormalizers; this.initializeModel(options); } @@ -63,7 +67,7 @@ export class BinaryClassificationTrainer { testingDataset, printTestingResults }: { - data?: Array, + data?: Array; trainingDataset?: data.Dataset; validationDataset?: data.Dataset; testingDataset?: data.Dataset; @@ -90,14 +94,19 @@ export class BinaryClassificationTrainer { validationDataset === undefined || testingDataset === undefined ) { - if (this.inputFeatureExtractors === undefined || this.outputFeatureExtractor === undefined) { - throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!'); + if ( + this.inputFeatureExtractors === undefined || + this.outputFeatureExtractor === undefined || + this.inputFeatureNormalizers === undefined + ) { + throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors, outputFeatureExtractor and inputFeatureNormalizers are not provided!'); } const datasets = await prepareDatasetsForBinaryClassification({ data: data as Array, inputFeatureExtractors: this.inputFeatureExtractors, outputFeatureExtractor: this.outputFeatureExtractor, + inputFeatureNormalizers: this.inputFeatureNormalizers, batchSize: this.batchSize }); diff --git a/packages/tfjs-node-helpers/src/feature-engineering/extract-features.ts b/packages/tfjs-node-helpers/src/feature-engineering/extract-features.ts index 4efb9a6..b8b64e8 100644 --- a/packages/tfjs-node-helpers/src/feature-engineering/extract-features.ts +++ b/packages/tfjs-node-helpers/src/feature-engineering/extract-features.ts @@ -1,5 +1,10 @@ -import { Sample } from '../training/sample'; import { FeatureExtractor } from './feature-extractor'; +import { Feature } from './feature'; + +export type DataItemExtract = { + inputFeatures: Array>; + outputFeature: Feature; +}; export const extractFeatures = async ({ data, @@ -9,8 +14,8 @@ export const extractFeatures = async ({ data: Array; inputFeatureExtractors: Array>; outputFeatureExtractor: FeatureExtractor; -}): Promise> => { - const samples = []; +}): Promise>> => { + const extracts = []; for (const dataItem of data) { const [inputFeatures, outputFeature] = await Promise.all([ @@ -22,11 +27,8 @@ export const extractFeatures = async ({ outputFeatureExtractor.extract(dataItem) ]); - const input = inputFeatures.map((feature) => feature.value); - const output = [outputFeature.value]; - - samples.push({ input, output }); + extracts.push({ inputFeatures, outputFeature }); } - return samples; + return extracts; } diff --git a/packages/tfjs-node-helpers/src/feature-engineering/feature-normalizer.ts b/packages/tfjs-node-helpers/src/feature-engineering/feature-normalizer.ts new file mode 100644 index 0000000..1784f3c --- /dev/null +++ b/packages/tfjs-node-helpers/src/feature-engineering/feature-normalizer.ts @@ -0,0 +1,7 @@ +import { Feature } from './feature'; + +export abstract class FeatureNormalizer { + public abstract featureType: T; + + public abstract normalize(feature: Feature): Feature | Promise>; +} diff --git a/packages/tfjs-node-helpers/src/feature-engineering/min-max-feature-normalizer.ts b/packages/tfjs-node-helpers/src/feature-engineering/min-max-feature-normalizer.ts new file mode 100644 index 0000000..35452e3 --- /dev/null +++ b/packages/tfjs-node-helpers/src/feature-engineering/min-max-feature-normalizer.ts @@ -0,0 +1,21 @@ +import { FeatureNormalizer } from './feature-normalizer'; +import { Feature } from './feature'; + +export abstract class MinMaxFeatureNormalizer extends FeatureNormalizer { + private min: number; + private max: number; + + constructor({ min, max }: { min: number; max: number }) { + super(); + + this.min = min; + this.max = max; + } + + public normalize(feature: Feature): Feature | Promise> { + return new Feature({ + ...feature, + value: (feature.value - this.min) / (this.max - this.min) + }); + } +} diff --git a/packages/tfjs-node-helpers/src/feature-engineering/normalize-features.ts b/packages/tfjs-node-helpers/src/feature-engineering/normalize-features.ts new file mode 100644 index 0000000..3be5a3a --- /dev/null +++ b/packages/tfjs-node-helpers/src/feature-engineering/normalize-features.ts @@ -0,0 +1,30 @@ +import { FeatureNormalizer } from './feature-normalizer'; +import { Sample } from '../training/sample'; +import { DataItemExtract } from './extract-features'; + +export const normalizeFeatures = async ({ + extracts, + inputFeatureNormalizers +}: { + extracts: Array>; + inputFeatureNormalizers: Array>; +}): Promise> => { + const samples = []; + + for (const extractItem of extracts) { + const inputNormalizedFeatures = await Promise.all( + extractItem.inputFeatures.map((feature) => { + const desiredNormalizer = inputFeatureNormalizers.find((normalizer) => normalizer.featureType === feature.type); + + return (desiredNormalizer !== undefined) ? desiredNormalizer.normalize(feature) : feature; + }) + ); + + const input = inputNormalizedFeatures.map((feature) => feature.value); + const output = [extractItem.outputFeature.value]; + + samples.push({ input, output }); + } + + return samples; +}; diff --git a/packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts b/packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts index 1511d10..51cf7c0 100644 --- a/packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts +++ b/packages/tfjs-node-helpers/src/feature-engineering/prepare-datasets-for-binary-classification.ts @@ -3,11 +3,14 @@ import { splitSamplesIntoTrainingValidationTestForBinaryClassification } from '. import { makeDataset } from '../utils/make-dataset'; import { extractFeatures } from './extract-features'; import { FeatureExtractor } from './feature-extractor'; +import { FeatureNormalizer } from './feature-normalizer'; +import { normalizeFeatures } from './normalize-features'; export const prepareDatasetsForBinaryClassification = async ({ data, inputFeatureExtractors, outputFeatureExtractor, + inputFeatureNormalizers, batchSize, trainingPercentage, validationPercentage, @@ -16,6 +19,7 @@ export const prepareDatasetsForBinaryClassification = async ({ data: Array; inputFeatureExtractors: Array>; outputFeatureExtractor: FeatureExtractor; + inputFeatureNormalizers: Array>; batchSize: number; trainingPercentage?: number; validationPercentage?: number; @@ -25,12 +29,17 @@ export const prepareDatasetsForBinaryClassification = async ({ validationDataset: data.Dataset; testingDataset: data.Dataset; }> => { - const samples = await extractFeatures({ + const extracts = await extractFeatures({ data, inputFeatureExtractors, outputFeatureExtractor }); + const samples = await normalizeFeatures({ + extracts, + inputFeatureNormalizers + }); + const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification( samples, trainingPercentage, diff --git a/packages/tfjs-node-helpers/src/index.ts b/packages/tfjs-node-helpers/src/index.ts index b3d62d9..7867d8b 100644 --- a/packages/tfjs-node-helpers/src/index.ts +++ b/packages/tfjs-node-helpers/src/index.ts @@ -3,7 +3,10 @@ export * from './classification/binary-classifier'; export * from './data-splitting/training-validation-test-for-binary-classification'; export * from './feature-engineering/extract-features'; export * from './feature-engineering/feature-extractor'; +export * from './feature-engineering/feature-normalizer'; export * from './feature-engineering/feature'; +export * from './feature-engineering/min-max-feature-normalizer'; +export * from './feature-engineering/normalize-features'; export * from './feature-engineering/prepare-datasets-for-binary-classification'; export * from './training/sample'; export * from './utils/make-chunked-dataset';