Skip to content

Commit

Permalink
add and use a blas_suffix macro for all blas and lapack symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
tkelman committed Oct 20, 2014
1 parent 7f596d4 commit 555bc3f
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 239 deletions.
182 changes: 91 additions & 91 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module BLAS

import ..axpy!
import Base.copy!
import Base: copy!, @blas_suffix

export
# Level 1
Expand Down Expand Up @@ -57,10 +57,10 @@ import ..LinAlg: BlasReal, BlasComplex, BlasFloat, BlasChar, BlasInt, blas_int,

# Level 1
## copy
for (fname, elty) in ((:dcopy_,:Float64),
(:scopy_,:Float32),
(:zcopy_,:Complex128),
(:ccopy_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dcopy_),:Float64),
(@blas_suffix(:scopy_),:Float32),
(@blas_suffix(:zcopy_),:Complex128),
(@blas_suffix(:ccopy_),:Complex64))
@eval begin
# SUBROUTINE DCOPY(N,DX,INCX,DY,INCY)
function blascopy!(n::Integer, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer, DY::Union(Ptr{$elty},StridedArray{$elty}), incy::Integer)
Expand All @@ -73,10 +73,10 @@ for (fname, elty) in ((:dcopy_,:Float64),
end

## scal
for (fname, elty) in ((:dscal_,:Float64),
(:sscal_,:Float32),
(:zscal_,:Complex128),
(:cscal_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dscal_),:Float64),
(@blas_suffix(:sscal_),:Float32),
(@blas_suffix(:zscal_),:Complex128),
(@blas_suffix(:cscal_),:Complex64))
@eval begin
# SUBROUTINE DSCAL(N,DA,DX,INCX)
function scal!(n::Integer, DA::$elty, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer)
Expand All @@ -89,8 +89,8 @@ for (fname, elty) in ((:dscal_,:Float64),
end
scal(n, DA, DX, incx) = scal!(n, DA, copy(DX), incx)
# In case DX is complex, and DA is real, use dscal/sscal to save flops
for (fname, elty, celty) in ((:sscal_, :Float32, :Complex64),
(:dscal_, :Float64, :Complex128))
for (fname, elty, celty) in ((@blas_suffix(:sscal_), :Float32, :Complex64),
(@blas_suffix(:dscal_), :Float64, :Complex128))
@eval begin
function scal!(n::Integer, DA::$elty, DX::Union(Ptr{$celty},StridedArray{$celty}), incx::Integer)
ccall(($(string(fname)),libblas), Void,
Expand All @@ -102,8 +102,8 @@ for (fname, elty, celty) in ((:sscal_, :Float32, :Complex64),
end

## dot
for (fname, elty) in ((:ddot_,:Float64),
(:sdot_,:Float32))
for (fname, elty) in ((@blas_suffix(:ddot_),:Float64),
(@blas_suffix(:sdot_),:Float32))
@eval begin
# DOUBLE PRECISION FUNCTION DDOT(N,DX,INCX,DY,INCY)
# * .. Scalar Arguments ..
Expand All @@ -118,8 +118,8 @@ for (fname, elty) in ((:ddot_,:Float64),
end
end
end
for (fname, elty) in ((:cblas_zdotc_sub,:Complex128),
(:cblas_cdotc_sub,:Complex64))
for (fname, elty) in ((@blas_suffix(:cblas_zdotc_sub),:Complex128),
(@blas_suffix(:cblas_cdotc_sub),:Complex64))
@eval begin
# DOUBLE PRECISION FUNCTION DDOT(N,DX,INCX,DY,INCY)
# * .. Scalar Arguments ..
Expand All @@ -136,8 +136,8 @@ for (fname, elty) in ((:cblas_zdotc_sub,:Complex128),
end
end
end
for (fname, elty) in ((:cblas_zdotu_sub,:Complex128),
(:cblas_cdotu_sub,:Complex64))
for (fname, elty) in ((@blas_suffix(:cblas_zdotu_sub),:Complex128),
(@blas_suffix(:cblas_cdotu_sub),:Complex64))
@eval begin
# DOUBLE PRECISION FUNCTION DDOT(N,DX,INCX,DY,INCY)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -171,10 +171,10 @@ function dotu{T<:BlasComplex}(DX::StridedArray{T}, DY::StridedArray{T})
end

## nrm2
for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
(:snrm2_,:Float32,:Float32),
(:dznrm2_,:Complex128,:Float64),
(:scnrm2_,:Complex64,:Float32))
for (fname, elty, ret_type) in ((@blas_suffix(:dnrm2_),:Float64,:Float64),
(@blas_suffix(:snrm2_),:Float32,:Float32),
(@blas_suffix(:dznrm2_),:Complex128,:Float64),
(@blas_suffix(:scnrm2_),:Complex64,:Float32))
@eval begin
# SUBROUTINE DNRM2(N,X,INCX)
function nrm2(n::Integer, X::Union(Ptr{$elty},StridedVector{$elty}), incx::Integer)
Expand All @@ -188,10 +188,10 @@ nrm2(x::StridedVector) = nrm2(length(x), x, stride(x,1))
nrm2(x::Array) = nrm2(length(x), pointer(x), 1)

## asum
for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
(:sasum_,:Float32,:Float32),
(:dzasum_,:Complex128,:Float64),
(:scasum_,:Complex64,:Float32))
for (fname, elty, ret_type) in ((@blas_suffix(:dasum_),:Float64,:Float64),
(@blas_suffix(:sasum_),:Float32,:Float32),
(@blas_suffix(:dzasum_),:Complex128,:Float64),
(@blas_suffix(:scasum_),:Complex64,:Float32))
@eval begin
# SUBROUTINE ASUM(N, X, INCX)
function asum(n::Integer, X::Union(Ptr{$elty},StridedVector{$elty}), incx::Integer)
Expand All @@ -205,10 +205,10 @@ asum(x::StridedVector) = asum(length(x), x, stride(x,1))
asum(x::Array) = asum(length(x), pointer(x), 1)

## axpy
for (fname, elty) in ((:daxpy_,:Float64),
(:saxpy_,:Float32),
(:zaxpy_,:Complex128),
(:caxpy_,:Complex64))
for (fname, elty) in ((@blas_suffix(:daxpy_),:Float64),
(@blas_suffix(:saxpy_),:Float32),
(@blas_suffix(:zaxpy_),:Complex128),
(@blas_suffix(:caxpy_),:Complex64))
@eval begin
# SUBROUTINE DAXPY(N,DA,DX,INCX,DY,INCY)
# DY <- DA*DX + DY
Expand Down Expand Up @@ -243,10 +243,10 @@ function axpy!{T<:BlasFloat,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::
end

## iamax
for (fname, elty) in ((:idamax_,:Float64),
(:isamax_,:Float32),
(:izamax_,:Complex128),
(:icamax_,:Complex64))
for (fname, elty) in ((@blas_suffix(:idamax_),:Float64),
(@blas_suffix(:isamax_),:Float32),
(@blas_suffix(:izamax_),:Complex128),
(@blas_suffix(:icamax_),:Complex64))
@eval begin
function iamax(n::BlasInt, dx::Union(StridedVector{$elty}, Ptr{$elty}), incx::BlasInt)
ccall(($(string(fname)), libblas),BlasInt,
Expand All @@ -260,10 +260,10 @@ iamax(dx::StridedVector) = iamax(length(dx), dx, 1)
# Level 2
## mv
### gemv
for (fname, elty) in ((:dgemv_,:Float64),
(:sgemv_,:Float32),
(:zgemv_,:Complex128),
(:cgemv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dgemv_),:Float64),
(@blas_suffix(:sgemv_),:Float32),
(@blas_suffix(:zgemv_),:Complex128),
(@blas_suffix(:cgemv_),:Complex64))
@eval begin
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
Expand Down Expand Up @@ -294,10 +294,10 @@ for (fname, elty) in ((:dgemv_,:Float64),
end

### (GB) general banded matrix-vector multiplication
for (fname, elty) in ((:dgbmv_,:Float64),
(:sgbmv_,:Float32),
(:zgbmv_,:Complex128),
(:cgbmv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dgbmv_),:Float64),
(@blas_suffix(:sgbmv_),:Float32),
(@blas_suffix(:zgbmv_),:Complex128),
(@blas_suffix(:cgbmv_),:Complex64))
@eval begin
# SUBROUTINE DGBMV(TRANS,M,N,KL,KU,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -329,10 +329,10 @@ for (fname, elty) in ((:dgbmv_,:Float64),
end

### symv
for (fname, elty) in ((:dsymv_,:Float64),
(:ssymv_,:Float32),
(:zsymv_,:Complex128),
(:csymv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dsymv_),:Float64),
(@blas_suffix(:ssymv_),:Float32),
(@blas_suffix(:zsymv_),:Complex128),
(@blas_suffix(:csymv_),:Complex64))
# Note that the complex symv are not BLAS but auiliary functions in LAPACK
@eval begin
# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
Expand Down Expand Up @@ -365,8 +365,8 @@ for (fname, elty) in ((:dsymv_,:Float64),
end

### hemv
for (fname, elty) in ((:zhemv_,:Complex128),
(:chemv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:zhemv_),:Complex128),
(@blas_suffix(:chemv_),:Complex64))
@eval begin
function hemv!(uplo::Char, α::$elty, A::StridedMatrix{$elty}, x::StridedVector{$elty}, β::$elty, y::StridedVector{$elty})
n = size(A, 2)
Expand Down Expand Up @@ -394,10 +394,10 @@ for (fname, elty) in ((:zhemv_,:Complex128),
end

### sbmv, (SB) symmetric banded matrix-vector multiplication
for (fname, elty) in ((:dsbmv_,:Float64),
(:ssbmv_,:Float32),
(:zsbmv_,:Complex128),
(:csbmv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dsbmv_),:Float64),
(@blas_suffix(:ssbmv_),:Float32),
(@blas_suffix(:zsbmv_),:Complex128),
(@blas_suffix(:csbmv_),:Complex64))
@eval begin
# SUBROUTINE DSBMV(UPLO,N,K,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -427,10 +427,10 @@ for (fname, elty) in ((:dsbmv_,:Float64),
end

### trmv, Triangular matrix-vector multiplication
for (fname, elty) in ((:dtrmv_,:Float64),
(:strmv_,:Float32),
(:ztrmv_,:Complex128),
(:ctrmv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dtrmv_),:Float64),
(@blas_suffix(:strmv_),:Float32),
(@blas_suffix(:ztrmv_),:Complex128),
(@blas_suffix(:ctrmv_),:Complex64))
@eval begin
# SUBROUTINE DTRMV(UPLO,TRANS,DIAG,N,A,LDA,X,INCX)
# * .. Scalar Arguments ..
Expand All @@ -456,10 +456,10 @@ for (fname, elty) in ((:dtrmv_,:Float64),
end
end
### trsv, Triangular matrix-vector solve
for (fname, elty) in ((:dtrsv_,:Float64),
(:strsv_,:Float32),
(:ztrsv_,:Complex128),
(:ctrsv_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dtrsv_),:Float64),
(@blas_suffix(:strsv_),:Float32),
(@blas_suffix(:ztrsv_),:Complex128),
(@blas_suffix(:ctrsv_),:Complex64))
@eval begin
# SUBROUTINE DTRSV(UPLO,TRANS,DIAG,N,A,LDA,X,INCX)
# .. Scalar Arguments ..
Expand All @@ -484,10 +484,10 @@ for (fname, elty) in ((:dtrsv_,:Float64),
end

### ger
for (fname, elty) in ((:dger_,:Float64),
(:sger_,:Float32),
(:zgerc_,:Complex128),
(:cgerc_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dger_),:Float64),
(@blas_suffix(:sger_),:Float32),
(@blas_suffix(:zgerc_),:Complex128),
(@blas_suffix(:cgerc_),:Complex64))
@eval begin
function ger!::$elty, x::StridedVector{$elty}, y::StridedVector{$elty}, A::StridedMatrix{$elty})
m, n = size(A)
Expand All @@ -506,10 +506,10 @@ for (fname, elty) in ((:dger_,:Float64),
end

### syr
for (fname, elty) in ((:dsyr_,:Float64),
(:ssyr_,:Float32),
(:zsyr_,:Complex128),
(:csyr_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dsyr_),:Float64),
(@blas_suffix(:ssyr_),:Float32),
(@blas_suffix(:zsyr_),:Complex128),
(@blas_suffix(:csyr_),:Complex64))
@eval begin
function syr!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
Expand All @@ -525,8 +525,8 @@ for (fname, elty) in ((:dsyr_,:Float64),
end

### her
for (fname, elty) in ((:zher_,:Complex128),
(:cher_,:Complex64))
for (fname, elty) in ((@blas_suffix(:zher_),:Complex128),
(@blas_suffix(:cher_),:Complex64))
@eval begin
function her!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
Expand All @@ -544,10 +544,10 @@ end
# Level 3
## (GE) general matrix-matrix multiplication
for (gemm, elty) in
((:dgemm_,:Float64),
(:sgemm_,:Float32),
(:zgemm_,:Complex128),
(:cgemm_,:Complex64))
((@blas_suffix(:dgemm_),:Float64),
(@blas_suffix(:sgemm_),:Float32),
(@blas_suffix(:zgemm_),:Complex128),
(@blas_suffix(:cgemm_),:Complex64))
@eval begin
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -587,10 +587,10 @@ for (gemm, elty) in
end

## (SY) symmetric matrix-matrix and matrix-vector multiplication
for (mfname, elty) in ((:dsymm_,:Float64),
(:ssymm_,:Float32),
(:zsymm_,:Complex128),
(:csymm_,:Complex64))
for (mfname, elty) in ((@blas_suffix(:dsymm_),:Float64),
(@blas_suffix(:ssymm_),:Float32),
(@blas_suffix(:zsymm_),:Complex128),
(@blas_suffix(:csymm_),:Complex64))
@eval begin
# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
Expand Down Expand Up @@ -622,10 +622,10 @@ for (mfname, elty) in ((:dsymm_,:Float64),
end

## syrk
for (fname, elty) in ((:dsyrk_,:Float64),
(:ssyrk_,:Float32),
(:zsyrk_,:Complex128),
(:csyrk_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dsyrk_),:Float64),
(@blas_suffix(:ssyrk_),:Float32),
(@blas_suffix(:zsyrk_),:Complex128),
(@blas_suffix(:csyrk_),:Complex64))
@eval begin
# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -659,7 +659,7 @@ function syrk(uplo::BlasChar, trans::BlasChar, alpha::Number, A::StridedVecOrMat
end
syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat) = syrk(uplo, trans, one(eltype(A)), A)

for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
for (fname, elty) in ((@blas_suffix(:zherk_),:Complex128), (@blas_suffix(:cherk_),:Complex64))
@eval begin
# SUBROUTINE CHERK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
Expand Down Expand Up @@ -692,10 +692,10 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
end

## syr2k
for (fname, elty) in ((:dsyr2k_,:Float64),
(:ssyr2k_,:Float32),
(:zsyr2k_,:Complex128),
(:csyr2k_,:Complex64))
for (fname, elty) in ((@blas_suffix(:dsyr2k_),:Float64),
(@blas_suffix(:ssyr2k_),:Float32),
(@blas_suffix(:zsyr2k_),:Complex128),
(@blas_suffix(:csyr2k_),:Complex64))
@eval begin
# SUBROUTINE DSYR2K(UPLO,TRANS,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
#
Expand Down Expand Up @@ -731,7 +731,7 @@ function syr2k(uplo::BlasChar, trans::BlasChar, alpha::Number, A::StridedVecOrMa
end
syr2k(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat, B::StridedVecOrMat) = syr2k(uplo, trans, one(eltype(A)), A, B)

for (fname, elty1, elty2) in ((:zher2k_,:Complex128,:Float64), (:cher2k_,:Complex64,:Float32))
for (fname, elty1, elty2) in ((@blas_suffix(:zher2k_),:Complex128,:Float64), (@blas_suffix(:cher2k_),:Complex64,:Float32))
@eval begin
# SUBROUTINE CHER2K(UPLO,TRANS,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
#
Expand Down Expand Up @@ -768,10 +768,10 @@ end

## (TR) Triangular matrix and vector multiplication and solution
for (mmname, smname, elty) in
((:dtrmm_,:dtrsm_,:Float64),
(:strmm_,:strsm_,:Float32),
(:ztrmm_,:ztrsm_,:Complex128),
(:ctrmm_,:ctrsm_,:Complex64))
((@blas_suffix(:dtrmm_),@blas_suffix(:dtrsm_),:Float64),
(@blas_suffix(:strmm_),@blas_suffix(:strsm_),:Float32),
(@blas_suffix(:ztrmm_),@blas_suffix(:ztrsm_),:Complex128),
(@blas_suffix(:ctrmm_),@blas_suffix(:ctrsm_),:Complex64))
@eval begin
# SUBROUTINE DTRMM(SIDE,UPLO,TRANSA,DIAG,M,N,ALPHA,A,LDA,B,LDB)
# * .. Scalar Arguments ..
Expand Down
Loading

0 comments on commit 555bc3f

Please sign in to comment.