Navigation Menu

Skip to content

Commit

Permalink
fix tensor array gather not be able to handle indices longer than ava…
Browse files Browse the repository at this point in the history
…ilable tensors (tensorflow#3157)

BUG
  • Loading branch information
pyu10055 committed Apr 28, 2020
1 parent ed978b8 commit 36f2863
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tfjs-converter/src/executor/tensor_array.ts
Expand Up @@ -67,9 +67,9 @@ export class TensorArray {
throw new Error(`TensorArray ${this.name} has already been closed.`);
}

if (index < 0 || index >= this.tensors.length) {
if (index < 0 || index >= this.size()) {
throw new Error(`Tried to read from index ${index}, but array size is: ${
this.tensors.length}`);
this.size()}`);
}

const tensorWithState = this.tensors[index];
Expand Down Expand Up @@ -182,6 +182,8 @@ export class TensorArray {
for (let i = 0; i < this.size(); i++) {
indices.push(i);
}
} else {
indices = indices.slice(0, this.size());
}

if (indices.length === 0) {
Expand Down
5 changes: 5 additions & 0 deletions tfjs-converter/src/executor/tensor_array_test.ts
Expand Up @@ -145,6 +145,11 @@ describe('TensorArray', () => {
expect(gathered.shape).toEqual([2, 1, 1]);
test_util.expectArraysClose(await gathered.data(), [2, 1]);
});
it('should return when indices longer than available tensors', async () => {
const gathered = tensorArray.gather([1, 0, 2, 3]);
expect(gathered.shape).toEqual([2, 1, 1]);
test_util.expectArraysClose(await gathered.data(), [2, 1]);
});
it('should fail if dtype is not matched', () => {
expect(() => tensorArray.gather([0, 1], 'float32')).toThrow();
});
Expand Down

0 comments on commit 36f2863

Please sign in to comment.