Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FillOp #3505

Closed
wants to merge 1 commit into from
Closed

Add FillOp #3505

wants to merge 1 commit into from

Conversation

reyoung
Copy link
Collaborator

@reyoung reyoung commented Aug 15, 2017

  • Fill Op will fill a tensor with specific shape and data every time
    when Run is invoked except run_once is True.

* Fill Op will fill a tensor with specific shape and data every time
  when Run is invoked except `run_once` is True.
platform::GPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need a sync Copy here. Copy work on a specific cuda stream too. If we really want to sync the copy:

Copy(dts_place, dst, src_place, src, num, stream_);
cudaStreamSynchronize(stream_);

At now, we only have default stream(and I am fixing it in #3497 ), and you can pass 0 as cuda stream at now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very strange that if we invoke some copy method in memory.h, it will trigger link error while compiling.

It is hard to debug if the developer is not familiar with C++, template, and memory.{h/cc}.

So, we should implement the Copy correctly in memory.{h/cc}. It is developer's choice to add a stream or not.

.SetDefault(false)
.InEnum({true, false});
AddAttr<std::vector<int>>("shape", "The shape of fill parameter");
AddAttr<std::vector<T>>("data", "The data will be filled");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have a look at #2917
There are mainly two kinds of ways to load data. The first way is load from vector or numpy. The second way is generated by paddle itself.
Will we have another method like FeedVariable(caffe2 have FeedBlob)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fill_op is part of topology and it does not conflict with FeedVariable.

Think a situation, the minus operator's gradient, are combined operators, they are

  • An Identify or Copy operator.
  • A Fill operator to fill a scalar as -1 and a Scale operator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not about load data. It is about designing topology.

namespace paddle {
namespace operators {
template <typename T>
class FillOpKernelBase : public framework::OpKernel {
Copy link
Member

@QiJune QiJune Aug 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the base class FillOpKernelBase is a little complex, just implementing data fill in FillOpGPUKernel and FillOpCPUKernel directly will be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are common lines of code, shared between CPU/GPU kernels. Make a BaseClass will let the code shared.

@luotao1
Copy link
Contributor

luotao1 commented Feb 1, 2019

Close due to fill_op has already been done.

@luotao1 luotao1 closed this Feb 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants