Permalink
Browse files

Assist the use of the MADD instruction in some cases.

Implement MADD-aware treeify of sums when requested via
:cuda-flags (:treeify-madd t); also force MADD in two
special cases of ordered inner loop code.
  • Loading branch information...
1 parent 409b9f2 commit ff13f4f89eeeba7eadaa1e1e421cb532b85b6270 @angavrilov committed Sep 20, 2009
Showing with 83 additions and 13 deletions.
  1. +43 −5 expr/opt-treeify.lisp
  2. +5 −4 gen/expand-aref.lisp
  3. +22 −3 gen/splice-carried.lisp
  4. +3 −1 gen/tgt-cuda.lisp
  5. +4 −0 gen/tgt-generic-c.lisp
  6. +6 −0 util-misc.lisp
View
@@ -2,6 +2,10 @@
(in-package fast-compute)
+(use-std-readtable)
+
+(defvar *treeify-madd* nil)
+
(defun split-parts (lst)
(let* ((len (length lst))
(half (floor len 2)))
@@ -23,6 +27,40 @@
`(+ ,(treeify+ a)
,(treeify+ b)))))))
+(defun has-mul-p (x)
+ (match x
+ (`(* ,@_) t)))
+
+(defun group-madd (bag)
+ (let* ((mul-exs (filter #f(canonic-in has-mul-p _) bag))
+ (muls (image #f(cons 1 _) mul-exs))
+ (nmuls (image #f(cons 1 _) (bag-difference bag mul-exs))))
+ (while (and (not (empty? muls))
+ (> (+ (size muls) (size nmuls)) 1))
+ (let ((mulv (removef-least muls))
+ (addv (if (empty? nmuls)
+ (removef-least muls)
+ (removef-least nmuls))))
+ (adjoinf nmuls
+ (cons (1+ (max (car mulv) (car addv)))
+ (make-canonic `(+ ,(cdr mulv)
+ ,(cdr addv)))))))
+ (canonic-bag-to-list (image #'cdr (union muls nmuls)))))
+
+(defun merge-shuffle (arr1 arr2)
+ (cond ((null arr1) arr2)
+ ((null arr2) arr1)
+ (t (list* (car arr1) (car arr2)
+ (merge-shuffle (cdr arr1) (cdr arr2))))))
+
+(defun treeify-madd (args)
+ (nlet ((minus plus (split-list #'has-minus-p args))
+ ((positive (list-to-canonic-bag plus))
+ (negative (list-to-canonic-bag (mapcar #'toggle-minus minus)))))
+ (treeify+ (merge-shuffle (group-madd positive)
+ (mapcar #'toggle-minus
+ (group-madd negative))))))
+
(defun treeify* (args)
(labels ((is-div (x) (match x (`(/ ,_) t)))
(do-tree (args)
@@ -40,11 +78,13 @@
`(/ ,(do-tree muls) ,(do-tree divs)))
(do-tree muls)))))
-(def-rewrite-pass treeify ()
+(def-rewrite-pass treeify (:canonic t)
(`(+ ,x) x)
(`(- ,_) expr)
(`(+ ,@args)
- (treeify+ args))
+ (if *treeify-madd*
+ (treeify-madd args)
+ (treeify+ args)))
(`(* ,@args)
(treeify* args)))
@@ -57,6 +97,4 @@
(defun optimize-tree (expr)
(pipeline (make-canonic expr)
flatten-exprs pull-minus pull-factors
- optimize-ifsign expand-ifsign
- canonic-expr-unwrap
- treeify))
+ treeify canonic-expr-unwrap))
@@ -33,10 +33,11 @@
(mapcar #'min-loop-level var-ofs-items))
#'level>))
(ofs-groups (mapcar #'(lambda (lvl)
- (treeify
- `(+ ,@(remove lvl var-ofs-items
- :test-not #'eql
- :key #'min-loop-level))))
+ (pipeline
+ `(+ ,@(remove lvl var-ofs-items
+ :test-not #'eql
+ :key #'min-loop-level))
+ make-canonic treeify canonic-expr-unwrap))
levels)))
(nconc ofs-groups num-ofs-items)))
@@ -2,6 +2,8 @@
(in-package fast-compute)
+(use-std-readtable)
+
(defun mark-carried-low (expr low-table crefs carried-p)
(multiple-value-bind (cur-carried found)
(gethash expr low-table)
@@ -104,10 +106,27 @@
(`(setf ,lhs ,rhs)
`(setf ,lhs
,(wrap-item rhs cur-high cur-low)))
+
+ ;; (a*b - c) -> (a*b + (-c)) if c is split
+ ((when (and cur-low cur-high
+ (gethash emul high-table)
+ (not (gethash (unwrap-factored rt) high-table)))
+ `(- ,(as emul `(* ,_ ,_)) ,rt))
+ `(+ ,emul
+ ,(wrap-let (copy-tags (unwrap-factored rt) `(- ,rt))
+ nil)))
+
+ ;; 1/(c - a*b) -> -1/(a*b+(-c)) if c is split
+ ((when (and cur-low cur-high
+ (gethash emul high-table)
+ (not (gethash (unwrap-factored rt) high-table)))
+ `(/ (- ,rt ,(as emul `(* ,_ ,_)))))
+ `(/ (- (+ ,emul
+ ,(wrap-let (copy-tags (unwrap-factored rt) `(- ,rt))
+ nil)))))
+
((type list _)
- (mapcar-save-old #'(lambda (item)
- (wrap-item item cur-high cur-low))
- expr))
+ (mapcar-save-old #f(wrap-item _ cur-high cur-low) expr))
(_
expr)))))
expr)))
View
@@ -374,7 +374,8 @@
(block-size 128)
(max-registers nil)
(textures nil)
- (spill-to-shared t))
+ (spill-to-shared t)
+ (treeify-madd *treeify-madd*))
cuda-flags
(let* ((*current-compute* original)
(*simplify-cache* (make-hash-table))
@@ -383,6 +384,7 @@
(*canonify-cache* (make-canonify-cache))
(*consistency-checks* (make-hash-table :test #'equal))
(*loop-cluster-size* block-size)
+ (*treeify-madd* treeify-madd)
(*align-cluster* 16))
(multiple-value-bind (loop-expr loop-list range-list)
(make-compute-loops name idxspec expr with
@@ -56,6 +56,10 @@
(recurse v)))
((when (eql form-type 'float)
+ (or `(- (/ ,x)) `(/ (- ,x))))
+ (code "-1.0f/(" x ")"))
+
+ ((when (eql form-type 'float)
`(/ ,x))
(code "1.0f/(" x ")"))
View
@@ -12,6 +12,12 @@
(apply #'concatenate 'string
(mapcar #'unsymbol items))))
+(defmacro removef-least (bag)
+ (with-gensyms (lv)
+ `(let ((,lv (least ,bag)))
+ (removef ,bag ,lv)
+ ,lv)))
+
(defun force-integer (expr)
(if (integerp expr) expr nil))

0 comments on commit ff13f4f

Please sign in to comment.