forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* needs linting * started category encoding layer * fixed kernal mistake * CategoryEncoding layer definition and export finished. Needs unit testing * CategoryEncoding layer definition and export finished. Needs unit testing * unit tests for category encoding * more unit testing * passing all unit tests for category encoding * done with category encoding, working on linting * passing linter
- Loading branch information
1 parent
aa5338f
commit 9f29a29
Showing
7 changed files
with
446 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
tfjs-layers/src/layers/preprocessing/category_encoding.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/** | ||
* @license | ||
* Copyright 2022 CodeSmith LLC | ||
* | ||
* Use of this source code is governed by an MIT-style | ||
* license that can be found in the LICENSE file or at | ||
* https://opensource.org/licenses/MIT. | ||
* ============================================================================= | ||
*/ | ||
|
||
import { LayerArgs, Layer } from '../../engine/topology'; | ||
import { serialization, Tensor, tidy, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core'; | ||
import { greater, greaterEqual, max, min} from '@tensorflow/tfjs-core'; | ||
import { Shape } from '../../keras_format/common'; | ||
import { getExactlyOneShape, getExactlyOneTensor } from '../../utils/types_utils'; | ||
import { Kwargs } from '../../types'; | ||
import { ValueError } from '../../errors'; | ||
import * as K from '../../backend/tfjs_backend'; | ||
import * as utils from './preprocessing_utils'; | ||
|
||
export declare interface CategoryEncodingArgs extends LayerArgs { | ||
numTokens: number; | ||
outputMode?: string; | ||
} | ||
|
||
export class CategoryEncoding extends Layer { | ||
/** @nocollapse */ | ||
static className = 'CategoryEncoding'; | ||
private readonly numTokens: number; | ||
private readonly outputMode: string; | ||
|
||
constructor(args: CategoryEncodingArgs) { | ||
super(args); | ||
this.numTokens = args.numTokens; | ||
|
||
if(args.outputMode) { | ||
this.outputMode = args.outputMode; | ||
} else { | ||
this.outputMode = utils.multiHot; | ||
} | ||
} | ||
|
||
getConfig(): serialization.ConfigDict { | ||
const config: serialization.ConfigDict = { | ||
'numTokens': this.numTokens, | ||
'outputMode': this.outputMode, | ||
}; | ||
|
||
const baseConfig = super.getConfig(); | ||
Object.assign(config, baseConfig); | ||
return config; | ||
} | ||
|
||
computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { | ||
inputShape = getExactlyOneShape(inputShape); | ||
|
||
if(inputShape == null) { | ||
return [this.numTokens]; | ||
} | ||
|
||
if(this.outputMode === utils.oneHot && inputShape[-1] !== 1) { | ||
inputShape.push(this.numTokens); | ||
return inputShape; | ||
} | ||
|
||
inputShape[-1] = this.numTokens; | ||
return inputShape; | ||
} | ||
|
||
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor[]|Tensor { | ||
return tidy(() => { | ||
|
||
inputs = getExactlyOneTensor(inputs); | ||
if(inputs.dtype !== 'int32') { | ||
inputs = K.cast(inputs, 'int32'); | ||
} | ||
|
||
let countWeights; | ||
|
||
if((typeof kwargs['countWeights']) !== 'undefined') { | ||
|
||
if(this.outputMode !== utils.count) { | ||
throw new ValueError( | ||
`countWeights is not used when outputMode !== count. | ||
Received countWeights=${kwargs['countWeights']}`); | ||
} | ||
const countWeightsRanked = getExactlyOneTensor(kwargs['countWeights']); | ||
|
||
if(countWeightsRanked.rank === 1) { | ||
countWeights = countWeightsRanked as Tensor1D; | ||
} if(countWeightsRanked.rank === 2) { | ||
countWeights = countWeightsRanked as Tensor2D; | ||
} | ||
} | ||
|
||
const depth = this.numTokens; | ||
const maxValue = max(inputs); | ||
const minValue = min(inputs); | ||
|
||
const greaterEqualMax = greater(depth, maxValue).bufferSync().get(0); | ||
const greaterMin = greaterEqual(minValue, 0).bufferSync().get(0); | ||
|
||
if(!(greaterEqualMax && greaterMin)) { | ||
throw new ValueError( | ||
`Input values must be between 0 < values <= numTokens`); | ||
} | ||
|
||
return utils.encodeCategoricalInputs(inputs, | ||
this.outputMode, depth, countWeights); | ||
}); | ||
} | ||
} | ||
|
||
serialization.registerClass(CategoryEncoding); |
141 changes: 141 additions & 0 deletions
141
tfjs-layers/src/layers/preprocessing/category_encoding_test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import { describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils'; | ||
import { Tensor, tensor} from '@tensorflow/tfjs-core'; | ||
import { CategoryEncoding } from './category_encoding'; | ||
import * as utils from './preprocessing_utils'; | ||
|
||
describeMathCPUAndGPU('Layer Output', () => { | ||
|
||
it('Calculates correct output for Count outputMode rank 0', () => { | ||
const categoryData = tensor(0); | ||
const expectedOutput = tensor([1,0,0,0]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
const computedOutput = encodingLayer. | ||
apply(categoryData) as Tensor; | ||
|
||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for Count outputMode rank 1 (weights)', () => { | ||
const categoryData = tensor([1, 2, 3, 3, 0]); | ||
const weightData = tensor([1, 2, 3, 1, 7]); | ||
const numTokens = 6; | ||
const expectedOutput = tensor([7, 1, 2, 4, 0, 0]); | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
|
||
const computedOutput = encodingLayer.apply(categoryData, | ||
{countWeights: weightData}) as Tensor; | ||
|
||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for Count outputMode rank 2', () => { | ||
const categoryData = tensor([[1, 2, 3, 1], [0, 3, 1, 0]]); | ||
const expectedOutput = tensor([[0, 2, 1, 1, 0, 0], [2, 1, 0, 1, 0, 0]]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.count}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for oneHot outputMode rank 0', () => { | ||
const categoryData = tensor(3); | ||
const expectedOutput = tensor([0, 0, 0, 1]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output and shape for oneHot outputMode rank 1', () => { | ||
const categoryData = tensor([3, 2, 0, 1]); | ||
const expectedOutput = tensor([[0, 0, 0, 1], | ||
[0, 0, 1, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 0, 0]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output and shape for oneHot outputMode rank 2', () => { | ||
const categoryData = tensor([[3], [2], [0], [1]]); | ||
const expectedOutput = tensor([[0, 0, 0, 1], | ||
[0, 0, 1, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 0, 0]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 0', () => { | ||
const categoryData = tensor(3); | ||
const expectedOutput = tensor([0, 0, 0, 1, 0, 0]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.oneHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 1', () => { | ||
const categoryData = tensor([3, 2, 0, 1]); | ||
const expectedOutput = tensor([1, 1, 1, 1, 0, 0]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Calculates correct output for multiHot outputMode rank 2', () => { | ||
const categoryData = tensor([[0, 1], [0, 0], [1, 2], [3, 1]]); | ||
const expectedOutput = tensor([[1, 1, 0, 0], | ||
[1, 0, 0, 0], | ||
[0, 1, 1, 0], | ||
[0, 1, 0, 1]]); | ||
const numTokens = 4; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
const computedOutput = encodingLayer.apply(categoryData) as Tensor; | ||
expectTensorsClose(computedOutput, expectedOutput); | ||
}); | ||
|
||
it('Raises Value Error if input Tensor has Rank > 2', () =>{ | ||
const categoryData = tensor([[[1], [2]], [[3], [4]]]); | ||
const numTokens = 6; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`When outputMode is not 'int', maximum output rank is 2 | ||
Received outputMode ${utils.multiHot} and input shape ${categoryData.shape} | ||
which would result in output rank ${categoryData.rank}.`); | ||
}); | ||
|
||
it('Raises Value Error if max input value !<= numTokens', () => { | ||
const categoryData = tensor([7, 2, 0, 1]); | ||
const numTokens = 3; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`Input values must be between 0 < values <= numTokens`); | ||
}); | ||
|
||
it('Raises Value Error if min input value < 0', () => { | ||
const categoryData = tensor([7, 2, -1, 1]); | ||
const numTokens = 3; | ||
const encodingLayer = new CategoryEncoding({numTokens, | ||
outputMode: utils.multiHot}); | ||
expect(() => encodingLayer.apply(categoryData)) | ||
.toThrowError(`Input values must be between 0 < values <= numTokens`); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.