forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Category Encoding Preprocessing Layer Co-authored-by: David Kim (@koyykdy) <dok098@ucsd.edu> Brian Zheng (@Brianzheng123) <brianzheng345@gmail.com>
- Loading branch information
1 parent
af54b76
commit 296bbb9
Showing
5 changed files
with
448 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
tfjs-layers/src/layers/preprocessing/category_encoding.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/** | ||
* @license | ||
* Copyright 2022 CodeSmith LLC | ||
* | ||
* Use of this source code is governed by an MIT-style | ||
* license that can be found in the LICENSE file or at | ||
* https://opensource.org/licenses/MIT. | ||
* ============================================================================= | ||
*/ | ||
|
||
import { LayerArgs, Layer } from '../../engine/topology'; | ||
import { serialization, Tensor, tidy, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core'; | ||
import { greater, greaterEqual, max, min} from '@tensorflow/tfjs-core'; | ||
import { Shape } from '../../keras_format/common'; | ||
import { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils'; | ||
import { Kwargs } from '../../types'; | ||
import { ValueError } from '../../errors'; | ||
import * as K from '../../backend/tfjs_backend'; | ||
import * as utils from './preprocessing_utils'; | ||
|
||
export declare interface CategoryEncodingArgs extends LayerArgs { | ||
numTokens: number; | ||
outputMode?: string; | ||
} | ||
|
||
export class CategoryEncoding extends Layer { | ||
/** @nocollapse */ | ||
static className = 'CategoryEncoding'; | ||
private readonly numTokens: number; | ||
private readonly outputMode: string; | ||
|
||
constructor(args: CategoryEncodingArgs) { | ||
super(args); | ||
this.numTokens = args.numTokens; | ||
|
||
if(args.outputMode) { | ||
this.outputMode = args.outputMode; | ||
} else { | ||
this.outputMode = utils.multiHot; | ||
} | ||
} | ||
|
||
getConfig(): serialization.ConfigDict { | ||
const config: serialization.ConfigDict = { | ||
'numTokens': this.numTokens, | ||
'outputMode': this.outputMode, | ||
}; | ||
|
||
const baseConfig = super.getConfig(); | ||
Object.assign(config, baseConfig); | ||
return config; | ||
} | ||
|
||
computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { | ||
inputShape = getExactlyOneShape(inputShape); | ||
|
||
if(inputShape == null) { | ||
return [this.numTokens]; | ||
} | ||
|
||
if(this.outputMode === utils.oneHot && inputShape[-1] !== 1) { | ||
inputShape.push(this.numTokens); | ||
return inputShape; | ||
} | ||
|
||
inputShape[-1] = this.numTokens; | ||
return inputShape; | ||
} | ||
|
||
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor { | ||
return tidy(() => { | ||
|
||
inputs = getExactlyOneTensor(inputs); | ||
if(inputs.dtype !== 'int32') { | ||
inputs = K.cast(inputs, 'int32'); | ||
} | ||
|
||
let countWeights; | ||
|
||
if((typeof kwargs['countWeights']) !== 'undefined') { | ||
|
||
if(this.outputMode !== utils.count) { | ||
throw new ValueError( | ||
`countWeights is not used when outputMode !== count. | ||
Received countWeights=${kwargs['countWeights']}`); | ||
} | ||
const countWeightsRanked = getExactlyOneTensor(kwargs['countWeights']); | ||
|
||
if(countWeightsRanked.rank === 1) { | ||
countWeights = countWeightsRanked as Tensor1D; | ||
} if(countWeightsRanked.rank === 2) { | ||
countWeights = countWeightsRanked as Tensor2D; | ||
} | ||
} | ||
|
||
const depth = this.numTokens; | ||
const maxValue = max(inputs); | ||
const minValue = min(inputs); | ||
|
||
const greaterEqualMax = greater(depth, maxValue).bufferSync().get(0); | ||
const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0); | ||
|
||
if(!(greaterEqualMax && greaterMin)) { | ||
throw new ValueError( | ||
`Input values must be between 0 < values <= numTokens`); | ||
} | ||
|
||
return utils.encodeCategoricalInputs(inputs, | ||
this.outputMode, depth, countWeights); | ||
}); | ||
} | ||
} | ||
|
||
serialization.registerClass(CategoryEncoding); |
141 changes: 141 additions & 0 deletions
141
tfjs-layers/src/layers/preprocessing/category_encoding_test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import { describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils'; | ||
import { Tensor, tensor} from '@tensorflow/tfjs-core'; | ||
import { CategoryEncoding } from './category_encoding'; | ||
import * as utils from './preprocessing_utils'; | ||
|
||
describeMathCPUAndGPU('Layer Output', () => { | ||
|
||
it('Calculates correct output for Count outputMode rank 0', () => { | ||
const categoryData = tensor(0); | ||
const expectedOutput = tensor([1,0,0,0]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
const computedOutput = encodingLayer. | ||
apply(categoryData) as Tensor; | ||
|
||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for Count outputMode rank 1 (weights)', () => { | ||
const categoryData = tensor([1, 2, 3, 3, 0]); | ||
const weightData = tensor([1, 2, 3, 1, 7]); | ||
const numTokens = 6; | ||
const expectedOutput = tensor([7, 1, 2, 4, 0, 0]); | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
|
||
const computedOutput = encodingLayer.apply(categoryData, | ||
{countWeights: weightData}) as Tensor; | ||
|
||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for Count outputMode rank 2', () => { | ||
const categoryData = tensor([[1, 2, 3, 1], [0, 3, 1, 0]]); | ||
const expectedOutput = tensor([[0, 2, 1, 1, 0, 0], [2, 1, 0, 1, 0, 0]]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for oneHot outputMode rank 0', () => { | ||
const categoryData = tensor(3); | ||
const expectedOutput = tensor([0, 0, 0, 1]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output and shape for oneHot outputMode rank 1', () => { | ||
const categoryData = tensor([3, 2, 0, 1]); | ||
const expectedOutput = tensor([[0, 0, 0, 1], | ||
[0, 0, 1, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 0, 0]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output and shape for oneHot outputMode rank 2', () => { | ||
const categoryData = tensor([[3], [2], [0], [1]]); | ||
const expectedOutput = tensor([[0, 0, 0, 1], | ||
[0, 0, 1, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 0, 0]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 0', () => { | ||
const categoryData = tensor(3); | ||
const expectedOutput = tensor([0, 0, 0, 1, 0, 0]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 1', () => { | ||
const categoryData = tensor([3, 2, 0, 1]); | ||
const expectedOutput = tensor([1, 1, 1, 1, 0, 0]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 2', () => { | ||
const categoryData = tensor([[0, 1], [0, 0], [1, 2], [3, 1]]); | ||
const expectedOutput = tensor([[1, 1, 0, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 1, 0], | ||
[0, 1, 0, 1]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Raises Value Error if input Tensor has Rank > 2', () =>{ | ||
const categoryData = tensor([[[1], [2]], [[3], [4]]]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`When outputMode is not 'int', maximum output rank is 2 | ||
Received outputMode ${utils.multiHot} and input shape ${categoryData.shape} | ||
which would result in output rank ${categoryData.rank}.`); | ||
}); | ||
|
||
it('Raises Value Error if max input value !<= numTokens', () => { | ||
const categoryData = tensor([7, 2, 0, 1]); | ||
const numTokens = 3; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`Input values must be between 0 < values <= numTokens`); | ||
}); | ||
|
||
it('Raises Value Error if min input value < 0', () => { | ||
const categoryData = tensor([7, 2, -1, 1]); | ||
const numTokens = 3; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`Input values must be between 0 < values <= numTokens`); | ||
}); | ||
}); |
Oops, something went wrong.