In [1]:
import pyopencl as cl
import pyopencl.array as cl_array
import numpy as np
import numpy.linalg as la
import math
import torch
import torch.nn as nn
import torch.functional as F

In [2]:
%load_ext pyopencl.ipython_ext

In [3]:
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
mf = cl.mem_flags

In [4]:
def maxpool(input_numpy,size=2,stride=2):
    C, Hi, Wi = input_numpy.shape
    Ho, Wo = int(np.floor(Hi/stride)),int(np.floor(Wi/stride))
    row_remainder,col_remainder = Hi%stride, Wi%stride
    Ho += int(row_remainder!=0)
    Wo += int(col_remainder!=0)
    temp_map = np.zeros((C, Hi+size-row_remainder, Wi+size-col_remainder))
    temp_map[:, :Hi, :Wi] = input_numpy
    out = np.zeros((C,Ho,Wo))
    for c in range(C):
        for h in range(Ho):
            for w in range(Wo):
                startX, startY = w*stride, h*stride
                out[c,h,w] = temp_map[c,startY,startX]
                for y in range(size):
                    for x in range(size):
                        out[c,h,w] = max(out[c,h,w],temp_map[c,startY+y,startX+x])                
    return  out

In [5]:
input_numpy = torch.randn(6,28,28)

size = 2
stride = 2
C, Hi, Wi = input_numpy.shape
Ho, Wo = int(np.floor(Hi/stride)),int(np.floor(Wi/stride))
row_remainder,col_remainder = Hi%stride, Wi%stride
Ho += int(row_remainder!=0)
Wo += int(col_remainder!=0)
input_cpu = np.zeros((C, Hi+size-row_remainder, Wi+size-col_remainder)).astype(np.float32)
input_cpu[:, :Hi, :Wi] = input_numpy

C, Hi, Wi = input_cpu.shape

output_cpu = np.zeros((C,Ho,Wo)).astype(np.float32)

In [6]:
input_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = input_cpu)

size_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(size))
stride_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(stride))

channel_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(C))
input_height_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(Hi))
input_width_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(Wi))
output_height_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(Ho))
output_width_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(Wo))

output_gpu = cl.Buffer(ctx, mf.WRITE_ONLY, output_cpu.nbytes)

In [7]:
%%cl_kernel -o "-cl-fast-relaxed-math"

__kernel void MaxPool2D(__global const float *ift, 
                        __global int *size, __global int *stride,
                        __global int *channel, 
                        __global int *input_height, __global int *input_width, 
                        __global int *output_height, __global int *output_width, 
                        __global float *oft)
{
    int sz = *size, sd = *stride;
    int C = *channel, Hi = *input_height, Wi = *input_width, Ho = *output_height, Wo = *output_width;
    int posc = get_global_id(0), posh = get_global_id(1), posw = get_global_id(2);
    
    int So = Ho*Wo, Si = Hi*Wi;
    int i = (posc*(So))+(posh*Wo)+(posw);
    int startX = posw*sd, startY = posh*sd;
    
    oft[i] = ift[(posc*(Si))+(startY*Wi)+startX];
    for(int y = 0; y < sz; y++) {
        for(int x = 0; x < sz; x++) {
            oft[i] = max(oft[i], ift[(posc*(Si))+((startY+y)*Wi)+(startX+x)]);
        }
    }
}

In [8]:
MaxPool2D(queue, output_cpu.shape, None, 
           input_gpu, 
           size_gpu, stride_gpu,
           channel_gpu,
           input_height_gpu, input_width_gpu,
           output_height_gpu, output_width_gpu,
           output_gpu)

<pyopencl._cl.Event at 0x290baadc888>

In [9]:
cl.enqueue_copy(queue, output_cpu, output_gpu)

<pyopencl._cl.NannyEvent at 0x290baadc8e8>

In [10]:
output_cpu[0,0,:]

array([ 2.1125185,  1.1281296,  1.2305954, -0.0323182,  0.422409 ,
        1.6023225,  1.677707 ,  1.041097 ,  1.7334143,  1.4964299,
        1.0928481,  2.2433877,  1.7289656,  0.5298232], dtype=float32)

In [11]:
np_res = maxpool(input_numpy,size=2,stride=2)

In [12]:
np_res[0,0,:]

array([ 2.11251855,  1.1281296 ,  1.23059535, -0.0323182 ,  0.422409  ,
        1.60232246,  1.67770696,  1.04109704,  1.73341429,  1.49642992,
        1.09284806,  2.2433877 ,  1.72896564,  0.52982318])