/
decomp_lu.py
161 lines (129 loc) · 4.44 KB
/
decomp_lu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""LU decomposition functions."""
from warnings import warn
from numpy import asarray, asarray_chkfinite
# Local imports
from misc import _datacopied
from lapack import get_lapack_funcs
from flinalg import get_flinalg_funcs
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False):
"""Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : array, shape (M, M)
Matrix to decompose
overwrite_a : boolean
Whether to overwrite data in A (may increase performance)
Returns
-------
lu : array, shape (N, N)
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : array, shape (N,)
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
See also
--------
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the ``*GETRF`` routines from LAPACK.
"""
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
RuntimeWarning)
return lu, piv
def lu_solve((lu, piv), b, trans=0, overwrite_b=False):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor
b : array
Right-hand side
trans : {0, 1, 2}
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
Returns
-------
x : array
Solution to the system
See also
--------
lu_factor : LU factorize a matrix
"""
b1 = asarray_chkfinite(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x,info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %d-th argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False):
"""Compute pivoted LU decompostion of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : array, shape (M, N)
Array to decompose
permute_l : boolean
Perform the multiplication P*L (Default: do not permute)
overwrite_a : boolean
Whether to overwrite data in a (may improve performance)
Returns
-------
(If permute_l == False)
p : array, shape (M, M)
Permutation matrix
l : array, shape (M, K)
Lower triangular or trapezoidal matrix with unit diagonal.
K = min(M, N)
u : array, shape (K, N)
Upper triangular or trapezoidal matrix
(If permute_l == True)
pl : array, shape (M, K)
Permuted L matrix.
K = min(M, N)
u : array, shape (K, N)
Upper triangular or trapezoidal matrix
Notes
-----
This is a LU factorization routine written for Scipy.
"""
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal lu.getrf' % -info)
if permute_l:
return l, u
return p, l, u