/
NormalEquations.scala
156 lines (126 loc) · 4.2 KB
/
NormalEquations.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package edu.berkeley.cs.amplab.mlmatrix
import java.util.concurrent.ThreadLocalRandom
import breeze.linalg._
import edu.berkeley.cs.amplab.mlmatrix.util.Utils
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.SparkContext
class NormalEquations extends RowPartitionedSolver with Logging with Serializable {
def solveManyLeastSquaresWithL2(
A: RowPartitionedMatrix,
b: RowPartitionedMatrix,
residuals: RDD[Array[DenseMatrix[Double]]],
lambdas: Array[Double]): Seq[DenseMatrix[Double]] = {
val Abs = A.rdd.zip(b.rdd).map { x =>
(x._1.mat, x._2.mat)
}
val ATA_ATb = Abs.zip(residuals).map { part =>
val aPart = part._1._1
val bPart = part._1._2
val res = part._2
val AtA = aPart.t * aPart
val AtBs = new Array[DenseMatrix[Double]](part._2.length)
var i = 0
val tmp = new DenseMatrix[Double](bPart.rows, bPart.cols)
while (i < res.length) {
tmp :+= bPart
tmp :-= res(i)
val atb = aPart.t * tmp
AtBs(i) = atb
java.util.Arrays.fill(tmp.data, 0.0)
i = i + 1
}
(AtA, AtBs)
}
// val ATA_ATb = Abs.map { part =>
// val AtBs = part._2.map { b =>
// part._1.t * b
// }
// (AtA, AtBs)
// }
// treeBranchingFactor has to be greater than or equal to 2.
val treeBranchingFactor = math.max(
A.rdd.context.getConf.getInt("spark.mlmatrix.treeBranchingFactor", 2).toInt,
2)
val depth = math.ceil(
math.log(math.max(ATA_ATb.partitions.size, 2.0))/
math.log(treeBranchingFactor)).toInt
val reduced = Utils.treeReduce(ATA_ATb, reduceNormalMany, depth=depth)
val ATA = reduced._1
// Local solve
val xs = lambdas.zip(reduced._2).map { l =>
val gamma = DenseMatrix.eye[Double](ATA.rows)
gamma :*= l._1
(ATA + gamma) \ l._2
}
xs
}
private def reduceNormalMany(
a: (DenseMatrix[Double], Array[DenseMatrix[Double]]),
b: (DenseMatrix[Double], Array[DenseMatrix[Double]])):
(DenseMatrix[Double], Array[DenseMatrix[Double]]) = {
a._1 :+= b._1
var i = 0
while (i < a._2.length) {
a._2(i) :+= b._2(i)
i = i + 1
}
a
}
def solveLeastSquaresWithManyL2(
A: RowPartitionedMatrix,
b: RowPartitionedMatrix,
lambdas: Array[Double]) : Seq[DenseMatrix[Double]] = {
val Ab = A.rdd.zip(b.rdd).map(x => (x._1.mat, x._2.mat))
val ATA_ATb = Ab.map { part =>
(part._1.t * part._1, part._1.t * part._2)
}
val treeBranchingFactor = math.max(
A.rdd.context.getConf.getInt("spark.mlmatrix.treeBranchingFactor", 2).toInt,
2)
val depth = math.ceil(
math.log(math.max(ATA_ATb.partitions.size, 2.0))/
math.log(treeBranchingFactor)).toInt
val reduced = Utils.treeReduce(ATA_ATb, reduceNormal, depth=depth)
val xs = lambdas.map { l =>
val gamma = DenseMatrix.eye[Double](reduced._1.rows)
gamma :*= l
(reduced._1 + gamma) \ reduced._2
}
xs
}
private def reduceNormal(
a: (DenseMatrix[Double], DenseMatrix[Double]),
b: (DenseMatrix[Double], DenseMatrix[Double])): (DenseMatrix[Double], DenseMatrix[Double]) = {
a._1 :+= b._1
a._2 :+= b._2
a
}
}
object NormalEquations {
def main(args: Array[String]) {
if (args.length < 5) {
println("Usage: NormalEquations <master> <numRows> <numCols> <numParts> <numClasses>")
System.exit(0)
}
val sparkMaster = args(0)
val numRows = args(1).toInt
val numCols = args(2).toInt
val numParts = args(3).toInt
val numClasses = args(4).toInt
val conf = new SparkConf()
.setMaster(sparkMaster)
.setAppName("NormalEquations")
.setJars(SparkContext.jarOfClass(this.getClass).toSeq)
val sc = new SparkContext(conf)
val A = RowPartitionedMatrix.createRandom(sc, numRows, numCols, numParts, cache=true)
val b = A.mapPartitions(
part => DenseMatrix.rand(part.rows, numClasses)).cache()
var begin = System.nanoTime()
val x = new NormalEquations().solveLeastSquares(A, b)
var end = System.nanoTime()
sc.stop()
println("Normal equations took " + (end-begin)/1e6 + " ms")
}
}