Skip to content

Commit

Permalink
Rewrite a few more transformations to use canonic trees.
Browse files Browse the repository at this point in the history
This allows removing the cached-simplifier hack. As a side
effect, it uncovered a bug in the simplify pass.
  • Loading branch information
angavrilov committed Sep 4, 2009
1 parent a9259be commit 0ffc1c9
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 186 deletions.
9 changes: 2 additions & 7 deletions expr/ref-info.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,8 @@
(if written
(push indexes (cadr rentry))
(push indexes (car rentry)))))))
(do-collect tree)
(let ((res-list nil))
(maphash
#'(lambda (key info)
(push (cons key info) res-list))
res-tbl)
res-list))))
(do-collect (canonic-expr-unwrap tree))
(hash-table-alist res-tbl))))

(defun get-ref-root (expr)
(match expr
Expand Down
9 changes: 8 additions & 1 deletion expr/simplify.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,17 @@
,(type integer divv)))
`(+ (,cmd (* ,marg ,mulv) ,divv)
(,cmd (,op ,remv) ,divv)))

((when (= (mod mulv divv) 0)
`(,(as cmd (or 'mod 'floor 'ceiling)) ; no truncate & rem !
(+ ,remv (* ,marg ,(type integer mulv)))
,(type integer divv)))
`(+ (,cmd ,remv ,divv)
(,cmd (* ,marg ,mulv) ,divv)))
;; Likewise for aligned constants
((when (= (mod remv divv) 0)
`(,(as cmd (or 'mod 'floor 'ceiling)) ; no truncate & rem !
(,(as op (or '+ '-)) ,exv ,remv)
(,(as op (or '+ '-)) ,exv ,(type integer remv))
,(type integer divv)))
`(+ (,cmd ,exv ,divv)
(,cmd (,op ,remv) ,divv)))
Expand Down
3 changes: 1 addition & 2 deletions formula.el
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ When Formula mode is enabled, code within {} is indented specially."

;; Import: misc-extensions
(put 'nlet 'common-lisp-indent-function
(or (get 'let 'common-lisp-indent-function)
'((&whole 4 &rest (&whole 1 1 2)) &body)))
'((&whole 4 &rest (&whole 1 1 1 1 1 1 1 1 1 &rest 1)) &body))

;; Import: cl-match
(put 'match 'common-lisp-indent-function
Expand Down
39 changes: 15 additions & 24 deletions gen/cuda-textures.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,21 @@
texture-refs) t)
rid))

(defun expand-aref-tex (expr old-expr)
(match expr
((when (@ cv-mval-set name)
`(aref (multivalue-data ,name ,@_) ,@idxvals))
(let* ((idx-cnt (length idxvals))
(stride-lst (aref-stride-list (second expr)
(1- idx-cnt) idx-cnt))
(ofs-lst (mapcar #'join* idxvals stride-lst))
(summands (sort-summands-by-level (cons 0.1 ofs-lst))))
(if (> idx-cnt 1)
`(texture-ref ,(register-ref name 2)
,(reduce #'join+ summands)
,(reduce #'join+
(sort-summands-by-level
(list 0.1 (car (last idxvals))))))
`(texture-ref-int ,(register-ref name 1)
,(car (last idxvals))))))))

(defun expand-rec (expr)
(simplify-rec-once
(cached-simplifier expand-aref-tex
`(aref ,@_)
(make-hash-table :test #'equal))
expr)))
(def-rewrite-pass expand-rec (:canonic t)
((when (@ cv-mval-set name)
`(aref (multivalue-data ,name ,@_) ,@idxvals))
(let* ((idx-cnt (length idxvals))
(stride-lst (aref-stride-list (second expr) (1- idx-cnt) idx-cnt))
(ofs-lst (mapcar #'join* idxvals stride-lst))
(summands (sort-summands-by-level (cons 0.1 ofs-lst))))
(if (> idx-cnt 1)
`(texture-ref ,(register-ref name 2)
,(reduce #'join+ summands)
,(reduce #'join+
(sort-summands-by-level
(list 0.1 (car (last idxvals))))))
`(texture-ref-int ,(register-ref name 1)
,(car (last idxvals))))))))

(defun use-textures (tex-set expr)
(with-context (cuda-texture-transform tex-set)
Expand Down
43 changes: 17 additions & 26 deletions gen/expand-aref.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
(x (list x))))

(defun sort-summands-by-level (exprs)
(let* ((ofs-items (get-summands
(flatten-exprs (make-canonic `(+ ,@exprs)))))
(let* ((ofs-items (pipeline `(+ ,@exprs)
make-canonic flatten-exprs get-summands))
(num-ofs-items (remove-if-not #'numberp ofs-items))
(var-ofs-items (remove-if #'numberp ofs-items))
(levels (sort (remove-duplicates
Expand All @@ -44,27 +44,18 @@
(`(arr-ptr (temporary ,@_)) (second expr))
(`(arr-dim (temporary ,_ ,dims ,@_) ,i ,_) (nth i dims)))

(defun expand-aref-1 (expr old-expr)
(match expr
(`(aref ,name ,@idxvals)
(let* ((idx-cnt (length idxvals))
(stride-lst (aref-stride-list name idx-cnt))
(ofs-lst (mapcar #'join* idxvals stride-lst))
(summands (sort-summands-by-level ofs-lst)))
`(ptr-deref ,(reduce #'(lambda (base ofs)
`(ptr+ ,base ,ofs))
summands
:initial-value `(arr-ptr ,name)))))
(`(tmp-ref ,name)
nil)
(`(tmp-ref ,name ,@idxvals)
(let ((rexpr (expand-aref-1 `(aref ,name ,@idxvals) old-expr)))
(simplify-index (eval-temporary-dims rexpr))))
(_ nil)))

(defun expand-aref (expr)
(simplify-rec-once
(cached-simplifier expand-aref-1
`(,(or 'aref 'tmp-ref) ,@_)
(make-hash-table :test #'equal))
expr))
(def-rewrite-pass expand-aref (:canonic t)
(`(aref ,name ,@idxvals)
(let* ((idx-cnt (length idxvals))
(stride-lst (aref-stride-list name idx-cnt))
(ofs-lst (mapcar #'join* idxvals stride-lst))
(summands (sort-summands-by-level ofs-lst)))
`(ptr-deref ,(reduce #'(lambda (base ofs)
`(ptr+ ,base ,ofs))
summands
:initial-value `(arr-ptr ,name)))))
(`(tmp-ref ,name)
nil)
(`(tmp-ref ,name ,@idxvals)
(let ((rexpr (expand-aref-1 `(aref ,name ,@idxvals) old-expr)))
(simplify-index (eval-temporary-dims rexpr)))))
56 changes: 22 additions & 34 deletions gen/localize-temp.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,37 @@

(in-package fast-compute)

(def-rewrite-pass replace-temp-refs ((replacement-tbl) :canonic t)
((when (gethash temp replacement-tbl) `(tmp-ref ,temp ,@_))
`(tmp-ref ,(gethash temp replacement-tbl))))

(defun localize-temps (expr ref-list range-list)
(let* ((inner-index (second (car (last range-list))))
(cluster-level
(convert 'set
(mapcar #'ranging-loop-level
(remove-if-not #'(lambda (rg)
(and
(eql inner-index
(get (second rg) 'band-master))
(get (second rg) 'is-cluster)))
range-list))))
(let* ((inner-index (second (lastcar range-list)))
(cluster-range (remove-if-not #'(lambda (rg)
(and (eql inner-index
(get (second rg) 'band-master))
(get (second rg) 'is-cluster)))
range-list))
(cluster-level (convert 'set (mapcar #'ranging-loop-level cluster-range)))
(local-temps (remove-if-not
#'(lambda (ref-entry)
(ifmatch
`(temporary ,_ ,dims 0 ,@_)
(first ref-entry)
(ifmatch `(temporary ,_ ,dims 0 ,@_) (first ref-entry)
(and (= (length dims) 1)
(null
(set-difference
(third ref-entry) (second ref-entry)
:test #'equal))
(null (set-difference (third ref-entry) (second ref-entry)
:test #'equal))
(every #'(lambda (x)
(and
(= (length x) 1)
(equal?
(set-difference
(get-loop-levels (first x))
cluster-level)
(set 0))))
(append
(third ref-entry) (second ref-entry))))))
(and (= (length x) 1)
(equal? (set-difference (get-loop-levels (first x))
cluster-level)
(set 0))))
(append (third ref-entry)
(second ref-entry))))))
ref-list))
(replacement-tbl (make-hash-table :test #'equal)))
(if (null local-temps)
expr
(progn
(dolist (ref-entry local-temps)
(setf (gethash (first ref-entry) replacement-tbl)
`(temporary ,(second (first ref-entry))
nil 0 :local)))
(simplify-rec-once #'(lambda (form old-form)
(match form
((when (gethash temp replacement-tbl)
`(tmp-ref ,temp ,@_))
`(tmp-ref ,(gethash temp replacement-tbl)))))
expr)))))
`(temporary ,(second (first ref-entry)) nil 0 :local)))
(replace-temp-refs expr replacement-tbl)))))
45 changes: 25 additions & 20 deletions gen/tgt-cuda.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -404,27 +404,32 @@
(make-compute-loops name idxspec expr with
where carrying precompute
:force-cluster t)
(let* ((nomacro-expr (expand-macros loop-expr))
(nolet-expr (expand-let nomacro-expr))
(noiref-expr (simplify-iref nolet-expr))
;; Apply optimizations
(let* ((noiref-expr (pipeline loop-expr
expand-macros expand-let make-canonic
simplify-iref))
;; A table of all array references
(ref-list (collect-arefs noiref-expr))
;; (opt-expr (optimize-tree noiref-expr))
(ltemp-expr (localize-temps noiref-expr ref-list range-list)))
(multiple-value-bind (tex-expr tex-list)
(use-textures (convert 'set textures) ltemp-expr)
(let ((c-levels (remove nil (get-check-level-set))))
(unless (null c-levels)
(error "Safety checks not supported by CUDA:~% ~A"
(mapcan #'get-checks-for-level c-levels))))
;; Convert temporary arrays to registers where possible
(ltemp-expr (localize-temps noiref-expr ref-list range-list)))
(nlet (;; Use textures where requested
(tex-expr tex-list (use-textures (convert 'set textures) ltemp-expr))
;; Apply final transformations
((res-expr (pipeline tex-expr
expand-aref optimize-tree
(code-motion _ :pull-symbols t)))
;; Inner check levels
(c-levels (remove nil (get-check-level-set)))))
;; Only top-level safety checks
(unless (null c-levels)
(error "Safety checks not supported by CUDA:~% ~A"
(mapcan #'get-checks-for-level c-levels)))
;; Generate the kernel call
(wrap-compute-sync-data :cuda-device ref-list
`(let ((*current-compute* ',original))
,(insert-checks nil)
,(compile-expr-cuda
(cond-list (t :name kernel-name)
(tex-list :textures tex-list)
(max-registers
:max-registers max-registers))
*loop-cluster-size* spill-to-shared
range-list
(code-motion (expand-aref (optimize-tree tex-expr))
:pull-symbols t))))))))))
,(compile-expr-cuda (cond-list (t :name kernel-name)
(tex-list :textures tex-list)
(max-registers :max-registers max-registers))
*loop-cluster-size* spill-to-shared
range-list res-expr)))))))))
18 changes: 11 additions & 7 deletions gen/tgt-inline-c.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,19 @@
(*consistency-checks* (make-hash-table :test #'equal)))
(multiple-value-bind (loop-expr loop-list range-list)
(make-compute-loops name idxspec expr with where carrying precompute)
(let* ((nomacro-expr (expand-macros loop-expr))
(nolet-expr (expand-let nomacro-expr))
(noiref-expr (simplify-iref nolet-expr))
;; Apply optimizations
(let* ((noiref-expr (pipeline loop-expr
expand-macros expand-let make-canonic
simplify-iref))
;; A table of all array references
(ref-list (collect-arefs noiref-expr))
;; (opt-expr (optimize-tree noiref-expr))
(noaref-expr (expand-aref noiref-expr))
(check-expr (insert-checks noaref-expr)))
;; Apply final transformations
(res-expr (pipeline noiref-expr
expand-aref canonic-expr-unwrap
insert-checks)))
;; Generate the computation code
(wrap-compute-sync-data :host ref-list
(wrap-compute-parallel parallel range-list check-expr
(wrap-compute-parallel parallel range-list res-expr
#'(lambda (code)
`(let ((*current-compute* ',original))
,(compile-expr-generic
Expand Down
26 changes: 18 additions & 8 deletions gen/tgt-lisp.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,29 @@
(*consistency-checks* (make-hash-table :test #'equal)))
(multiple-value-bind (loop-expr loop-list range-list)
(make-compute-loops name idxspec expr with where carrying precompute)
(let* ((nolet-expr (expand-let loop-expr))
(noiref-expr (simplify-iref nolet-expr))
;; Apply optimizations
(let* ((noiref-expr (pipeline loop-expr
expand-let make-canonic
simplify-iref))
;; A table of all array references
(ref-list (collect-arefs noiref-expr))
(check-expr (insert-checks noiref-expr))
(motion-expr (code-motion check-expr))
(annot-expr (annotate-types motion-expr)))
;; Apply final transformations
(res-expr (pipeline noiref-expr
canonic-expr-unwrap
insert-checks code-motion
annotate-types)))
;; Generate the computation code
(wrap-compute-sync-data :host ref-list
(wrap-compute-parallel parallel range-list
`(let ((*current-compute* ',original)
(*current-compute-body* ',motion-expr))
(*current-compute-body* ',res-expr))
(declare (optimize (safety 1) (debug 1)))
,annot-expr)))))))
,res-expr)))))))

(defmacro calc (exprs)
(annotate-types (code-motion (simplify-iref (expand-let (macroexpand-1 `(letv ,exprs)))))))
(pipeline `(letv ,exprs)
macroexpand-1 expand-let make-canonic
simplify-iref
canonic-expr-unwrap
code-motion annotate-types))

59 changes: 25 additions & 34 deletions logic/iref.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -117,37 +117,28 @@
(dolist (arg rest) (check-index-alignment arg iref-expr aref-expr))
nil)))

(defun simplify-iref-1 (expr old-expr)
(match expr
(`(iref ,name ,@idxvals)
(multiple-value-bind (rexpr mv-p checks)
(expand-iref name idxvals :verbose-p *consistency-checks*)
(unless mv-p
(error "Not a multivalue reference: ~A" expr))
(when (eql mv-p :macro)
(return-from simplify-iref-1
(simplify-iref rexpr)))
(check-index-alignment rexpr expr rexpr)
(when *consistency-checks*
;; Remember bound consistency checks
(dolist (check checks)
(incf-nil (gethash check *consistency-checks*)))
;; Create dimension consistency checks
(do ((dims (get name 'mv-dimensions) (cdr dims))
(rank (length (get name 'mv-dimensions)))
(idx 0 (1+ idx)))
((null dims) nil)
(incf-nil
(gethash `(<= ,(car dims)
(arr-dim (multivalue-data ,name t)
,idx ,rank))
*consistency-checks*))))
;; Return the expression
rexpr))
(_ nil)))

(defun simplify-iref (expr)
(simplify-rec-once (cached-simplifier simplify-iref-1
`(iref ,@_)
(make-hash-table :test #'equal))
expr))
(def-rewrite-pass simplify-iref (:canonic t)
(`(iref ,name ,@idxvals)
(multiple-value-bind (rexpr mv-p checks)
(expand-iref name idxvals :verbose-p *consistency-checks*)
(unless mv-p
(error "Not a multivalue reference: ~A" expr))
(if (eql mv-p :macro)
(simplify-iref (make-canonic rexpr))
(prog1
rexpr
;; Verify before returning
(check-index-alignment rexpr expr rexpr)
(when *consistency-checks*
;; Remember bound consistency checks
(dolist (check checks)
(incf-nil (gethash check *consistency-checks*)))
;; Create dimension consistency checks
(do ((dims (get name 'mv-dimensions) (cdr dims))
(rank (length (get name 'mv-dimensions)))
(idx 0 (1+ idx)))
((null dims) nil)
(incf-nil (gethash `(<= ,(car dims)
(arr-dim (multivalue-data ,name t)
,idx ,rank))
*consistency-checks*)))))))))
Loading

0 comments on commit 0ffc1c9

Please sign in to comment.