Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions dpnp/backend/kernels/dpnp_krnl_random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,27 +108,45 @@ INP_DLLEXPORT void dpnp_rng_beta_c(void* result, const _DataType a, const _DataT
template <typename _DataType>
void dpnp_rng_binomial_c(void* result, const int ntrial, const double p, const size_t size)
{
if (result == nullptr)
{
return;
}
if (!size)
{
return;
}
_DataType* result1 = reinterpret_cast<_DataType*>(result);

if (dpnp_queue_is_cpu_c())
if (ntrial == 0 || p == 0)
{
mkl_rng::binomial<_DataType> distribution(ntrial, p);
// perform generation
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
event_out.wait();
dpnp_zeros_c<_DataType>(result, size);
}
else if (p == 1)
{
_DataType* fill_value = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(sizeof(_DataType)));
fill_value[0] = static_cast<_DataType>(ntrial);
dpnp_initval_c<_DataType>(result, fill_value, size);
dpnp_memory_free_c(fill_value);
}
else
{
int errcode = viRngBinomial(VSL_RNG_METHOD_BINOMIAL_BTPE, get_rng_stream(), size, result1, ntrial, p);
if (errcode != VSL_STATUS_OK)
if (dpnp_queue_is_cpu_c())
{
throw std::runtime_error("DPNP RNG Error: dpnp_rng_binomial_c() failed.");
mkl_rng::binomial<_DataType> distribution(ntrial, p);
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
event_out.wait();
}
else
{
int errcode = viRngBinomial(VSL_RNG_METHOD_BINOMIAL_BTPE, get_rng_stream(), size, result1, ntrial, p);
if (errcode != VSL_STATUS_OK)
{
throw std::runtime_error("DPNP RNG Error: dpnp_rng_binomial_c() failed.");
}
}
}
return;
}

template <typename _DataType>
Expand Down
27 changes: 10 additions & 17 deletions dpnp/random/dpnp_algo_random.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -177,26 +177,19 @@ cpdef dparray dpnp_rng_binomial(int ntrial, double p, size):
cdef DPNPFuncData kernel_data
cdef fptr_dpnp_rng_binomial_c_1out_t func

if ntrial == 0 or p == 0.0:
result = dparray(size, dtype=dtype)
result.fill(0.0)
elif p == 1.0:
result = dparray(size, dtype=dtype)
result.fill(ntrial)
else:
# convert string type names (dparray.dtype) to C enum DPNPFuncType
param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
# convert string type names (dparray.dtype) to C enum DPNPFuncType
param1_type = dpnp_dtype_to_DPNPFuncType(dtype)

# get the FPTR data structure
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_BINOMIAL, param1_type, param1_type)
# get the FPTR data structure
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_BINOMIAL, param1_type, param1_type)

result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
# ceate result array with type given by FPTR data
result = dparray(size, dtype=result_type)
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
# ceate result array with type given by FPTR data
result = dparray(size, dtype=result_type)

func = <fptr_dpnp_rng_binomial_c_1out_t > kernel_data.ptr
# call FPTR function
func(result.get_data(), ntrial, p, result.size)
func = <fptr_dpnp_rng_binomial_c_1out_t > kernel_data.ptr
# call FPTR function
func(result.get_data(), ntrial, p, result.size)

return result

Expand Down