Skip to content
Merged
25 changes: 14 additions & 11 deletions numba_dppy/dppy_host_fn_call_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _init_llvm_types_and_constants(self):
self.byte_ptr_t = lc.Type.pointer(self.byte_t)
self.byte_ptr_ptr_t = lc.Type.pointer(self.byte_ptr_t)
self.intp_t = self.context.get_value_type(types.intp)
self.long_t = self.context.get_value_type(types.int64)
self.int64_t = self.context.get_value_type(types.int64)
self.int32_t = self.context.get_value_type(types.int32)
self.int32_ptr_t = lc.Type.pointer(self.int32_t)
self.uintp_t = self.context.get_value_type(types.uintp)
Expand Down Expand Up @@ -113,23 +113,26 @@ def allocate_kenrel_arg_array(self, num_kernel_args):


def resolve_and_return_dpctl_type(self, ty):
"""This function looks up the dpctl defined enum values from DPCTLKernelArgType.
"""

val = None
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
val = self.context.get_constant(types.int32, 4)
val = self.context.get_constant(types.int32, 9) # DPCTL_LONG_LONG
elif ty == types.uint32:
val = self.context.get_constant(types.int32, 5)
val = self.context.get_constant(types.int32, 10) # DPCTL_UNSIGNED_LONG_LONG
elif ty == types.boolean:
val = self.context.get_constant(types.int32, 5)
val = self.context.get_constant(types.int32, 5) # DPCTL_UNSIGNED_INT
elif ty == types.int64:
val = self.context.get_constant(types.int32, 7)
val = self.context.get_constant(types.int32, 9) # DPCTL_LONG_LONG
elif ty == types.uint64:
val = self.context.get_constant(types.int32, 8)
val = self.context.get_constant(types.int32, 11) # DPCTL_SIZE_T
elif ty == types.float32:
val = self.context.get_constant(types.int32, 12)
val = self.context.get_constant(types.int32, 12) # DPCTL_FLOAT
elif ty == types.float64:
val = self.context.get_constant(types.int32, 13)
val = self.context.get_constant(types.int32, 13) # DPCTL_DOUBLE
elif ty == types.voidptr:
val = self.context.get_constant(types.int32, 15)
val = self.context.get_constant(types.int32, 15) # DPCTL_VOID_PTR
else:
raise NotImplementedError

Expand All @@ -151,12 +154,12 @@ def process_kernel_arg(self, var, llvm_arg, arg_type, gu_sig, val_type, index, m
if llvm_arg is None:
raise NotImplementedError(arg_type, var)

storage = cgutils.alloca_once(self.builder, self.long_t)
storage = cgutils.alloca_once(self.builder, self.int64_t)
self.builder.store(self.context.get_constant(types.int64, 0), storage)
ty = self.resolve_and_return_dpctl_type(types.int64)
self.form_kernel_arg_and_arg_ty(self.builder.bitcast(storage, self.void_ptr_t), ty)

storage = cgutils.alloca_once(self.builder, self.long_t)
storage = cgutils.alloca_once(self.builder, self.int64_t)
self.builder.store(self.context.get_constant(types.int64, 0), storage)
ty = self.resolve_and_return_dpctl_type(types.int64)
self.form_kernel_arg_and_arg_ty(self.builder.bitcast(storage, self.void_ptr_t), ty)
Expand Down
16 changes: 11 additions & 5 deletions numba_dppy/examples/pa_examples/test1-2d.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from numba import njit, gdb
import numpy as np
import dpctl

@njit(parallel={'offload':True})

@njit
def f1(a, b):
c = a + b
return c


N = 1000
print("N", N)

a = np.ones((N,N), dtype=np.float32)
b = np.ones((N,N), dtype=np.float32)
a = np.ones((N, N), dtype=np.float32)
b = np.ones((N, N), dtype=np.float32)

print("a:", a, hex(a.ctypes.data))
print("b:", b, hex(b.ctypes.data))
c = f1(a,b)

with dpctl.device_context("opencl:gpu:0"):
c = f1(a, b)

print("BIG RESULT c:", c, hex(c.ctypes.data))
for i in range(N):
for j in range(N):
if c[i,j] != 2.0:
if c[i, j] != 2.0:
print("First index not equal to 2.0 was", i, j)
break
16 changes: 11 additions & 5 deletions numba_dppy/examples/pa_examples/test1-3d.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
from numba import njit, gdb
import numpy as np
import dpctl

@njit(parallel={'offload':True})

@njit
def f1(a, b):
c = a + b
return c


N = 10
print("N", N)

a = np.ones((N,N,N), dtype=np.float32)
b = np.ones((N,N,N), dtype=np.float32)
a = np.ones((N, N, N), dtype=np.float32)
b = np.ones((N, N, N), dtype=np.float32)

print("a:", a, hex(a.ctypes.data))
print("b:", b, hex(b.ctypes.data))
c = f1(a,b)

with dpctl.device_context("opencl:gpu:0"):
c = f1(a, b)

print("BIG RESULT c:", c, hex(c.ctypes.data))
for i in range(N):
for j in range(N):
for k in range(N):
if c[i,j,k] != 2.0:
if c[i, j, k] != 2.0:
print("First index not equal to 2.0 was", i, j, k)
break
16 changes: 11 additions & 5 deletions numba_dppy/examples/pa_examples/test1-4d.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from numba import njit, gdb
import numpy as np
import dpctl

@njit(parallel={'offload':True})

@njit
def f1(a, b):
c = a + b
return c


N = 10
print("N", N)

a = np.ones((N,N,N,N), dtype=np.float32)
b = np.ones((N,N,N,N), dtype=np.float32)
a = np.ones((N, N, N, N), dtype=np.float32)
b = np.ones((N, N, N, N), dtype=np.float32)

print("a:", a, hex(a.ctypes.data))
print("b:", b, hex(b.ctypes.data))
c = f1(a,b)

with dpctl.device_context("opencl:gpu:0"):
c = f1(a, b)

print("BIG RESULT c:", c, hex(c.ctypes.data))
for i in range(N):
for j in range(N):
for k in range(N):
for l in range(N):
if c[i,j,k,l] != 2.0:
if c[i, j, k, l] != 2.0:
print("First index not equal to 2.0 was", i, j, k, l)
break
16 changes: 11 additions & 5 deletions numba_dppy/examples/pa_examples/test1-5d.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from numba import njit, gdb
import numpy as np
import dpctl

@njit(parallel={'offload':True})

@njit
def f1(a, b):
c = a + b
return c


N = 5
print("N", N)

a = np.ones((N,N,N,N,N), dtype=np.float32)
b = np.ones((N,N,N,N,N), dtype=np.float32)
a = np.ones((N, N, N, N, N), dtype=np.float32)
b = np.ones((N, N, N, N, N), dtype=np.float32)

print("a:", a, hex(a.ctypes.data))
print("b:", b, hex(b.ctypes.data))
c = f1(a,b)

with dpctl.device_context("opencl:gpu:0"):
c = f1(a, b)

print("BIG RESULT c:", c, hex(c.ctypes.data))
for i in range(N):
for j in range(N):
for k in range(N):
for l in range(N):
for m in range(N):
if c[i,j,k,l,m] != 2.0:
if c[i, j, k, l, m] != 2.0:
print("First index not equal to 2.0 was", i, j, k, l, m)
break
8 changes: 6 additions & 2 deletions numba_dppy/examples/pa_examples/test1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from numba import njit
import numpy as np
import dpctl


@njit(parallel={'offload':True})
@njit
def f1(a, b):
c = a + b
return c
Expand All @@ -19,7 +20,10 @@ def main():

print("a:", a, hex(a.ctypes.data))
print("b:", b, hex(b.ctypes.data))
c = f1(a,b)

with dpctl.device_context("opencl:gpu:0"):
c = f1(a, b)

print("RESULT c:", c, hex(c.ctypes.data))
for i in range(N):
if c[i] != 2.0:
Expand Down
Loading