diff --git a/src/main/scala/nodes/learning/KernelBlockLinearMapper.scala b/src/main/scala/nodes/learning/KernelBlockLinearMapper.scala index 05fd71d7..6f4d57fe 100644 --- a/src/main/scala/nodes/learning/KernelBlockLinearMapper.scala +++ b/src/main/scala/nodes/learning/KernelBlockLinearMapper.scala @@ -61,20 +61,20 @@ class KernelBlockLinearMapper[T: ClassTag]( pred :+ (testKernelBB * modelBlockBC.value) } + predictionsNew.cache() + predictionsNew.count() + predictions.unpersist(true) + + testKernelMat.unpersist(blockIdxs.toSeq) + modelBlockBC.unpersist(true) + // If we are checkpointing update our cache if (in.context.getCheckpointDir.isDefined && block % blocksBeforeCheckpoint == (blocksBeforeCheckpoint - 1)) { - predictionsNew = MatrixUtils.truncateLineage(predictionsNew, true) - predictionsNew.count - - predictions.unpersist(true) - predictions = predictionsNew - } else { - predictions = predictionsNew + predictionsNew = MatrixUtils.truncateLineage(predictionsNew, false) } + predictions = predictionsNew } - // TODO: We need to cache, count the predictions if we want to unpersist - // the broadcast variables here ? predictions.flatMap(x => MatrixUtils.matrixToRowArray(x)) } diff --git a/src/main/scala/nodes/learning/KernelGenerator.scala b/src/main/scala/nodes/learning/KernelGenerator.scala index 7ad3a934..b1f34ff4 100644 --- a/src/main/scala/nodes/learning/KernelGenerator.scala +++ b/src/main/scala/nodes/learning/KernelGenerator.scala @@ -67,7 +67,11 @@ trait KernelTransformer[T] { def apply(data: T): DenseVector[Double] // Internal function used to lazily populate the kernel matrix. - private[learning] def computeKernel(data: RDD[T], idxs: Seq[Int]): RDD[DenseVector[Double]] + // NOTE: This function returns a *cached* RDD and the caller is responsible + // for managing its life cycle. + private[learning] def computeKernel( + data: RDD[T], + idxs: Seq[Int]): (RDD[DenseVector[Double]], DenseMatrix[Double]) } class GaussianKernelTransformer( @@ -110,7 +114,7 @@ class GaussianKernelTransformer( def computeKernel( data: RDD[DenseVector[Double]], blockIdxs: Seq[Int]) - : RDD[DenseVector[Double]] = { + : (RDD[DenseVector[Double]], DenseMatrix[Double]) = { // Dot product of rows of X with each other val dataDotProd = if (data.id == trainData.id) { @@ -122,19 +126,19 @@ class GaussianKernelTransformer( // Extract a b x d block of training data val blockIdxSet = blockIdxs.toSet - val blockDataArray = trainData.zipWithIndex.filter { case (vec, idx) => + val trainBlockArray = trainData.zipWithIndex.filter { case (vec, idx) => blockIdxSet.contains(idx.toInt) }.map(x => x._1).collect() - val blockData = MatrixUtils.rowsToMatrix(blockDataArray) - assert(blockData.rows == blockIdxs.length) - val blockDataBC = data.context.broadcast(blockData) + val trainBlock = MatrixUtils.rowsToMatrix(trainBlockArray) + assert(trainBlock.rows == blockIdxs.length) + val trainBlockBC = data.context.broadcast(trainBlock) // for i in [nTest], j in blockIdxs val blockXXT = data.mapPartitions { itr => - val bd = blockDataBC.value + val bd = trainBlockBC.value val vecMat = MatrixUtils.rowsToMatrix(itr) - Iterator.single(vecMat*bd.t) + Iterator.single(vecMat * bd.t) } val trainBlockDotProd = DenseVector(trainDotProd.zipWithIndex.filter { case (vec, idx) => @@ -142,24 +146,40 @@ class GaussianKernelTransformer( }.map(x => x._1).collect()) val trainBlockDotProdBC = data.context.broadcast(trainBlockDotProd) - val kBlock = blockXXT.zipPartitions(dataDotProd) { case (iterXXT, iterTestDotProds) => + val kBlock = blockXXT.zipPartitions(dataDotProd) { case (iterXXT, iterDataDotProds) => val xxt = iterXXT.next() assert(iterXXT.isEmpty) - iterTestDotProds.zipWithIndex.map { case (testDotProd, idx) => + iterDataDotProds.zipWithIndex.map { case (dataDotProdVal, idx) => val term1 = xxt(idx, ::).t * (-2.0) - val term2 = DenseVector.fill(xxt.cols){testDotProd} + val term2 = DenseVector.fill(xxt.cols)(dataDotProdVal) val term3 = trainBlockDotProdBC.value val term4 = (term1 + term2 + term3) * (-gamma) exp(term4) } } - // TODO: We can't unpersist these broadcast variables until we have cached the - // kernel block matrices ? - // We could introduce a new clean up method here ? + kBlock.cache() + kBlock.count + + trainBlockBC.unpersist(true) + trainBlockDotProdBC.unpersist(true) + + val diagBlock = if (data.id == trainData.id) { + // For train data use locally available data to compute diagonal block + val kBlockBlock = trainBlock * trainBlock.t + kBlockBlock :*= (-2.0) + kBlockBlock(::, *) :+= trainBlockDotProd + kBlockBlock(*, ::) :+= trainBlockDotProd + kBlockBlock *= -gamma + exp.inPlace(kBlockBlock) + kBlockBlock + } else { + // For test data extract the diagonal block from the cached block + MatrixUtils.rowsToMatrix(kBlock.zipWithIndex.filter { case (vec, idx) => + blockIdxSet.contains(idx.toInt) + }.map(x => x._1).collect()) + } - // blockDataBC.unpersist() - // trainBlockDotProdBC.unpersist() - kBlock + (kBlock, diagBlock) } } diff --git a/src/main/scala/nodes/learning/KernelMatrix.scala b/src/main/scala/nodes/learning/KernelMatrix.scala index 7372efcc..8775b4df 100644 --- a/src/main/scala/nodes/learning/KernelMatrix.scala +++ b/src/main/scala/nodes/learning/KernelMatrix.scala @@ -18,11 +18,28 @@ trait KernelMatrix { /** * Extract specified columns from the kernel matrix. + * NOTE: This returns a *cached* RDD and unpersist should + * be called at the end of a block. * * @param colIdxs the column indexes to extract * @return A sub-matrix of size n x idxs.size as an RDD. */ def apply(colIdxs: Seq[Int]): RDD[DenseMatrix[Double]] + + /** + * Extract a diagonal block from the kernel matrix. + * + * @param idxs the column, row indexes to extract + * @return A local matrix of size idxs.size x idxs.size + */ + def diagBlock(idxs: Seq[Int]): DenseMatrix[Double] + + /** + * Clean up resources associated with a kernel block. + * + * @param colIdxs column indexes corresponding to the block. + */ + def unpersist(colIdxs: Seq[Int]): Unit } /** @@ -36,18 +53,36 @@ class BlockKernelMatrix[T: ClassTag]( val cacheKernel: Boolean) extends KernelMatrix { - val cache = HashMap.empty[Seq[Int], RDD[DenseMatrix[Double]]] + val colBlockCache = HashMap.empty[Seq[Int], RDD[DenseVector[Double]]] + val diagBlockCache = HashMap.empty[Seq[Int], DenseMatrix[Double]] def apply(colIdxs: Seq[Int]): RDD[DenseMatrix[Double]] = { - if (cache.contains(colIdxs)) { - cache(colIdxs) + if (colBlockCache.contains(colIdxs)) { + MatrixUtils.rowsToMatrix(colBlockCache(colIdxs)) } else { - val kBlock = MatrixUtils.rowsToMatrix(kernelGen.computeKernel(data, colIdxs)) + val (kBlock, diagBlock) = kernelGen.computeKernel(data, colIdxs) + if (cacheKernel) { + colBlockCache += (colIdxs -> kBlock) + diagBlockCache += (colIdxs -> diagBlock) + } + MatrixUtils.rowsToMatrix(kBlock) + } + } + + def unpersist(colIdxs: Seq[Int]): Unit = { + if (colBlockCache.contains(colIdxs) && !cacheKernel) { + colBlockCache(colIdxs).unpersist(true) + } + } + + def diagBlock(idxs: Seq[Int]): DenseMatrix[Double] = { + if (!diagBlockCache.contains(idxs)) { + val (kBlock, diagBlock) = kernelGen.computeKernel(data, idxs) if (cacheKernel) { - kBlock.cache() - cache += (colIdxs -> kBlock) + colBlockCache += (idxs -> kBlock) + diagBlockCache += (idxs -> diagBlock) } - kBlock } + diagBlockCache(idxs) } } diff --git a/src/main/scala/nodes/learning/KernelRidgeRegression.scala b/src/main/scala/nodes/learning/KernelRidgeRegression.scala index 5a1f5854..56f54c10 100644 --- a/src/main/scala/nodes/learning/KernelRidgeRegression.scala +++ b/src/main/scala/nodes/learning/KernelRidgeRegression.scala @@ -145,16 +145,7 @@ object KernelRidgeRegression extends Logging { val blockIdxsBC = labelsMat.context.broadcast(blockIdxsSeq) val kernelBlockMat = trainKernelMat(blockIdxsSeq) - - // If the kernel block is not already cached, cache it. - // This helps us compute the kernel only once per block. - val kernelCached = kernelBlockMat.getStorageLevel.useMemory - if (!kernelCached) { - kernelBlockMat.cache() - } - - val kernelBlockRPM = RowPartitionedMatrix.fromMatrix(kernelBlockMat) - val kernelBlockBlockMat = kernelBlockRPM(blockIdxs, ::).collect() + val kernelBlockBlockMat = trainKernelMat.diagBlock(blockIdxsSeq) val kernelGenEnd = System.nanoTime @@ -229,10 +220,7 @@ object KernelRidgeRegression extends Logging { s"localSolve: ${(localSolveEnd - collectEnd)/1e9} " + s"modelUpdate: ${(updateEnd - localSolveEnd)/1e9}") - // If we cached the kernel in KRR then unpersist it - if (!kernelCached) { - kernelBlockMat.unpersist() - } + trainKernelMat.unpersist(blockIdxsSeq) wBlockBCs.map { case (wBlockOldBC, wBlockNewBC) => wBlockOldBC.unpersist(true)