Skip to content
Browse files

Change implementation of symmetric matrix and add hermitian too

  • Loading branch information...
1 parent 66d98d0 commit 8590a11b3e72c875a06c10de0a3a2be9932103c8 @Shimuuar committed
Showing with 231 additions and 54 deletions.
  1. +61 −12 Data/Matrix/Symmetric.hs
  2. +153 −25 Data/Matrix/Symmetric/Mutable.hs
  3. +6 −6 Numeric/BLAS.hs
  4. +11 −11 Numeric/BLAS/Mutable/Unsafe.hs
View
73 Data/Matrix/Symmetric.hs
@@ -14,6 +14,17 @@
module Data.Matrix.Symmetric (
-- * Matrix data type
Symmetric
+ , Hermitian
+ -- ** Implementation
+ , SymmetricRaw
+ , IsSymmetric
+ , IsHermitian
+ -- * Complex numbers
+ , NumberType
+ , IsReal
+ , IsComplex
+ , castSymmetric
+ , Conjugate(..)
) where
import Control.Monad.Primitive
@@ -26,34 +37,72 @@ import Foreign.ForeignPtr
import Foreign.Storable
import qualified Data.Matrix.Symmetric.Mutable as M
+import Data.Matrix.Symmetric.Mutable
+ ( IsSymmetric, IsHermitian, NumberType, IsReal, IsComplex, Conjugate(..) )
import Data.Matrix.Generic
+import Unsafe.Coerce
+
----------------------------------------------------------------
-- Data type
----------------------------------------------------------------
--- | Immutable dense matrix
-data Symmetric a = Symmetric {-# UNPACK #-} !Int -- N of rows
- {-# UNPACK #-} !Int -- Leading dim size
- {-# UNPACK #-} !(ForeignPtr a)
+-- | Immutable symmetric or hermitian matrix
+data SymmetricRaw tag a = SymmetricRaw
+ {-# UNPACK #-} !Int -- N of rows
+ {-# UNPACK #-} !Int -- Leading dim size
+ {-# UNPACK #-} !(ForeignPtr a)
deriving ( Typeable )
-type instance G.Mutable Symmetric = M.MSymmetric
+type Symmetric = SymmetricRaw IsSymmetric
+
+type Hermitian = SymmetricRaw IsHermitian
+
+
+type instance G.Mutable (SymmetricRaw tag) = M.MSymmetricRaw tag
-instance NFData (Symmetric a)
-instance Storable a => IsMatrix Symmetric a where
- basicRows (Symmetric n _ _) = n
+
+instance NFData (SymmetricRaw tag a)
+
+instance Storable a => IsMatrix (SymmetricRaw IsSymmetric) a where
+ basicRows (SymmetricRaw n _ _) = n
{-# INLINE basicRows #-}
- basicCols (Symmetric n _ _) = n
+ basicCols (SymmetricRaw n _ _) = n
{-# INLINE basicCols #-}
- basicUnsafeIndex (Symmetric _ lda fp) (M.symmIndex -> (i,j))
+ basicUnsafeIndex (SymmetricRaw _ lda fp) (M.symmIndex -> (i,j))
= unsafeInlineIO $ withForeignPtr fp $ \p -> peekElemOff p (i + lda * j)
{-# INLINE basicUnsafeIndex #-}
- basicUnsafeThaw (Symmetric n lda fp) = return $! M.MSymmetric n lda fp
+ basicUnsafeThaw (SymmetricRaw n lda fp) = return $! M.MSymmetricRaw n lda fp
{-# INLINE basicUnsafeThaw #-}
- basicUnsafeFreeze (M.MSymmetric n lda fp) = return $! Symmetric n lda fp
+ basicUnsafeFreeze (M.MSymmetricRaw n lda fp) = return $! SymmetricRaw n lda fp
{-# INLINE basicUnsafeFreeze #-}
+
+
+instance (M.Conjugate a, Storable a) => IsMatrix (SymmetricRaw IsHermitian) a where
+ basicRows (SymmetricRaw n _ _) = n
+ {-# INLINE basicRows #-}
+ basicCols (SymmetricRaw n _ _) = n
+ {-# INLINE basicCols #-}
+ basicUnsafeIndex (SymmetricRaw _ lda fp) (M.symmIndex -> (i,j))
+ = unsafeInlineIO
+ $ withForeignPtr fp $ \p ->
+ case () of
+ _| i > j -> conjugateNum `fmap` peekElemOff p (j + i*lda)
+ | otherwise -> peekElemOff p (i + j*lda)
+ {-# INLINE basicUnsafeIndex #-}
+ basicUnsafeThaw (SymmetricRaw n lda fp) = return $! M.MSymmetricRaw n lda fp
+ {-# INLINE basicUnsafeThaw #-}
+ basicUnsafeFreeze (M.MSymmetricRaw n lda fp) = return $! SymmetricRaw n lda fp
+ {-# INLINE basicUnsafeFreeze #-}
+
+
+-- | Cast between symmetric and hermitian matrices is data parameter
+-- is real.
+castSymmetric :: (NumberType a ~ IsReal)
+ => SymmetricRaw tag a -> SymmetricRaw tag' a
+{-# INLINE castSymmetric #-}
+castSymmetric = unsafeCoerce
View
178 Data/Matrix/Symmetric/Mutable.hs
@@ -1,3 +1,5 @@
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE BangPatterns #-}
@@ -10,16 +12,30 @@
-- Maintainer : Aleksey Khudyakov <alexey.skladnoy@gmail.com>
-- Stability : experimental
--
--- Symmetric matrices
+-- Symmetric and hermitian matrices
module Data.Matrix.Symmetric.Mutable (
- MSymmetric(..)
+ -- * Data types
+ MSymmetric
+ , MHermitian
+ -- ** Implementation
+ , MSymmetricRaw(..)
+ , IsSymmetric
+ , IsHermitian
+ -- * Function
, new
, symmIndex
+ -- * Complex number
+ , NumberType
+ , IsReal
+ , IsComplex
+ , castSymmetric
+ , Conjugate(..)
) where
import Control.Monad
import Control.Monad.Primitive
+import Data.Complex (Complex,conjugate)
import Data.Typeable (Typeable)
import Data.Vector.Storable.Internal
import qualified Data.Vector.Generic.Mutable as M
@@ -33,59 +49,171 @@ import Data.Internal
import Data.Vector.Storable.Strided.Mutable
import Data.Matrix.Generic.Mutable
+import Unsafe.Coerce
--- | Mutable symmetric matrix. Storage takes n² elements and data is
--- stored in column major order
-data MSymmetric s a = MSymmetric {-# UNPACK #-} !Int -- Order of matrix
- {-# UNPACK #-} !Int -- Leading dimension size
- {-# UNPACK #-} !(ForeignPtr a)
+
+----------------------------------------------------------------
+-- Data types
+----------------------------------------------------------------
+
+-- | Symmetric/hermitian matrix. Whether it's symmetric of hermitian
+-- is determined by type tag. See 'IsSymmetric' and 'IsHermitian'.
+--
+-- Storage takes n² elements and data is stored in column major
+-- order. Fields are
+--
+-- * Order of matrix
+--
+-- * Leading dimension size
+--
+-- * Pointer to data
+data MSymmetricRaw tag s a = MSymmetricRaw
+ {-# UNPACK #-} !Int -- Order of matrix
+ {-# UNPACK #-} !Int -- Leading dimension size
+ {-# UNPACK #-} !(ForeignPtr a)
deriving (Typeable)
-instance Storable a => IsMMatrix MSymmetric a where
- basicRows (MSymmetric n _ _) = n
+-- | Type tag for symmetric matrices.
+data IsSymmetric
+
+-- | Type tag for hermitian matrices.
+data IsHermitian
+
+-- | Mutable symmetric matrix
+type MSymmetric = MSymmetricRaw IsSymmetric
+
+-- | Mutable hermitian matrix
+type MHermitian = MSymmetricRaw IsHermitian
+
+
+instance Storable a => IsMMatrix (MSymmetricRaw IsSymmetric) a where
+ basicRows (MSymmetricRaw n _ _) = n
{-# INLINE basicRows #-}
- basicCols (MSymmetric _ n _) = n
+ basicCols (MSymmetricRaw _ n _) = n
{-# INLINE basicCols #-}
basicIsIndexMutable _ _ = True
{-# INLINE basicIsIndexMutable #-}
- basicUnsafeRead (MSymmetric _ lda fp) (symmIndex -> (!i,!j))
+ basicUnsafeRead (MSymmetricRaw _ lda fp) (symmIndex -> (!i,!j))
= unsafePrimToPrim
$ withForeignPtr fp (`peekElemOff` (i + j*lda))
{-# INLINE basicUnsafeRead #-}
- basicUnsafeWrite (MSymmetric _ lda fp) (symmIndex -> (!i,!j)) x
+ basicUnsafeWrite (MSymmetricRaw _ lda fp) (symmIndex -> (!i,!j)) x
= unsafePrimToPrim
$ withForeignPtr fp $ \p -> pokeElemOff p (i + j*lda) x
{-# INLINE basicUnsafeWrite #-}
- basicCloneShape m
- = new (rows m)
+ basicCloneShape = new . rows
+ {-# INLINE basicCloneShape #-}
+ basicClone = cloneSym
+ {-# INLINE basicClone #-}
+
+instance (Conjugate a, Storable a) => IsMMatrix (MSymmetricRaw IsHermitian) a where
+ basicRows (MSymmetricRaw n _ _) = n
+ {-# INLINE basicRows #-}
+ basicCols (MSymmetricRaw _ n _) = n
+ {-# INLINE basicCols #-}
+ basicIsIndexMutable _ _ = True
+ {-# INLINE basicIsIndexMutable #-}
+ basicUnsafeRead (MSymmetricRaw _ lda fp) (!i,!j)
+ = unsafePrimToPrim
+ $ case () of
+ _| i > j -> conjugateNum `liftM` withForeignPtr fp (`peekElemOff` (j + i*lda))
+ | otherwise -> withForeignPtr fp (`peekElemOff` (i + j*lda))
+ {-# INLINE basicUnsafeRead #-}
+ basicUnsafeWrite (MSymmetricRaw _ lda fp) (!i,!j) x
+ = unsafePrimToPrim
+ $ case () of
+ _| i > j -> withForeignPtr fp $ \p -> pokeElemOff p (j + i*lda) (conjugateNum x)
+ | otherwise -> withForeignPtr fp $ \p -> pokeElemOff p (i + j*lda) x
+ {-# INLINE basicUnsafeWrite #-}
+ basicCloneShape = new . rows
{-# INLINE basicCloneShape #-}
- basicClone m = do
- q <- basicCloneShape m
- forM_ [0 .. cols m - 1] $ \i ->
- M.unsafeCopy (unsafeGetCol q i) (unsafeGetCol m i)
- return q
+ basicClone = cloneSym
+ {-# INLINE basicClone #-}
--- Choose index so upper part of matrix is accessed
+
+
+-- | Choose index so upper part of matrix is accessed
symmIndex :: (Int,Int) -> (Int,Int)
{-# INLINE symmIndex #-}
symmIndex (i,j)
| i > j = (j,i)
| otherwise = (i,j)
--- | Allocate new matrix
+-- | Allocate new matrix. It works for both symmetric and hermitian
+-- matrices.
new :: (PrimMonad m, Storable a)
=> Int -- ^ Matrix order
- -> m (MSymmetric (PrimState m) a)
+ -> m (MSymmetricRaw tag (PrimState m) a)
{-# INLINE new #-}
new n = do
fp <- unsafePrimToPrim $ mallocVector $ n * n
- return $ MSymmetric n n fp
+ return $ MSymmetricRaw n n fp
+
+
+
+----------------------------------------------------------------
+-- Real/Complex distinction
+----------------------------------------------------------------
+
+-- | Type tag for real numbers
+data IsReal
+
+-- | Type tag for complex numbers
+data IsComplex
+
+type family NumberType a :: *
+
+type instance NumberType Float = IsReal
+type instance NumberType Double = IsReal
+type instance NumberType (Complex a) = IsComplex
+
+-- | Cast between symmetric and hermitian matrices is data parameter
+-- is real.
+castSymmetric :: (NumberType a ~ IsReal)
+ => MSymmetricRaw tag s a -> MSymmetricRaw tag' s a
+{-# INLINE castSymmetric #-}
+castSymmetric = unsafeCoerce
+
+
+
+-- | Conjugate which works for both real (noop) and complex values.
+class Conjugate a where
+ conjugateNum :: a -> a
+
+instance Conjugate Float where
+ conjugateNum = id
+ {-# INLINE conjugateNum #-}
+
+instance Conjugate Double where
+ conjugateNum = id
+ {-# INLINE conjugateNum #-}
+
+instance RealFloat a => Conjugate (Complex a) where
+ conjugateNum = conjugate
+ {-# INLINE conjugateNum #-}
+
+
+
+
+----------------------------------------------------------------
+-- Helpers
+----------------------------------------------------------------
-- Get n'th column of matrix as mutable vector. Internal since part of
-- vector contain junk
-unsafeGetCol :: Storable a => MSymmetric s a -> Int -> MVector s a
+unsafeGetCol :: Storable a => MSymmetricRaw tag s a -> Int -> MVector s a
{-# INLINE unsafeGetCol #-}
-unsafeGetCol (MSymmetric n lda fp) i
+unsafeGetCol (MSymmetricRaw n lda fp) i
= MVector n 1 $ updPtr (`advancePtr` (i*lda)) fp
+
+cloneSym :: (Storable a, PrimMonad m)
+ => MSymmetricRaw tag (PrimState m) a
+ -> m (MSymmetricRaw tag (PrimState m) a)
+{-# INLINE cloneSym #-}
+cloneSym m@(MSymmetricRaw n _ _) = do
+ q <- new n
+ forM_ [0 .. n - 1] $ \i ->
+ M.unsafeCopy (unsafeGetCol q i) (unsafeGetCol m i)
+ return q
View
12 Numeric/BLAS.hs
@@ -254,12 +254,12 @@ instance (BLAS2 a, a ~ a') => Mul (Conjugated Matrix a) (S.Vector a') where
{-# INLINE (.*.) #-}
-instance (BLAS2 a, a ~ a') => Mul (Symmetric a) (S.Vector a') where
- type MulRes (Symmetric a)
- (S.Vector a')
- = S.Vector a
- m .*. v = eval $ MulMV () (Lit m) (Lit v)
- {-# INLINE (.*.) #-}
+-- instance (BLAS2 a, a ~ a') => Mul (Symmetric a) (S.Vector a') where
+-- type MulRes (Symmetric a)
+-- (S.Vector a')
+-- = S.Vector a
+-- m .*. v = eval $ MulMV () (Lit m) (Lit v)
+-- {-# INLINE (.*.) #-}
View
22 Numeric/BLAS/Mutable/Unsafe.hs
@@ -206,17 +206,17 @@ instance S.Storable a => MultTMV MMatrix a where
b py (blasStride y)
{-# INLINE unsafeMultTMV #-}
-instance S.Storable a => MultMV MSymmetric a where
- unsafeMultMV α (MSymmetric n lda fp) x β y
- = unsafePrimToPrim
- $ withForeignPtr fp $ \pa ->
- withForeignPtr (blasFPtr x) $ \px ->
- withForeignPtr (blasFPtr y) $ \py ->
- BLAS.hemv ColMajor Upper
- n α pa lda
- px (blasStride x)
- β py (blasStride y)
- {-# INLINE unsafeMultMV #-}
+-- instance S.Storable a => MultMV MSymmetric a where
+-- unsafeMultMV α (MSymmetric n lda fp) x β y
+-- = unsafePrimToPrim
+-- $ withForeignPtr fp $ \pa ->
+-- withForeignPtr (blasFPtr x) $ \px ->
+-- withForeignPtr (blasFPtr y) $ \py ->
+-- BLAS.hemv ColMajor Upper
+-- n α pa lda
+-- px (blasStride x)
+-- β py (blasStride y)
+-- {-# INLINE unsafeMultMV #-}

0 comments on commit 8590a11

Please sign in to comment.
Something went wrong with that request. Please try again.