Skip to content
This repository
Browse code

Merge branch 'master' of github.com:chriswailes/nmatrix

  • Loading branch information...
commit d667ff88246b5b1adecc98ee9c76692bde5efaf0 2 parents 58f6baa + 764b97b
John Woods authored August 17, 2012
2  ext/nmatrix/nmatrix.cpp
@@ -1017,7 +1017,7 @@ static VALUE nm_xslice(int argc, VALUE* argv, void* (*slice_func)(STORAGE*, SLIC
1017 1017
       fprintf(stderr, "\n");
1018 1018
       */
1019 1019
 
1020  
-      DENSE_STORAGE* s = NM_DENSE_STORAGE(self);
  1020
+      //DENSE_STORAGE* s = NM_DENSE_STORAGE(self);
1021 1021
 
1022 1022
       if (NM_DTYPE(self) == RUBYOBJ)  result = *reinterpret_cast<VALUE*>( ttable[NM_STYPE(self)](NM_STORAGE(self), slice) );
1023 1023
       else                            result = rubyobj_from_cval( ttable[NM_STYPE(self)](NM_STORAGE(self), slice), NM_DTYPE(self) ).rval;
216  ext/nmatrix/storage/list.cpp
@@ -43,6 +43,7 @@
43 43
 #include "common.h"
44 44
 #include "list.h"
45 45
 
  46
+#include "util/math.h"
46 47
 #include "util/sl_list.h"
47 48
 
48 49
 /*
@@ -64,7 +65,10 @@ template <typename LDType, typename RDType>
64 65
 static bool list_storage_eqeq_template(const LIST_STORAGE* left, const LIST_STORAGE* right);
65 66
 
66 67
 template <typename LDType, typename RDType>
67  
-static void list_storage_ew_multiply_template(LIST* dest, const LIST* left, const LIST* right, size_t rank, const size_t* shape, size_t level);
  68
+static void* list_storage_ew_multiply_template(LIST* dest, const LIST* left, const void* l_default, const LIST* right, const void* r_default, const size_t* shape, size_t rank);
  69
+
  70
+template <typename LDType, typename RDType>
  71
+static void list_storage_ew_multiply_template_prime(LIST* dest, LDType d_default, const LIST* left, LDType l_default, const LIST* right, RDType r_default, const size_t* shape, size_t last_level, size_t level);
68 72
 
69 73
 /*
70 74
  * Functions
@@ -249,14 +253,40 @@ bool list_storage_eqeq(const STORAGE* left, const STORAGE* right) {
249 253
  * Documentation goes here.
250 254
  */
251 255
 STORAGE* list_storage_ew_multiply(const STORAGE* left, const STORAGE* right) {
252  
-	LR_DTYPE_TEMPLATE_TABLE(list_storage_ew_multiply_template, void, LIST*, const LIST*, const LIST*, size_t, const size_t*, size_t);
  256
+	LR_DTYPE_TEMPLATE_TABLE(list_storage_ew_multiply_template, void*, LIST*, const LIST*, const void*, const LIST*, const void*, const size_t*, size_t);
  257
+	
  258
+	dtype_t new_dtype = Upcast[left->dtype][right->dtype];
253 259
 	
254  
-	size_t* new_shape = (size_t*)calloc(left->rank, sizeof(size_t));
255  
-	memcpy(new_shape, left->shape, sizeof(size_t) * left->rank);
  260
+	const LIST_STORAGE* l = reinterpret_cast<const LIST_STORAGE*>(left),
  261
+										* r = reinterpret_cast<const LIST_STORAGE*>(right);
256 262
 	
257  
-	LIST_STORAGE* result = list_storage_create(left->dtype, new_shape, left->rank, NULL); 
  263
+	LIST_STORAGE* new_l = NULL;
258 264
 	
259  
-	ttable[left->dtype][right->dtype](result->rows, ((LIST_STORAGE*)left)->rows, ((LIST_STORAGE*)right)->rows, result->rank, result->shape, 0);
  265
+	// Allocate a new shape array for the resulting matrix.
  266
+	size_t* new_shape = (size_t*)calloc(l->rank, sizeof(size_t));
  267
+	memcpy(new_shape, left->shape, sizeof(size_t) * l->rank);
  268
+	
  269
+	// Create the result matrix.
  270
+	LIST_STORAGE* result = list_storage_create(new_dtype, new_shape, left->rank, NULL); 
  271
+	
  272
+	/*
  273
+	 * Call the templated elementwise multiplication function and set the default
  274
+	 * value for the resulting matrix.
  275
+	 */
  276
+	if (new_dtype != left->dtype) {
  277
+		// Upcast the left-hand side if necessary.
  278
+		new_l = reinterpret_cast<LIST_STORAGE*>(list_storage_cast_copy(l, new_dtype));
  279
+		
  280
+		result->default_val =
  281
+			ttable[left->dtype][right->dtype](result->rows, new_l->rows, new_l->default_val, r->rows, r->default_val, result->shape, result->rank);
  282
+		
  283
+		// Delete the temporary left-hand side matrix.
  284
+		list_storage_delete(reinterpret_cast<STORAGE*>(new_l));
  285
+			
  286
+	} else {
  287
+		result->default_val =
  288
+			ttable[left->dtype][right->dtype](result->rows, l->rows, l->default_val, r->rows, r->default_val, result->shape, result->rank);
  289
+	}
260 290
 	
261 291
 	return result;
262 292
 }
@@ -453,46 +483,168 @@ bool list_storage_eqeq_template(const LIST_STORAGE* left, const LIST_STORAGE* ri
453 483
  * Documentation goes here.
454 484
  */
455 485
 template <typename LDType, typename RDType>
456  
-static void list_storage_ew_multiply_template(LIST* dest, const LIST* left, const LIST* right, size_t rank, const size_t* shape, size_t level) {
457  
-	unsigned int index;
  486
+static void* list_storage_ew_multiply_template(LIST* dest, const LIST* left, const void* l_default, const LIST* right, const void* r_default, const size_t* shape, size_t rank) {
  487
+	
  488
+	/*
  489
+	 * Allocate space for, and calculate, the default value for the destination
  490
+	 * matrix.
  491
+	 */
  492
+	LDType* d_default_mem = ALLOC(LDType);
  493
+	*d_default_mem = *reinterpret_cast<const LDType*>(l_default) * *reinterpret_cast<const RDType*>(r_default);
  494
+	
  495
+	// Now that setup is done call the actual elementwise multiplication function.
  496
+	list_storage_ew_multiply_template_prime<LDType, RDType>(dest, *reinterpret_cast<const LDType*>(d_default_mem),
  497
+		left, *reinterpret_cast<const LDType*>(l_default), right, *reinterpret_cast<const RDType*>(r_default), shape, rank - 1, 0);
  498
+	
  499
+	// Return a pointer to the destination matrix's default value.
  500
+	return d_default_mem;
  501
+}
  502
+
  503
+/*
  504
+ * Documentation goes here.
  505
+ */
  506
+template <typename LDType, typename RDType>
  507
+static void list_storage_ew_multiply_template_prime(LIST* dest, LDType d_default, const LIST* left, LDType l_default, const LIST* right, RDType r_default, const size_t* shape, size_t last_level, size_t level) {
  508
+	
  509
+	static LIST EMPTY_LIST = {NULL};
  510
+	
  511
+	size_t index;
  512
+	
  513
+	LDType tmp_result;
458 514
 	
459  
-	LDType* new_val;
460  
-	LIST* new_level;
  515
+	LIST* new_level = NULL;
461 516
 	
462 517
 	NODE* l_node		= left->first,
463 518
 			* r_node		= right->first,
464 519
 			* dest_node	= NULL;
465 520
 	
466  
-	if (rank == (level + 1)) {
467  
-		for (index = 0; index < shape[level]; ++index) {
468  
-			new_val = ALLOC(LDType);
469  
-			*new_val = *reinterpret_cast<LDType*>(l_node->val) * *reinterpret_cast<RDType*>(r_node->val);
  521
+	for (index = 0; index < shape[level]; ++index) {
  522
+		if (l_node == NULL and r_node == NULL) {
  523
+			/*
  524
+			 * Both source lists are now empty.  Because the default value of the
  525
+			 * destination is already set appropriately we can now return.
  526
+			 */
470 527
 			
471  
-			if (index == 0) {
472  
-				dest_node = list_insert(dest, false, index, new_val);
  528
+			return;
  529
+			
  530
+		} else {
  531
+			// At least one list still has entries.
  532
+			
  533
+			if (l_node == NULL and (l_default == 0 and d_default == 0)) {
  534
+				/* 
  535
+				 * The left hand list has run out of elements.  We don't need to add new
  536
+				 * values to the destination if l_default and d_default are both 0.
  537
+				 */
473 538
 				
474  
-			} else {
475  
-				dest_node = list_insert_after(dest_node, index, new_val);
  539
+				return;
  540
+			
  541
+			} else if (r_node == NULL and (r_default == 0 and d_default == 0)) {
  542
+				/*
  543
+				 * The right hand list has run out of elements.  We don't need to add new
  544
+				 * values to the destination if r_default and d_default are both 0.
  545
+				 */
  546
+				
  547
+				return;
476 548
 			}
477 549
 			
478  
-			l_node = l_node->next;
479  
-			r_node = r_node->next;
480  
-		}
481  
-		
482  
-	} else {
483  
-		for (index = 0; index < shape[level]; ++index) {
484  
-			new_level = list_create();
485  
-			list_storage_ew_multiply_template<LDType, RDType>(new_level, reinterpret_cast<LIST*>(l_node->val), reinterpret_cast<LIST*>(r_node->val), rank, shape, level + 1);
  550
+			// We need to continue processing the lists.
  551
+			
  552
+			if (l_node == NULL and r_node->key == index) {
  553
+				/*
  554
+				 * One source list is empty, but the index has caught up to the key of
  555
+				 * the other list.
  556
+				 */
  557
+				
  558
+				if (level == last_level) {
  559
+					tmp_result = l_default * *reinterpret_cast<RDType*>(r_node->val);
  560
+					
  561
+					if (tmp_result != d_default) {
  562
+						dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
  563
+					}
  564
+					
  565
+				} else {
  566
+					new_level = list_create();
  567
+					dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);
  568
+				
  569
+					list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
  570
+						&EMPTY_LIST, l_default,
  571
+						reinterpret_cast<LIST*>(r_node->val), r_default,
  572
+						shape, last_level, level + 1);
  573
+				}
  574
+				
  575
+				r_node = r_node->next;
  576
+				
  577
+			} else if (r_node == NULL and l_node->key == index) {
  578
+				/*
  579
+				 * One source list is empty, but the index has caught up to the key of
  580
+				 * the other list.
  581
+				 */
  582
+				
  583
+				if (level == last_level) {
  584
+					tmp_result = *reinterpret_cast<LDType*>(l_node->val) * r_default;
  585
+					
  586
+					if (tmp_result != d_default) {
  587
+						dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
  588
+					}
  589
+					
  590
+				} else {
  591
+					new_level = list_create();
  592
+					dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);
  593
+				
  594
+					list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
  595
+						reinterpret_cast<LIST*>(r_node->val), l_default,
  596
+						&EMPTY_LIST, r_default,
  597
+						shape, last_level, level + 1);
  598
+				}
  599
+				
  600
+				l_node = l_node->next;
  601
+				
  602
+			} else if (l_node != NULL and r_node != NULL and index == NM_MIN(l_node->key, r_node->key)) {
  603
+				/*
  604
+				 * Neither list is empty and our index has caught up to one of the
  605
+				 * source lists.
  606
+				 */
  607
+				
  608
+				if (l_node->key == r_node->key) {
  609
+					
  610
+					if (level == last_level) {
  611
+						tmp_result = *reinterpret_cast<LDType*>(l_node->val) * *reinterpret_cast<RDType*>(r_node->val);
  612
+						
  613
+						if (tmp_result != d_default) {
  614
+							dest_node = list_insert_val_helper<LDType>(dest, dest_node, index, tmp_result);
  615
+						}
  616
+						
  617
+					} else {
  618
+						new_level = list_create();
  619
+						dest_node = list_insert_ptr_helper(dest, dest_node, index, new_level);
  620
+					
  621
+						list_storage_ew_multiply_template_prime<LDType, RDType>(new_level, d_default,
  622
+							reinterpret_cast<LIST*>(l_node->val), l_default,
  623
+							reinterpret_cast<LIST*>(r_node->val), r_default,
  624
+							shape, last_level, level + 1);
  625
+					}
  626
+				
  627
+					l_node = l_node->next;
  628
+					r_node = r_node->next;
  629
+			
  630
+				} else if (l_node->key < r_node->key) {
  631
+					// Advance the left node knowing that the default value is OK.
486 632
 			
487  
-			if (index == 0) {
488  
-				dest_node = list_insert(dest, false, index, new_level);
  633
+					l_node = l_node->next;
  634
+					 
  635
+				} else /* if (l_node->key > r_node->key) */ {
  636
+					// Advance the right node knowing that the default value is OK.
  637
+			
  638
+					r_node = r_node->next;
  639
+				}
489 640
 				
490 641
 			} else {
491  
-				dest_node = list_insert_after(dest_node, index, new_level);
  642
+				/*
  643
+				 * Our index needs to catch up but the default value is OK.  This
  644
+				 * conditional is here only for documentation and should be optimized
  645
+				 * out.
  646
+				 */
492 647
 			}
493  
-			
494  
-			l_node = l_node->next;
495  
-			r_node = r_node->next;
496 648
 		}
497 649
 	}
498 650
 }
24  ext/nmatrix/util/sl_list.h
@@ -91,6 +91,28 @@ NODE* list_insert(LIST* list, bool replace, size_t key, void* val);
91 91
 NODE* list_insert_after(NODE* node, size_t key, void* val);
92 92
 void* list_remove(LIST* list, size_t key);
93 93
 
  94
+template <typename Type>
  95
+inline NODE* list_insert_val_helper(LIST* list, NODE* node, size_t key, Type val) {
  96
+	Type* val_mem = ALLOC(Type);
  97
+	*val_mem = val;
  98
+	
  99
+	if (node == NULL) {
  100
+		return list_insert(list, false, key, val_mem);
  101
+		
  102
+	} else {
  103
+		return list_insert_after(node, key, val_mem);
  104
+	}
  105
+}
  106
+
  107
+inline NODE* list_insert_ptr_helper(LIST* list, NODE* node, size_t key, void* ptr) {
  108
+	if (node == NULL) {
  109
+		return list_insert(list, false, key, ptr);
  110
+		
  111
+	} else {
  112
+		return list_insert_after(node, key, ptr);
  113
+	}
  114
+}
  115
+
94 116
 ///////////
95 117
 // Tests //
96 118
 ///////////
@@ -214,8 +236,6 @@ bool list_eqeq_list_template(const LIST* left, const LIST* right, const LDType*
214 236
   return true;
215 237
 }
216 238
 
217  
-
218  
-
219 239
 /////////////
220 240
 // Utility //
221 241
 /////////////

0 notes on commit d667ff8

Please sign in to comment.
Something went wrong with that request. Please try again.