Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categoryencoding #46

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions tfjs-layers/src/layers/preprocessing/category_encoding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { Kwargs } from '../../types';
import { ValueError } from '../../errors';
import * as K from '../../backend/tfjs_backend';
import * as utils from './preprocessing_utils';
import { OutputMode } from './preprocessing_utils';

export declare interface CategoryEncodingArgs extends LayerArgs {
numTokens: number;
Expand Down Expand Up @@ -58,7 +59,7 @@ export class CategoryEncoding extends Layer {
return [this.numTokens];
}

if(this.outputMode === utils.oneHot && inputShape[-1] !== 1) {
if(this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1){
inputShape.push(this.numTokens);
return inputShape;
}
Expand All @@ -70,43 +71,45 @@ export class CategoryEncoding extends Layer {
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor {
return tidy(() => {

inputs = getExactlyOneTensor(inputs);
if(inputs.dtype !== 'int32') {
inputs = K.cast(inputs, 'int32');
}
inputs = getExactlyOneTensor(inputs);
if(inputs.dtype !== 'int32') {
inputs = K.cast(inputs, 'int32');
}

let countWeights: Tensor1D | Tensor2D;
let countWeights: Tensor1D | Tensor2D;

if((typeof kwargs['countWeights']) !== 'undefined') {
if((typeof kwargs['countWeights']) !== 'undefined') {

if(this.outputMode !== utils.count) {
throw new ValueError(
`countWeights is not used when outputMode !== count.
Received countWeights=${kwargs['countWeights']}`);
if(this.outputMode !== 'count') {
throw new ValueError(
`countWeights is not used when outputMode !== count.
Received countWeights=${kwargs['countWeights']}`);
}
const countWeightsArg = getExactlyOneTensor(kwargs['countWeights']);

if(countWeightsArg.rank === 1) {
countWeights = countWeightsArg as Tensor1D;
} if(countWeightsArg.rank === 2) {
countWeights = countWeightsArg as Tensor2D;
}
}
const countWeightsRanked = getExactlyOneTensor(kwargs['countWeights']);

if(countWeightsRanked.rank === 1) {
countWeights = countWeightsRanked as Tensor1D;
} if(countWeightsRanked.rank === 2) {
countWeights = countWeightsRanked as Tensor2D;
}
}
const maxValue = max(inputs);
const minValue = min(inputs);
const greaterEqualMax = greater(this.numTokens, maxValue)
.bufferSync().get(0);

const depth = this.numTokens;
const maxValue = max(inputs);
const minValue = min(inputs);
const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0);

const greaterEqualMax = greater(depth, maxValue).bufferSync().get(0);
const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0);
if(!(greaterEqualMax && greaterMin)) {

if(!(greaterEqualMax && greaterMin)) {
throw new ValueError(
`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
}
throw new ValueError(
`Input values must be between 0 < values <= numTokens
with numTokens=${this.numTokens}`);
}

return utils.encodeCategoricalInputs(inputs,
this.outputMode, depth, countWeights);
return utils.encodeCategoricalInputs(inputs,
this.outputMode, this.numTokens, countWeights);
});
}
}
Expand Down
43 changes: 27 additions & 16 deletions tfjs-layers/src/layers/preprocessing/category_encoding_test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import { describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils';
import { Tensor, tensor} from '@tensorflow/tfjs-core';
import { CategoryEncoding } from './category_encoding';
import * as utils from './preprocessing_utils';

describeMathCPUAndGPU('Category Encoding', () => {

Expand All @@ -10,7 +19,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const expectedOutput = tensor([1,0,0,0]);
const numTokens = 4;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.count});
outputMode: 'count'});
const computedOutput = encodingLayer.
apply(categoryData) as Tensor;

Expand All @@ -23,7 +32,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const numTokens = 6;
const expectedOutput = tensor([7, 1, 2, 4, 0, 0]);
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.count});
outputMode: 'count'});

const computedOutput = encodingLayer.apply(categoryData,
{countWeights: weightData}) as Tensor;
Expand All @@ -36,7 +45,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const expectedOutput = tensor([[0, 2, 1, 1, 0, 0], [2, 1, 0, 1, 0, 0]]);
const numTokens = 6;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.count});
outputMode: 'count'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -46,7 +55,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const expectedOutput = tensor([0, 0, 0, 1]);
const numTokens = 4;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.oneHot});
outputMode: 'oneHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -59,7 +68,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
[0, 1, 0, 0]]);
const numTokens = 4;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.oneHot});
outputMode: 'oneHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -72,7 +81,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
[0, 1, 0, 0]]);
const numTokens = 4;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.oneHot});
outputMode: 'oneHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -82,7 +91,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const expectedOutput = tensor([0, 0, 0, 1, 0, 0]);
const numTokens = 6;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -92,7 +101,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
const expectedOutput = tensor([1, 1, 1, 1, 0, 0]);
const numTokens = 6;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -105,7 +114,7 @@ describeMathCPUAndGPU('Category Encoding', () => {
[0, 1, 0, 1]]);
const numTokens = 4;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
const computedOutput = encodingLayer.apply(categoryData) as Tensor;
expectTensorsClose(computedOutput, expectedOutput);
});
Expand All @@ -114,28 +123,30 @@ describeMathCPUAndGPU('Category Encoding', () => {
const categoryData = tensor([[[1], [2]], [[3], [4]]]);
const numTokens = 6;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
expect(() => encodingLayer.apply(categoryData))
.toThrowError(`When outputMode is not 'int', maximum output rank is 2
Received outputMode ${utils.multiHot} and input shape ${categoryData.shape}
Received outputMode ${'multiHot'} and input shape ${categoryData.shape}
which would result in output rank ${categoryData.rank}.`);
});

it('Raises Value Error if max input value !<= numTokens', () => {
const categoryData = tensor([7, 2, 0, 1]);
const numTokens = 3;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
expect(() => encodingLayer.apply(categoryData))
.toThrowError(`Input values must be between 0 < values <= numTokens`);
.toThrowError(`Input values must be between 0 < values <= numTokens
with numTokens=${numTokens}`);
});

it('Raises Value Error if min input value < 0', () => {
const categoryData = tensor([7, 2, -1, 1]);
const numTokens = 3;
const encodingLayer = new CategoryEncoding({numTokens,
outputMode: utils.multiHot});
outputMode: 'multiHot'});
expect(() => encodingLayer.apply(categoryData))
.toThrowError(`Input values must be between 0 < values <= numTokens`);
.toThrowError(`Input values must be between 0 < values <= numTokens
with numTokens=${numTokens}`);
});
});
60 changes: 35 additions & 25 deletions tfjs-layers/src/layers/preprocessing/preprocessing_utils.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
/**
* @license
* Copyright 2022 CodeSmith LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/

import { Tensor, denseBincount, Tensor1D, Tensor2D, TensorLike, mul} from '@tensorflow/tfjs-core';
import { getExactlyOneTensor } from '../../utils/types_utils';
import { expandDims} from '@tensorflow/tfjs-core';
Expand All @@ -7,68 +17,68 @@ import * as K from '../../backend/tfjs_backend';
export type OutputMode = 'int' | 'oneHot' | 'multiHot' | 'count' | 'tfIdf';

export function encodeCategoricalInputs(inputs: Tensor|Tensor[],
outputMode: string,
outputMode: OutputMode,
depth: number,
weights?: Tensor1D|Tensor2D|TensorLike):
Tensor|Tensor[] {

let input = getExactlyOneTensor(inputs);

if(inputs.dtype !== 'int32') {
inputs = K.cast(inputs, 'int32');
if(input.dtype !== 'int32') {
input = K.cast(input, 'int32');
}

if(outputMode === int) {
return inputs;
if(outputMode === 'int') {
return input;
}

const originalShape = inputs.shape;
const originalShape = input.shape;

if(inputs.rank === 0) {
inputs = expandDims(inputs, -1);
if(input.rank === 0) {
input = expandDims(input, -1);
}

if(outputMode === oneHot) {
if(inputs.shape[inputs.shape.length - 1] !== 1) {
inputs = expandDims(inputs, -1);
if(outputMode === 'oneHot') {
if(input.shape[input.shape.length - 1] !== 1) {
input = expandDims(input, -1);
}
}

if(inputs.rank > 2) {
if(input.rank > 2) {
throw new ValueError(`When outputMode is not 'int', maximum output rank is 2
Received outputMode ${outputMode} and input shape ${originalShape}
which would result in output rank ${inputs.rank}.`);
which would result in output rank ${input.rank}.`);
}

const binaryOutput = [multiHot, oneHot].includes(outputMode);
const binaryOutput = ['multiHot', 'oneHot'].includes(outputMode);

let denseBincountInput: Tensor1D | Tensor2D;

if(inputs.rank === 1) {
denseBincountInput = inputs as Tensor1D;
if(input.rank === 1) {
denseBincountInput = input as Tensor1D;
}

if(inputs.rank === 2) {
denseBincountInput = inputs as Tensor2D;
if(input.rank === 2) {
denseBincountInput = input as Tensor2D;
}

let binCounts: Tensor1D | Tensor2D;

if ((typeof weights) !== 'undefined' && outputMode === count) {
if ((typeof weights) !== 'undefined' && outputMode === 'count') {
binCounts = denseBincount(denseBincountInput, weights, depth, binaryOutput);
} else {
binCounts = denseBincount(denseBincountInput, [], depth, binaryOutput);
}

if(outputMode !== tfIdf) {
if(outputMode !== 'tfIdf') {
return binCounts;
}

if(weights === null || weights === undefined) {
throw new ValueError(
`When outputMode is 'tfIdf', weights must be provided.`
);
} else {
if (weights) {
return mul(binCounts, weights);
} else {
throw new ValueError(
`When outputMode is 'tfIdf', weights must be provided.`
);
}
}
Loading