In [2]:
import pyopencl as cl
import pyopencl.array as cl_array
import numpy as np
import numpy.linalg as la
import math
import torch
import torch.nn as nn
import torch.functional as F

In [3]:
%load_ext pyopencl.ipython_ext

In [4]:
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
mf = cl.mem_flags

In [5]:
class LeNet(nn.Module):
    def __init__(self, num_classes):
        super(LeNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1 ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fclayer = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,num_classes)
        ) 
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x=x.view(-1, 16*5*5)
        x=self.fclayer(x)
        return x

model = LeNet(10)
model.load_state_dict(torch.load('model.pth'))
layer1_Conv2d = model.layer1[0]

In [6]:
layer1_Conv2d.weight.shape,layer1_Conv2d.bias.shape

(torch.Size([6, 1, 5, 5]), torch.Size([6]))

In [7]:
def conv2d(input_numpy, kernel_weight_numpy, kernel_bias_numpy, padding = 0):
    B, Ci, Hi, Wi = input_numpy.shape
    input_pad_numpy = torch.zeros(B, Ci, Hi+2*padding, Wi+2*padding)
    if padding > 0:
        input_pad_numpy[:, :, padding:-padding, padding:-padding] = input_numpy
    else:
        input_pad_numpy = input_numpy
    B, Ci, Hi, Wi = input_pad_numpy.shape
    Co, Ci, Hf, Wf = kernel_weight_numpy.shape
    Ho, Wo = Hi - Hf + 1, Wi - Wf + 1
    # conv2d weight 7 loop
    out = np.zeros((B,Co,Ho,Wo))
    for b in range(B):
        for i in range(Ho):
            for j in range(Wo):
                for k in range(Co):
                    for l in range(Hf):
                        for m in range(Wf):
                            for n in range(Ci):
                                out[b,k,i,j] += input_pad_numpy[b,n,i+l,j+m]*kernel_weight_numpy[k,n,l,m]
    for b in range(B):
        for i in range(Ho):
            for j in range(Wo):
                for k in range(Co):
                    out[b,k,i,j] += kernel_bias_numpy[k]
    return out

In [17]:
input_ = torch.randn(1,1,28,28)

input = 2.0 * (np.random.rand(4,4,4).astype(np.float32) - 0.5)
weight = layer1_BatchNorm2d.weight.detach().numpy()
bias = layer1_BatchNorm2d.bias.detach().numpy()

In [18]:
input_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = input)

mean_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = mean)
std_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = std)

eps_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(eps))
weight_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = weight)
bias_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = bias)

channel_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(input.shape[0]))
height_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(input.shape[1]))
width_gpu = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf = np.int32(input.shape[2]))

output_gpu = cl.Buffer(ctx, mf.WRITE_ONLY, input.nbytes)

In [67]:
%%cl_kernel -o "-cl-fast-relaxed-math"

__kernel void BatchNorm2D(__global const float *ift, 
        __global float *mean, __global float *std,
        __global float *eps, __global float *weight, __global float *bias,
        __global int *channel, __global int *height, __global int *width,
        __global float *oft)
{
    int c = *channel, h = *height, w = *width;
    int posc = get_global_id(0), posh = get_global_id(1), posw = get_global_id(2);
    int i = posc*(w*h) + (posh*w+posw);
    float e = *eps;
    float res1 = ift[i] - mean[posc];
    float res2 = sqrt(std[posc] + e);
    float res3 = res1 / res2;
    oft[i] = res3 * weight[posc] + bias[posc];
}

In [68]:
BatchNorm2D(queue, input.shape, None, input_gpu, mean_gpu, std_gpu, eps_gpu, weight_gpu, bias_gpu, channel_gpu, height_gpu, width_gpu, output_gpu)

<pyopencl._cl.Event at 0x2c6a0bce3a8>

In [69]:
result = np.empty_like(input)
cl.enqueue_copy(queue, result, output_gpu)

<pyopencl._cl.NannyEvent at 0x2c6a0bce8e8>

In [70]:
result

array([[[ 0.12930298, -0.54422164, -0.5404496 ,  0.56641155],
        [ 0.32734716, -0.04603638, -0.27638122,  0.12862907],
        [ 0.20284243, -0.24211188, -0.14477941, -0.3558718 ],
        [ 0.48549154, -0.09393339, -0.05353037, -0.36649898]],

       [[ 1.0770097 , -0.81170666,  0.5653544 , -1.1986618 ],
        [-0.01809002, -0.12676686,  0.05970845, -0.7814994 ],
        [ 1.0996418 , -0.02657328, -1.268249  , -1.7871226 ],
        [ 1.3546574 ,  0.05242227, -0.02483744, -0.40442872]],

       [[ 0.2779464 , -1.314091  , -0.18237348, -1.2827958 ],
        [-1.0967008 , -0.7878097 ,  0.36083177, -0.4661757 ],
        [-0.12390527,  0.06764246, -0.55066997, -1.2130908 ],
        [-0.5317401 ,  0.2721272 ,  0.1377413 ,  0.6034245 ]],

       [[-0.53256226, -0.6765078 ,  0.5455077 , -1.0082785 ],
        [ 0.945174  , -0.27320042, -0.51388526,  1.3549521 ],
        [ 0.37321118,  0.32849908,  0.76314795, -0.9714877 ],
        [-0.8785274 ,  0.19135866,  1.6811198 , -0.9413318 ]]],


In [71]:
batchnorm2d(input,eps,weight,bias)

array([[[ 0.12929948, -0.5442121 , -0.54044013,  0.56639959],
        [ 0.32733981, -0.04603649, -0.27637688,  0.12862558],
        [ 0.2028375 , -0.24210819, -0.14477761, -0.35586592],
        [ 0.48548119, -0.09393256, -0.05353033, -0.36649286]],

       [[ 1.07698299, -0.81169194,  0.56533896, -1.19863864],
        [-0.01809269, -0.12676716,  0.05970407, -0.78148535],
        [ 1.09961462, -0.02657576, -1.26822428, -1.78708647],
        [ 1.35462465,  0.05241806, -0.02483996, -0.40442291]],

       [[ 0.27793692, -1.31407694, -0.18237617, -1.28278217],
        [-1.09669   , -0.78780346,  0.36082103, -0.46617419],
        [-0.12390883,  0.06763607, -0.55066722, -1.21307823],
        [-0.53173765,  0.27211785,  0.13773388,  0.60341025]],

       [[-0.53255353, -0.67649676,  0.54549953, -1.00826225],
        [ 0.94515944, -0.27319575, -0.51387681,  1.35493121],
        [ 0.37320569,  0.32849431,  0.76313634, -0.97147209],
        [-0.87851317,  0.19135603,  1.68109375, -0.94131665]]])