-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathSparseVector.h
192 lines (173 loc) · 5.54 KB
/
SparseVector.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#ifndef CH3_SPARSEVECTOR_H
#define CH3_SPARSEVECTOR_H
#include "../head/ST.h"
#include <vector>
#include <cmath>
#include <ostream>
#include <sstream>
using std::vector;
using std::ostream;
using std::stringstream;
/**
* The {@code SparseVector} class represents a <em>d</em>-dimensional mathematical vector.
* Vectors are mutable: their values can be changed after they are created.
* It includes methods for addition, subtraction,
* dot product, scalar product, unit vector, and Euclidean norm.
* <p>
* The implementation is a symbol table of indices and values for which the vector
* coordinates are nonzero. This makes it efficient when most of the vector coordindates
* are zero.
* <p>
* For additional documentation,
* see <a href="https://algs4.cs.princeton.edu/35applications">Section 3.5</a> of
* <i>Algorithms, 4th Edition</i> by Robert Sedgewick and Kevin Wayne.
* See also {@link Vector} for an immutable (dense) vector data type.
*
* @author Robert Sedgewick
* @author Kevin Wayne
*/
class SparseVector {
public:
/**
* Initializes a d-dimensional zero vector.
* @param d the dimension of the vector
*/
SparseVector(int d) : d(d) {}
/**
* Sets the ith coordinate of this vector to the specified value.
*
* @param i the index
* @param value the new value
* @throws IllegalArgumentException unless i is between 0 and d-1
*/
void put(int i, double value) {
if (i < 0 || i >= d) throw runtime_error("Illegal index");
if (value == 0.0) st.delete_op(i);
else st.put(i, value);
}
/**
* Returns the ith coordinate of this vector.
*
* @param i the index
* @return the value of the ith coordinate of this vector
* @throws IllegalArgumentException unless i is between 0 and d-1
*/
double get(int i) {
if (i < 0 || i >= d) throw runtime_error("Illegal index");
if (st.contains(i)) return st.get(i); //TODO: maybe a bug
else return 0.0;
}
/**
* Returns the number of nonzero entries in this vector.
*
* @return the number of nonzero entries in this vector
*/
int nnz() {
return st.size();
}
/**
* Returns the dimension of this vector.
*
* @return the dimension of this vector
*/
int dimension() {
return d;
}
/**
* Returns the inner product of this vector with the specified vector.
*
* @param that the other vector
* @return the dot product between this vector and that vector
* @throws IllegalArgumentException if the lengths of the two vectors are not equal
*/
double dot(SparseVector &that) {
if (d != that.d) throw runtime_error("Vector lengths disagree");
double sum = 0.0;
// iterate over the vector with the fewest nonzeros
if (st.size() <= that.st.size()) {
for (auto p : st) {
auto i = p.first;
if (that.st.contains(i)) sum += get(i) * that.get(i);
}
} else {
for (auto p : that.st) {
int i = p.first;
if (st.contains(i)) sum += get(i) * that.get(i);
}
}
return sum;
}
/**
* Returns the inner product of this vector with the specified array.
*
* @param that the array
* @return the dot product between this vector and that array
* @throws IllegalArgumentException if the dimensions of the vector and the array are not equal
*/
double dot(vector<double> &that) {
double sum = 0.0;
for (auto p : st) {
auto i = p.first;
sum += that[i] * get(i);
}
return sum;
}
/**
* Returns the magnitude of this vector.
* This is also known as the L2 norm or the Euclidean norm.
*
* @return the magnitude of this vector
*/
double magnitude() {
return std::sqrt(this->dot(*this));
}
/**
* Returns the scalar-vector product of this vector with the specified scalar.
*
* @param alpha the scalar
* @return the scalar-vector product of this vector with the specified scalar
*/
SparseVector scale(double alpha) {
SparseVector c(d);
for (auto p : st) {
auto i = p.first;
c.put(i, alpha * get(i));
}
return c;
}
/**
* Returns the sum of this vector and the specified vector.
*
* @param that the vector to add to this vector
* @return the sum of this vector and that vector
* @throws IllegalArgumentException if the dimensions of the two vectors are not equal
*/
SparseVector plus(SparseVector &that) {
if (d != that.d) throw runtime_error("Vector lengths disagree");
SparseVector c(d);
for (auto p : st) { // c = this
auto i = p.first;
c.put(i, get(i));
}
for (auto p : that.st) { // c = c + that
auto i = p.first;
c.put(i, that.get(i) + c.get(i));
}
return c;
}
// TODO: change to const SparseVector---due to map issue
friend ostream &operator<<(ostream &stream, SparseVector &sparse);
private:
int d; // dimension
ST<int, double> st; // the vector, represented by index-value pairs
};
ostream &operator<<(ostream &stream, SparseVector &sparse) {
stringstream ss;
for (auto p: sparse.st) {
auto i = p.first;
ss << "(" << i << ", " << sparse.st.get(i) << ")";
}
stream << ss.str();
return stream;
}
#endif //CH3_SPARSEVECTOR_H