Permalink
Browse files

Fix matrix vetor multiplication for symmetric/hermitian matrices

*hemv works for hermitian and not for symmetric matrices.
  • Loading branch information...
1 parent 94cc0f8 commit d8887e9a25840876f13cd9729c9d3e3b73df36b1 @Shimuuar committed Aug 28, 2012
Showing with 46 additions and 21 deletions.
  1. +15 −7 Numeric/BLAS.hs
  2. +9 −2 Numeric/BLAS/Expression.hs
  3. +22 −12 Numeric/BLAS/Mutable/Unsafe.hs
View
@@ -47,7 +47,8 @@ import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Strided as V
-- Concrete matrices
import Data.Matrix.Dense (Matrix)
-import Data.Matrix.Symmetric (Symmetric)
+import Data.Matrix.Symmetric
+ (SymmetricRaw,IsSymmetric,IsHermitian,Conjugate(..),NumberType,IsReal)
import qualified Numeric.BLAS.Mutable as M
@@ -254,12 +255,19 @@ instance (BLAS2 a, a ~ a') => Mul (Conjugated Matrix a) (S.Vector a') where
{-# INLINE (.*.) #-}
--- instance (BLAS2 a, a ~ a') => Mul (Symmetric a) (S.Vector a') where
--- type MulRes (Symmetric a)
--- (S.Vector a')
--- = S.Vector a
--- m .*. v = eval $ MulMV () (Lit m) (Lit v)
--- {-# INLINE (.*.) #-}
+instance (BLAS2 a, Conjugate a, a ~ a') => Mul (SymmetricRaw IsHermitian a) (S.Vector a') where
+ type MulRes (SymmetricRaw IsHermitian a)
+ (S.Vector a')
+ = S.Vector a
+ m .*. v = eval $ MulMV () (Lit m) (Lit v)
+ {-# INLINE (.*.) #-}
+
+instance (BLAS2 a, NumberType a ~ IsReal, a ~ a') => Mul (SymmetricRaw IsSymmetric a) (S.Vector a') where
+ type MulRes (SymmetricRaw IsSymmetric a)
+ (S.Vector a')
+ = S.Vector a
+ m .*. v = eval $ MulMV () (Lit m) (Lit v)
+ {-# INLINE (.*.) #-}
View
@@ -460,10 +460,17 @@ instance Storable a => Freeze MatD.Matrix a where
unsafeFreeze = Mat.unsafeFreeze
unsafeThaw = Mat.unsafeThaw
-instance Storable a => Clonable MMatS.MSymmetric a where
+instance (Storable a) => Clonable (MMatS.MSymmetricRaw MMatS.IsSymmetric) a where
cloneShape = MMat.cloneShape
clone = MMat.clone
-instance Storable a => Freeze MatS.Symmetric a where
+instance Storable a => Freeze (MatS.SymmetricRaw MMatS.IsSymmetric) a where
+ unsafeFreeze = Mat.unsafeFreeze
+ unsafeThaw = Mat.unsafeThaw
+
+instance (Storable a, MMatS.Conjugate a) => Clonable (MMatS.MSymmetricRaw MMatS.IsHermitian) a where
+ cloneShape = MMat.cloneShape
+ clone = MMat.clone
+instance (Storable a, MMatS.Conjugate a) => Freeze (MatS.SymmetricRaw MMatS.IsHermitian) a where
unsafeFreeze = Mat.unsafeFreeze
unsafeThaw = Mat.unsafeThaw
@@ -1,3 +1,4 @@
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
-- |
@@ -66,7 +67,8 @@ import qualified Data.Vector.Storable.Strided.Mutable as V
import qualified Data.Matrix.Generic.Mutable as M
-- Concrete matrices
import Data.Matrix.Dense.Mutable (MMatrix(..))
-import Data.Matrix.Symmetric.Mutable (MSymmetric(..))
+import Data.Matrix.Symmetric.Mutable
+ (MSymmetricRaw(..),IsSymmetric,IsHermitian,Conjugate(..), NumberType, IsReal)
@@ -206,18 +208,26 @@ instance S.Storable a => MultTMV MMatrix a where
b py (blasStride y)
{-# INLINE unsafeMultTMV #-}
--- instance S.Storable a => MultMV MSymmetric a where
--- unsafeMultMV α (MSymmetric n lda fp) x β y
--- = unsafePrimToPrim
--- $ withForeignPtr fp $ \pa ->
--- withForeignPtr (blasFPtr x) $ \px ->
--- withForeignPtr (blasFPtr y) $ \py ->
--- BLAS.hemv ColMajor Upper
--- n α pa lda
--- px (blasStride x)
--- β py (blasStride y)
--- {-# INLINE unsafeMultMV #-}
+instance (S.Storable a, NumberType a ~ IsReal) => MultMV (MSymmetricRaw IsSymmetric) a where
+ unsafeMultMV = multHermtianMV
+ {-# INLINE unsafeMultMV #-}
+instance (S.Storable a, Conjugate a) => MultMV (MSymmetricRaw IsHermitian) a where
+ unsafeMultMV = multHermtianMV
+ {-# INLINE unsafeMultMV #-}
+-- Worker for symmetric/hermitian matrix multiplication
+multHermtianMV :: (PrimMonad m, BLAS2 a, MVectorBLAS v)
+ => a -> MSymmetricRaw tag s a -> v s a -> a -> v s a -> m ()
+{-# INLINE multHermtianMV #-}
+multHermtianMV α (MSymmetricRaw n lda fp) x β y
+ = unsafePrimToPrim
+ $ withForeignPtr fp $ \pa ->
+ withForeignPtr (blasFPtr x) $ \px ->
+ withForeignPtr (blasFPtr y) $ \py ->
+ BLAS.hemv ColMajor Upper
+ n α pa lda
+ px (blasStride x)
+ β py (blasStride y)
-- | Compute vector-vector product:

0 comments on commit d8887e9

Please sign in to comment.