In [1]:
import ctypes
import numpy as np

In [2]:
def init_kernel_bias(num_inp_channels, kernel_size, num_kernels,mean=0,std=0.01):
    shape = [num_inp_channels, kernel_size, kernel_size, num_kernels]
    weights = std*np.random.randn(*shape) + mean
    # weights/=np.sqrt(num_inp_channels)
    bias = std*np.random.randn(1,num_kernels) + mean
    return np.asfortranarray(weights.astype(np.float32)), np.asfortranarray(bias.astype(np.float32))

In [3]:
w0,b0=init_kernel_bias(num_inp_channels=32,kernel_size=3,num_kernels=64)

In [4]:
inp=np.arange(130*32*32*32).reshape(130,32,32,32).astype(np.float32)

In [5]:
#inp[batches,row,col,d],w0(d,ksz,ksz,num_ker),b0[1,num_ker],stride[row,col]
padding=0
stride=[1,1]
ipp=inp.transpose(0,3,1,2)  #ipp[batches,d,row,col]
output=[]
ksz=w0.shape[1]
num_ker=w0.shape[3]
if not padding: #take care of padding in backprop too
    padding=(ksz-1)//2  #currently don't give 'even' ksz
out_row,out_col=((ipp.shape[2]-ksz+2*padding)//stride[0]+1),((ipp.shape[3]-ksz+2*padding)//stride[1]+1)
batches,d,row,col=ipp.shape
row+=2*padding
col+=2*padding

In [6]:
%%time
padded=np.zeros((batches,d,row,col),dtype=np.float32)
padded[:,:,padding:-padding,padding:-padding]=ipp

CPU times: user 12.8 ms, sys: 3.32 ms, total: 16.1 ms
Wall time: 14.7 ms


In [7]:
# %%timeit
window=(np.arange(ksz)[:,None]*row+np.arange(ksz)).ravel()+np.arange(d)[:,None]*row*col
slider=(np.arange(out_row*stride[0])[:,None]*row+np.arange(out_col*stride[1]))
ind = window.ravel()+slider[::stride[0],::stride[1]].ravel()[:,None]
# bind= np.arange(batches)[:,None]*d*row*col+ind.ravel()
kern = w0.reshape(-1,num_ker,order='F')
# output=(np.dot(np.take(padded, bind).reshape(-1,d*ksz*ksz), kern)).reshape(batches,out_row,out_col,num_ker)

In [8]:
%%time
checker=np.empty((batches,*ind.shape),dtype=np.float32)#,order='F')
for i,img in enumerate(padded):      #img[d,row,col]
    # windows(out_row*out_col, ksz*ksz*d) . kernels(d*ksz*ksz,num_ker)
    checker[i]=img.take(ind)

CPU times: user 74.7 ms, sys: 10.1 ms, total: 84.8 ms
Wall time: 83.5 ms


In [9]:
%%time
coled=np.empty((batches,*ind.shape),dtype=np.float32,order='F').reshape(-1,d*ksz*ksz,order='A')

CPU times: user 32 µs, sys: 3 µs, total: 35 µs
Wall time: 38.4 µs


In [10]:
coled.shape

(133120, 288)

In [11]:
ctake=ctypes.CDLL("libctake.so")

In [12]:
%%timeit
ctake.take(ctypes.c_void_p(padded.ctypes.data),ctypes.c_void_p(ind.ctypes.data),ctypes.c_void_p(coled.ctypes.data),ctypes.c_int(batches),ctypes.c_int(padded[0].size),ctypes.c_int(ind.size),ctypes.c_int(coled.shape[0]),ctypes.c_int(coled.shape[1]),ord('F'),ctypes.c_int(4))

203 ms ± 7.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
padded.shape,coled.shape,ind.shape              #make coled=(131072, 288)

((130, 32, 34, 34), (133120, 288), (1024, 288))

In [43]:
%time a=np.asfortranarray(a)

CPU times: user 11 µs, sys: 0 ns, total: 11 µs
Wall time: 13.4 µs


In [42]:
%%time
a=checker.reshape(-1,d*ksz*ksz,order='F')

CPU times: user 585 ms, sys: 16.6 ms, total: 601 ms
Wall time: 601 ms


In [44]:
a.astype(int)

array([[      0,       0,       0, ...,       0,    1055,    1087],
       [      0,       0,       0, ...,       0,   33823,   33855],
       [      0,       0,       0, ...,       0,   66591,   66623],
       ...,
       [4193216, 4193248,       0, ...,       0,       0,       0],
       [4225984, 4226016,       0, ...,       0,       0,       0],
       [4258752, 4258784,       0, ...,       0,       0,       0]])

In [52]:
a.reshape(a.shape,order='F').astype(int)

array([[      0,       0,       0, ...,       0,    1055,    1087],
       [      0,       0,       0, ...,       0,   33823,   33855],
       [      0,       0,       0, ...,       0,   66591,   66623],
       ...,
       [4193216, 4193248,       0, ...,       0,       0,       0],
       [4225984, 4226016,       0, ...,       0,       0,       0],
       [4258752, 4258784,       0, ...,       0,       0,       0]])

In [53]:
coled.astype(int)

array([[      0,       0,       0, ...,       0,    1055,    1087],
       [      0,       0,       0, ...,    1055,    1087,    1119],
       [      0,       0,       0, ...,    1087,    1119,    1151],
       ...,
       [4258688, 4258720, 4258752, ...,       0,       0,       0],
       [4258720, 4258752, 4258784, ...,       0,       0,       0],
       [4258752, 4258784,       0, ...,       0,       0,       0]])

In [34]:
(a==coled).all()

False

In [38]:
%%time
cold=coled.reshape(batches,*ind.shape)

CPU times: user 10 µs, sys: 0 ns, total: 10 µs
Wall time: 12.4 µs


In [15]:
%%time
for i in range(batches):
    if not (cold[i]==checker[i]).all():
        print(i,end=',')

CPU times: user 150 ms, sys: 3.1 ms, total: 153 ms
Wall time: 155 ms


In [16]:
%%time
coled.shape

CPU times: user 6 µs, sys: 0 ns, total: 6 µs
Wall time: 9.3 µs


(133120, 288)

In [17]:
coled.ravel(order='F')[:100].astype(int)

array([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,   32,   64,   96,  128,  160,  192,  224,  256,  288,  320,
        352,  384,  416,  448,  480,  512,  544,  576,  608,  640,  672,
        704,  736,  768,  800,  832,  864,  896,  928,  960,    0, 1024,
       1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376,
       1408, 1440, 1472, 1504, 1536, 1568, 1600, 1632, 1664, 1696, 1728,
       1760, 1792, 1824, 1856, 1888, 1920, 1952, 1984,    0, 2048, 2080,
       2112])

In [18]:
%time checker.reshape(coled.shape).ravel(order='F')[:100].astype(int)

CPU times: user 312 ms, sys: 13.3 ms, total: 325 ms
Wall time: 324 ms


array([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,   32,   64,   96,  128,  160,  192,  224,  256,  288,  320,
        352,  384,  416,  448,  480,  512,  544,  576,  608,  640,  672,
        704,  736,  768,  800,  832,  864,  896,  928,  960,    0, 1024,
       1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376,
       1408, 1440, 1472, 1504, 1536, 1568, 1600, 1632, 1664, 1696, 1728,
       1760, 1792, 1824, 1856, 1888, 1920, 1952, 1984,    0, 2048, 2080,
       2112])

In [19]:
cold.ravel(order='K')[:100]

array([   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
          0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
          0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
          0.,    0.,    0.,    0.,    0.,    0.,    0.,   32.,   64.,
         96.,  128.,  160.,  192.,  224.,  256.,  288.,  320.,  352.,
        384.,  416.,  448.,  480.,  512.,  544.,  576.,  608.,  640.,
        672.,  704.,  736.,  768.,  800.,  832.,  864.,  896.,  928.,
        960.,    0., 1024., 1056., 1088., 1120., 1152., 1184., 1216.,
       1248., 1280., 1312., 1344., 1376., 1408., 1440., 1472., 1504.,
       1536., 1568., 1600., 1632., 1664., 1696., 1728., 1760., 1792.,
       1824., 1856., 1888., 1920., 1952., 1984.,    0., 2048., 2080.,
       2112.], dtype=float32)

In [23]:
checker.ravel(order='F')[:100].astype(int)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [21]:
(coled.ravel(order='C')==checker.ravel(order='C')).all()

True

In [22]:
np.where(coled==11872)

(array([338, 339, 340, 370, 371, 372, 402, 403, 404]),
 array([8, 7, 6, 5, 4, 3, 2, 1, 0]))