Skip to content

Commit

Permalink
Add GivenTensorInt64Fill on gpu
Browse files Browse the repository at this point in the history
Summary: Before we fix it properly with 'type' argument.

Reviewed By: bddppq

Differential Revision: D6103973

fbshipit-source-id: 8c00a93c373dd0ad0bbfe59944495f6574223ab6
  • Loading branch information
Dmytro Dzhulgakov authored and facebook-github-bot committed Oct 20, 2017
1 parent a78008f commit 40210af
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions caffe2/operators/given_tensor_fill_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ NO_GRADIENT(GivenTensorDoubleFill);
NO_GRADIENT(GivenTensorBoolFill);
NO_GRADIENT(GivenTensorIntFill);
NO_GRADIENT(GivenTensorInt64Fill);
NO_GRADIENT(GivenTensorStringFill);

OPERATOR_SCHEMA(GivenTensorFill)
.NumInputs(0, 1)
Expand Down
3 changes: 3 additions & 0 deletions caffe2/operators/given_tensor_fill_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ REGISTER_CUDA_OPERATOR(
GivenTensorDoubleFill,
GivenTensorFillOp<double, CUDAContext>);
REGISTER_CUDA_OPERATOR(GivenTensorIntFill, GivenTensorFillOp<int, CUDAContext>);
REGISTER_CUDA_OPERATOR(
GivenTensorInt64Fill,
GivenTensorFillOp<int64_t, CUDAContext>);
REGISTER_CUDA_OPERATOR(
GivenTensorBoolFill,
GivenTensorFillOp<bool, CUDAContext>);
Expand Down
3 changes: 2 additions & 1 deletion caffe2/python/operator_test/given_tensor_fill_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class TestGivenTensorFillOps(hu.HypothesisTestCase):
t=st.sampled_from([
(core.DataType.FLOAT, np.float32, "GivenTensorFill"),
(core.DataType.INT32, np.int32, "GivenTensorIntFill"),
(core.DataType.INT64, np.int64, "GivenTensorInt64Fill"),
(core.DataType.BOOL, np.bool_, "GivenTensorBoolFill"),
(core.DataType.DOUBLE, np.double, "GivenTensorDoubleFill"),
]),
**hu.gcs_cpu_only)
**hu.gcs)
def test_given_tensor_fill(self, X, t, gc, dc):
X = X.astype(t[1])
print('X: ', str(X))
Expand Down

0 comments on commit 40210af

Please sign in to comment.