From d72df895ab0f79281912d8621d84d9cb8723be4c Mon Sep 17 00:00:00 2001 From: smarthi Date: Mon, 21 Mar 2016 22:54:51 -0400 Subject: [PATCH] MAHOUT-1816: Implement newRowCardinality in CheckpointedFlinkDrm --- .../flinkbindings/drm/CheckpointedFlinkDrm.scala | 10 ++++++---- .../sparkbindings/drm/CheckpointedDrmSpark.scala | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala index 6f1ba9fe89..a6b267bc5e 100644 --- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala +++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala @@ -192,15 +192,17 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], (x: K) => new Text(x.asInstanceOf[String]) } else if (keyTag.runtimeClass == classOf[Long]) { (x: K) => new LongWritable(x.asInstanceOf[Long]) - // WritableTypeInfo will reject the base Writable class -// } else if (classOf[Writable].isAssignableFrom(keyTag.runtimeClass)) { -// (x: K) => x.asInstanceOf[Writable] } else { throw new IllegalArgumentException("Do not know how to convert class tag %s to Writable.".format(keyTag)) } } - def newRowCardinality(n: Int): CheckpointedDrm[K] = ??? + def newRowCardinality(n: Int): CheckpointedDrm[K] = { + assert(n > -1) + assert(n >= nrow) + new CheckpointedFlinkDrm(ds = ds, _nrow = n, _ncol = _ncol, cacheHint = cacheHint, + partitioningTag = partitioningTag, _canHaveMissingRows = _canHaveMissingRows) + } override val context: DistributedContext = ds.getExecutionEnvironment diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala index 71755c5b39..ff150a11e0 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/CheckpointedDrmSpark.scala @@ -155,7 +155,7 @@ class CheckpointedDrmSpark[K: ClassTag]( /** * Dump matrix as computed Mahout's DRM into specified (HD)FS path * - * @param path + * @param path output path to dump Matrix to */ def dfsWrite(path: String) = { val ktag = implicitly[ClassTag[K]] @@ -201,7 +201,7 @@ class CheckpointedDrmSpark[K: ClassTag]( rddInput.isBlockified match { case true ⇒ rddInput.asBlockified(throw new AssertionError("not reached")) .map(_._2.ncol).reduce(max) - case false ⇒ cache().rddInput.asRowWise().map(_._2.length).fold(-1)(max(_, _)) + case false ⇒ cache().rddInput.asRowWise().map(_._2.length).fold(-1)(max) } }