diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala index b3b72b06ef..f1d23b2266 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala @@ -357,7 +357,16 @@ object FlinkEngine extends DistributedEngine { implicit val typeInformation = generateTypeInformation[K] val sample = DataSetUtils(drmX.dataset).sample(replacement, fraction) - new CheckpointedFlinkDrm[K](sample) + + val res = if (kTag != ClassTag.Int) { + new CheckpointedFlinkDrm[K](sample) + } + else { + blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1 + .asInstanceOf[DrmLike[K]] + } + + res } def drmSampleKRows[K](drmX: DrmLike[K], numSamples:Int, replacement: Boolean = false): Matrix = { @@ -365,15 +374,17 @@ object FlinkEngine extends DistributedEngine { implicit val typeInformation = generateTypeInformation[K] val sample = DataSetUtils(drmX.dataset).sampleWithSize(replacement, numSamples) + val sampleArray = sample.collect().toArray + val isSparse = sampleArray.exists { case (_, vec) ⇒ !vec.isDense } - val res = if (kTag != ClassTag.Int) { - new CheckpointedFlinkDrm[K](sample) - } - else { - blas.rekeySeqInts(new RowsFlinkDrm[K](sample, ncol = drmX.ncol), computeMap = false)._1 - } + val vectors = sampleArray.map(_._2) + val labels = sampleArray.view.zipWithIndex + .map { case ((key, _), idx) ⇒ key.toString → (idx: Integer) }.toMap + + val mx: Matrix = if (isSparse) sparse(vectors: _*) else dense(vectors) + mx.setRowLabelBindings(labels) - res.collect + mx } /** Engine-specific all reduce tensor operation. */