Find file
Fetching contributors…
Cannot retrieve contributors at this time
535 lines (504 sloc) 18.2 KB
#lang scheme
(require srfi/1
"typed-vector.ss")
;;; (struct nd-array (shape
;;; type
;;; order
;;; rank
;;; size
;;; stride
;;; offset
;;; data))
;;; shape : (listof exact-nonnegative-integer?)
;;; type : vtype?
;;; order : (symbols 'row 'column)
;;; rank : exact-nonnegative-integer?
;;; size : exact-nonnegative-integer?
;;; stride : (listof exact-nonnegative-integer?)
;;; offset " exact-nonnegative-integer?
;;; data : typed-vector?
(define-struct nd-array
(shape
type
order
rank
size
stride
offset
data)
#:mutable
#:property
prop:custom-write
(lambda (nd-array port write?)
(write-string (format "#<array ~a ~a>"
(nd-array-shape nd-array)
(vtype-specifier (nd-array-type nd-array)))
port)))
;;; Rename the nd-array accessors for export.
(define array? nd-array?)
(define array-shape nd-array-shape)
(define array-type nd-array-type)
(define array-order nd-array-order)
(define array-rank nd-array-rank)
(define array-size nd-array-size)
(define array-stride nd-array-stride)
(define array-offset nd-array-offset)
(define array-data nd-array-data)
(define (array-dimension array i)
(if (< i (array-rank array))
(list-ref (array-shape array) i)
0))
;;; (describe-array array) -> void?
;;; array : array?
;;; Print a description of an array object for debugging.
(define (describe-array array)
(printf "~a is an array object~n" array)
(printf " shape = ~a~n" (array-shape array))
(printf " type = ~a~n" (vtype-specifier (array-type array)))
(printf " order = ~a~n" (array-order array))
(printf " rank = ~a~n" (array-rank array))
(printf " size = ~a~n" (array-size array))
(printf " stride = ~a~n" (array-stride array))
(printf " offset = ~a~n" (array-offset array))
(printf " contents = ~a~n" (typed-vector->list (array-data array))))
;;; (compute-addressing shape order element-size)
;;; -> exact-nonnegative-integer?
;;; exact-nonnegative-integer?
;;; (listof exact-nonnegative-integer?)
;;; shape : (listof exact-nonnegative-integer?)
;;; order : (symbols 'row 'column)
;;; element-size : exact-nonnegative-integer?
;;; Returns the number of dimensions, size, and strides given the shape, order,
;;; and element size of an array object.
(define (compute-addressing shape order element-size)
(if (null? shape)
(values 0 element-size '())
(if (eq? order 'row)
(let ((dim (car shape)))
(let-values (((rank size stride)
(compute-addressing
(cdr shape) order element-size)))
(if (= dim 1)
(values (+ rank 1) size (cons 0 stride))
(values (+ rank 1) (* dim size) (cons size stride)))))
(let ((dim (car shape)))
(let-values (((rank size stride)
(compute-addressing
(cdr shape) order (* dim element-size))))
(if (= dim 1)
(values (+ rank 1) size (cons 0 stride))
(values (+ rank 1) size (cons element-size stride))))))))
;;; (make-array shape
;;; [#:type vtype-or-symbol]
;;; [#:order order]
;;; [#:fill fill]) -> array?
;;; shape : (listof exact-nonnegative-integer?)
;;; v-type-or-symbol : vtype-or-symbol? = object
;;; order : (symbols 'row 'column) = 'row
;;; fill : any/c = (void)
;;; Creates an array with the specified shape.
(define (make-array shape
#:type (vtype-or-symbol object)
#:order (order 'row)
#:fill (fill (void)))
(let-values (((rank size stride)
(compute-addressing shape order 1)))
(let* ((type (vtype-or-symbol->vtype vtype-or-symbol))
(data (apply make-typed-vector type size
(if (void? fill) '() (list fill)))))
(make-nd-array shape
type
order
rank
size
stride
0
data))))
;;; (generate-indices shape order)
;;; -> (listof (listof exact-nonnegative-integer?))
;;; shape : (listof exact-nonnegative-integer?)
;;; order : (symbols 'row 'column)
(define (generate-indices shape order)
(if (null? shape)
'(())
(let ((subindices (generate-indices (cdr shape) order)))
(if (eq? order 'row)
(for*/list ((i (in-range (car shape)))
(subindex (in-list subindices)))
(cons i subindex))
(for*/list ((subindex (in-list subindices))
(i (in-range (car shape))))
(cons i subindex))))))
;;; (build-array shape proc
;;; [#:type vtype-or-symbol]
;;; [#:order order]) -> array?
;;; shape : (listof exact-nonnegative-integer?)
;;; proc : procedure?
;;; vtype-or-symbol? : vtype-or-symbol? = object
;;; order : (symbols 'row 'column) = 'row
(define (build-array shape proc
#:type (vtype-or-symbol object)
#:order (order 'row))
(let-values (((rank size stride)
(compute-addressing shape order 1)))
(let* ((type (vtype-or-symbol->vtype vtype-or-symbol))
(data (make-typed-vector type size)))
(for ((index (in-list (generate-indices shape order)))
(i (in-naturals)))
(typed-vector-set! data i (apply proc index)))
(make-nd-array shape
type
order
rank
size
stride
0
data))))
;;; (arange size
;;; [#:type vtype-or-symbol]
;;; [#:order order]) -> array?
;;; size : exact-nonnegative-integer?
;;; vtype-or-symbol : vtype-or-symbol? = object
;;; order : (symbols 'row 'column) = 'row
;;; Creates a one-dimensional array object of the specified size initialized with
;;; the natural numbers 0 ... size-1.
(define (arange size
#:type (vtype-or-symbol object)
#:order (order 'row))
(let* ((type (vtype-or-symbol->vtype vtype-or-symbol))
(data (make-typed-vector type size)))
(for ((i (in-range size)))
(typed-vector-set! data i i))
(make-nd-array (list size)
type
order
1
size
'(1)
0
data)))
;;; (transpose array) -> array?
;;; array : array?
;;; Returns an array object that is the transpose of array.
(define (transpose array)
(let ((shape (reverse (array-shape array)))
(type (array-type array))
(order (if (eq? (array-order array) 'row) 'column 'row))
(data (array-data array)))
(let-values (((rank size stride)
(compute-addressing shape order 1)))
(make-nd-array shape
type
order
rank
size
stride
0
data))))
;;; (reshape array shape) -> array?
;;; array : array?
;;; shape : (listof exact-nonnegative-integer?)
(define (reshape array shape)
(let ((type (array-type array))
(order (array-order array))
(data (array-data array)))
(let-values (((rank size stride)
(compute-addressing shape order 1)))
(unless (= size (nd-array-size array))
(error 'reshape
"new shape size error, ~a" shape))
(make-nd-array shape
type
order
rank
size
stride
0
data))))
;;; (ref->index ref stride) -> exact-nonnegative-integer?
;;; ref : (listof exact-nonnegative-integer?)
;;; stride : (listof exact-nonnegative-integer?)
(define (ref->index ref stride offset)
(if (null? ref)
offset
(+ (* (car ref) (car stride))
(ref->index (cdr ref) (cdr stride) offset))))
;;; (ref->addressing ref shape stride offset)
;;; -> exact-nonnegative-integer?
;;; exact-nonnegative-integer?
;;; (listof exact-nonnegative-integer?)
;;; (listof exact-nonnegative-integer?)
;;; exact-nonnegative-integer?
;;; ref : list?
;;; shape : (listof exact-nonnegative-integer?)
;;; stride : (listof exact-nonnegative-integer?)
;;; offset : exact-nonnegative-integer?
(define (ref->addressing ref shape stride offset)
(if (null? ref)
(values 0 1 '() '() offset)
(let ((index (car ref))
(dim (car shape))
(mult (car stride)))
(cond ((exact-nonnegative-integer? index)
(ref->addressing (cdr ref) (cdr shape) (cdr stride)
(+ (* index mult) offset)))
((eq? index '*)
(let-values
(((new-rank new-size new-shape new-stride new-offset)
(ref->addressing
(cdr ref) (cdr shape) (cdr stride) offset)))
(values (+ new-rank 1)
(* dim new-size)
(cons dim new-shape)
(cons mult new-stride)
new-offset)))
((list? index)
(let*-values
(((start stop step)
(let ((start (car index))
(stop (if (> (length index) 1) (cadr index) '*))
(step (if (> (length index) 2) (caddr index) '*)))
(values (if (eq? start '*) 0 start)
(if (eq? stop '*) (- dim 1) stop)
(if (eq? step '*) 1 step))))
((new-rank new-size new-shape new-stride new-offset)
(ref->addressing
(cdr ref) (cdr shape) (cdr stride)
(+ (* mult start) offset))))
(let ((new-dim (ceiling (/ (+ (- stop start) 1) step)))
(new-mult (* step (car stride))))
(values (+ new-rank 1)
(* new-dim new-size)
(cons new-dim new-shape)
(cons new-mult new-stride)
new-offset))))
(else
(error "unknown reference index" index))))))
;;; (array-ref array . ref) -> (or/c array? any/c)
;;; array : array?
;;; ref : (listof exact-nonnegative-integer?)
(define (array-ref array . ref)
(let-values
(((rank size shape stride offset)
(ref->addressing
ref (array-shape array) (array-stride array) (array-offset array))))
(if (= rank 0)
(typed-vector-ref (array-data array) offset)
(make-nd-array shape
(array-type array)
(array-order array)
rank
size
stride
offset
(array-data array)))))
;;; (array-ref* array ref) -> (or/c array? any/c)
;;; array : array?
;;; ref : (listof exact-nonnegative-integer?)
(define (array-ref* array ref)
(apply array-ref array ref))
;;; (real array) -> array?
;;; array : array?
(define (real array)
(if (char=? (string-ref (symbol->string
(vtype-specifier (array-type array))) 0) #\c)
(let-values (((rank size stride)
(compute-addressing (array-shape array)
(array-order array)
2)))
(make-nd-array (array-shape array)
(cond ((eq? (array-type array) cf32) f32)
((eq? (array-type array) cf64) f64))
(array-order array)
rank
size
stride
0
(typed-vector-base (array-data array))))
array))
;;; (imag array) -> array?
;;; array : array?
(define (imag array)
(if (char=? (string-ref (symbol->string
(vtype-specifier (array-type array))) 0) #\c)
(let-values (((rank size stride)
(compute-addressing (array-shape array)
(array-order array)
2)))
(make-nd-array (array-shape array)
(cond ((eq? (array-type array) cf32) f32)
((eq? (array-type array) cf64) f64))
(array-order array)
rank
size
stride
1
(typed-vector-base (array-data array))))
array))
;;; (generate-list array ref) -> list?
;;; array : array?
;;; ref : (listof exact-nonnegative-integer?)
(define (generate-list array ref)
(let ((shape (array-shape array)))
(if (= (length ref) (length shape))
(array-ref* array ref)
(for/list ((i (in-range (list-ref shape (length ref)))))
(generate-list array (append ref (list i)))))))
;;; (array->list array) -> list?
;;; array : array?
(define (array->list array)
(generate-list array '()))
;;; (generate-print array ref) -> void?
;;; array : array?
;;; ref : (listof exact-nonnegative-integer?)
(define (generate-print array ref)
(let ((shape (array-shape array)))
(if (= (length ref) (length shape))
(printf (if (= (last ref) 0) "~a" " ~a")
(array-ref* array ref))
(begin
(when (and (not (null? ref))
(> (- (length shape) (length ref)) 1)
(> (last ref) 0))
(printf "~n"))
(printf "[")
(for ((i (in-range (list-ref shape (length ref)))))
(generate-print array (append ref (list i))))
(printf "]")
(unless (or (null? ref)
(= (last ref) (- (list-ref shape (- (length ref) 1)) 1)))
(printf "~n"))))))
;;; (print-array array) -> void?
;;; array : array?
(define (print-array array)
(generate-print array '())
(printf "~n"))
;;; (estimate-shape item) -> (listof exact-nonnegative-integer?)
;;; item : any/c
(define (estimate-shape item)
(if (not (pair? item))
'()
(cons (length item) (estimate-shape (car item)))))
;;; (valid? shape item) -> boolean
;;; shape : (listof exact-nonnegative-integer?
;;; item : any/c
(define (valid? shape item)
(if (null? shape)
(if (pair? item) #f #t)
(if (pair? item)
(let ((dim (car shape)))
(if (and (= (length item) dim)
(every
(lambda (subitem)
(valid? (cdr shape) subitem))
item))
#t
#f))
#f)))
;;; (generate-data-list item order) -> list?
;;; item : any/c
;;; order : (symbols 'row 'column)
(define (generate-data-list item order)
(if (not (pair? item))
(list item)
(if (eq? order 'row)
(for/fold ((data-list '()))
((subitem (in-list item)))
(append data-list (generate-data-list subitem order)))
(let-values (((first-subitems rest-subitems)
(for/fold ((first-items '())
(rest-items '()))
((subitem (in-list item)))
(values (append first-items (list (car subitem)))
(append rest-items (list (cdr subitem)))))))
(append first-subitems
(if (null? (car rest-subitems))
'()
(generate-data-list rest-subitems order)))))))
;;; (list->array lst
;;; [#:type vtype-or-symbol]
;;; [#:order order]) -> array?
;;; lst : list?
;;; vtype-or-symbol : (or/c vtype? symbol?)
;;; order : (symbols 'row 'column)
(define (list->array lst
#:type (vtype-or-symbol object)
#:order (order 'row))
(let ((shape (estimate-shape lst)))
(unless (valid? shape lst)
(error 'list->array
"list does not conform to shape ~a" shape))
(let-values (((rank size stride)
(compute-addressing shape order 1)))
(let* ((type (vtype-or-symbol->vtype vtype-or-symbol))
(data (list->typed-vector
type (generate-data-list lst order))))
(make-nd-array shape
type
order
rank
size
stride
0
data)))))
;;; Module Contracts
(provide object
u8 u16 u32 u64 s8 s16 s32 s64 f32 f64
cu8 cu16 cu32 cu64 cs8 cs16 cs32 cs64 cf32 cf64)
(provide/contract
(vtype?
(-> any/c boolean?))
(array?
(-> any/c boolean?))
(array-shape
(-> array? (listof exact-nonnegative-integer?)))
(array-type
(-> array? vtype?))
(array-order
(-> array? (symbols 'row 'column)))
(array-rank
(-> array? exact-nonnegative-integer?))
(array-size
(-> array? exact-nonnegative-integer?))
(array-dimension
(-> array? exact-nonnegative-integer? exact-nonnegative-integer?))
(describe-array
(-> array? void?))
(make-array
(->* ((listof exact-nonnegative-integer?))
(#:type (or/c vtype? symbol?)
#:order (symbols 'row 'column)
#:fill any/c)
array?))
(build-array
(->* ((listof exact-nonnegative-integer?)
procedure?)
(#:type (or/c vtype? symbol?)
#:order (symbols 'row 'column))
array?))
(arange
(->* (exact-nonnegative-integer?)
(#:type (or/c vtype? symbol?)
#:order (symbols 'row 'column))
array?))
(transpose
(-> array? array?))
(reshape
(-> array? (listof exact-nonnegative-integer?) array?))
(array-ref
(->* (array?) () #:rest (listof exact-nonnegative-integer?) any/c))
(array-ref*
(-> array? list? any/c))
(real
(-> array? array?))
(imag
(-> array? array?))
(array->list
(-> array? list?))
(print-array
(-> array? void?))
(list->array
(->* (list?)
(#:type (or/c vtype? symbol?)
#:order (symbols 'row 'column))
array?)))