-
Notifications
You must be signed in to change notification settings - Fork 28.1k
/
GaussianMixture.scala
721 lines (626 loc) · 26.1 KB
/
GaussianMixture.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
/**
* Common params for GaussianMixture and GaussianMixtureModel
*/
private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasWeightCol with HasProbabilityCol with HasTol
with HasAggregationDepth {
/**
* Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.
*
* @group param
*/
@Since("2.0.0")
final val k = new IntParam(this, "k", "Number of independent Gaussians in the mixture model. " +
"Must be > 1.", ParamValidators.gt(1))
/** @group getParam */
@Since("2.0.0")
def getK: Int = $(k)
setDefault(k -> 2, maxIter -> 100, tol -> 0.01)
/**
* Validates and transforms the input schema.
*
* @param schema input schema
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
}
}
/**
* Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
* are drawn from each Gaussian i with probability weights(i).
*
* @param weights Weight for each Gaussian distribution in the mixture.
* This is a multinomial probability distribution over the k Gaussians,
* where weights(i) is the weight for Gaussian i, and weights sum to 1.
* @param gaussians Array of `MultivariateGaussian` where gaussians(i) represents
* the Multivariate Gaussian (Normal) Distribution for Gaussian i
*/
@Since("2.0.0")
class GaussianMixtureModel private[ml] (
@Since("2.0.0") override val uid: String,
@Since("2.0.0") val weights: Array[Double],
@Since("2.0.0") val gaussians: Array[MultivariateGaussian])
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
with HasTrainingSummary[GaussianMixtureSummary] {
@Since("3.0.0")
lazy val numFeatures: Int = gaussians.head.mean.size
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
@Since("2.1.0")
def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
copied.setSummary(trainingSummary).setParent(this.parent)
}
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val vectorCol = columnToVector(dataset, $(featuresCol))
var outputData = dataset
var numColsOutput = 0
if ($(probabilityCol).nonEmpty) {
val probUDF = udf((vector: Vector) => predictProbability(vector))
outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol),
outputSchema($(probabilityCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
if ($(probabilityCol).nonEmpty) {
val predUDF = udf((vector: Vector) => vector.argmax)
outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))),
outputSchema($(predictionCol)).metadata)
} else {
val predUDF = udf((vector: Vector) => predict(vector))
outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol),
outputSchema($(predictionCol)).metadata)
}
numColsOutput += 1
}
if (numColsOutput == 0) {
this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" +
" because no output columns were set.")
}
outputData.toDF()
}
@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
$(predictionCol), weights.length)
}
if ($(probabilityCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(probabilityCol), weights.length)
}
outputSchema
}
@Since("3.0.0")
def predict(features: Vector): Int = {
val r = predictProbability(features)
r.argmax
}
@Since("3.0.0")
def predictProbability(features: Vector): Vector = {
val probs = GaussianMixtureModel.computeProbabilities(features, gaussians, weights)
Vectors.dense(probs)
}
/**
* Retrieve Gaussian distributions as a DataFrame.
* Each row represents a Gaussian Distribution.
* Two columns are defined: mean and cov.
* Schema:
* {{{
* root
* |-- mean: vector (nullable = true)
* |-- cov: matrix (nullable = true)
* }}}
*/
@Since("2.0.0")
def gaussiansDF: DataFrame = {
val modelGaussians = gaussians.map { gaussian =>
(OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov))
}.toImmutableArraySeq
SparkSession.builder().getOrCreate().createDataFrame(modelGaussians).toDF("mean", "cov")
}
/**
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
*
* For [[GaussianMixtureModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
*/
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"GaussianMixtureModel: uid=$uid, k=${weights.length}, numFeatures=$numFeatures"
}
/**
* Gets summary of model on training set. An exception is
* thrown if `hasSummary` is false.
*/
@Since("2.0.0")
override def summary: GaussianMixtureSummary = super.summary
}
@Since("2.0.0")
object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
@Since("2.0.0")
override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
@Since("2.0.0")
override def load(path: String): GaussianMixtureModel = super.load(path)
/** [[MLWriter]] instance for [[GaussianMixtureModel]] */
private[GaussianMixtureModel] class GaussianMixtureModelWriter(
instance: GaussianMixtureModel) extends MLWriter {
private case class Data(weights: Array[Double], mus: Array[OldVector], sigmas: Array[OldMatrix])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
val mus = gaussians.map(g => OldVectors.fromML(g.mean))
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] {
/** Checked against metadata when loading model */
private val className = classOf[GaussianMixtureModel].getName
override def load(path: String): GaussianMixtureModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
val mus = row.getSeq[OldVector](1).toArray
val sigmas = row.getSeq[OldMatrix](2).toArray
require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
require(mus.length == weights.length, "Length of weight and Gaussian array must match")
val gaussians = mus.zip(sigmas)
.map { case (mu, sigma) => new MultivariateGaussian(mu.asML, sigma.asML) }
val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
metadata.getAndSetParams(model)
model
}
}
/**
* Compute the probability (partial assignment) for each cluster for the given data point.
*
* @param features Data point
* @param dists Gaussians for model
* @param weights Weights for each Gaussian
* @return Probability (partial assignment) for each of the k clusters
*/
private[clustering] def computeProbabilities(
features: Vector,
dists: Array[MultivariateGaussian],
weights: Array[Double]): Array[Double] = {
val probArray = Array.ofDim[Double](weights.length)
var probSum = 0.0
var i = 0
while (i < weights.length) {
val p = EPSILON + weights(i) * dists(i).pdf(features)
probArray(i) = p
probSum += p
i += 1
}
i = 0
while (i < weights.length) {
probArray(i) /= probSum
i += 1
}
probArray
}
}
/**
* Gaussian Mixture clustering.
*
* This class performs expectation maximization for multivariate Gaussian
* Mixture Models (GMMs). A GMM represents a composite distribution of
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
* Given a set of sample points, this class will maximize the log-likelihood
* for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
* to find a global optimum.
*
* @note This algorithm is limited in its number of features since it requires storing a covariance
* matrix which has size quadratic in the number of features. Even when the number of features does
* not exceed this limit, this algorithm may perform poorly on high-dimensional data.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
* on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
*/
@Since("2.0.0")
class GaussianMixture @Since("2.0.0") (
@Since("2.0.0") override val uid: String)
extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra)
@Since("2.0.0")
def this() = this(Identifiable.randomUID("GaussianMixture"))
/** @group setParam */
@Since("2.0.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
@Since("2.0.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
@Since("2.0.0")
def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
/** @group setParam */
@Since("2.0.0")
def setK(value: Int): this.type = set(k, value)
/** @group setParam */
@Since("2.0.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */
@Since("2.0.0")
def setTol(value: Double): this.type = set(tol, value)
/** @group setParam */
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
/** @group expertSetParam */
@Since("3.0.0")
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
/**
* Number of samples per cluster to use when initializing Gaussians.
*/
private val numSamples = 5
@Since("2.0.0")
override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
import spark.implicits._
val numFeatures = getNumFeatures(dataset, $(featuresCol))
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, probabilityCol, weightCol, k, maxIter,
seed, tol, aggregationDepth)
instr.logNumFeatures(numFeatures)
val instances = dataset.select(
checkNonNanVectors(columnToVector(dataset, $(featuresCol))),
checkNonNegativeWeights(get(weightCol))
).as[(Vector, Double)].rdd
.setName("training instances")
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) }
// TODO: SPARK-15785 Support users supplied initial GMM.
val (weights, gaussians) = initRandom(instances, $(k), numFeatures)
val (logLikelihood, iteration) = trainImpl(instances, weights, gaussians, numFeatures, instr)
if (handlePersistence) { instances.unpersist() }
val gaussianDists = gaussians.map { case (mean, covVec) =>
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
new MultivariateGaussian(mean, cov)
}
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists))
.setParent(this)
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration)
instr.logNamedValue("logLikelihood", logLikelihood)
instr.logNamedValue("clusterSizes", summary.clusterSizes)
model.setSummary(Some(summary))
}
private def trainImpl(
instances: RDD[(Vector, Double)],
weights: Array[Double],
gaussians: Array[(DenseVector, DenseVector)],
numFeatures: Int,
instr: Instrumentation): (Double, Int) = {
val sc = instances.sparkContext
var logLikelihood = Double.MinValue
var logLikelihoodPrev = 0.0
var iteration = 0
while (iteration < $(maxIter) && math.abs(logLikelihood - logLikelihoodPrev) > $(tol)) {
val weightSumAccum = if (iteration == 0) sc.doubleAccumulator else null
val logLikelihoodAccum = sc.doubleAccumulator
val bcWeights = sc.broadcast(weights)
val bcGaussians = sc.broadcast(gaussians)
// aggregate the cluster contribution for all sample points,
// and then compute the new distributions
instances.mapPartitions { iter =>
if (iter.nonEmpty) {
val agg = new ExpectationAggregator(numFeatures, bcWeights, bcGaussians)
while (iter.hasNext) { agg.add(iter.next()) }
// sum of weights in this partition
val ws = agg.weights.sum
if (iteration == 0) weightSumAccum.add(ws)
logLikelihoodAccum.add(agg.logLikelihood)
Iterator.tabulate(bcWeights.value.length) { i =>
(i, (agg.means(i), agg.covs(i), agg.weights(i), ws))
}
} else Iterator.empty
}.reduceByKey(GaussianMixture.mergeWeightsMeans).mapValues { case (mean, cov, w, ws) =>
// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
GaussianMixture.updateWeightsAndGaussians(mean, cov, w, ws)
}.collect().foreach { case (i, (weight, gaussian)) =>
weights(i) = weight
gaussians(i) = gaussian
}
bcWeights.destroy()
bcGaussians.destroy()
if (iteration == 0) {
instr.logNumExamples(weightSumAccum.count)
instr.logSumOfWeights(weightSumAccum.value)
}
logLikelihoodPrev = logLikelihood // current becomes previous
logLikelihood = logLikelihoodAccum.value // this is the freshly computed log-likelihood
instr.logNamedValue(s"logLikelihood@iter$iteration", logLikelihood)
iteration += 1
}
(logLikelihood, iteration)
}
@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
/**
* Initialize weights and corresponding gaussian distributions at random.
*
* We start with uniform weights, a random mean from the data, and diagonal covariance matrices
* using component variances derived from the samples.
*
* @param instances The training instances.
* @param numClusters The number of clusters.
* @param numFeatures The number of features of training instance.
* @return The initialized weights and corresponding gaussian distributions. Note the
* covariance matrix of multivariate gaussian distribution is symmetric and
* we only save the upper triangular part as a dense vector (column major).
*/
private def initRandom(
instances: RDD[(Vector, Double)],
numClusters: Int,
numFeatures: Int): (Array[Double], Array[(DenseVector, DenseVector)]) = {
val (samples, sampleWeights) = instances
.takeSample(withReplacement = true, numClusters * numSamples, $(seed))
.unzip
val weights = new Array[Double](numClusters)
val weightSum = sampleWeights.sum
val gaussians = Array.tabulate(numClusters) { i =>
val start = i * numSamples
val end = start + numSamples
val sampleSlice = samples.slice(start, end)
val weightSlice = sampleWeights.slice(start, end)
val localWeightSum = weightSlice.sum
weights(i) = localWeightSum / weightSum
val mean = {
val v = new DenseVector(new Array[Double](numFeatures))
var j = 0
while (j < numSamples) {
BLAS.axpy(weightSlice(j), sampleSlice(j), v)
j += 1
}
BLAS.scal(1.0 / localWeightSum, v)
v
}
/*
Construct matrix where diagonal entries are element-wise
variance of input vectors (computes biased variance).
Since the covariance matrix of multivariate gaussian distribution is symmetric,
only the upper triangular part of the matrix (column major) will be saved as
a dense vector in order to reduce the shuffled data size.
*/
val cov = {
val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
var j = 0
while (j < numSamples) {
val v = sampleSlice(j).asBreeze - mean.asBreeze
ss += (v * v) * weightSlice(j)
j += 1
}
val diagVec = Vectors.fromBreeze(ss)
BLAS.scal(1.0 / localWeightSum, diagVec)
val covVec = new DenseVector(Array.ofDim[Double](numFeatures * (numFeatures + 1) / 2))
diagVec.foreach { (i, v) => covVec.values(i + i * (i + 1) / 2) = v }
covVec
}
(mean, cov)
}
(weights, gaussians)
}
}
@Since("2.0.0")
object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
/** Limit number of features such that numFeatures^2^ < Int.MaxValue */
private[clustering] val MAX_NUM_FEATURES = math.sqrt(Int.MaxValue).toInt
@Since("2.0.0")
override def load(path: String): GaussianMixture = super.load(path)
/**
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
* into an n * n array representing the full symmetric matrix (column major).
*
* @param n The order of the n by n matrix.
* @param triangularValues The upper triangular part of the matrix packed in an array
* (column major).
* @return A dense matrix which represents the symmetric matrix in column major.
*/
private[clustering] def unpackUpperTriangularMatrix(
n: Int,
triangularValues: Array[Double]): DenseMatrix = {
val symmetricValues = unpackUpperTriangular(n, triangularValues)
new DenseMatrix(n, n, symmetricValues)
}
private def mergeWeightsMeans(
a: (DenseVector, DenseVector, Double, Double),
b: (DenseVector, DenseVector, Double, Double)): (DenseVector, DenseVector, Double, Double) =
{
// update the weights, means and covariances for i-th distributions
BLAS.axpy(1.0, b._1, a._1)
BLAS.axpy(1.0, b._2, a._2)
(a._1, a._2, a._3 + b._3, a._4 + b._4)
}
/**
* Update the weight, mean and covariance of gaussian distribution.
*
* @param mean The mean of the gaussian distribution.
* @param cov The covariance matrix of the gaussian distribution. Note we only
* save the upper triangular part as a dense vector (column major).
* @param weight The weight of the gaussian distribution.
* @param sumWeights The sum of weights of all clusters.
* @return The updated weight, mean and covariance.
*/
private[clustering] def updateWeightsAndGaussians(
mean: DenseVector,
cov: DenseVector,
weight: Double,
sumWeights: Double): (Double, (DenseVector, DenseVector)) = {
BLAS.scal(1.0 / weight, mean)
BLAS.spr(-weight, mean, cov)
BLAS.scal(1.0 / weight, cov)
val newWeight = weight / sumWeights
val newGaussian = (mean, cov)
(newWeight, newGaussian)
}
}
/**
* ExpectationAggregator computes the partial expectation results.
*
* @param numFeatures The number of features.
* @param bcWeights The broadcast weights for each Gaussian distribution in the mixture.
* @param bcGaussians The broadcast array of Multivariate Gaussian (Normal) Distribution
* in the mixture. Note only upper triangular part of the covariance
* matrix of each distribution is stored as dense vector (column major)
* in order to reduce shuffled data size.
*/
private class ExpectationAggregator(
numFeatures: Int,
bcWeights: Broadcast[Array[Double]],
bcGaussians: Broadcast[Array[(DenseVector, DenseVector)]]) extends Serializable {
private val k = bcWeights.value.length
private var totalCnt = 0L
private var newLogLikelihood = 0.0
private val covSize = numFeatures * (numFeatures + 1) / 2
private lazy val newWeights = Array.ofDim[Double](k)
@transient private lazy val newMeans = Array.fill(k)(Vectors.zeros(numFeatures).toDense)
@transient private lazy val newCovs = Array.fill(k)(Vectors.zeros(covSize).toDense)
@transient private lazy val gaussians = {
bcGaussians.value.map { case (mean, covVec) =>
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
new MultivariateGaussian(mean, cov)
}
}
def count: Long = totalCnt
def logLikelihood: Double = newLogLikelihood
def weights: Array[Double] = newWeights
def means: Array[DenseVector] = newMeans
def covs: Array[DenseVector] = newCovs
/**
* Add a new training instance to this ExpectationAggregator, update the weights,
* means and covariances for each distributions, and update the log likelihood.
*
* @param instance The instance of data point to be added.
* @return This ExpectationAggregator object.
*/
def add(instance: (Vector, Double)): this.type = {
val (vector: Vector, weight: Double) = instance
val localWeights = bcWeights.value
val localGaussians = gaussians
val prob = new Array[Double](k)
var probSum = 0.0
var i = 0
while (i < k) {
val p = EPSILON + localWeights(i) * localGaussians(i).pdf(vector)
prob(i) = p
probSum += p
i += 1
}
newLogLikelihood += math.log(probSum) * weight
val localNewWeights = newWeights
val localNewMeans = newMeans
val localNewCovs = newCovs
i = 0
while (i < k) {
val w = prob(i) / probSum * weight
localNewWeights(i) += w
BLAS.axpy(w, vector, localNewMeans(i))
BLAS.spr(w, vector, localNewCovs(i))
i += 1
}
totalCnt += 1
this
}
}
/**
* Summary of GaussianMixture.
*
* @param predictions `DataFrame` produced by `GaussianMixtureModel.transform()`.
* @param predictionCol Name for column of predicted clusters in `predictions`.
* @param probabilityCol Name for column of predicted probability of each cluster
* in `predictions`.
* @param featuresCol Name for column of features in `predictions`.
* @param k Number of clusters.
* @param logLikelihood Total log-likelihood for this model on the given data.
* @param numIter Number of iterations.
*/
@Since("2.0.0")
class GaussianMixtureSummary private[clustering] (
predictions: DataFrame,
predictionCol: String,
@Since("2.0.0") val probabilityCol: String,
featuresCol: String,
k: Int,
@Since("2.2.0") val logLikelihood: Double,
numIter: Int)
extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) {
/**
* Probability of each cluster.
*/
@Since("2.0.0")
@transient lazy val probability: DataFrame = predictions.select(probabilityCol)
}