Skip to content

Commit

Permalink
Remove use of chained ops. (tensorflow#3058)
Browse files Browse the repository at this point in the history
DEV Remove use of chained ops.
  • Loading branch information
lina128 committed Apr 27, 2020
1 parent acfa7f7 commit 3a58e2f
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 40 deletions.
6 changes: 4 additions & 2 deletions tfjs-core/src/gradients/Div_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import {Div} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {reshape} from '../ops/array_ops';
import {mul} from '../ops/binary_ops';
import * as broadcast_util from '../ops/broadcast_util';
import {div} from '../ops/div';
import {sum} from '../ops/reduction_ops';
Expand All @@ -40,10 +42,10 @@ export const divGradConfig: GradConfig = {
return res;
};
const derB = () => {
let res = dy.mul(a.toFloat());
let res = mul(dy, a.toFloat());
const reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum(res, reduceAxes).reshape(b.shape);
res = reshape(sum(res, reduceAxes), b.shape);
}
const tmp = square(b);
return neg(div(res, tmp.toFloat()));
Expand Down
21 changes: 10 additions & 11 deletions tfjs-core/src/gradients/FusedBatchNorm_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,32 @@
*/
import {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {xAs4D} from '../ops/batchnorm_util';
import {add} from '../ops/add';
import {reshape} from '../ops/array_ops';
import {mul} from '../ops/binary_ops';
import {getReductionAxes} from '../ops/broadcast_util';
import {add, mul, reshape, sub} from '../ops/ops';
import {sum} from '../ops/reduction_ops';
import {sub} from '../ops/sub';
import {scalar} from '../ops/tensor_ops';
import {tile} from '../ops/tile';
import {rsqrt} from '../ops/unary_ops';
import {Tensor, Tensor4D} from '../tensor';
import {Tensor} from '../tensor';
import {Rank, ShapeMap} from '../types';

export const fusedBatchNormGradConfig: GradConfig = {
kernelName: FusedBatchNorm,
inputsToSave: ['x', 'mean', 'variance', 'scale'],
gradFunc: <R extends Rank>(
dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const batchNormalizationAttrs: FusedBatchNormAttrs =
attrs as {} as FusedBatchNormAttrs;
const {varianceEpsilon} = batchNormalizationAttrs;
const {varianceEpsilon} = attrs as {} as FusedBatchNormAttrs;
const [x, mean, variance, scale] = saved;

const x4D: Tensor4D = xAs4D(x);

const scaleValue = scale == null ? scalar(1) : scale;
const reductionAxes = getReductionAxes(mean.shape, x4D.shape);
const reductionAxes = getReductionAxes(mean.shape, x.shape);
const tileShape: number[] = [];
if (mean.rank === 1) {
for (let i = 0; i < x4D.shape.length - 1; ++i) {
tileShape.push(x4D.shape[i]);
for (let i = 0; i < x.shape.length - 1; ++i) {
tileShape.push(x.shape[i]);
}
tileShape.push(1);
}
Expand Down Expand Up @@ -100,6 +98,7 @@ export const fusedBatchNormGradConfig: GradConfig = {
}
return reshape(offsetDer, mean.shape as ShapeMap[R]);
};

return {
x: derX,
mean: derMean,
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/gradients/Square_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

import {Square} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {mul} from '../ops/binary_ops';
import {Tensor} from '../tensor';

export const squareGradConfig: GradConfig = {
kernelName: Square,
inputsToSave: ['x'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [x] = saved;
return {x: () => dy.mul(x.toFloat().mul(2))};
return {x: () => mul(dy, mul(x.toFloat(), 2))};
}
};
32 changes: 20 additions & 12 deletions tfjs-core/src/gradients/Tile_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import {Tile, TileAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {add} from '../ops/add';
import {slice} from '../ops/slice';
import {zerosLike} from '../ops/tensor_ops';
import {Tensor} from '../tensor';

Expand All @@ -33,22 +35,25 @@ export const tileGradConfig: GradConfig = {
// slicing.
if (x.rank === 1) {
for (let i = 0; i < reps[0]; ++i) {
xGrad = xGrad.add(dy.slice([i * x.shape[0]], [x.shape[0]]));
xGrad = add(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]]));
}
} else if (x.rank === 2) {
for (let i = 0; i < reps[0]; ++i) {
for (let j = 0; j < reps[1]; ++j) {
xGrad = xGrad.add(dy.slice(
[i * x.shape[0], j * x.shape[1]], [x.shape[0], x.shape[1]]));
xGrad = add(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [
x.shape[0], x.shape[1]
]));
}
}
} else if (x.rank === 3) {
for (let i = 0; i < reps[0]; ++i) {
for (let j = 0; j < reps[1]; ++j) {
for (let k = 0; k < reps[2]; ++k) {
xGrad = xGrad.add(dy.slice(
[i * x.shape[0], j * x.shape[1], k * x.shape[2]],
[x.shape[0], x.shape[1], x.shape[2]]));
xGrad =
add(xGrad,
slice(
dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]],
[x.shape[0], x.shape[1], x.shape[2]]));
}
}
}
Expand All @@ -57,12 +62,15 @@ export const tileGradConfig: GradConfig = {
for (let j = 0; j < reps[1]; ++j) {
for (let k = 0; k < reps[2]; ++k) {
for (let l = 0; l < reps[3]; ++l) {
xGrad = xGrad.add(dy.slice(
[
i * x.shape[0], j * x.shape[1], k * x.shape[2],
l * x.shape[3]
],
[x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
xGrad =
add(xGrad,
slice(
dy,
[
i * x.shape[0], j * x.shape[1], k * x.shape[2],
l * x.shape[3]
],
[x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
}
}
}
Expand Down
22 changes: 13 additions & 9 deletions tfjs-core/src/ops/batchnorm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env';
import {Rank, TensorLike} from '../types';
import * as util from '../util';

import {reshape} from './array_ops';
import {warnDeprecation, xAs4D} from './batchnorm_util';
import {op} from './operation';

Expand Down Expand Up @@ -99,28 +100,31 @@ function batchNorm_<R extends Rank>(
() => 'Batch normalization gradient requires mean and scale to have ' +
'equal ranks.');

const x4D: Tensor4D = xAs4D($x);

const forward: ForwardFunc<Tensor> = (backend, save) => {
const x4D: Tensor4D = xAs4D($x);
save([x4D, $mean, $variance, $scale]);

const res = backend.batchNormalization(
return backend.batchNormalization(
x4D, as1DOr4D($mean), as1DOr4D($variance), varianceEpsilon,
as1DOr4D($scale), as1DOr4D($offset));

save([$x, $mean, $variance, $scale]);

return res;
};

const inputs: FusedBatchNormInputs =
{x: $x, scale: $scale, offset: $offset, mean: $mean, variance: $variance};
const inputs: FusedBatchNormInputs = {
x: x4D,
scale: $scale,
offset: $offset,
mean: $mean,
variance: $variance
};

const attrs: FusedBatchNormAttrs = {varianceEpsilon};

const res = ENGINE.runKernelFunc(
forward, inputs as {} as NamedTensorMap, null /* gradient */,
FusedBatchNorm, attrs as {} as NamedAttrMap);

return res.reshape($x.shape);
return reshape(res, $x.shape);
}

function as1DOr4D(x: Tensor): Tensor4D|Tensor1D {
Expand Down
6 changes: 4 additions & 2 deletions tfjs-core/src/ops/broadcast_to.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, ShapeMap, TensorLike} from '../types';

import {reshape} from './array_ops';
import {clone} from './clone';
import {op} from './operation';

/**
Expand Down Expand Up @@ -58,7 +60,7 @@ function broadcastTo_<R extends Rank>(
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = input.reshape(newShape);
input = reshape(input, newShape);
}

const inputShape = input.shape;
Expand All @@ -74,7 +76,7 @@ function broadcastTo_<R extends Rank>(
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);

if (axes.length === 0) {
return input.clone() as Tensor<R>;
return clone(input) as Tensor<R>;
}

const forward = (backend: KernelBackend) => backend.tile(input, reps);
Expand Down
8 changes: 5 additions & 3 deletions tfjs-core/src/ops/one_hot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';

import {reshape} from './array_ops';
import {op} from './operation';

/**
Expand Down Expand Up @@ -55,16 +56,17 @@ function oneHot_(

const forward: ForwardFunc<Tensor> = (backend, save) => {
save([$indices]);
return backend.oneHot($indices as Tensor1D, depth, onValue, offValue);
return reshape(
backend.oneHot($indices as Tensor1D, depth, onValue, offValue),
outShape);
};

const inputs: OneHotInputs = {indices: $indices};
const attrs: OneHotAttrs = {depth, onValue, offValue};

const result = ENGINE.runKernelFunc(
return ENGINE.runKernelFunc(
forward, inputs as unknown as NamedTensorMap, null /* grad */, OneHot,
attrs as unknown as NamedAttrMap);
return result.reshape(outShape);
}

export const oneHot = op({oneHot_});

0 comments on commit 3a58e2f

Please sign in to comment.