/
Generic.hs
191 lines (157 loc) · 6.01 KB
/
Generic.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module : Data.Matrix.Generic
-- Copyright : Copyright (c) 2012 Aleksey Khudyakov <alexey.skladnoy@gmail.com>
-- License : BSD3
-- Maintainer : Aleksey Khudyakov <alexey.skladnoy@gmail.com>
-- Stability : experimental
--
-- Interface for generic immutable matrices. For matrix transposition
-- and conjugate transposition newtype wrappers are used.
module Data.Matrix.Generic (
-- * Type class
IsMatrix(..)
, Mutable
-- * Accessors
, rows
, cols
, shape
, (@!)
, unsafeIndex
-- * Converions to/from mutable
, freeze
, thaw
, unsafeFreeze
, unsafeThaw
-- * Conversion to string
, showMatrixWith
-- * Newtype wrappers
, Transposed(..)
, Conjugated(..)
) where
import Control.Monad (liftM)
import Control.Monad.Primitive
import Data.List (intercalate)
import Data.Complex (Complex,conjugate)
import qualified Data.Matrix.Generic.Mutable as M
import Data.Vector.Generic (Mutable)
----------------------------------------------------------------
-- Type class
----------------------------------------------------------------
-- | Basic API for immutable matrices. Since there's many way to lay
-- matrix in memory there isn't many operation which work for all of
-- them.
--
-- Methods of this type class shouldn't be used directly.
class M.IsMMatrix (Mutable mat) a => IsMatrix mat a where
-- | Number of rows
basicRows :: mat a -> Int
-- | Number of columns
basicCols :: mat a -> Int
-- | Read element from matrix
basicUnsafeIndex :: mat a -> (Int,Int) -> a
-- | Convert immutable matrix to mutable. Immutable matrix may not
-- be used after operation.
basicUnsafeThaw :: PrimMonad m => mat a -> m (Mutable mat (PrimState m) a)
-- | Convert mutable matrix to immutable. Mutable matrix may not be
-- modified after operation.
basicUnsafeFreeze :: PrimMonad m => Mutable mat (PrimState m) a -> m (mat a)
----------------------------------------------------------------
-- Accessors
----------------------------------------------------------------
-- | Number of rows
rows :: IsMatrix mat a => mat a -> Int
{-# INLINE rows #-}
rows = basicRows
-- | Number of columns
cols :: IsMatrix mat a => mat a -> Int
{-# INLINE cols #-}
cols = basicCols
-- | Shape of the matrix.
shape :: IsMatrix mat a => mat a -> (Int,Int)
{-# INLINE shape #-}
shape m = (rows m, cols m)
-- | Indexing operator without range checking.
unsafeIndex :: IsMatrix mat a
=> mat a -- ^ Matrix
-> (Int,Int) -- ^ (row,column)
-> a
{-# INLINE unsafeIndex #-}
unsafeIndex = basicUnsafeIndex
-- | Indexing operator with range checking
(@!) :: IsMatrix mat a
=> mat a -- ^ Matrix
-> (Int,Int) -- ^ (row,column)
-> a
{-# INLINE (@!) #-}
m @! a@(i,j)
| i < 0 || i >= rows m = error "ROW"
| j > 0 || j >= cols m = error "COL"
| otherwise = unsafeIndex m a
-- | Convert mutable matrix to immutable.
freeze :: (PrimMonad m, IsMatrix mat a) => Mutable mat (PrimState m) a -> m (mat a)
{-# INLINE freeze #-}
freeze m = do
unsafeFreeze =<< M.clone m
-- | Convert immutable matrix to mutable.
thaw :: (PrimMonad m, IsMatrix mat a) => mat a -> m (Mutable mat (PrimState m) a)
{-# INLINE thaw #-}
thaw m = do
M.clone =<< unsafeThaw m
-- | Convert mutable matrix to immutable. Mutable matrix may not be
-- modified after operation.
unsafeFreeze :: (PrimMonad m, IsMatrix mat a) => Mutable mat (PrimState m) a -> m (mat a)
{-# INLINE unsafeFreeze #-}
unsafeFreeze = basicUnsafeFreeze
-- | Convert immutable matrix to mutable. Immutable matrix may not
-- be used after operation.
unsafeThaw :: (PrimMonad m, IsMatrix mat a) => mat a -> m (Mutable mat (PrimState m) a)
{-# INLINE unsafeThaw #-}
unsafeThaw = basicUnsafeThaw
-- | Generic function for printing matrix. Mostly useful for debugging
-- purposes.
showMatrixWith :: IsMatrix m a => (a -> String) -> m a -> String
showMatrixWith f m
= unlines
$ (show (rows m) ++ " >< " ++ show (cols m))
: [ intercalate "\t" [f $ unsafeIndex m (i,j) | j <- [0 .. cols m - 1]]
| i <- [0 .. rows m - 1]
]
----------------------------------------------------------------
-- Newtype wrappers
----------------------------------------------------------------
-- | Transposed matrix or vector. Being newtype this wrapper type is
-- used to select different instances for multiplication.
newtype Transposed mat a = Transposed { unTranspose :: mat a }
type instance Mutable (Transposed mat) = M.TransposedM (Mutable mat)
instance IsMatrix mat a => IsMatrix (Transposed mat) a where
basicRows (Transposed m) = cols m
{-# INLINE basicRows #-}
basicCols (Transposed m) = rows m
{-# INLINE basicCols #-}
basicUnsafeIndex (Transposed m) (i,j) = unsafeIndex m (j,i)
{-# INLINE basicUnsafeIndex #-}
basicUnsafeThaw (Transposed m) = M.TransposedM `liftM` unsafeThaw m
{-# INLINE basicUnsafeThaw #-}
basicUnsafeFreeze (M.TransposedM m) = Transposed `liftM` unsafeFreeze m
{-# INLINE basicUnsafeFreeze #-}
-- | Conjugate-transposed matrix or vector. Being newtype this wrapper type is
-- used to select different instances for multiplication.
newtype Conjugated mat a = Conjugated { unConjugate :: mat a }
type instance Mutable (Conjugated mat) = M.ConjugatedM (Mutable mat)
instance (IsMatrix mat (Complex a), RealFloat a) => IsMatrix (Conjugated mat) (Complex a) where
basicRows (Conjugated m) = cols m
{-# INLINE basicRows #-}
basicCols (Conjugated m) = rows m
{-# INLINE basicCols #-}
basicUnsafeIndex (Conjugated m) (i,j)
| j >= i = unsafeIndex m (j,i)
| otherwise = conjugate $! unsafeIndex m (j,i)
{-# INLINE basicUnsafeIndex #-}
basicUnsafeThaw (Conjugated m) = M.ConjugatedM `liftM` unsafeThaw m
{-# INLINE basicUnsafeThaw #-}
basicUnsafeFreeze (M.ConjugatedM m) = Conjugated `liftM` unsafeFreeze m
{-# INLINE basicUnsafeFreeze #-}