Skip to content

Commit

Permalink
Hide Array<T> constructor. Use createStridedArray instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Nov 21, 2018
1 parent 3c03c16 commit db43f3f
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 28 deletions.
34 changes: 21 additions & 13 deletions src/api/c/internal.cpp
Expand Up @@ -18,7 +18,15 @@
#include <common/err_common.hpp>
#include <cstring>

using namespace detail;
using af::dim4;
using detail::cdouble;
using detail::cfloat;
using detail::createStridedArray;
using detail::intl;
using detail::uchar;
using detail::uint;
using detail::uintl;
using detail::ushort;

af_err af_create_strided_array(af_array *arr,
const void *data,
Expand Down Expand Up @@ -54,18 +62,18 @@ af_err af_create_strided_array(af_array *arr,
AF_CHECK(af_init());

switch (ty) {
case f32: res = getHandle(Array<float >(dims, strides, offset, (float *)data, isdev)); break;
case f64: res = getHandle(Array<double >(dims, strides, offset, (double *)data, isdev)); break;
case c32: res = getHandle(Array<cfloat >(dims, strides, offset, (cfloat *)data, isdev)); break;
case c64: res = getHandle(Array<cdouble>(dims, strides, offset, (cdouble *)data, isdev)); break;
case u32: res = getHandle(Array<uint >(dims, strides, offset, (uint *)data, isdev)); break;
case s32: res = getHandle(Array<int >(dims, strides, offset, (int *)data, isdev)); break;
case u64: res = getHandle(Array<uintl >(dims, strides, offset, (uintl *)data, isdev)); break;
case s64: res = getHandle(Array<intl >(dims, strides, offset, (intl *)data, isdev)); break;
case u16: res = getHandle(Array<ushort >(dims, strides, offset, (ushort *)data, isdev)); break;
case s16: res = getHandle(Array<short >(dims, strides, offset, (short *)data, isdev)); break;
case b8 : res = getHandle(Array<char >(dims, strides, offset, (char *)data, isdev)); break;
case u8 : res = getHandle(Array<uchar >(dims, strides, offset, (uchar *)data, isdev)); break;
case f32: res = getHandle(createStridedArray<float >(dims, strides, offset, (float *)data, isdev)); break;
case f64: res = getHandle(createStridedArray<double >(dims, strides, offset, (double *)data, isdev)); break;
case c32: res = getHandle(createStridedArray<cfloat >(dims, strides, offset, (cfloat *)data, isdev)); break;
case c64: res = getHandle(createStridedArray<cdouble>(dims, strides, offset, (cdouble *)data, isdev)); break;
case u32: res = getHandle(createStridedArray<uint >(dims, strides, offset, (uint *)data, isdev)); break;
case s32: res = getHandle(createStridedArray<int >(dims, strides, offset, (int *)data, isdev)); break;
case u64: res = getHandle(createStridedArray<uintl >(dims, strides, offset, (uintl *)data, isdev)); break;
case s64: res = getHandle(createStridedArray<intl >(dims, strides, offset, (intl *)data, isdev)); break;
case u16: res = getHandle(createStridedArray<ushort >(dims, strides, offset, (ushort *)data, isdev)); break;
case s16: res = getHandle(createStridedArray<short >(dims, strides, offset, (short *)data, isdev)); break;
case b8 : res = getHandle(createStridedArray<char >(dims, strides, offset, (char *)data, isdev)); break;
case u8 : res = getHandle(createStridedArray<uchar >(dims, strides, offset, (uchar *)data, isdev)); break;
default: TYPE_ERROR(6, ty);
}

Expand Down
14 changes: 11 additions & 3 deletions src/backend/cpu/Array.hpp
Expand Up @@ -58,6 +58,12 @@ namespace cpu
template<typename T>
Array<T> createDeviceDataArray(const af::dim4 &size, const void *data);

template<typename T>
Array<T> createStridedArray(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device) {
return Array<T>(dims, strides, offset, in_data, is_device);
}

/// Copies data to an existing Array object from a host pointer
template<typename T>
void writeHostDataArray(Array<T> &arr, const T * const data, const size_t bytes);
Expand Down Expand Up @@ -107,7 +113,6 @@ namespace cpu
class Array
{
ArrayInfo info; // Must be the first element of Array<T>
//TODO: Generator based array

//data if parent. empty if child
std::shared_ptr<T> data;
Expand All @@ -123,11 +128,11 @@ namespace cpu
explicit Array(dim4 dims, const T * const in_data, bool is_device, bool copy_device=false);
Array(const Array<T>& parnt, const dim4 &dims, const dim_t &offset, const dim4 &stride);
explicit Array(af::dim4 dims, jit::Node_ptr n);

public:
Array(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device = false);

public:

void resetInfo(const af::dim4& dims) { info.resetInfo(dims); }
void resetDims(const af::dim4& dims) { info.resetDims(dims); }
void modDims(const af::dim4 &newDims) { info.modDims(newDims); }
Expand Down Expand Up @@ -238,6 +243,9 @@ namespace cpu
friend Array<T> createValueArray<T>(const af::dim4 &size, const T& value);
friend Array<T> createHostDataArray<T>(const af::dim4 &size, const T * const data);
friend Array<T> createDeviceDataArray<T>(const af::dim4 &size, const void *data);
friend Array<T> createStridedArray<T>(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device);


friend Array<T> *initArray<T>();
friend Array<T> createEmptyArray<T>(const af::dim4 &size);
Expand Down
12 changes: 7 additions & 5 deletions src/backend/cpu/types.hpp
Expand Up @@ -12,9 +12,11 @@

namespace cpu
{
typedef std::complex<float> cfloat;
typedef std::complex<double> cdouble;
typedef unsigned int uint;
typedef unsigned char uchar;
typedef unsigned short ushort;
using cdouble = std::complex<double>;
using cfloat = std::complex<float>;
using intl = long long;
using uint = unsigned int;
using uchar = unsigned char;
using uintl = unsigned long long;
using ushort = unsigned short;
}
8 changes: 8 additions & 0 deletions src/backend/cuda/Array.hpp
Expand Up @@ -48,6 +48,12 @@ namespace cuda
template<typename T>
Array<T> createDeviceDataArray(const af::dim4 &size, const void *data);

template<typename T>
Array<T> createStridedArray(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device) {
return Array<T>(dims, strides, offset, in_data, is_device);
}

/// Copies data to an existing Array object from a host pointer
template<typename T>
void writeHostDataArray(Array<T> &arr, const T * const data, const size_t bytes);
Expand Down Expand Up @@ -231,6 +237,8 @@ namespace cuda
friend Array<T> createValueArray<T>(const af::dim4 &size, const T& value);
friend Array<T> createHostDataArray<T>(const af::dim4 &size, const T * const data);
friend Array<T> createDeviceDataArray<T>(const af::dim4 &size, const void *data);
friend Array<T> createStridedArray<T>(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device);

friend Array<T> *initArray<T>();
friend Array<T> createEmptyArray<T>(const af::dim4 &size);
Expand Down
4 changes: 2 additions & 2 deletions src/backend/cuda/types.hpp
Expand Up @@ -17,8 +17,8 @@ using cdouble = cuDoubleComplex;
using cfloat = cuFloatComplex;
using uchar = unsigned char;
using uint = unsigned int;
// using intl = long long ; // defined in af/defines.h
// using uintl = unsigned long long; // defined in af/defines.h
using intl = long long ;
using uintl = unsigned long long;
using ushort = unsigned short;

namespace {
Expand Down
8 changes: 8 additions & 0 deletions src/backend/opencl/Array.hpp
Expand Up @@ -48,6 +48,12 @@ namespace opencl
template<typename T>
Array<T> createDeviceDataArray(const af::dim4 &size, const void *data, bool copy = false);

template<typename T>
Array<T> createStridedArray(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device) {
return Array<T>(dims, strides, offset, in_data, is_device);
}

/// Copies data to an existing Array object from a host pointer
template<typename T>
void writeHostDataArray(Array<T> &arr, const T * const data, const size_t bytes);
Expand Down Expand Up @@ -276,6 +282,8 @@ namespace opencl
friend Array<T> createValueArray<T>(const af::dim4 &size, const T& value);
friend Array<T> createHostDataArray<T>(const af::dim4 &size, const T * const data);
friend Array<T> createDeviceDataArray<T>(const af::dim4 &size, const void *data, bool copy);
friend Array<T> createStridedArray<T>(af::dim4 dims, af::dim4 strides, dim_t offset,
const T * const in_data, bool is_device);

friend Array<T> *initArray<T>();
friend Array<T> createEmptyArray<T>(const af::dim4 &size);
Expand Down
13 changes: 8 additions & 5 deletions src/backend/opencl/types.hpp
Expand Up @@ -25,11 +25,14 @@

namespace opencl
{
typedef cl_float2 cfloat;
typedef cl_double2 cdouble;
typedef cl_uchar uchar;
typedef cl_uint uint;
typedef cl_ushort ushort;

using cdouble = cl_double2;
using cfloat = cl_float2;
using intl = long long;
using uchar = cl_uchar;
using uint = cl_uint;
using uintl = unsigned long long;
using ushort = cl_ushort;

template<typename T>
struct ToNumStr
Expand Down

0 comments on commit db43f3f

Please sign in to comment.