Skip to content

Commit

Permalink
Rotate conv-kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Dec 14, 2019
1 parent a08c872 commit 5e03556
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 16 deletions.
3 changes: 3 additions & 0 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ struct layer {
int truth;
float smooth;
float dot;
int deform;
int sway;
int rotate;
int stretch;
float angle;
float jitter;
float saturation;
Expand Down
3 changes: 2 additions & 1 deletion src/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_size,

void sam_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out);

void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse);
void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse);
void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups);
void expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups);

Expand Down
90 changes: 87 additions & 3 deletions src/blas_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ extern "C" void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_si
}


__global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
__global__ void smooth_rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
Expand Down Expand Up @@ -1296,12 +1296,12 @@ __global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weigh
}


extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
extern "C" void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
smooth_rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);

CHECK_CUDA(cudaPeekAtLastError());
}
Expand Down Expand Up @@ -1396,6 +1396,90 @@ extern "C" void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *we
}







__global__ void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int reverse)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
const int kernel_area = kernel_size * kernel_size;
const int i = index * kernel_area;

const int stage_step = (nweights / kernel_area) / 4; // 4 stages
const int stage_id = index / stage_step;

// nweights = (c / groups) * n * size * size;
// kernel_area = size*size

if (i < nweights)
{
// if(reverse)

if (stage_id == 0) {
// simple copy
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = x + y*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 1)
{
// 90 degree clockwise rotation - 1
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = (kernel_size - 1 - y) + x*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 2)
{
// 180 degree clockwise rotation - 2
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = (kernel_size - 1 - x) + (kernel_size - 1 - y)*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
else if (stage_id == 3)
{
// 270 degree clockwise rotation - 3
for (int x = 0; x < kernel_size; ++x) {
for (int y = 0; y < kernel_size; ++y) {
const int src_i = x + y*kernel_size + i;
const int dst_i = y + (kernel_size - 1 - x)*kernel_size + i;
if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
}
}
}
}
}


extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse)
{
const int kernel_area = size*size;
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, reverse);

CHECK_CUDA(cudaPeekAtLastError());
}



__global__ void reduce_and_expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
Expand Down
16 changes: 9 additions & 7 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1192,20 +1192,21 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
/*
for (int angle = 0; angle < 360; angle++) {
printf(" angle = %d \n", angle);
sway_and_flip_weights_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, angle, 0);
smooth_rotate_weights_kernel(l.weights_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, angle, 0);
cuda_pull_array(l.weight_deform_gpu, l.weights, l.nweights);
visualize_convolutional_layer(l, "weights", NULL);
wait_key_cv(10);
}
*/

if (l.sway) {
if (l.deform) {

//for (l.angle = 0; l.angle < 360; l.angle++)
//{

sway_and_flip_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle, 1);
if (l.rotate) rotate_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, 1);
else if (l.sway) sway_and_flip_weights_gpu(l.weight_updates_gpu, l.weight_deform_gpu, l.nweights, l.n, l.size, l.angle, 1);

//simple_copy_ongpu(l.nweights, l.weight_updates_gpu, l.weight_deform_gpu);

Expand Down Expand Up @@ -1254,16 +1255,17 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
}
}

if (l.sway) {
//for (l.angle = 0; l.angle < 360; l.angle += 4)
if (l.deform) {
//for (l.angle = 0; l.angle < 50; l.angle += 0.1)
//{
expand_array_gpu(l.weights_gpu, l.weight_deform_gpu, l.nweights, 4);

//simple_copy_ongpu(l.nweights, l.weight_deform_gpu, l.weights_gpu);

sway_and_flip_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, l.angle, 0);
if (l.rotate) rotate_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, 0);
else if (l.sway) sway_and_flip_weights_gpu(l.weight_deform_gpu, l.weights_gpu, l.nweights, l.n, l.size, l.angle, 0);

//printf(" angle = %f \n", l.angle);
//printf(" angle = %f, reverse = %d \n", l.angle, 0);
//cuda_pull_array(l.weights_gpu, l.weights, l.nweights);
//visualize_convolutional_layer(l, "weights", NULL);
//wait_key_cv(10);
Expand Down
6 changes: 3 additions & 3 deletions src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ void free_convolutional_batchnorm(convolutional_layer *l)
}
}

convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int sway, int train)
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int deform, int train)
{
int total_batch = batch*steps;
int i;
Expand All @@ -388,7 +388,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
stride_x = stride_y = l.stride = l.stride_x = l.stride_y = 1; // use stride=1 in host-layer
}

l.sway = sway;
l.deform = deform;
l.assisted_excitation = assisted_excitation;
l.share_layer = share_layer;
l.index = index;
Expand Down Expand Up @@ -543,7 +543,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w,
l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*l.outputs);
}

if (l.sway) l.weight_deform_gpu = cuda_make_array(NULL, l.nweights);
if (l.deform) l.weight_deform_gpu = cuda_make_array(NULL, l.nweights);

if (adam) {
l.m_gpu = cuda_make_array(l.m, l.nweights);
Expand Down
2 changes: 1 addition & 1 deletion src/convolutional_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16);
void free_convolutional_batchnorm(convolutional_layer *l);

size_t get_convolutional_workspace_size(layer l);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int sway, int train);
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int deform, int train);
void denormalize_convolutional_layer(convolutional_layer l);
void set_specified_workspace_limit(convolutional_layer *l, size_t workspace_size_limit);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
Expand Down
12 changes: 11 additions & 1 deletion src/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,20 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int xnor = option_find_int_quiet(options, "xnor", 0);
int use_bin_output = option_find_int_quiet(options, "bin_output", 0);
int sway = option_find_int_quiet(options, "sway", 0);
int rotate = option_find_int_quiet(options, "rotate", 0);
int stretch = option_find_int_quiet(options, "stretch", 0);
if ((sway + rotate + stretch) > 1) {
printf(" Error: should be used only 1 param: sway=1, rotate=1 or stretch=1 in the [convolutional] layer \n");
exit(0);
}
int deform = sway || rotate || stretch;

convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation, sway, params.train);
convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride_x,stride_y,dilation,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index, antialiasing, share_layer, assisted_excitation, deform, params.train);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
layer.sway = sway;
layer.rotate = rotate;
layer.stretch = stretch;
layer.angle = option_find_float_quiet(options, "angle", 15);

if(params.net.adam){
Expand Down

0 comments on commit 5e03556

Please sign in to comment.