Skip to content

Commit

Permalink
Resume unit testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Nov 13, 2017
1 parent 884ce5d commit e9082bb
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 37 deletions.
2 changes: 0 additions & 2 deletions paddle/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int64_t* label, const int N,
const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
int idx = i * D + label[i];
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/math/math_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void axpy<platform::GPUPlace, float>(const platform::DeviceContext& context,
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
n, alpha, x, 1, y, 1));
n, &alpha, x, 1, y, 1));
}

template <>
Expand All @@ -250,7 +250,7 @@ void axpy<platform::GPUPlace, double>(const platform::DeviceContext& context,
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
n, alpha, x, 1, y, 1));
n, &alpha, x, 1, y, 1));
}

template struct SetConstant<platform::GPUPlace, float>;
Expand All @@ -270,7 +270,7 @@ DEFINE_GPU_TRANS(6);

struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context,
framework::Tensor* tensor, float value)
framework::Tensor* tensor, float value)
: context_(context), tensor_(tensor), value_(value) {}

template <typename T>
Expand Down
2 changes: 0 additions & 2 deletions paddle/operators/sequence_conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ class SequenceConvKernel : public framework::OpKernel<T> {
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad);

context.device_context().Finish();
math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0));
context.device_context().Finish();
}
};

Expand Down
3 changes: 1 addition & 2 deletions python/paddle/v2/framework/tests/test_lstm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def test_check_grad(self):
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)


"""
class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
Expand Down Expand Up @@ -281,7 +280,7 @@ def set_argument(self):
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = False
"""


if __name__ == '__main__':
unittest.main()
57 changes: 29 additions & 28 deletions python/paddle/v2/framework/tests/test_seq_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_check_grad_padding_data(self):
max_relative_error=0.05,
no_grad_set=set(['X', 'Filter']))

def not_test_check_grad_Filter(self):
def test_check_grad_Filter(self):
self.check_grad(
['Filter'],
'Out',
Expand Down Expand Up @@ -165,33 +165,34 @@ def init_test_case(self):
self.output_represention = 8 # output feature size


#class TestSeqProjectCase1(TestSeqProject):
# def init_test_case(self):
# self.input_row = 11
# self.context_start = -1
# self.context_length = 3
# self.padding_trainable = True
# self.context_stride = 1
#
# self.input_size = [self.input_row, 23]
# self.lod = [[0, 4, 5, 8, self.input_row]]
# self.output_represention = 8 # output feature size
#
#
#class TestSeqProjectCase2(TestSeqProject):
# def init_test_case(self):
# self.input_row = 25
# self.context_start = 2
# self.context_length = 3
# self.padding_trainable = True
# self.context_stride = 1
#
# self.input_size = [self.input_row, 23]
# idx = range(self.input_size[0])
# del idx[0]
# self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
# [self.input_size[0]]]
# self.output_represention = 8 # output feature size
class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self):
self.input_row = 11
self.context_start = -1
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1

self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
self.output_represention = 8 # output feature size


class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self):
self.input_row = 25
self.context_start = 2
self.context_length = 3
self.padding_trainable = True
self.context_stride = 1

self.input_size = [self.input_row, 23]
idx = range(self.input_size[0])
del idx[0]
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
[self.input_size[0]]]
self.output_represention = 8 # output feature size


if __name__ == '__main__':
unittest.main()

0 comments on commit e9082bb

Please sign in to comment.