From 31df1a92f99ca3b4c7f0868ae29d6a874c3729b9 Mon Sep 17 00:00:00 2001 From: AngusG Date: Fri, 21 Jul 2017 16:13:39 -0400 Subject: [PATCH] tweak benchmarking scripts --- smoke_test.py | 79 ++++++++++++++++++++++++++---------- src/concatenate_kernel.cu.cc | 4 ++ src/gemm_op.cc | 16 +++++++- src/xnor_gemm_kernel.cu.cc | 2 + 4 files changed, 79 insertions(+), 22 deletions(-) diff --git a/smoke_test.py b/smoke_test.py index 98d83e4..51ae086 100644 --- a/smoke_test.py +++ b/smoke_test.py @@ -1,38 +1,75 @@ import time +import numpy as np import tensorflow as tf N = 8192 gemm_module = tf.load_op_library('./libs/gemm_op.so') -A = tf.placeholder(tf.float32, [N, N]) -B = tf.placeholder(tf.float32, [N, N]) +sess = tf.InteractiveSession() + +a = tf.cast( + 2 * (tf.random_normal(shape=[N, N], seed=1).eval() > 0) - 1, tf.float32) + +#b = tf.cast( +# 2 * (tf.random_normal(shape=[N, N], seed=2).eval() > 0) - 1, tf.float32) + +N_RUNS=5 +xnor_timings = np.zeros(N_RUNS) +base_timings = np.zeros(N_RUNS) + +for i in range(N_RUNS): + start_time = time.time() + gemm_module.gemm(a, a).eval() + xnor_timings[i] = time.time() - start_time + print("xnor_gemm %d took %f" % (i, xnor_timings[i])) +print("Avg XNOR kernel execution time over %d runs: %f +/- %f") % ((N_RUNS, xnor_timings.mean(), xnor_timings.std())) + +for i in range(N_RUNS): + start_time = time.time() + print(tf.matmul(a, a).eval()) + base_timings[i] = time.time() - start_time + print("matmul %d took %f" % (i,base_timings[i])) +print("Avg MatMul execution time over %d runs: %f +/- %f") % ((N_RUNS, base_timings.mean(), base_timings.std())) + + +''' +#A = tf.placeholder(tf.float32, [N, N]) +#B = tf.placeholder(tf.float32, [N, N]) + +# For benchmarking on GPU w/only 4GB memory a = 2 * tf.cast(tf.random_normal(shape=[N, N], seed=1) > 0, tf.float32) - 1 -b = 2 * tf.cast(tf.random_normal(shape=[N, N], seed=2) > 0, tf.float32) - 1 -xnor_gemm = gemm_module.gemm(A, B) -matmul = tf.matmul(a, b) +N_RUNS = 5 +xnor_timings = np.zeros(N_RUNS) +base_timings = np.zeros(N_RUNS) with tf.Session() as sess: a_f32 = sess.run(a) - b_f32 = sess.run(b) + #b_f32 = sess.run(b) - ########### benchmark xnor ############ - start_time = time.time() - xnor_gemm_result = sess.run(xnor_gemm, feed_dict={A: a_f32, B: b_f32}) - xnor_gemm_time = time.time() - start_time + for i in range(N_RUNS): + ########### benchmark xnor ############ + start_time = time.time() + #xnor_gemm_result = sess.run(xnor_gemm, feed_dict={A: a_f32, B: b_f32}) + xnor_gemm_result = sess.run(gemm_module.gemm(a_f32, a_f32)) + xnor_timings[i] = time.time() - start_time - print("xnor_gemm took %f" % xnor_time) - print(xnor_gemm_result) - ####################################### + print("xnor_gemm %d took %f" % (i, xnor_timings[i])) + print(xnor_gemm_result) + ####################################### + print("Avg XNOR kernel execution time over %d runs: %f +/- %f" % (N_RUNS, xnor_timings.mean(), xnor_timings.std())) + for i in range(N_RUNS): + ########### benchmark matmul ########## + start_time = time.time() + #matmul_result = sess.run(matmul, feed_dict={A: a_f32, B: b_f32}) + matmul_result = sess.run(tf.matmul(a_f32, a_f32)) + base_timings[i] = time.time() - start_time - ########### benchmark matmul ########## - start_time = time.time() - matmul_result = sess.run(matmul, feed_dict={A: a_f32, B: b_f32}) - matmul_time = time.time() - start_time - - print("matmul took %f" % tf_time) - print(matmul_result) - ####################################### + print("matmul %d took %f" % (i, base_timings[i])) + print(matmul_result) + ####################################### + print("Avg MatMul execution time over %d runs: %f +/- %f" % (N_RUNS, base_timings.mean(), base_timings.std())) +''' \ No newline at end of file diff --git a/src/concatenate_kernel.cu.cc b/src/concatenate_kernel.cu.cc index 9ce3202..53be65b 100644 --- a/src/concatenate_kernel.cu.cc +++ b/src/concatenate_kernel.cu.cc @@ -83,7 +83,9 @@ __global__ void deconcatenate_rows_kernel(int *a, float *b, int size) template struct ConcatenateRowsFunctor { void operator()(const GPUDevice& d, const float* fA, int* Aconc, const int N) { +#ifdef DEBUG printf("\n\nConcatenateRowsFunctor\n\n"); +#endif int block = BLOCK_SIZE * 4, grid = N * N / (block * 32) + 1; concatenate_rows_kernel <<>>(fA, Aconc, N * N / 32); @@ -93,7 +95,9 @@ struct ConcatenateRowsFunctor { template struct ConcatenateColsFunctor { void operator()(const GPUDevice& d, const float* fB, int* Bconc, const int N) { +#ifdef DEBUG printf("\n\nConcatenateColsFunctor\n\n"); +#endif int block = BLOCK_SIZE * 4; int grid = N / block + 1; concatenate_cols_kernel diff --git a/src/gemm_op.cc b/src/gemm_op.cc index b864f4a..365c199 100644 --- a/src/gemm_op.cc +++ b/src/gemm_op.cc @@ -1,4 +1,5 @@ // gemm_op.cc +//#define DEBUG #define EIGEN_USE_THREADS #include @@ -152,15 +153,18 @@ class XnorGemmOp : public OpKernel { Status allocate_temp(DataType type, const TensorShape& shape, Tensor* out_temp); */ - + #ifdef DEBUG printf("\n\nXnorGemmOp -- allocated output\n\n"); + #endif Tensor Aconc;// = nullptr; Tensor Bconc;// = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, out_shape, &Aconc)); OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, out_shape, &Bconc)); + #ifdef DEBUG printf("\n\nXnorGemmOp -- allocated temp\n\n"); + #endif if (out->NumElements() == 0) { // If a has shape [0, x] or b has shape [x, 0], the output shape @@ -187,7 +191,9 @@ class XnorGemmOp : public OpKernel { const int32 k = a.dim_size(dim_pair[0].first); const int32 n = b.dim_size(1 - dim_pair[0].second); + #ifdef DEBUG printf("\n\nXnorGemmOp -- created m,n,k\n\n"); + #endif auto a_flat = a.flat().data(); auto b_flat = b.flat().data(); @@ -195,7 +201,9 @@ class XnorGemmOp : public OpKernel { auto Bconc_flat = Bconc.flat().data(); auto c_flat = out->flat().data(); + #ifdef DEBUG printf("\n\nXnorGemmOp -- created a_flat, Aconc_flat\n\n"); + #endif #if 1 ConcatenateRowsFunctor()( @@ -203,8 +211,10 @@ class XnorGemmOp : public OpKernel { a_flat, Aconc_flat, m); + #ifdef DEBUG printf("\n\nXnorGemmOp -- ran ConcatenateRowsFunctor\n\n"); #endif + #endif #if 1 ConcatenateColsFunctor()( @@ -212,8 +222,10 @@ class XnorGemmOp : public OpKernel { b_flat, Bconc_flat, m); + #ifdef DEBUG printf("\n\nXnorGemmOp -- ran ConcatenateColsFunctor\n\n"); #endif + #endif #if 1 XnorGemmFunctor()( @@ -224,8 +236,10 @@ class XnorGemmOp : public OpKernel { m, n, k); + #ifdef DEBUG printf("\n\nXnorGemmOp -- ran XnorGemmFunctor\n\n"); #endif + #endif #if 0 /* For testing base kernel */ XnorGemmFunctor()( diff --git a/src/xnor_gemm_kernel.cu.cc b/src/xnor_gemm_kernel.cu.cc index 5c4daef..0573157 100644 --- a/src/xnor_gemm_kernel.cu.cc +++ b/src/xnor_gemm_kernel.cu.cc @@ -93,7 +93,9 @@ struct XnorGemmFunctor { // // See core/util/cuda_kernel_helper.h for example of computing // block count and thread_per_block count. +#ifdef DEBUG printf("\n\nInt32 input -- using XnorGemmFunctor\n\n"); +#endif /* int block_count = BLOCK_SIZE; int thread_per_block = 512;