Skip to content

Commit

Permalink
Finished an implementation of list_ew_multiply that actually works!
Browse files Browse the repository at this point in the history
  • Loading branch information
chriswailes committed Aug 17, 2012
1 parent c7b593e commit 764b97b
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ext/nmatrix/nmatrix.cpp
Expand Up @@ -1017,7 +1017,7 @@ static VALUE nm_xslice(int argc, VALUE* argv, void* (*slice_func)(STORAGE*, SLIC
fprintf(stderr, "\n");
*/

DENSE_STORAGE* s = NM_DENSE_STORAGE(self);
//DENSE_STORAGE* s = NM_DENSE_STORAGE(self);

if (NM_DTYPE(self) == RUBYOBJ) result = *reinterpret_cast<VALUE*>( ttable[NM_STYPE(self)](NM_STORAGE(self), slice) );
else result = rubyobj_from_cval( ttable[NM_STYPE(self)](NM_STORAGE(self), slice), NM_DTYPE(self) ).rval;
Expand Down
216 changes: 184 additions & 32 deletions ext/nmatrix/storage/list.cpp
Expand Up @@ -43,6 +43,7 @@
#include "common.h"
#include "list.h"

#include "util/math.h"
#include "util/sl_list.h"

/*
Expand All @@ -64,7 +65,10 @@ template <typename LDType, typename RDType>
static bool list_storage_eqeq_template(const LIST_STORAGE* left, const LIST_STORAGE* right);

template <typename LDType, typename RDType>
static void list_storage_ew_multiply_template(LIST* dest, const LIST* left, const LIST* right, size_t rank, const size_t* shape, size_t level);
static void* list_storage_ew_multiply_template(LIST* dest, const LIST* left, const void* l_default, const LIST* right, const void* r_default, const size_t* shape, size_t rank);

template <typename LDType, typename RDType>
static void list_storage_ew_multiply_template_prime(LIST* dest, LDType d_default, const LIST* left, LDType l_default, const LIST* right, RDType r_default, const size_t* shape, size_t last_level, size_t level);

/*
* Functions
Expand Down Expand Up @@ -249,14 +253,40 @@ bool list_storage_eqeq(const STORAGE* left, const STORAGE* right) {
* Documentation goes here.
*/
STORAGE* list_storage_ew_multiply(const STORAGE* left, const STORAGE* right) {
LR_DTYPE_TEMPLATE_TABLE(list_storage_ew_multiply_template, void, LIST*, const LIST*, const LIST*, size_t, const size_t*, size_t);
LR_DTYPE_TEMPLATE_TABLE(list_storage_ew_multiply_template, void*, LIST*, const LIST*, const void*, const LIST*, const void*, const size_t*, size_t);

dtype_t new_dtype = Upcast[left->dtype][right->dtype];

size_t* new_shape = (size_t*)calloc(left->rank, sizeof(size_t));
memcpy(new_shape, left->shape, sizeof(size_t) * left->rank);
const LIST_STORAGE* l = reinterpret_cast<const LIST_STORAGE*>(left),
* r = reinterpret_cast<const LIST_STORAGE*>(right);

LIST_STORAGE* result = list_storage_create(left->dtype, new_shape, left->rank, NULL);
LIST_STORAGE* new_l = NULL;

ttable[left->dtype][right->dtype](result->rows, ((LIST_STORAGE*)left)->rows, ((LIST_STORAGE*)right)->rows, result->rank, result->shape, 0);
// Allocate a new shape array for the resulting matrix.
size_t* new_shape = (size_t*)calloc(l->rank, sizeof(size_t));
memcpy(new_shape, left->shape, sizeof(size_t) * l->rank);

// Create the result matrix.
LIST_STORAGE* result = list_storage_create(new_dtype, new_shape, left->rank, NULL);

/*
* Call the templated elementwise multiplication function and set the default
* value for the resulting matrix.
*/
if (new_dtype != left->dtype) {
// Upcast the left-hand side if necessary.
new_l = reinterpret_cast<LIST_STORAGE*>(list_storage_cast_copy(l, new_dtype));

result->default_val =
ttable[left->dtype][right->dtype](result->rows, new_l->rows, new_l->default_val, r->rows, r->default_val, result->shape, result->rank);

// Delete the temporary left-hand side matrix.
list_storage_delete(reinterpret_cast<STORAGE*>(new_l));

} else {
result->default_val =
ttable[left->dtype][right->dtype](result->rows, l->rows, l->default_val, r->rows, r->default_val, result->shape, result->rank);
}

return result;
}
Expand Down Expand Up @@ -453,46 +483,168 @@ bool list_storage_eqeq_template(const LIST_STORAGE* left, const LIST_STORAGE* ri
* Documentation goes here.
*/
template <typename LDType, typename RDType>
static void list_storage_ew_multiply_template(LIST* dest, const LIST* left, const LIST* right, size_t rank, const size_t* shape, size_t level) {
unsigned int index;
static void* list_storage_ew_multiply_template(LIST* dest, const LIST* left, const void* l_default, const LIST* right, const void* r_default, const size_t* shape, size_t rank) {

/*
* Allocate space for, and calculate, the default value for the destination
* matrix.
*/
LDType* d_default_mem = ALLOC(LDType);
*d_default_mem = *reinterpret_cast<const LDType*>(l_default) * *reinterpret_cast<const RDType*>(r_default);

// Now that setup is done call the actual elementwise multiplication function.
list_storage_ew_multiply_template_prime<LDType, RDType>(dest, *reinterpret_cast<const LDType*>(d_default_mem),
left, *reinterpret_cast<const LDType*>(l_default), right, *reinterpret_cast<const RDType*>(r_default), shape, rank - 1, 0);

// Return a pointer to the destination matrix's default value.
return d_default_mem;
}

/*
* Documentation goes here.
*/
template <typename LDType, typename RDType>
static void list_storage_ew_multiply_template_prime(LIST* dest, LDType d_default, const LIST* left, LDType l_default, const LIST* right, RDType r_default, const size_t* shape, size_t last_level, size_t level) {

static LIST EMPTY_LIST = {NULL};

size_t index;

LDType tmp_result;

LDType* new_val;
LIST* new_level;
LIST* new_level = NULL;

NODE* l_node = left->first,
* r_node = right->first,
* dest_node = NULL;

if (rank == (level + 1)) {
for (index = 0; index < shape[level]; ++index) {
new_val = ALLOC(LDType);
*new_val = *reinterpret_cast<LDType*>(l_node->val) * *reinterpret_cast<RDType*>(r_node->val);
for (index = 0; index < shape[level]; ++index) {
if (l_node == NULL and r_node == NULL) {
/*
* Both source lists are now empty. Because the default value of the
* destination is already set appropriately we can now return.
*/

if (index == 0) {
dest_node = list_insert(dest, false, index, new_val);
return;

} else {
// At least one list still has entries.

if (l_node == NULL and (l_default == 0 and d_default == 0)) {
/*
* The left hand list has run out of elements. We don't need to add new
* values to the destination if l_default and d_default are both 0.
*/

} else {
dest_node = list_insert_after(dest_node, index, new_val);
return;

} else if (r_node == NULL and (r_default == 0 and d_default == 0)) {
/*
* The right hand list has run out of elements. We don't need to add new
* values to the destination if r_default and d_default are both 0.
*/

return;
}

l_node = l_node->next;
r_node = r_node->next;
}

} else {
for (index = 0; index < shape[level]; ++index) {
new_level = list_create();
list_storage_ew_multiply_template<LDType, RDType>(new_level, reinterpret_cast<LIST*>(l_node->val), reinterpret_cast<LIST*>(r_node->val), rank, shape, level + 1);
// We need to continue processing the lists.

if (l_node == NULL and r_node->key == index) {
/*
* One source list is empty, but the index has caught up to the key of
* the other list.
*/

if (level == last_level) {
tmp_result = l_default * *reinterpret_cast<RDType*>(r_node->val);

if (tmp_result != d_default) {
dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
}

} else {
new_level = list_create();
dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);

list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
&EMPTY_LIST, l_default,
reinterpret_cast<LIST*>(r_node->val), r_default,
shape, last_level, level + 1);
}

r_node = r_node->next;

} else if (r_node == NULL and l_node->key == index) {
/*
* One source list is empty, but the index has caught up to the key of
* the other list.
*/

if (level == last_level) {
tmp_result = *reinterpret_cast<LDType*>(l_node->val) * r_default;

if (tmp_result != d_default) {
dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
}

} else {
new_level = list_create();
dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);

list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
reinterpret_cast<LIST*>(r_node->val), l_default,
&EMPTY_LIST, r_default,
shape, last_level, level + 1);
}

l_node = l_node->next;

} else if (l_node != NULL and r_node != NULL and index == NM_MIN(l_node->key, r_node->key)) {
/*
* Neither list is empty and our index has caught up to one of the
* source lists.
*/

if (l_node->key == r_node->key) {

if (level == last_level) {
tmp_result = *reinterpret_cast<LDType*>(l_node->val) * *reinterpret_cast<RDType*>(r_node->val);

if (tmp_result != d_default) {
dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
}

} else {
new_level = list_create();
dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);

list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
reinterpret_cast<LIST*>(l_node->val), l_default,
reinterpret_cast<LIST*>(r_node->val), r_default,
shape, last_level, level + 1);
}

l_node = l_node->next;
r_node = r_node->next;

} else if (l_node->key < r_node->key) {
// Advance the left node knowing that the default value is OK.

if (index == 0) {
dest_node = list_insert(dest, false, index, new_level);
l_node = l_node->next;

} else /* if (l_node->key > r_node->key) */ {
// Advance the right node knowing that the default value is OK.

r_node = r_node->next;
}

} else {
dest_node = list_insert_after(dest_node, index, new_level);
/*
* Our index needs to catch up but the default value is OK. This
* conditional is here only for documentation and should be optimized
* out.
*/
}

l_node = l_node->next;
r_node = r_node->next;
}
}
}
Expand Down
24 changes: 22 additions & 2 deletions ext/nmatrix/util/sl_list.h
Expand Up @@ -91,6 +91,28 @@ NODE* list_insert(LIST* list, bool replace, size_t key, void* val);
NODE* list_insert_after(NODE* node, size_t key, void* val);
void* list_remove(LIST* list, size_t key);

template <typename Type>
inline NODE* list_insert_val_helper(LIST* list, NODE* node, size_t key, Type val) {
Type* val_mem = ALLOC(Type);
*val_mem = val;

if (node == NULL) {
return list_insert(list, false, key, val_mem);

} else {
return list_insert_after(node, key, val_mem);
}
}

inline NODE* list_insert_ptr_helper(LIST* list, NODE* node, size_t key, void* ptr) {
if (node == NULL) {
return list_insert(list, false, key, ptr);

} else {
return list_insert_after(node, key, ptr);
}
}

///////////
// Tests //
///////////
Expand Down Expand Up @@ -214,8 +236,6 @@ bool list_eqeq_list_template(const LIST* left, const LIST* right, const LDType*
return true;
}



/////////////
// Utility //
/////////////
Expand Down

0 comments on commit 764b97b

Please sign in to comment.