In [None]:
# load needed modules
import time

import numpy as np
from math import ceil 

import numpy as np
import matplotlib.pyplot as plt


In [None]:
#print out gpu information
gpu = numba.cuda.get_current_device()
print("name = %s" % gpu.name)
print("maxThreadsPerBlock = %s" % str(gpu.MAX_THREADS_PER_BLOCK))
print("maxBlockDimX = %s" % str(gpu.MAX_BLOCK_DIM_X))
print("maxBlockDimY = %s" % str(gpu.MAX_BLOCK_DIM_Y))
print("maxBlockDimZ = %s" % str(gpu.MAX_BLOCK_DIM_Z))
print("maxGridDimX = %s" % str(gpu.MAX_GRID_DIM_X))
print("maxGridDimY = %s" % str(gpu.MAX_GRID_DIM_Y))
print("maxGridDimZ = %s" % str(gpu.MAX_GRID_DIM_Z))
print("maxSharedMemoryPerBlock = %s" % str(gpu.MAX_SHARED_MEMORY_PER_BLOCK))
print("asyncEngineCount = %s" % str(gpu.ASYNC_ENGINE_COUNT))
print("canMapHostMemory = %s" % str(gpu.CAN_MAP_HOST_MEMORY))
print("multiProcessorCount = %s" % str(gpu.MULTIPROCESSOR_COUNT))
print("warpSize = %s" % str(gpu.WARP_SIZE))
print("unifiedAddressing = %s" % str(gpu.UNIFIED_ADDRESSING))
print("pciBusID = %s" % str(gpu.PCI_BUS_ID))
print("pciDeviceID = %s" % str(gpu.PCI_DEVICE_ID))

### Define kernel

In [None]:
# ================================= median ================================== #

@cuda.jit
def gpu_median_zero_padding(input_data, output_data, stencil_z, stencil_y, stencil_x, Nz, Ny, Nx):
    
    # ==== full kernel size ==== #
    dx = 2*stencil_x+1
    dy = 2*stencil_y+1
    dz = 2*stencil_z+1

    row_init = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
    col_init = cuda.threadIdx.y + cuda.blockIdx.y * cuda.blockDim.y
    depth_init = cuda.threadIdx.z + cuda.blockIdx.z * cuda.blockDim.z
    
    # grid stride:
    for depth_global in range(depth_init, Nz, cuda.blockDim.z * cuda.gridDim.z):
        z_min = max(depth_global - stencil_z, 0)
        z_max = min(depth_global + stencil_z + 1, Nz)
        for col_global in range(col_init, Ny, cuda.blockDim.y * cuda.gridDim.y):
            y_min = max(col_global - stencil_y, 0)
            y_max = min(col_global + stencil_y + 1, Ny)
            for row_global in range(row_init, Nx, cuda.blockDim.x * cuda.gridDim.x):
                
                x_min = max(row_global - stencil_x, 0)
                x_max = min(row_global + stencil_x + 1, Nx)
                
                # define local array in cuda: 
                # https://stackoverflow.com/questions/48642481/what-is-the-correct-usage-of-cuda-local-array-in-numba
                # BE AWARE!: Array has to be big enough to store the data of the whole kernal
                sort_array = cuda.local.array(10 * 8 * 10, numba.float32) 
                
                
                
                # load in local memory to sort
                m = 0
                for i in range(z_min,z_max):
                    for j in range(y_min,y_max):
                        for k in range(x_min,x_max):
                            sort_array[m] = input_data[i,j,k]
                            m += 1
                            
                # full size of the kernel
                n = dx*dy*dz
                
                # fill up the rest of the array with zeros
                for i in range(m,n):
                    sort_array[n] = 0
                
                # selection sort:
                # https://www.cnblogs.com/BobHuang/p/11263183.html <== chinese website
                for i in range(n - 1):
                    min_index = i;                  
                    for j in range(i+1, n):
                        if (sort_array[j] < sort_array[min_index]):
                            min_index = j; 
                
                    #swap(sort_array[i], sort_array[minIndex]);
                    tmp = sort_array[i]
                    sort_array[i] = sort_array[min_index]
                    sort_array[min_index] = tmp
                    
                half = int(n / 2)
                if (n % 2) == 1:
                    median = sort_array[half]
                else:
                    median = (sort_array[half-1] + sort_array[half]) / 2.0
             
                output_data[depth_global, col_global, row_global] = median
                

        
def lauch_kernel(input_data, output_data, stencil_):
    # TODO: set blocksize, gridsize and lauch kernel
    # define threads and blocks
    threads = (16, 8, 4)
    blocks = (8, 8, 8)
    print("==================")
    print("threads: " + str(threads))
    print("blocks: " + str(blocks))
    print("==================")
    
    Nz, Ny, Nx = np.shape(input_data) # why here z, y, x ?
    stencil_z, stencil_y, stencil_x = stencil_[0], stencil_[1], stencil_[2]
  
    # call CUDA kernel
    gpu_median_zero_padding[threads, blocks](input_data, output_data,stencil_z, stencil_y, stencil_x, Nz, Ny, Nx)
    


### Generate dummy data

In [None]:
#generate dummy data for test

Nz, Ny, Nx = 128, 512, 128

real_data = (np.ones((Nz, Ny, Nx)) * np.sin(np.linspace(0, 20, Ny)[np.newaxis, :, np.newaxis])*0.5 
             + np.random.normal(scale=1.0, size=(Nz, Ny, Nx)) )

In [None]:
#visulize date: before filtering:

plt.pcolormesh(real_data[Nz//2,:, :], vmin=-2, vmax=2, cmap=plt.cm.jet)

cb = plt.colorbar()
for i in cb.ax.get_yticklabels():
    i.set_fontsize(14)
    
plt.xlabel("x-index", fontsize=18)
plt.xticks( fontsize=14)
plt.ylabel("y-index", fontsize=18)
plt.yticks( fontsize=14)

plt.tight_layout()

plt.show()

In [None]:
# define the output result: same size of input data
output_data_ = real_data.copy()

### Run Kernel

In [None]:
stencil_t = np.array([1, 2, 1])

start = time.time()
lauch_kernel(real_data, output_data_, stencil_t)
end = time.time()
print("Elapsed (gpu naive with compilation) = %s" % (end - start))

start = time.time()
lauch_kernel(real_data, output_data_, stencil_t)
end = time.time()
print("Elapsed (gpu naive without compilation) = %s" % (end - start))

In [None]:
#visulize date: after filtering:
plt.pcolormesh(output_data_[Nz//2,:, :], vmin=-2, vmax=2, cmap=plt.cm.jet)

cb = plt.colorbar()
for i in cb.ax.get_yticklabels():
    i.set_fontsize(14)
    
plt.xlabel("x-index", fontsize=18)
plt.xticks( fontsize=14)
plt.ylabel("y-index", fontsize=18)
plt.yticks( fontsize=14)

plt.tight_layout()

plt.show()