Skip to content

Commit

Permalink
Refactor KernelMatrix API to improve performance
Browse files Browse the repository at this point in the history
This changes the memory management approach and unpersists
broadcast variables which could become significant for large kernels.
Also this avoids running a job to compute the diagonal block for training
data.
  • Loading branch information
shivaram committed Jun 15, 2016
1 parent b623115 commit efe7374
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 47 deletions.
18 changes: 9 additions & 9 deletions src/main/scala/nodes/learning/KernelBlockLinearMapper.scala
Expand Up @@ -61,20 +61,20 @@ class KernelBlockLinearMapper[T: ClassTag](
pred :+ (testKernelBB * modelBlockBC.value)
}

predictionsNew.cache()
predictionsNew.count()
predictions.unpersist(true)

testKernelMat.unpersist(blockIdxs.toSeq)
modelBlockBC.unpersist(true)

// If we are checkpointing update our cache
if (in.context.getCheckpointDir.isDefined &&
block % blocksBeforeCheckpoint == (blocksBeforeCheckpoint - 1)) {
predictionsNew = MatrixUtils.truncateLineage(predictionsNew, true)
predictionsNew.count

predictions.unpersist(true)
predictions = predictionsNew
} else {
predictions = predictionsNew
predictionsNew = MatrixUtils.truncateLineage(predictionsNew, false)
}
predictions = predictionsNew
}
// TODO: We need to cache, count the predictions if we want to unpersist
// the broadcast variables here ?
predictions.flatMap(x => MatrixUtils.matrixToRowArray(x))
}

Expand Down
54 changes: 37 additions & 17 deletions src/main/scala/nodes/learning/KernelGenerator.scala
Expand Up @@ -67,7 +67,11 @@ trait KernelTransformer[T] {
def apply(data: T): DenseVector[Double]

// Internal function used to lazily populate the kernel matrix.
private[learning] def computeKernel(data: RDD[T], idxs: Seq[Int]): RDD[DenseVector[Double]]
// NOTE: This function returns a *cached* RDD and the caller is responsible
// for managing its life cycle.
private[learning] def computeKernel(
data: RDD[T],
idxs: Seq[Int]): (RDD[DenseVector[Double]], DenseMatrix[Double])
}

class GaussianKernelTransformer(
Expand Down Expand Up @@ -110,7 +114,7 @@ class GaussianKernelTransformer(
def computeKernel(
data: RDD[DenseVector[Double]],
blockIdxs: Seq[Int])
: RDD[DenseVector[Double]] = {
: (RDD[DenseVector[Double]], DenseMatrix[Double]) = {

// Dot product of rows of X with each other
val dataDotProd = if (data.id == trainData.id) {
Expand All @@ -122,44 +126,60 @@ class GaussianKernelTransformer(
// Extract a b x d block of training data
val blockIdxSet = blockIdxs.toSet

val blockDataArray = trainData.zipWithIndex.filter { case (vec, idx) =>
val trainBlockArray = trainData.zipWithIndex.filter { case (vec, idx) =>
blockIdxSet.contains(idx.toInt)
}.map(x => x._1).collect()

val blockData = MatrixUtils.rowsToMatrix(blockDataArray)
assert(blockData.rows == blockIdxs.length)
val blockDataBC = data.context.broadcast(blockData)
val trainBlock = MatrixUtils.rowsToMatrix(trainBlockArray)
assert(trainBlock.rows == blockIdxs.length)
val trainBlockBC = data.context.broadcast(trainBlock)

// <xi,xj> for i in [nTest], j in blockIdxs
val blockXXT = data.mapPartitions { itr =>
val bd = blockDataBC.value
val bd = trainBlockBC.value
val vecMat = MatrixUtils.rowsToMatrix(itr)
Iterator.single(vecMat*bd.t)
Iterator.single(vecMat * bd.t)
}

val trainBlockDotProd = DenseVector(trainDotProd.zipWithIndex.filter { case (vec, idx) =>
blockIdxSet.contains(idx.toInt)
}.map(x => x._1).collect())
val trainBlockDotProdBC = data.context.broadcast(trainBlockDotProd)

val kBlock = blockXXT.zipPartitions(dataDotProd) { case (iterXXT, iterTestDotProds) =>
val kBlock = blockXXT.zipPartitions(dataDotProd) { case (iterXXT, iterDataDotProds) =>
val xxt = iterXXT.next()
assert(iterXXT.isEmpty)
iterTestDotProds.zipWithIndex.map { case (testDotProd, idx) =>
iterDataDotProds.zipWithIndex.map { case (dataDotProdVal, idx) =>
val term1 = xxt(idx, ::).t * (-2.0)
val term2 = DenseVector.fill(xxt.cols){testDotProd}
val term2 = DenseVector.fill(xxt.cols)(dataDotProdVal)
val term3 = trainBlockDotProdBC.value
val term4 = (term1 + term2 + term3) * (-gamma)
exp(term4)
}
}

// TODO: We can't unpersist these broadcast variables until we have cached the
// kernel block matrices ?
// We could introduce a new clean up method here ?
kBlock.cache()
kBlock.count

trainBlockBC.unpersist(true)
trainBlockDotProdBC.unpersist(true)

val diagBlock = if (data.id == trainData.id) {
// For train data use locally available data to compute diagonal block
val kBlockBlock = trainBlock * trainBlock.t
kBlockBlock :*= (-2.0)
kBlockBlock(::, *) :+= trainBlockDotProd
kBlockBlock(*, ::) :+= trainBlockDotProd
kBlockBlock *= -gamma
exp.inPlace(kBlockBlock)
kBlockBlock
} else {
// For test data extract the diagonal block from the cached block
MatrixUtils.rowsToMatrix(kBlock.zipWithIndex.filter { case (vec, idx) =>
blockIdxSet.contains(idx.toInt)
}.map(x => x._1).collect())
}

// blockDataBC.unpersist()
// trainBlockDotProdBC.unpersist()
kBlock
(kBlock, diagBlock)
}
}
49 changes: 42 additions & 7 deletions src/main/scala/nodes/learning/KernelMatrix.scala
Expand Up @@ -18,11 +18,28 @@ trait KernelMatrix {

/**
* Extract specified columns from the kernel matrix.
* NOTE: This returns a *cached* RDD and unpersist should
* be called at the end of a block.
*
* @param colIdxs the column indexes to extract
* @return A sub-matrix of size n x idxs.size as an RDD.
*/
def apply(colIdxs: Seq[Int]): RDD[DenseMatrix[Double]]

/**
* Extract a diagonal block from the kernel matrix.
*
* @param idxs the column, row indexes to extract
* @return A local matrix of size idxs.size x idxs.size
*/
def diagBlock(idxs: Seq[Int]): DenseMatrix[Double]

/**
* Clean up resources associated with a kernel block.
*
* @param colIdxs column indexes corresponding to the block.
*/
def unpersist(colIdxs: Seq[Int]): Unit
}

/**
Expand All @@ -36,18 +53,36 @@ class BlockKernelMatrix[T: ClassTag](
val cacheKernel: Boolean)
extends KernelMatrix {

val cache = HashMap.empty[Seq[Int], RDD[DenseMatrix[Double]]]
val colBlockCache = HashMap.empty[Seq[Int], RDD[DenseVector[Double]]]
val diagBlockCache = HashMap.empty[Seq[Int], DenseMatrix[Double]]

def apply(colIdxs: Seq[Int]): RDD[DenseMatrix[Double]] = {
if (cache.contains(colIdxs)) {
cache(colIdxs)
if (colBlockCache.contains(colIdxs)) {
MatrixUtils.rowsToMatrix(colBlockCache(colIdxs))
} else {
val kBlock = MatrixUtils.rowsToMatrix(kernelGen.computeKernel(data, colIdxs))
val (kBlock, diagBlock) = kernelGen.computeKernel(data, colIdxs)
if (cacheKernel) {
colBlockCache += (colIdxs -> kBlock)
diagBlockCache += (colIdxs -> diagBlock)
}
MatrixUtils.rowsToMatrix(kBlock)
}
}

def unpersist(colIdxs: Seq[Int]): Unit = {
if (colBlockCache.contains(colIdxs) && !cacheKernel) {
colBlockCache(colIdxs).unpersist(true)
}
}

def diagBlock(idxs: Seq[Int]): DenseMatrix[Double] = {
if (!diagBlockCache.contains(idxs)) {
val (kBlock, diagBlock) = kernelGen.computeKernel(data, idxs)
if (cacheKernel) {
kBlock.cache()
cache += (colIdxs -> kBlock)
colBlockCache += (idxs -> kBlock)
diagBlockCache += (idxs -> diagBlock)
}
kBlock
}
diagBlockCache(idxs)
}
}
16 changes: 2 additions & 14 deletions src/main/scala/nodes/learning/KernelRidgeRegression.scala
Expand Up @@ -145,16 +145,7 @@ object KernelRidgeRegression extends Logging {
val blockIdxsBC = labelsMat.context.broadcast(blockIdxsSeq)

val kernelBlockMat = trainKernelMat(blockIdxsSeq)

// If the kernel block is not already cached, cache it.
// This helps us compute the kernel only once per block.
val kernelCached = kernelBlockMat.getStorageLevel.useMemory
if (!kernelCached) {
kernelBlockMat.cache()
}

val kernelBlockRPM = RowPartitionedMatrix.fromMatrix(kernelBlockMat)
val kernelBlockBlockMat = kernelBlockRPM(blockIdxs, ::).collect()
val kernelBlockBlockMat = trainKernelMat.diagBlock(blockIdxsSeq)

val kernelGenEnd = System.nanoTime

Expand Down Expand Up @@ -229,10 +220,7 @@ object KernelRidgeRegression extends Logging {
s"localSolve: ${(localSolveEnd - collectEnd)/1e9} " +
s"modelUpdate: ${(updateEnd - localSolveEnd)/1e9}")

// If we cached the kernel in KRR then unpersist it
if (!kernelCached) {
kernelBlockMat.unpersist()
}
trainKernelMat.unpersist(blockIdxsSeq)

wBlockBCs.map { case (wBlockOldBC, wBlockNewBC) =>
wBlockOldBC.unpersist(true)
Expand Down

0 comments on commit efe7374

Please sign in to comment.