Skip to content

Commit

Permalink
Added strided elewise ops (logistic, tanh).
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Nov 25, 2015
1 parent 2d03d14 commit 5eec5fe
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 0 deletions.
157 changes: 157 additions & 0 deletions brainstorm/handlers/pycuda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,49 @@ def avgpool2d_backward_batch(self, inputs, window, outputs, padding,
in_deltas,
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))

def slice_copy_strided(self, inputs, outputs, slice_shape_inputs, slice_shape_outputs):

_slice_copy_impl(inputs, outputs,
slice_shape_inputs,slice_shape_outputs, np.int32(len(slice_shape_inputs)/4),
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))


def handle_shape(self, shape):
shape = [np.int32(dim) for dim in shape]
while len(shape) < 5:
shape = [np.int32(1)] + shape
return shape

def strided_elementwise_inplace(self, inputs, idx,func):
shape = self.handle_shape(inputs.shape)

if func not in strided_inp_funcs:
raise Exception("Strided function not supported. \
Supported functions are: {0}"
.format(strided_inp_funcs.keys()))


strided_inp_funcs[func](inputs, np.int32(idx),
shape[0],shape[1], shape[2], shape[3], shape[4],
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))


def strided_elementwise(self, inputs, outputs, stride,func):
shape = self.handle_shape(inputs.shape)

if func not in strided_funcs:
raise Exception("Strided function not supported. \
Supported functions are: {0}"
.format(strided_inp_funcs.keys()))


strided_funcs[func](inputs, outputs, np.int32(stride),
shape[0],shape[1], shape[2], shape[3], shape[4],
block=(NUM_CUDA_THREADS, 1, 1),
grid=(get_blocks(inputs.size), 1))

def avgpool2d_forward_batch(self, inputs, window, outputs, padding,
stride):
Expand Down Expand Up @@ -877,3 +920,117 @@ def tanh_deriv(self, x, y, dy, dx):
"""
_mod_avepool_bwd_fp32 = SourceModule(__avepool_bwd_fp32_kernel)
_avepool_bwd_fp32_impl = _mod_avepool_bwd_fp32.get_function("ave_pool_bwd")


__strided_elewise_inp = """
__global__ void strided_elementwise_inp_{0}(float *in,
int matrixIndex, int dim1, int dim2, int dim3, int dim4, int dim5)
{{
int index = blockIdx.x * blockDim.x + threadIdx.x;
int idx = 0;
for(int i = index; i < dim1*dim2*dim3*dim4; i+=blockDim.x * gridDim.x)
{{
idx = (dim1*dim2*dim3*dim4*matrixIndex) + i;
in[idx] = {1};
}}
}}
"""
__strided_logistic_inp = __strided_elewise_inp.format(
"logistic","1./(1.+__expf(in[idx]))")
_mod_strided_elewise_kernel_logistic = SourceModule(__strided_logistic_inp)
_strided_elewise_inp_logistic = _mod_strided_elewise_kernel_logistic.get_function(
"strided_elementwise_inp_logistic")

__strided_tanh_inp = __strided_elewise_inp.format("tanh","tanh(in[idx])")
_mod_strided_ele_kernel_tanh_inp = SourceModule(__strided_tanh_inp)
_strided_elewise_inp_tanh = _mod_strided_ele_kernel_tanh_inp.get_function(
"strided_elementwise_inp_tanh")

strided_inp_funcs = {}
strided_inp_funcs['logistic'] = _strided_elewise_inp_logistic
strided_inp_funcs['tanh'] = _strided_elewise_inp_tanh


__strided_elewise = """
__global__ void strided_elementwise_{0}(float *in, float *out,
int matrixIndex, int dim1, int dim2, int dim3, int dim4, int dim5)
{{
int index = blockIdx.x * blockDim.x + threadIdx.x;
int idx = 0;
for(int d5 = 0; d5 < dim5; d5++)
{{
for(int i = index; i < dim1*dim2*dim3*dim4; i+=blockDim.x * gridDim.x)
{{
idx = (dim1*dim2*dim3*dim4*d5) + i;
if(d5 == matrixIndex){{ out[idx] = {1}; }}
else{{ out[idx] = in[idx]; }}
}}
}}
}}
"""
__strided_logistic = __strided_elewise.format(
"logistic","1./(1.+__expf(in[idx]))")
_mod_strided_elewise_kernel_logistic = SourceModule(__strided_logistic)
_strided_elewise_logistic = _mod_strided_elewise_kernel_logistic.get_function(
"strided_elementwise_logistic")

__strided_tanh = __strided_elewise.format("tanh","tanh(in[idx])")
_mod_strided_ele_kernel_tanh = SourceModule(__strided_tanh)
_strided_elewise_tanh = _mod_strided_ele_kernel_tanh.get_function(
"strided_elementwise_tanh")

strided_funcs = {}
strided_funcs['logistic'] = _strided_elewise_logistic
strided_funcs['tanh'] = _strided_elewise_tanh

__slice_copy_kernel = """
__global__ void slice_copy(float *in, float *out,
float *from_shape, float *to_shape, int shapes)
{
int in_start = 0;
int in_length = 0;//segment length
int in_segments= 0;
int in_stride = 0;
int in_current_segment = 0;
int in_slice_idx = 0;
int out_start = 0;
int out_length = 0;
int out_segments = 0;
int out_stride = 0;
int out_current_segment = 0;
int out_slice_idx = 0;
for(int shape = 0; shape < shapes; shape++)
{
in_start = (int)from_shape[shape*4];
in_length = (int)from_shape[shape*4+1];
in_segments = (int)from_shape[shape*4+2];
in_stride = (int)from_shape[shape*4+3];
out_start = (int)to_shape[shape*4];
out_length = (int)to_shape[shape*4+1];
out_segments = (int)from_shape[shape*4+2];
out_stride = (int)to_shape[shape*4+3];
for(int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
idx < in_length*in_segments ; idx+=blockDim.x * gridDim.x)
{
in_current_segment = ((idx)/in_length);
out_current_segment = ((idx)/out_length);
in_slice_idx = in_start + idx + (in_current_segment*in_stride);
out_slice_idx = out_start + idx + (out_current_segment*out_stride);
out[out_slice_idx] = in[in_slice_idx];
}
}
}
"""
_mod_slice_copy_kernel = SourceModule(__slice_copy_kernel)
_slice_copy_impl = _mod_slice_copy_kernel.get_function("slice_copy")
132 changes: 132 additions & 0 deletions brainstorm/tests/test_handler_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from brainstorm.handlers import NumpyHandler
from brainstorm.optional import has_pycuda

import pycuda

# np.random.seed(1234)
dtype = np.float32
NO_CON = set()



def _conv2d_forward_batch(inputs, weights, bias, outputs, padding, stride):
"""
Loop-based implementation of 2D convolution to check against.
Expand Down Expand Up @@ -167,3 +170,132 @@ def test_conv2d_forward_batch_pycuda():
print("Expected:\n", true_outputs)
print("Obtained:\n", outputs)
assert passed


@pytest.mark.skipif(has_pycuda is False, reason='requires PyCUDA+scikit-cuda')
def test_strided_elementwise():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
rdm = np.random.RandomState(1345)

def get_rdm_array(shape, dims):
if dims == 2: return rdm.randn(shape[0],shape[1])
elif dims == 3: return rdm.randn(shape[0],shape[1], shape[2])
else: return rdm.randn(shape[0],shape[1], shape[2], shape[3])

for dims in range(2,5):
for i in range(10):
shape = rdm.randint(1,17,dims)
a1 = np.float32(get_rdm_array(shape, dims))
a2 = np.float32(get_rdm_array(shape, dims))
a3 = np.float32(get_rdm_array(shape, dims))
a = np.vstack([a1,a2,a3])
original_shape = a.shape
a = a.reshape([int(original_shape[0]/3)] + list(original_shape[1:])+[3])
b = np.zeros_like(a, dtype=np.float32)
A = _h.create_from_numpy(a)


idx = rdm.randint(0,2)
func = ['logistic', 'tanh'][idx]

_h.strided_elementwise_inplace(A, 1,func)
outputs = _h.get_numpy_copy(A).reshape(original_shape)

c1 = a1
c2 = 1./(1.+np.exp(a2)) if idx == 0 else np.tanh(a2)
c3 = a3
c = np.vstack([c1,c2,c3])

passed = np.allclose(outputs, c)
assert passed

def test_strided_elementwise_inplace():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
rdm = np.random.RandomState(1345)

def get_rdm_array(shape, dims):
if dims == 2: return rdm.randn(shape[0],shape[1])
elif dims == 3: return rdm.randn(shape[0],shape[1], shape[2])
else: return rdm.randn(shape[0],shape[1], shape[2], shape[3])

for dims in range(2,5):
for i in range(10):
shape = rdm.randint(1,17,dims)
a1 = np.float32(get_rdm_array(shape, dims))
a2 = np.float32(get_rdm_array(shape, dims))
a3 = np.float32(get_rdm_array(shape, dims))
a = np.vstack([a1,a2,a3])
original_shape = a.shape
a = a.reshape([int(original_shape[0]/3)] + list(original_shape[1:])+[3])
b = np.zeros_like(a, dtype=np.float32)
A = _h.create_from_numpy(a)

_h.strided_elementwise_inplace(A, 1,'logistic')
_h.strided_elementwise_inplace(A, 0,'tanh')
outputs = _h.get_numpy_copy(A).reshape(original_shape)

c1 = np.tanh(a1)
c2 = 1./(1.+np.exp(a2))
c3 = a3
c = np.vstack([c1,c2,c3])

passed = np.allclose(outputs, c)
assert passed



'''
@pytest.mark.skipif(has_pycuda is False, reason='requires PyCUDA+scikit-cuda')
def test_slice_copy_stride():
from brainstorm.handlers import PyCudaHandler
_h = PyCudaHandler()
#2 dim test
a = np.float32(np.random.rand(10,10))
start = 4
length = 2
segments = 3
stride = 1
slices = [start, length, segments, stride]
data = []
for seg in range(segments):
row = np.int32(start/a.shape[1])
offset = start - (row*a.shape[0])
data += a[row,offset + (length*seg) + (seg*stride):offset + (length*seg) + (seg*stride) + length].tolist()
s = np.array(data, dtype=np.float32)
A = _h.create_from_numpy(a)
S = _h.create_from_numpy(np.zeros_like(s,dtype=np.float32))
slices_A =_h.create_from_numpy(np.array(slices,dtype=np.float32))
slices_B =_h.create_from_numpy(np.array([0,length*segments,1,0],dtype=np.float32))
_h.slice_copy_strided(A,S, slices_A, slices_B)
outputs = _h.get_numpy_copy(S)
passed = np.allclose(outputs, s)
assert passed
#3 dim test
a = np.float32(np.random.rand(10,10,10))
start = 50
length = 6
segments = 4
stride = 5
slices = [start, length, segments, stride]
data = []
for seg in range(segments):
row = np.int32(start/(a.shape[1]*a.shape[2]))
col = np.int32(start/(a.shape[1]))
offset = start - (row*(a.shape[1]*a.shape[2])) - (col*(a.shape[1]))
data += a[row,col, offset + (length*seg) + (seg*stride):offset + (length*seg) + (seg*stride) + length].tolist()
s = np.array(data, dtype=np.float32)
A = _h.create_from_numpy(a)
S = _h.create_from_numpy(np.zeros_like(s,dtype=np.float32))
slices_A =_h.create_from_numpy(np.array(slices,dtype=np.float32))
slices_B =_h.create_from_numpy(np.array([0,length*segments,1,0],dtype=np.float32))
_h.slice_copy_strided(A,S, slices_A, slices_B)
outputs = _h.get_numpy_copy(S)
passed = np.allclose(outputs, s)
assert passed
'''

0 comments on commit 5eec5fe

Please sign in to comment.