Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP][MXNET-107] Fused LSTM implementation for CPU #10104

Merged
merged 39 commits into from May 14, 2018

Conversation

chenchu-zs
Copy link
Contributor

@chenchu-zs chenchu-zs commented Mar 14, 2018

Description

In this PR, a fused LSTM operator for CPU is implemented. More supports for other RNN variants are WIP and will be submitted in other PRs.

Feature changes

New features

  • Fused LSTM implemention, including both forward and backward computation.
  • Share the same frontend interfaces with current sym.RNN operator
  • Share the same algorithm and input layout with current sym.RNN operator
  • Refactor code and register it with NNVM interfaces
  • Support both FP32 and FP64 inputs, more types will be supported later, such as int8
  • Provide more extensible APIs for other RNN variants(vanilla RNN/GRU)

Unit-test changes

  • Create new test test_lstm and test_lstm_bidirectionalin tests/python/unittests/test_operator.py
  • Check consistency with original LSTMCell implementation

Performance

We have tested performance of sym.RNN and rnn.LSTMCell on our local Skylake-8180 with 2 Sockets and 56 cores. Use MKL as blas lib in this performance test.
Test input size is from DS2 default parameters(seq_length = 300, batch_size = 20, input_size = 800, hidden_size = 800).

single layer measurement:

API Inference time(fwd, sec) Training time(fwd + bwd, sec)
rnn.LSTMCell 0.106902 0.273108
#9977 0.050126 ---
this PR 0.050668 0.130266
speedup 2.1x 2.1x

multi layer measurement: num_layers=5

API Inference time(fwd, sec) Training time(fwd + bwd, sec)
rnn.LSTMCell 0.532034 1.546486
sym.RNN(#9977) 0.18641 ---
sym.RNN(this PR) 0.190032 0.619439
rnn.LSTMCell(cuda) 0.231355 0.785780
sym.RNN(cudnn) 0.060647 0.161115
speedup
#10104 /LSTMCell
285.41% 249.66%
speedup
#10104 /LSTMCell(cuda)
124.09% 126.85%
speedup
#10104 / sym.RNN(cudnn)
32.53% 26.01%

Opens

  • Fix cudnn registeration in this PR
  • Add multi-layer and bidirectional support for LSTM.
  • Support gluon interfaces
  • fix NNVM registration
  • Other RNN variants (will be added in in other PRs)
  • Add dropout support(in other PRs)

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@TaoLv
Copy link
Member

TaoLv commented Mar 14, 2018

@szha @Jerryzcn @eric-haibin-lin @pengzhao-intel Could you help to review this PR? Need cooperation to refactor cudnn registration.

size_t size = 0;
switch (mode) {
case rnn_enum::kRnnRelu:
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need error message for unimplemented modes.

size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size;
break;
case rnn_enum::kGru:
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add default statement for code robustnesss.

w_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
break;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need error message for unimplemented modes and default statement for switch-case.

@@ -19,40 +19,214 @@

/*!
* Copyright (c) 2015 by Contributors
* \file rnn.cc
* \file rnn.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this change

* \brief
* \author Sebastian Bodenstein
* \author Sebastian Bodenstein, Shu Zhang(shu.zhang@intel.com)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove whitespace

}
}
static inline int NumVisibleOutputs(const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix indents in this function.

MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe("Applies a recurrent layer to input.")
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix indent. Align with the first parameter.


inline static bool BackwardRNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix indent. Align with the first parameter.

@TaoLv
Copy link
Member

TaoLv commented Mar 14, 2018

Seems collapse clause in omp parallel for is not suppoted on Windows.

wh = mx.random.uniform(-1, 1, (4 * H, H), ctx=xpu,dtype=type1)
bx = mx.nd.zeros((4 * H,), ctx=xpu, dtype=type1)
bh = mx.nd.zeros((4 * H,), ctx=xpu, dtype=type1)
x1.attach_grad()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to manually attach grad??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attach_grad is used to create gradient buffer for these NDArrays here. Do you mean this can be implemented in other ways or do you have any suggestion about this piece of code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case use stateful OP, what's your opinion @eric-haibin-lin ?

@@ -1540,6 +1548,7 @@ def check_rnn_layer_w_rand_inputs(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)

@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it failing ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If USE_CUDNN=1, I think this test will run into cudnn implementation which has been disabled temporarily. We will reopen this test case after we add cudnn back. In fact, building on gpu is failed currently. We are working on the failure.

};

NNVM_REGISTER_OP(RNN)
.describe(R"code(Applies a recurrent layer to input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide more detailed descriptions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will do.

DType* reserve_space_ptr = out_data[out_expected - 1].dptr<DType>();

// allocate temp space
size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const

Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
.get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);

int direction = param_.bidirectional ? 2 : 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const

@Jerryzcn
Copy link
Contributor

I used

for (int64_t ji = 0; ji < length; ++ji) {
      int64_t j = ji / h_channel;  // batch dim
      int64_t i = ji % h_channel;

to replace collapse

@TaoLv
Copy link
Member

TaoLv commented Mar 15, 2018

@Jerryzcn Good suggestion. I will take a try.

@TaoLv
Copy link
Member

TaoLv commented Mar 15, 2018

BTW, is there any existing jira issue for RNN implementation? Do I need to create a jira issue for this PR? @eric-haibin-lin @Jerryzcn @szha

@eric-haibin-lin
Copy link
Member

pls create one

@chenchu-zs chenchu-zs changed the title [WIP] Fused RNN implementation for CPU [WIP][MXNET-107] Fused RNN implementation for CPU Mar 15, 2018
@Jerryzcn
Copy link
Contributor

should we create a separate branch for cpu rnn. Once all the changes are checked in, we merge the rnn branch with the master. This way the master won't break people's code.

@piiswrong
Copy link
Contributor

@Jerryzcn good idea. @eric-haibin-lin please open an branch

@szha szha changed the base branch from master to cpu_fused_rnn March 16, 2018 07:14
@szha
Copy link
Member

szha commented Mar 16, 2018

Happen to be around on github. I created the branch cpu_fused_rnn and updated PR base.

@TaoLv
Copy link
Member

TaoLv commented Mar 16, 2018

Thanks @szha, will keep working on this.

@pengzhao-intel
Copy link
Contributor

@sherry-zhang Good Job!
Please update this PR's description for multiple layers and bidirectional function 👍

@chenchu-zs chenchu-zs force-pushed the rnn_refactor branch 2 times, most recently from 1ee9ee3 to dde8e23 Compare March 22, 2018 02:40
@TaoLv
Copy link
Member

TaoLv commented Mar 22, 2018

@marcoabreu I am working on branch cpu_fused_rnn, but CI fails in sanity check. I doubt that CI environment has been adjusted for master branch, so this branch cannot work properly. Could you help take a look? Thanks.

pylint check is passed on my local server but fails in snanity check:

Makefile:479: recipe for target 'pylint' failed
make: *** [pylint] Error 22
build.py: 2018-03-22 04:55:51,746 Running of command in container failed: docker run --rm -v /home/jenkins_slave/workspace/sanity:/work/mxnet -v /home/jenkins_slave/workspace/sanity/build:/work/build -u 1001:1001 mxnet/build.ubuntu_cpu /work/runtime_functions.sh sanity_check

build.py: 2018-03-22 04:55:51,746 You can try to get into the container by using the following command: docker run --rm -v /home/jenkins_slave/workspace/sanity:/work/mxnet -v /home/jenkins_slave/workspace/sanity/build:/work/build -u 1001:1001 -ti --entrypoint bash mxnet/build.ubuntu_cpu /work/runtime_functions.sh sanity_check

Traceback (most recent call last):
  File "ci/build.py", line 179, in <module>
    sys.exit(main())
  File "ci/build.py", line 159, in main
    container_run(platform, docker_binary, command)
  File "ci/build.py", line 110, in container_run
    raise subprocess.CalledProcessError(ret, cmd)
subprocess.CalledProcessError: Command 'docker run --rm -v /home/jenkins_slave/workspace/sanity:/work/mxnet -v /home/jenkins_slave/workspace/sanity/build:/work/build -u 1001:1001 mxnet/build.ubuntu_cpu /work/runtime_functions.sh sanity_check' returned non-zero exit status 2

script returned exit code 1

@piiswrong
Copy link
Contributor

Looks like this is not going to make it into 1.2
can we revert the lstm forward part that's already merged into master so that we don't ship half baked feature?

@szha
Copy link
Member

szha commented Mar 22, 2018

@piiswrong the merged RNN feature supports inference-only LSTM that is compatible with cudnn implementation. Gluon LSTM layer now supports inference-only forwarding with this feature, and the rest of the use cases are still on old code paths, thanks to @Jerryzcn. The merged PR does what it sets out to do better than what previously exists, so it's more than half baked.

@szha
Copy link
Member

szha commented Mar 22, 2018

cc @zhiheng-huang as his team will likely be impacted by the decision of reverting Jerry's PR.

@marcoabreu
Copy link
Contributor

marcoabreu commented Mar 22, 2018

For now, please fork the master branch in your own repository and let collaborators make PRs towards your repository. At the same time, create a PR from your fork towards the master branch to have constant feedback every time a commit to your fork is being made.

@TaoLv
Copy link
Member

TaoLv commented Mar 22, 2018

@szha @piiswrong May I have your opinions? I don't have permission to create/delete branchs and redirect this PR to master branch. I can rebase code to master branch if needed.

@marcoabreu marcoabreu changed the base branch from cpu_fused_rnn to master March 22, 2018 14:06
@marcoabreu
Copy link
Contributor

I have changed the base branch as requested. We currently have an internal discussion about whether we support feature-branches in the official repository, until then, it would be better to work towards the master to ensure your PR is always receiving the latest updates. I have also retriggered CI.

@TaoLv
Copy link
Member

TaoLv commented Mar 22, 2018

Thanks, @marcoabreu! Really understand your concern. Will keep working on this PR.

@marcoabreu
Copy link
Contributor

Thanks a lot! Please excuse the inconvenience - in case of further problems, feel free to ping me again.

@TaoLv
Copy link
Member

TaoLv commented May 8, 2018

I feel it difficult to change the existing gluon LSTM layer from normal Block to HybridBlock without changing APIs.
(1) I need concatenate the exsiting i2h_weight, h2h_weight, i2h_bias and h2h_bias together to feed them into the fused operator. I think that is time consuming. link
(2) I cannot create begin_state if it's not presented in the hybrid_forward function, since I cannot get the shape and batch size here in a HybridBlock. link
Maybe I missed something. Any cues about it? @szha @piiswrong

@piiswrong
Copy link
Contributor

@TaoLv I'll look into this later. I think you can do it similar to RNNCell.
Let's merge the backend LSTM implementation first. Is there enough test?

out = exe.forward(is_train=False)
out[0].wait_to_read()
assert False # should not reach here
except mx.base.MXNetError as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent approach! This will ensure we don't miss it to re-enable the test when we introduce dropout. Great job

Copy link
Member

@TaoLv TaoLv May 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Also to ensure the failure happens at a proper position and correct error message is presented. Follow @reminisce 's idea in PR 10844 .

@TaoLv
Copy link
Member

TaoLv commented May 9, 2018

@piiswrong I added a test for dropout. I think this lstm operator is good to merge. Dropout support and hybrid rnn layer are WIP and will be submitted in another PR. I will also rebase #10311 accordingly.
If you have any idea or design conception of hybrid rnn layer, please let me know.

@szha szha added this to In progress in gluon.rnn improvements May 9, 2018
@szha szha changed the title [WIP][MXNET-107] Fused RNN implementation for CPU [WIP][MXNET-107] Fused LSTM implementation for CPU May 9, 2018
@@ -0,0 +1,454 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use hpp. Please rename to .h

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for (int i = 0; i < T; ++i) {
int t = bid ? T - 1 - i : i;
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
#pragma omp parallel for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
}
}
memcpy(y_ptr, rs + y_offset, T * N * H * D * sizeof(DType));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is copy needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One copy is for forward output and the other copy is for the reuse in backward computation.

for (int i = 0; i < T; ++i) {
int t = bid ? T - 1 - i : i;
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
#pragma omp parallel for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx;
const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx;
const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx;
#pragma omp parallel for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

const int row = T * N;
const int col = H * 4;
for (int i = 0; i < row; ++i) {
#pragma omp parallel for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

omp usage may not be efficient here. Operations in this loop is very simple while col usually is less than a few thousands

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I will remove this omp temporarily and look for better optimization for this piece of code.

const DType beta1 = 1.0;
const int cell_size = N * H;
if (dhy_ptr != NULL) {
memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is copies needed?

data = mx.sym.Variable('data')

Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True)
mod1 = mx.mod.Module(Y1, label_names=None, context=mx.cpu())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use default_context() here and remove the corresponding tests in test_operator_gpu. These tests will automatically be run again in test_operator_gpu with default_context() = gpu()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Also I changed the name of test_lstm to test_lstm_sym since it would confict with that in /unittest/test_gluon_rnn.py after imported to test_operator_gpu.py.

@pengzhao-intel
Copy link
Contributor

@eric-haibin-lin @piiswrong @szha @Jerryzcn the comments are solved. Please help take a review again.
After this PR is merged, we can rebase GRU PR and add dropout to LSTM/GRU soon.

DType ft = ifgo[i][j][k][1];
DType gt = ifgo[i][j][k][2];
DType ot = ifgo[i][j][k][3];
dh[j][k] += dy[t][j][k + offset];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dh and dc is never read before they are overwritten. Why do you need the copy at line 341?

}
}
}
memcpy(y_ptr, rs + y_offset, T * N * H * D * sizeof(DType));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not write to y_ptr directly for the last layer?

}
Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dwh is overwritten. why do you need to set it to 0 with memset at 328?

@piiswrong
Copy link
Contributor

I can merge this first. But I think the memset and memcpy statements are superfluous. We should get ride of them later

@szha

@piiswrong piiswrong merged commit 275378a into apache:master May 14, 2018
gluon.rnn improvements automation moved this from In progress to Done May 14, 2018
@szha szha mentioned this pull request May 15, 2018
4 tasks
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request May 29, 2018
* register RNN fused-API with nnvm, finish single-layer && undirection LSTM forward function

* fix coding style and lint complains

* add single-layer && undirectional LSTM backward function

* make interface universal for other RNN mode

* share intermediate result between forward and backward in a trick way

* add comments for important parameters

* modify testcase

* Fix coding style and error message

* fix openmp collapse error

* fix const

* remove rnn.cu and skip related testcases temporarily for building on GPU

* support multi-layer and bidirectional for lstm inference

* remove some testcaseS in test_gluon_rnn.py to build on GPU

* remove testcase between fp32 and fp64 temporarily

* retrigger ci

* fix some logs

* use a better way to share memory

* fix cudnn registration

* fix invariant calculations and enable some gpu testcases

* add thread local cache for cudnn rnn op

* add thread local cache for rnn op

* fix bugs

* remove some testcases to check segmentfault

* remove cudnn registeration to check segmentfault

* support multi-layer for LSTM Training

* modify lstm testcase

* add bidirectional support for lstm

* fix gluon and coding style

* fix bugs

* remove nnvm registration

* enable gpu testcases

* add detailed descriptions

* add dropout check

* fix workspace size

* dropout is not supported, add unit test for it

* fix review comments
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* register RNN fused-API with nnvm, finish single-layer && undirection LSTM forward function

* fix coding style and lint complains

* add single-layer && undirectional LSTM backward function

* make interface universal for other RNN mode

* share intermediate result between forward and backward in a trick way

* add comments for important parameters

* modify testcase

* Fix coding style and error message

* fix openmp collapse error

* fix const

* remove rnn.cu and skip related testcases temporarily for building on GPU

* support multi-layer and bidirectional for lstm inference

* remove some testcaseS in test_gluon_rnn.py to build on GPU

* remove testcase between fp32 and fp64 temporarily

* retrigger ci

* fix some logs

* use a better way to share memory

* fix cudnn registration

* fix invariant calculations and enable some gpu testcases

* add thread local cache for cudnn rnn op

* add thread local cache for rnn op

* fix bugs

* remove some testcases to check segmentfault

* remove cudnn registeration to check segmentfault

* support multi-layer for LSTM Training

* modify lstm testcase

* add bidirectional support for lstm

* fix gluon and coding style

* fix bugs

* remove nnvm registration

* enable gpu testcases

* add detailed descriptions

* add dropout check

* fix workspace size

* dropout is not supported, add unit test for it

* fix review comments
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* register RNN fused-API with nnvm, finish single-layer && undirection LSTM forward function

* fix coding style and lint complains

* add single-layer && undirectional LSTM backward function

* make interface universal for other RNN mode

* share intermediate result between forward and backward in a trick way

* add comments for important parameters

* modify testcase

* Fix coding style and error message

* fix openmp collapse error

* fix const

* remove rnn.cu and skip related testcases temporarily for building on GPU

* support multi-layer and bidirectional for lstm inference

* remove some testcaseS in test_gluon_rnn.py to build on GPU

* remove testcase between fp32 and fp64 temporarily

* retrigger ci

* fix some logs

* use a better way to share memory

* fix cudnn registration

* fix invariant calculations and enable some gpu testcases

* add thread local cache for cudnn rnn op

* add thread local cache for rnn op

* fix bugs

* remove some testcases to check segmentfault

* remove cudnn registeration to check segmentfault

* support multi-layer for LSTM Training

* modify lstm testcase

* add bidirectional support for lstm

* fix gluon and coding style

* fix bugs

* remove nnvm registration

* enable gpu testcases

* add detailed descriptions

* add dropout check

* fix workspace size

* dropout is not supported, add unit test for it

* fix review comments
@Roshrini Roshrini mentioned this pull request Aug 8, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet