-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathdiag.go
168 lines (159 loc) · 3.79 KB
/
diag.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
// Copyright (c) Harri Rautila, 2012,2013
// This file is part of github.com/hrautila/matops package. It is free software,
// distributed under the terms of GNU Lesser General Public License Version 3, or
// any later version. See the COPYING tile included in this archive.
package matops
import (
"errors"
"github.com/henrylee2cn/algorithm/matrix"
)
/*
* Compute
* C = C*diag(D) flags & RIGHT == true
* C = diag(D)*C flags & LEFT == true
*
* Arguments
* C M-by-N matrix if flags&RIGHT == true or N-by-M matrix if flags&LEFT == true
*
* D N element column or row vector or N-by-N matrix
*
* flags Indicator bits, LEFT or RIGHT
*/
func MultDiag(C, D *matrix.FloatMatrix, flags Flags) {
var c, d0 matrix.FloatMatrix
if D.Cols() == 1 {
// diagonal is column vector
switch flags & (LEFT | RIGHT) {
case LEFT:
// scale rows; for each column element-wise multiply with D-vector
for k := 0; k < C.Cols(); k++ {
C.SubMatrix(&c, 0, k, C.Rows(), 1)
c.Mul(D)
}
case RIGHT:
// scale columns
for k := 0; k < C.Cols(); k++ {
C.SubMatrix(&c, 0, k, C.Rows(), 1)
// scale the column
c.Scale(D.GetAt(k, 0))
}
}
} else {
// diagonal is row vector
var d *matrix.FloatMatrix
if D.Rows() == 1 {
d = D
} else {
D.SubMatrix(&d0, 0, 0, 1, D.Cols(), D.LeadingIndex()+1)
d = &d0
}
switch flags & (LEFT | RIGHT) {
case LEFT:
for k := 0; k < C.Rows(); k++ {
C.SubMatrix(&c, k, 0, 1, C.Cols())
// scale the row
c.Scale(d.GetAt(0, k))
}
case RIGHT:
// scale columns
for k := 0; k < C.Cols(); k++ {
C.SubMatrix(&c, 0, k, C.Rows(), 1)
// scale the column
c.Scale(d.GetAt(0, k))
}
}
}
}
/*
* Compute
* X = B*diag(D).-1 flags & RIGHT == true
* X = diag(D).-1*C flags & LEFT == true
*
* Arguments:
* B M-by-N matrix if flags&RIGHT == true or N-by-M matrix if flags&LEFT == true
*
* D N element column or row vector or N-by-N matrix
*
* flags Indicator bits, LEFT or RIGHT
*/
func SolveDiag(B, D *matrix.FloatMatrix, flags Flags) {
var c, d0 matrix.FloatMatrix
if D.Cols() == 1 {
// diagonal is column vector
switch flags & (LEFT | RIGHT) {
case LEFT:
// scale rows; for each column element-wise multiply with D-vector
for k := 0; k < B.Cols(); k++ {
B.SubMatrix(&c, 0, k, B.Rows(), 1)
c.Div(D)
}
case RIGHT:
// scale columns
for k := 0; k < B.Cols(); k++ {
B.SubMatrix(&c, 0, k, B.Rows(), 1)
// scale the column
c.Scale(1.0 / D.GetAt(k, 0))
}
}
} else {
var d *matrix.FloatMatrix
if D.Rows() == 1 {
d = D
} else {
D.SubMatrix(&d0, 0, 0, 1, D.Cols(), D.LeadingIndex()+1)
d = &d0
}
switch flags & (LEFT | RIGHT) {
case LEFT:
for k := 0; k < B.Rows(); k++ {
B.SubMatrix(&c, k, 0, 1, B.Cols())
// scale the row
c.Scale(1.0 / d.GetAt(0, k))
}
case RIGHT:
// scale columns
for k := 0; k < B.Cols(); k++ {
B.SubMatrix(&c, 0, k, B.Rows(), 1)
// scale the column
c.Scale(1.0 / d.GetAt(0, k))
}
}
}
}
/*
* Generic rank update of diagonal matrix.
* diag(D) = diag(D) + alpha * x * y.T
*
* Arguments:
* D N element column or row vector or N-by-N matrix
*
* x, y N element vectors
*
* alpha scalar
*/
func MVUpdateDiag(D, x, y *matrix.FloatMatrix, alpha float64) error {
var d *matrix.FloatMatrix
var dvec matrix.FloatMatrix
if !isVector(x) || !isVector(y) {
return errors.New("x, y not vectors")
}
if D.Rows() > 0 && D.Cols() == D.Rows() {
D.Diag(&dvec)
d = &dvec
} else if isVector(D) {
d = D
} else {
return errors.New("D not a diagonal")
}
N := d.NumElements()
for k := 0; k < N; k++ {
val := d.GetIndex(k)
val += x.GetIndex(k) * y.GetIndex(k) * alpha
d.SetIndex(k, val)
}
return nil
}
// Local Variables:
// tab-width: 4
// indent-tabs-mode: nil
// End: