Skip to content

Commit

Permalink
Corrected af_random_engine_set_type
Browse files Browse the repository at this point in the history
- Correct usage of getRandomEngine
- Remove redundant release of Mersenne arrays
- Added test for setDefaultRandomEngine
  • Loading branch information
Kumar Aatish committed Jan 24, 2017
1 parent e59cdbc commit 38f92ec
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
39 changes: 18 additions & 21 deletions src/api/c/random.cpp
Expand Up @@ -179,29 +179,26 @@ af_err af_random_engine_set_type(af_random_engine *engine, const af_random_engin
try {
AF_CHECK(af_init());
validateRandomType(rtype);
RandomEngine e = *(getRandomEngine(engine));
if (rtype != e.type) {
RandomEngine *e = getRandomEngine(*engine);
if (rtype != e->type) {
if (rtype == AF_RANDOM_ENGINE_MERSENNE_GP11213) {
bool empty;
AF_CHECK(af_is_empty(&empty, e.state));
if (empty) {
AF_CHECK(af_release_array(e.pos));
AF_CHECK(af_release_array(e.sh1));
AF_CHECK(af_release_array(e.sh2));
AF_CHECK(af_release_array(e.recursion_table));
AF_CHECK(af_release_array(e.temper_table));
AF_CHECK(af_release_array(e.state));
AF_CHECK(af_create_array(&e.pos, pos, 1, &MaxBlocks, u32));
AF_CHECK(af_create_array(&e.sh1, sh1, 1, &MaxBlocks, u32));
AF_CHECK(af_create_array(&e.sh2, sh2, 1, &MaxBlocks, u32));
e.mask = mask;
AF_CHECK(af_create_array(&e.recursion_table, recursion_tbl, 1, &TableLength, u32));
AF_CHECK(af_create_array(&e.temper_table, temper_tbl, 1, &TableLength, u32));
AF_CHECK(af_create_handle(&e.state, 1, &MtStateLength, u32));
initMersenneState(getWritableArray<uint>(e.state), *e.seed, getArray<uint>(e.recursion_table));
}
AF_CHECK(af_create_array(&e->pos, pos, 1, &MaxBlocks, u32));
AF_CHECK(af_create_array(&e->sh1, sh1, 1, &MaxBlocks, u32));
AF_CHECK(af_create_array(&e->sh2, sh2, 1, &MaxBlocks, u32));
e->mask = mask;
AF_CHECK(af_create_array(&e->recursion_table, recursion_tbl, 1, &TableLength, u32));
AF_CHECK(af_create_array(&e->temper_table, temper_tbl, 1, &TableLength, u32));
AF_CHECK(af_create_handle(&e->state, 1, &MtStateLength, u32));
initMersenneState(getWritableArray<uint>(e->state), *(e->seed), getArray<uint>(e->recursion_table));
} else if (e->type == AF_RANDOM_ENGINE_MERSENNE_GP11213) {
AF_CHECK(af_release_array(e->pos));
AF_CHECK(af_release_array(e->sh1));
AF_CHECK(af_release_array(e->sh2));
AF_CHECK(af_release_array(e->recursion_table));
AF_CHECK(af_release_array(e->temper_table));
AF_CHECK(af_release_array(e->state));
}
e.type = rtype;
e->type = rtype;
}
} CATCHALL;
return AF_SUCCESS;
Expand Down
9 changes: 9 additions & 0 deletions test/random.cpp
Expand Up @@ -188,6 +188,15 @@ TEST(Random, CPP)
af::dim4 dims(1, 2, 3, 1);
af::array out1 = af::randu(dims);
af::array out2 = af::randn(dims);
af::setDefaultRandomEngineType(AF_RANDOM_ENGINE_PHILOX);
af::array out3 = af::randu(dims);
af::array out4 = af::randn(dims);
af::setDefaultRandomEngineType(AF_RANDOM_ENGINE_THREEFRY);
af::array out5 = af::randu(dims);
af::array out6 = af::randn(dims);
af::setDefaultRandomEngineType(AF_RANDOM_ENGINE_MERSENNE);
af::array out7 = af::randu(dims);
af::array out8 = af::randn(dims);
af::sync();
}

Expand Down

0 comments on commit 38f92ec

Please sign in to comment.