Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix empty array handling in lu, rank, and qr. Other minor refactoring #2838

Merged
merged 4 commits into from Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/api/c/lu.cpp
Expand Up @@ -49,6 +49,9 @@ af_err af_lu(af_array *lower, af_array *upper, af_array *pivot,

af_dtype type = i_info.getType();

ARG_ASSERT(0, lower != nullptr);
9prady9 marked this conversation as resolved.
Show resolved Hide resolved
ARG_ASSERT(1, upper != nullptr);
ARG_ASSERT(2, pivot != nullptr);
ARG_ASSERT(3, i_info.isFloating()); // Only floating and complex types

if (i_info.ndims() == 0) {
Expand Down Expand Up @@ -81,21 +84,21 @@ af_err af_lu_inplace(af_array *pivot, af_array in, const bool is_lapack_piv) {
}

ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
ARG_ASSERT(0, pivot != nullptr);

if (i_info.ndims() == 0) {
return af_create_handle(pivot, 0, nullptr, type);
}

af_array out;

switch (type) {
case f32: out = lu_inplace<float>(in, is_lapack_piv); break;
case f64: out = lu_inplace<double>(in, is_lapack_piv); break;
case c32: out = lu_inplace<cfloat>(in, is_lapack_piv); break;
case c64: out = lu_inplace<cdouble>(in, is_lapack_piv); break;
default: TYPE_ERROR(1, type);
}
if (pivot != NULL) std::swap(*pivot, out);
std::swap(*pivot, out);
}
CATCHALL;

Expand Down
24 changes: 10 additions & 14 deletions src/api/c/median.cpp
Expand Up @@ -41,14 +41,12 @@ static double median(const af_array& in) {
} else if (nElems == 2) {
T result[2];
AF_CHECK(af_get_data_ptr((void*)&result, in));
if (input.isFloating()) {
return division(result[0] + result[1], 2.0);
} else {
return division((float)result[0] + (float)result[1], 2.0);
}
return division(
(static_cast<double>(result[0]) + static_cast<double>(result[1])),
2.0);
}

double mid = (nElems + 1) / 2;
double mid = static_cast<double>(nElems + 1) / 2.0;
af_seq mdSpan[1] = {af_make_seq(mid - 1, mid, 1)};

Array<T> sortedArr = sort<T>(input, 0, true);
Expand All @@ -68,11 +66,9 @@ static double median(const af_array& in) {
if (nElems % 2 == 1) {
result = resPtr[0];
} else {
if (input.isFloating()) {
result = division(resPtr[0] + resPtr[1], 2);
} else {
result = division((float)resPtr[0] + (float)resPtr[1], 2);
}
result = division(
static_cast<double>(resPtr[0]) + static_cast<double>(resPtr[1]),
2.0);
}

return result;
Expand All @@ -90,9 +86,9 @@ static af_array median(const af_array& in, const dim_t dim) {

Array<T> sortedIn = sort<T>(input, dim, true);

int dimLength = input.dims()[dim];
double mid = (dimLength + 1) / 2;
af_array left = 0;
size_t dimLength = input.dims()[dim];
umar456 marked this conversation as resolved.
Show resolved Hide resolved
double mid = static_cast<double>(dimLength + 1) / 2.0;
af_array left = 0;

af_seq slices[4] = {af_span, af_span, af_span, af_span};
slices[dim] = af_make_seq(mid - 1.0, mid - 1.0, 1.0);
Expand Down
7 changes: 5 additions & 2 deletions src/api/c/qr.cpp
Expand Up @@ -55,6 +55,9 @@ af_err af_qr(af_array *q, af_array *r, af_array *tau, const af_array in) {
return AF_SUCCESS;
}

ARG_ASSERT(0, q != nullptr);
ARG_ASSERT(1, r != nullptr);
ARG_ASSERT(2, tau != nullptr);
ARG_ASSERT(3, i_info.isFloating()); // Only floating and complex types

switch (type) {
Expand All @@ -81,21 +84,21 @@ af_err af_qr_inplace(af_array *tau, af_array in) {
af_dtype type = i_info.getType();

ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
ARG_ASSERT(0, tau != nullptr);

if (i_info.ndims() == 0) {
return af_create_handle(tau, 0, nullptr, type);
}

af_array out;

switch (type) {
case f32: out = qr_inplace<float>(in); break;
case f64: out = qr_inplace<double>(in); break;
case c32: out = qr_inplace<cfloat>(in); break;
case c64: out = qr_inplace<cdouble>(in); break;
default: TYPE_ERROR(1, type);
}
if (tau != NULL) std::swap(*tau, out);
std::swap(*tau, out);
}
CATCHALL;

Expand Down
22 changes: 10 additions & 12 deletions src/api/c/rank.cpp
Expand Up @@ -56,19 +56,17 @@ af_err af_rank(uint* out, const af_array in, const double tol) {
af_dtype type = i_info.getType();

ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types
ARG_ASSERT(0, out != nullptr);

uint output;
if (i_info.ndims() == 0) {
output = 0;
return AF_SUCCESS;
}

switch (type) {
case f32: output = rank<float>(in, tol); break;
case f64: output = rank<double>(in, tol); break;
case c32: output = rank<cfloat>(in, tol); break;
case c64: output = rank<cdouble>(in, tol); break;
default: TYPE_ERROR(1, type);
uint output = 0;
if (i_info.ndims() != 0) {
switch (type) {
case f32: output = rank<float>(in, tol); break;
case f64: output = rank<double>(in, tol); break;
case c32: output = rank<cfloat>(in, tol); break;
case c64: output = rank<cdouble>(in, tol); break;
default: TYPE_ERROR(1, type);
}
}
std::swap(*out, output);
}
Expand Down
58 changes: 31 additions & 27 deletions src/backend/common/DefaultMemoryManager.cpp
Expand Up @@ -166,13 +166,15 @@ void *DefaultMemoryManager::alloc(bool user_lock, const unsigned ndims,
}

lock_guard_t lock(this->memory_mutex);
free_iter iter = current.free_map.find(alloc_bytes);
umar456 marked this conversation as resolved.
Show resolved Hide resolved
auto free_buffer_iter = current.free_map.find(alloc_bytes);
vector<void *> &free_buffer_vector = free_buffer_iter->second;

if (iter != current.free_map.end() && !iter->second.empty()) {
if (free_buffer_iter != current.free_map.end() &&
!free_buffer_vector.empty()) {
// Delete existing buffer info and underlying event
// Set to existing in from free map
ptr = iter->second.back();
iter->second.pop_back();
ptr = free_buffer_vector.back();
free_buffer_vector.pop_back();
current.locked_map[ptr] = info;
current.lock_bytes += alloc_bytes;
current.lock_buffers++;
Expand Down Expand Up @@ -206,9 +208,9 @@ void *DefaultMemoryManager::alloc(bool user_lock, const unsigned ndims,
size_t DefaultMemoryManager::allocated(void *ptr) {
if (!ptr) return 0;
memory_info &current = this->getCurrentMemoryInfo();
locked_iter iter = current.locked_map.find((void *)ptr);
umar456 marked this conversation as resolved.
Show resolved Hide resolved
if (iter == current.locked_map.end()) return 0;
return (iter->second).bytes;
auto locked_iter = current.locked_map.find(ptr);
if (locked_iter == current.locked_map.end()) { return 0; }
return (locked_iter->second).bytes;
}

void DefaultMemoryManager::unlock(void *ptr, bool user_unlock) {
Expand All @@ -221,39 +223,43 @@ void DefaultMemoryManager::unlock(void *ptr, bool user_unlock) {
lock_guard_t lock(this->memory_mutex);
memory_info &current = this->getCurrentMemoryInfo();

locked_iter iter = current.locked_map.find((void *)ptr);
auto locked_buffer_iter = current.locked_map.find(ptr);
locked_info &locked_buffer_info = locked_buffer_iter->second;
void *locked_buffer_ptr = locked_buffer_iter->first;

// Pointer not found in locked map
if (iter == current.locked_map.end()) {
if (locked_buffer_iter == current.locked_map.end()) {
// Probably came from user, just free it
freed_ptr.reset(ptr);
return;
}

if (user_unlock) {
(iter->second).user_lock = false;
locked_buffer_info.user_lock = false;
} else {
(iter->second).manager_lock = false;
locked_buffer_info.manager_lock = false;
}

// Return early if either one is locked
if ((iter->second).user_lock || (iter->second).manager_lock) { return; }
if (locked_buffer_info.user_lock || locked_buffer_info.manager_lock) {
return;
}

size_t bytes = iter->second.bytes;
current.lock_bytes -= iter->second.bytes;
size_t bytes = locked_buffer_info.bytes;
current.lock_bytes -= locked_buffer_info.bytes;
current.lock_buffers--;

if (this->debug_mode) {
// Just free memory in debug mode
if ((iter->second).bytes > 0) {
freed_ptr.reset(iter->first);
if (locked_buffer_info.bytes > 0) {
freed_ptr.reset(locked_buffer_ptr);
current.total_buffers--;
current.total_bytes -= iter->second.bytes;
current.total_bytes -= locked_buffer_info.bytes;
}
} else {
current.free_map[bytes].emplace_back(ptr);
}
current.locked_map.erase(iter);
current.locked_map.erase(locked_buffer_iter);
}
}

Expand All @@ -262,6 +268,7 @@ void DefaultMemoryManager::signalMemoryCleanup() {
}

void DefaultMemoryManager::printInfo(const char *msg, const int device) {
UNUSED(device);
const memory_info &current = this->getCurrentMemoryInfo();

printf("%s\n", msg);
Expand Down Expand Up @@ -325,9 +332,9 @@ void DefaultMemoryManager::userLock(const void *ptr) {

lock_guard_t lock(this->memory_mutex);

locked_iter iter = current.locked_map.find(const_cast<void *>(ptr));
if (iter != current.locked_map.end()) {
iter->second.user_lock = true;
auto locked_iter = current.locked_map.find(const_cast<void *>(ptr));
if (locked_iter != current.locked_map.end()) {
locked_iter->second.user_lock = true;
} else {
locked_info info = {false, true, 100}; // This number is not relevant

Expand All @@ -342,12 +349,9 @@ void DefaultMemoryManager::userUnlock(const void *ptr) {
bool DefaultMemoryManager::isUserLocked(const void *ptr) {
memory_info &current = this->getCurrentMemoryInfo();
lock_guard_t lock(this->memory_mutex);
locked_iter iter = current.locked_map.find(const_cast<void *>(ptr));
if (iter != current.locked_map.end()) {
return iter->second.user_lock;
} else {
return false;
}
auto locked_iter = current.locked_map.find(const_cast<void *>(ptr));
if (locked_iter == current.locked_map.end()) { return false; }
return locked_iter->second.user_lock;
}

size_t DefaultMemoryManager::getMemStepSize() {
Expand Down
7 changes: 2 additions & 5 deletions src/backend/common/DefaultMemoryManager.hpp
Expand Up @@ -35,11 +35,8 @@ class DefaultMemoryManager final : public common::memory::MemoryManagerBase {
size_t bytes;
};

using locked_t = typename std::unordered_map<void *, locked_info>;
using locked_iter = typename locked_t::iterator;

using free_t = std::unordered_map<size_t, std::vector<void *>>;
using free_iter = typename free_t::iterator;
using locked_t = typename std::unordered_map<void *, locked_info>;
using free_t = std::unordered_map<size_t, std::vector<void *>>;

struct memory_info {
locked_t locked_map;
Expand Down
43 changes: 43 additions & 0 deletions test/lu_dense.cpp
Expand Up @@ -235,3 +235,46 @@ TYPED_TEST(LU, RectangularLarge1) {
TYPED_TEST(LU, RectangularMultipleOfTwoLarge1) {
luTester<TypeParam>(512, 1024, eps<TypeParam>());
}

TEST(LU, NullLowerOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));

af_array upper, pivot;
ASSERT_EQ(AF_ERR_ARG, af_lu(NULL, &upper, &pivot, in));
ASSERT_SUCCESS(af_release_array(in));
}

TEST(LU, NullUpperOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));

af_array lower, pivot;
ASSERT_EQ(AF_ERR_ARG, af_lu(&lower, NULL, &pivot, in));
ASSERT_SUCCESS(af_release_array(in));
}

TEST(LU, NullPivotOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));

af_array lower, upper;
ASSERT_EQ(AF_ERR_ARG, af_lu(&lower, &upper, NULL, in));
ASSERT_SUCCESS(af_release_array(in));
}

TEST(LU, InPlaceNullOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));

ASSERT_EQ(AF_ERR_ARG, af_lu_inplace(NULL, in, true));
ASSERT_SUCCESS(af_release_array(in));
}
15 changes: 15 additions & 0 deletions test/median.cpp
Expand Up @@ -150,3 +150,18 @@ MEDIAN(float, uchar)
MEDIAN(float, short)
MEDIAN(float, ushort)
MEDIAN(double, double)

TEST(Median, OneElement) {
af::array in = randu(1, f32);

af::array out = median(in);
ASSERT_ARRAYS_EQ(in, out);
}

TEST(Median, TwoElements) {
af::array in = randu(2, f32);

af::array out = median(in);
af::array gold = mean(in);
ASSERT_ARRAYS_EQ(gold, out);
}
umar456 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion test/memory.cpp
Expand Up @@ -710,7 +710,7 @@ af_err unlock_fn(af_memory_manager manager, void *ptr, int userLock) {

af_err user_unlock_fn(af_memory_manager manager, void *ptr) {
auto *payload = getMemoryManagerPayload<E2ETestPayload>(manager);
af_err err = unlock_fn(manager, ptr, /* user */ 1);
af_err err = unlock_fn(manager, ptr, /* user */ 1);
payload->lockedBytes -= payload->table[ptr];
return err;
}
Expand Down
10 changes: 10 additions & 0 deletions test/qr_dense.cpp
Expand Up @@ -179,3 +179,13 @@ TYPED_TEST(QR, RectangularLarge1) {
TYPED_TEST(QR, RectangularMultipleOfTwoLarge1) {
qrTester<TypeParam>(512, 1024, eps<TypeParam>());
}

TEST(QR, InPlaceNullOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
ASSERT_SUCCESS(af_randu(&in, dims.ndims(), dims.get(), f32));

ASSERT_EQ(AF_ERR_ARG, af_qr_inplace(NULL, in));
ASSERT_SUCCESS(af_release_array(in));
}
10 changes: 10 additions & 0 deletions test/rank_dense.cpp
Expand Up @@ -112,3 +112,13 @@ void detTest() {
}

TYPED_TEST(Det, Small) { detTest<TypeParam>(); }

TEST(Rank, NullOutput) {
if (noLAPACKTests()) return;
dim4 dims(3, 3);
af_array in = 0;
af_randu(&in, dims.ndims(), dims.get(), f32);

ASSERT_EQ(AF_ERR_ARG, af_rank(NULL, in, 1e-6));
ASSERT_SUCCESS(af_release_array(in));
}