Skip to content

Commit

Permalink
Fix matrix vetor multiplication for symmetric/hermitian matrices
Browse files Browse the repository at this point in the history
*hemv works for hermitian and not for symmetric matrices.
  • Loading branch information
Shimuuar committed Aug 28, 2012
1 parent 94cc0f8 commit d8887e9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
22 changes: 15 additions & 7 deletions Numeric/BLAS.hs
Expand Up @@ -47,7 +47,8 @@ import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Strided as V
-- Concrete matrices
import Data.Matrix.Dense (Matrix)
import Data.Matrix.Symmetric (Symmetric)
import Data.Matrix.Symmetric
(SymmetricRaw,IsSymmetric,IsHermitian,Conjugate(..),NumberType,IsReal)

import qualified Numeric.BLAS.Mutable as M

Expand Down Expand Up @@ -254,12 +255,19 @@ instance (BLAS2 a, a ~ a') => Mul (Conjugated Matrix a) (S.Vector a') where
{-# INLINE (.*.) #-}


-- instance (BLAS2 a, a ~ a') => Mul (Symmetric a) (S.Vector a') where
-- type MulRes (Symmetric a)
-- (S.Vector a')
-- = S.Vector a
-- m .*. v = eval $ MulMV () (Lit m) (Lit v)
-- {-# INLINE (.*.) #-}
instance (BLAS2 a, Conjugate a, a ~ a') => Mul (SymmetricRaw IsHermitian a) (S.Vector a') where
type MulRes (SymmetricRaw IsHermitian a)
(S.Vector a')
= S.Vector a
m .*. v = eval $ MulMV () (Lit m) (Lit v)
{-# INLINE (.*.) #-}

instance (BLAS2 a, NumberType a ~ IsReal, a ~ a') => Mul (SymmetricRaw IsSymmetric a) (S.Vector a') where
type MulRes (SymmetricRaw IsSymmetric a)
(S.Vector a')
= S.Vector a
m .*. v = eval $ MulMV () (Lit m) (Lit v)
{-# INLINE (.*.) #-}



Expand Down
11 changes: 9 additions & 2 deletions Numeric/BLAS/Expression.hs
Expand Up @@ -460,10 +460,17 @@ instance Storable a => Freeze MatD.Matrix a where
unsafeFreeze = Mat.unsafeFreeze
unsafeThaw = Mat.unsafeThaw

instance Storable a => Clonable MMatS.MSymmetric a where
instance (Storable a) => Clonable (MMatS.MSymmetricRaw MMatS.IsSymmetric) a where
cloneShape = MMat.cloneShape
clone = MMat.clone
instance Storable a => Freeze MatS.Symmetric a where
instance Storable a => Freeze (MatS.SymmetricRaw MMatS.IsSymmetric) a where
unsafeFreeze = Mat.unsafeFreeze
unsafeThaw = Mat.unsafeThaw

instance (Storable a, MMatS.Conjugate a) => Clonable (MMatS.MSymmetricRaw MMatS.IsHermitian) a where
cloneShape = MMat.cloneShape
clone = MMat.clone
instance (Storable a, MMatS.Conjugate a) => Freeze (MatS.SymmetricRaw MMatS.IsHermitian) a where
unsafeFreeze = Mat.unsafeFreeze
unsafeThaw = Mat.unsafeThaw

Expand Down
34 changes: 22 additions & 12 deletions Numeric/BLAS/Mutable/Unsafe.hs
@@ -1,3 +1,4 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
-- |
Expand Down Expand Up @@ -66,7 +67,8 @@ import qualified Data.Vector.Storable.Strided.Mutable as V
import qualified Data.Matrix.Generic.Mutable as M
-- Concrete matrices
import Data.Matrix.Dense.Mutable (MMatrix(..))
import Data.Matrix.Symmetric.Mutable (MSymmetric(..))
import Data.Matrix.Symmetric.Mutable
(MSymmetricRaw(..),IsSymmetric,IsHermitian,Conjugate(..), NumberType, IsReal)



Expand Down Expand Up @@ -206,18 +208,26 @@ instance S.Storable a => MultTMV MMatrix a where
b py (blasStride y)
{-# INLINE unsafeMultTMV #-}

-- instance S.Storable a => MultMV MSymmetric a where
-- unsafeMultMV α (MSymmetric n lda fp) x β y
-- = unsafePrimToPrim
-- $ withForeignPtr fp $ \pa ->
-- withForeignPtr (blasFPtr x) $ \px ->
-- withForeignPtr (blasFPtr y) $ \py ->
-- BLAS.hemv ColMajor Upper
-- n α pa lda
-- px (blasStride x)
-- β py (blasStride y)
-- {-# INLINE unsafeMultMV #-}
instance (S.Storable a, NumberType a ~ IsReal) => MultMV (MSymmetricRaw IsSymmetric) a where
unsafeMultMV = multHermtianMV
{-# INLINE unsafeMultMV #-}
instance (S.Storable a, Conjugate a) => MultMV (MSymmetricRaw IsHermitian) a where
unsafeMultMV = multHermtianMV
{-# INLINE unsafeMultMV #-}

-- Worker for symmetric/hermitian matrix multiplication
multHermtianMV :: (PrimMonad m, BLAS2 a, MVectorBLAS v)
=> a -> MSymmetricRaw tag s a -> v s a -> a -> v s a -> m ()
{-# INLINE multHermtianMV #-}
multHermtianMV α (MSymmetricRaw n lda fp) x β y
= unsafePrimToPrim
$ withForeignPtr fp $ \pa ->
withForeignPtr (blasFPtr x) $ \px ->
withForeignPtr (blasFPtr y) $ \py ->
BLAS.hemv ColMajor Upper
n α pa lda
px (blasStride x)
β py (blasStride y)


-- | Compute vector-vector product:
Expand Down

0 comments on commit d8887e9

Please sign in to comment.