-
Notifications
You must be signed in to change notification settings - Fork 7
/
arraymath.jl
74 lines (61 loc) · 2.69 KB
/
arraymath.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import Base: conj, copy, real, imag
import LinearAlgebra: transpose, transpose!, adjoint!, adjoint
# IMatrix
for func in (:conj, :real, :transpose, :adjoint, :copy)
@eval ($func)(M::IMatrix{N,T}) where {N,T} = IMatrix{N,T}()
end
for func in (:adjoint!, :transpose!)
@eval ($func)(M::IMatrix) = M
end
imag(M::IMatrix{N,T}) where {N,T} = Diagonal(zeros(T, N))
# PermMatrix
for func in (:conj, :real, :imag)
@eval ($func)(M::PermMatrix) = PermMatrix(M.perm, ($func)(M.vals))
end
copy(M::PermMatrix) = PermMatrix(copy(M.perm), copy(M.vals))
function transpose(M::PermMatrix)
new_perm = fast_invperm(M.perm)
return PermMatrix(new_perm, M.vals[new_perm])
end
adjoint(S::PermMatrix{<:Real}) = transpose(S)
adjoint(S::PermMatrix{<:Complex}) = conj(transpose(S))
# scalar
import Base: *, /, ==, +, -, ≈
*(A::IMatrix{N,T}, B::Number) where {N,T} = Diagonal(fill(promote_type(T, eltype(B))(B), N))
*(B::Number, A::IMatrix{N,T}) where {N,T} = Diagonal(fill(promote_type(T, eltype(B))(B), N))
/(A::IMatrix{N,T}, B::Number) where {N,T} =
Diagonal(fill(promote_type(T, eltype(B))(1 / B), N))
*(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals * B)
*(B::Number, A::PermMatrix) = A * B
/(A::PermMatrix, B::Number) = PermMatrix(A.perm, A.vals / B)
#+(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv+B.dv, A.ev+B.ev)
#-(A::PermMatrix, B::PermMatrix) = PermMatrix(A.dv-B.dv, A.ev-B.ev)
const IDP = Union{Diagonal,PermMatrix,IMatrix}
for op in [:+, :-, :(==), :≈]
@eval begin
$op(A::IDP, B::SparseMatrixCSC) = $op(SparseMatrixCSC(A), B)
$op(B::SparseMatrixCSC, A::IDP) = $op(B, SparseMatrixCSC(A))
# intra-IDP
$op(A::PermMatrix, B::IDP) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
$op(A::IDP, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
$op(A::PermMatrix, B::PermMatrix) = $op(SparseMatrixCSC(A), SparseMatrixCSC(B))
end
# intra-ID
if op in [:+, :-]
@eval begin
$op(d1::Diagonal, d2::IMatrix) = Diagonal($op(d1.diag, diag(d2)))
$op(d1::IMatrix, d2::Diagonal) = Diagonal($op(diag(d1), d2.diag))
end
else
@eval begin
$op(d1::IMatrix, d2::Diagonal) = $op(diag(d1), d2.diag)
$op(d1::Diagonal, d2::IMatrix) = $op(d1.diag, diag(d2))
$op(d1::IMatrix{Na}, d2::IMatrix{Nb}) where {Na,Nb} = $op(Na, Nb)
end
end
end
# NOTE: promote 2 at least as an integer
+(d1::IMatrix{Na,Ta}, d2::IMatrix{Nb,Tb}) where {Na,Nb,Ta,Tb} =
d1 == d2 ? Diagonal(fill(promote_type(Ta, Tb, Int)(2), Na)) : throw(DimensionMismatch())
-(d1::IMatrix{Na,Ta}, d2::IMatrix{Nb,Tb}) where {Na,Ta,Nb,Tb} =
d1 == d2 ? spzeros(promote_type(Ta, Tb), Na, Na) : throw(DimensionMismatch())