Skip to content

Commit

Permalink
Added fast_binarize_weights_gpu()
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Nov 5, 2018
1 parent 2c5e383 commit 25f65f6
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 10 deletions.
63 changes: 58 additions & 5 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,80 @@ void binarize_input_gpu(float *input, int n, int size, float *binary)
check_error(cudaPeekAtLastError());
}


__global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary)
{
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (f >= n) return;
int i = 0;
float mean = 0;
for(i = 0; i < size; ++i){
for (i = 0; i < size; ++i) {
mean += fabs(weights[f*size + i]);
}
mean = mean / size;
for(i = 0; i < size; ++i){
for (i = 0; i < size; ++i) {
binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
//binary[f*size + i] = weights[f*size + i];
}
}

void binarize_weights_gpu(float *weights, int n, int size, float *binary)
{
binarize_weights_kernel<<<cuda_gridsize(n), BLOCK>>>(weights, n, size, binary);
binarize_weights_kernel << <cuda_gridsize(n), BLOCK >> >(weights, n, size, binary);
check_error(cudaPeekAtLastError());
}

#define WARP_SIZE 32

__global__ void set_zero_kernel(float *src, int size)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size) src[i] = 0;
}

__inline__ __device__
float warpAllReduceSum(float val) {
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
val += __shfl_xor(val, mask);
return val;
}

// only if (size % 32 == 0)
__global__ void reduce_kernel(float *weights, int n, int size, float *mean_arr_gpu)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int f = i / size;
if (f >= n) return;
float warp_mean = warpAllReduceSum(fabs(weights[i]));
if(i % 32 == 0)
atomicAdd(&mean_arr_gpu[f], warp_mean / size);
}

__global__ void binarize_weights_mean_kernel(float *weights, int n, int size, float *binary, float *mean_arr_gpu)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int f = i / size;
if (f >= n) return;
float mean = mean_arr_gpu[f];
binary[i] = (weights[i] > 0) ? mean : -mean;
}

void fast_binarize_weights_gpu(float *weights, int n, int size, float *binary, float *mean_arr_gpu)
{
if (size % 32 == 0) {
size_t gridsize = n * size;
const int num_blocks = gridsize / BLOCK + 1;

set_zero_kernel << <(n/BLOCK + 1), BLOCK >> > (mean_arr_gpu, n);
reduce_kernel << <num_blocks, BLOCK >> > (weights, n, size, mean_arr_gpu);
binarize_weights_mean_kernel << <num_blocks, BLOCK >> > (weights, n, size, binary, mean_arr_gpu);
check_error(cudaPeekAtLastError());
}
else {
binarize_weights_gpu(weights, n, size, binary);
}
}


__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -128,7 +179,9 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)

if(l.xnor){
if (!l.align_bit_weights_gpu || state.train) {
binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
//binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);

fast_binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu, l.mean_arr_gpu);
}
//swap_binary(&l);
//binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
Expand Down
8 changes: 6 additions & 2 deletions src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
int align = 32;// 8;
int src_align = l.out_h*l.out_w;
l.bit_align = src_align + (align - src_align % align);

l.mean_arr = calloc(l.n, sizeof(float));
}

if(batch_normalize){
Expand Down Expand Up @@ -369,6 +371,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
}
if(xnor){
l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.mean_arr_gpu = cuda_make_array(0, l.n);
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
}

Expand Down Expand Up @@ -628,7 +631,7 @@ void binary_align_weights(convolutional_layer *l)
}
float_to_bit(align_weights, l->align_bit_weights, align_weights_size);

l->mean_arr = calloc(l->n, sizeof(float));
//l->mean_arr = calloc(l->n, sizeof(float));
get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);

#ifdef GPU
Expand All @@ -646,7 +649,8 @@ void binary_align_weights(convolutional_layer *l)
status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k*sizeof(float), cudaMemcpyHostToDevice);
check_error(status);

l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
//l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n);
cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n);
cudaDeviceSynchronize();
#endif // GPU

Expand Down
32 changes: 32 additions & 0 deletions src/im2col_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,38 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int
int count = 0;
k = 0;

#ifdef NOT_USED
// 32 thread X 256 bit = 8192 bit
for (; k < (K - 8192); k += 8192) { // l.size*l.size*l.c - one filter size [27 - 9216]
ulonglong4 c_bit256;

//int64_t A_cur_index = (i*lda + k) / 8;
int64_t A_cur_index = (local_i*lda + k) / 8;
int64_t B_cur_index = (j*ldb + k) / 8;
if (i >= M) A_cur_index = 0;

#pragma unroll
for (int t = 0; t < WARP_SIZE; ++t) {
const int lane_id = threadIdx.x % WARP_SIZE;

const int64_t A_i = __shfl(A_cur_index, t) + 32 * lane_id;
const int64_t B_i = __shfl(B_cur_index, t) + 32 * lane_id;

{
//ulonglong4 a_bit256 = *((ulonglong4 *)(A + A_i)); // weights
ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + A_i)); // weights
ulonglong4 b_bit256 = *((ulonglong4 *)(B + B_i)); // input
c_bit256 = xnor_int256(a_bit256, b_bit256);
int tmp_count = __popcll(c_bit256.w) + __popcll(c_bit256.x) +
__popcll(c_bit256.y) + __popcll(c_bit256.z);

int sum_count = warpAllReduceSum(tmp_count);
if (lane_id == t) count += sum_count;
}
}
}
#endif

//#ifdef NOT_USED
// 32 thread X 64 bit = 2048 bit
for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216]
Expand Down
6 changes: 3 additions & 3 deletions src/network.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,10 @@ void calculate_binary_weights(network net)
//if (l->size*l->size*l->c >= 2048) l->lda_align = 512;

binary_align_weights(l);
}

if (net.layers[j].use_bin_output) {
l->activation = LINEAR;
if (net.layers[j].use_bin_output) {
l->activation = LINEAR;
}
}
}
}
Expand Down

2 comments on commit 25f65f6

@niroussel
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using __shfl_xor, __shfl and __shfl_down, you've made fermi GPUs deprecated: maybe you should add it to your README.

@AlexeyAB
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@niroussel Yes, it is here: https://github.com/AlexeyAB/darknet#requires

Yolo requires CC >= 3.0, while Fermi is CC 2.x.

Requires:
Linux GCC>=4.9 or Windows MS Visual Studio 2015 (v140): https://go.microsoft.com/fwlink/?LinkId=532606&clcid=0x409 (or offline ISO image)
CUDA 9.1: https://developer.nvidia.com/cuda-91-download-archive
OpenCV 3.3.0: https://sourceforge.net/projects/opencvlibrary/files/opencv-win/3.3.0/opencv-3.3.0-vc14.exe/download
or OpenCV 2.4.13: https://sourceforge.net/projects/opencvlibrary/files/opencv-win/2.4.13/opencv-2.4.13.2-vc14.exe/download
OpenCV allows to show image or video detection in the window and store result to file that specified in command line -out_filename res.avi
GPU with CC >= 3.0: https://en.wikipedia.org/wiki/CUDA#GPUs_supported

Also cuDNN doesn't support Fermi too.

Please sign in to comment.