Skip to content

Commit

Permalink
added unpooling gradient class (#2169)
Browse files Browse the repository at this point in the history
  • Loading branch information
Okai Addy committed Feb 14, 2018
1 parent fcd53a7 commit f6c63ff
Showing 1 changed file with 82 additions and 7 deletions.
89 changes: 82 additions & 7 deletions tensorflow/core/kernels/unpooling_op.cc
Expand Up @@ -19,10 +19,10 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
namespace tensorflow {

template <typename Device, typename T>
struct LaunchMaxUnpool;
struct LaunchUnpool;

template <typename T>
struct LaunchMaxUnpool<CPUDevice,T>
struct LaunchUnpool<CPUDevice,T>
{
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> EigenMatrixMap;

Expand Down Expand Up @@ -64,16 +64,66 @@ struct LaunchMaxUnpool<CPUDevice,T>
Shard(workerThreads.num_threads, workerThreads.workers, batchSize, shardCost, shard);

if (!status) {
context->SetStatus(errors::Internal("Failed launching MaxUnpool on CPU"));
context->SetStatus(errors::Internal("Failed launching Unpool on CPU"));
}
}
};

template <typename Device, typename T>
struct MaxUnpoolOp : public OpKernel
struct LaunchUnpoolGradient;

template <typename T>
struct LaunchUnpoolGradient<CPUDevice,T>
{
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> EigenMatrixMap;

static void launch(tensorflow::OpKernelContext* context, const tensorflow::Tensor& unpooledGradient, const tensorflow::Tensor& indices, tensorflow::Tensor* pooledGradient)
{
bool status = true;

const tensorflow::DeviceBase::CpuWorkerThreads& workerThreads = *(context->device()->tensorflow_cpu_worker_threads());

auto shard = [&unpooledGradient, &indices, &pooledGradient](tensorflow::int64 start, tensorflow::int64 limit)
{
const tensorflow::int64 batchSize = tensorflow::GetTensorDim(unpooledGradient.shape(), tensorflow::FORMAT_NHWC, 'N');
const tensorflow::int64 numPooledPointsPerBatch = pooledGradient->NumElements()/batchSize;

{
auto pooledGradientFlat = pooledGradient->flat<T>();
auto unpooledGradientFlat = unpooledGradient.flat<T>();
auto indicesFlat = indices.flat<tensorflow::int64>();

const tensorflow::int64 pooledStart = start*numPooledPointsPerBatch;
const tensorflow::int64 pooledEnd = limit*numPooledPointsPerBatch;
EigenMatrixMap pooledGradientShard(pooledGradientFlat.data()+pooledStart, 1, pooledEnd-pooledStart);
pooledGradientShard.setConstant(T(0));

for (tensorflow::int64 batch=start; batch<limit; batch++) {
for (tensorflow::int64 batchPooledIndex=0; batchPooledIndex<numPooledPointsPerBatch; batchPooledIndex++) {
const tensorflow::int64 pooledIndex = batch*numPooledPointsPerBatch + batchPooledIndex;
CHECK(pooledIndex<batchSize*numPooledPointsPerBatch) << "pooled index out of range: " << pooledIndex << ">=" << batchSize*numPooledPointsPerBatch;
const tensorflow::int64 unpooledIndex = indicesFlat(pooledIndex);
pooledGradientFlat(pooledIndex) += unpooledGradientFlat(unpooledIndex);
}
}
}
};

const int batchSize = tensorflow::GetTensorDim(unpooledGradient.shape(), tensorflow::FORMAT_NHWC, 'N');
const tensorflow::int64 shardCost = unpooledGradient.shape().num_elements();
tensorflow::Shard(workerThreads.num_threads, workerThreads.workers, batchSize, shardCost, shard);

if (!status) {
context->SetStatus(tensorflow::errors::Internal("Failed launching Unpool on CPU"));
}
}
};

template <typename Device, typename T>
struct UnpoolOp : public OpKernel
{
public:
explicit MaxUnpoolOp(OpKernelConstruction* context) : OpKernel(context)
explicit UnpoolOp(OpKernelConstruction* context) : OpKernel(context)
{}

void Compute(OpKernelContext* context) override
Expand Down Expand Up @@ -113,12 +163,37 @@ struct MaxUnpoolOp : public OpKernel
Tensor* unpooledData = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, unpoolShape, &unpooledData));

LaunchMaxUnpool<Device,T>::launch(context, pooledData, indices, unpooledData);
LaunchUnpool<Device,T>::launch(context, pooledData, indices, unpooledData);
}
private:
std::vector<int32> m_unpoolShape;
};

REGISTER_KERNEL_BUILDER(Name("Unpool").Device(tensorflow::DEVICE_CPU), MaxUnpoolOp<CPUDevice, float>)
template <typename Device, typename T>
struct UnpoolGradientOp : public tensorflow::OpKernel
{
public:
explicit UnpoolGradientOp(tensorflow::OpKernelConstruction* context) :
tensorflow::OpKernel(context) {}

void Compute(tensorflow::OpKernelContext* context) override
{
const tensorflow::Tensor& unpooledGradient = context->input(0);
const tensorflow::Tensor& indices = context->input(1);

if (!context->status().ok()) {
return;
}

tensorflow::TensorShape pooledShape = indices.shape();
tensorflow::Tensor* pooledGradient = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, pooledShape, &pooledGradient));

LaunchUnpoolGradient<Device,T>::launch(context, unpooledGradient, indices, pooledGradient);
}
};

REGISTER_KERNEL_BUILDER(Name("Unpool").Device(tensorflow::DEVICE_CPU), UnpoolOp<CPUDevice, float>)
REGISTER_KERNEL_BUILDER(Name("UnpoolGradient").Device(tensorflow::DEVICE_CPU), UnpoolGradientOp<CPUDevice, float>)

}

0 comments on commit f6c63ff

Please sign in to comment.