In [1]:
import pyopencl as cl
import pyopencl.array as cl_array
import numpy as np
import numpy.linalg as la

In [2]:
%load_ext pyopencl.ipython_ext

In [3]:
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
mf = cl.mem_flags

In [4]:
a = 2.0 * (np.random.rand(2,3,4).astype(np.float32) - 0.5)

In [5]:
input_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = a)
output_gpu = cl.Buffer(ctx, mf.WRITE_ONLY, a.nbytes)
channel_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(a.shape[0]))
height_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(a.shape[1]))
width_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(a.shape[2]))

In [6]:
%%cl_kernel -o "-cl-fast-relaxed-math"

__kernel void ReluD1(__global const float *ift, __global float *oft)
{
    int i = get_global_id(0);
    oft[i] = max((float)0, ift[i]);
}

In [7]:
%%cl_kernel -o "-cl-fast-relaxed-math"

__kernel void ReluD3(__global const float *ift, __global float *oft,
                  __global int *channel, __global int *height, __global int *width)
{
    int c = *channel;
    int h = *height;
    int w = *width;
    int posc = get_global_id(0);
    int posh = get_global_id(1);
    int posw = get_global_id(2);
    int i = posc*(w*h) + (posh*w+posw);
    oft[i] = max((float)0, ift[i]);
}

In [8]:
ReluD3(queue, a.shape, None, input_gpu, output_gpu, channel_gpu, height_gpu, width_gpu)

<pyopencl._cl.Event at 0x21198a61708>

In [9]:
result = np.empty_like(a)
cl.enqueue_copy(queue, result, output_gpu)

<pyopencl._cl.NannyEvent at 0x21198d13048>

In [10]:
result

array([[[0.56101716, 0.3585919 , 0.78431845, 0.7043483 ],
        [0.09047246, 0.        , 0.        , 0.        ],
        [0.        , 0.49617136, 0.98590934, 0.        ]],

       [[0.        , 0.        , 0.1869905 , 0.05707383],
        [0.8965759 , 0.38249564, 0.        , 0.11432385],
        [0.60342634, 0.        , 0.7510967 , 0.55352986]]], dtype=float32)

In [11]:
a

array([[[ 0.56101716,  0.3585919 ,  0.78431845,  0.7043483 ],
        [ 0.09047246, -0.2773316 , -0.15283114, -0.869961  ],
        [-0.72097254,  0.49617136,  0.98590934, -0.49990082]],

       [[-0.13152266, -0.7272972 ,  0.1869905 ,  0.05707383],
        [ 0.8965759 ,  0.38249564, -0.30758226,  0.11432385],
        [ 0.60342634, -0.80447567,  0.7510967 ,  0.55352986]]],
      dtype=float32)