Skip to content

Commit

Permalink
MAHOUT-1848: drmSampleKRows in FlinkEngine should generate a dense or…
Browse files Browse the repository at this point in the history
… sparse matrix, this closes #233
  • Loading branch information
smarthi committed May 3, 2016
1 parent 6ac833b commit 6ab5a8d
Showing 1 changed file with 19 additions and 8 deletions.
Expand Up @@ -357,23 +357,34 @@ object FlinkEngine extends DistributedEngine {
implicit val typeInformation = generateTypeInformation[K]

val sample = DataSetUtils(drmX.dataset).sample(replacement, fraction)
new CheckpointedFlinkDrm[K](sample)

val res = if (kTag != ClassTag.Int) {
new CheckpointedFlinkDrm[K](sample)
}
else {
blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1
.asInstanceOf[DrmLike[K]]
}

res
}

def drmSampleKRows[K](drmX: DrmLike[K], numSamples:Int, replacement: Boolean = false): Matrix = {
implicit val kTag: ClassTag[K] = drmX.keyClassTag
implicit val typeInformation = generateTypeInformation[K]

val sample = DataSetUtils(drmX.dataset).sampleWithSize(replacement, numSamples)
val sampleArray = sample.collect().toArray
val isSparse = sampleArray.exists { case (_, vec) !vec.isDense }

val res = if (kTag != ClassTag.Int) {
new CheckpointedFlinkDrm[K](sample)
}
else {
blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1
}
val vectors = sampleArray.map(_._2)
val labels = sampleArray.view.zipWithIndex
.map { case ((key, _), idx) key.toString (idx: Integer) }.toMap

val mx: Matrix = if (isSparse) sparse(vectors: _*) else dense(vectors)
mx.setRowLabelBindings(labels)

res.collect
mx
}

/** Engine-specific all reduce tensor operation. */
Expand Down

0 comments on commit 6ab5a8d

Please sign in to comment.