Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Dec 17, 2014
1 parent 1efaecf commit 3f2d81a
Showing 1 changed file with 40 additions and 12 deletions.
52 changes: 40 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Expand Up @@ -25,7 +25,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW

import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.{Logging, HashPartitioner, Partitioner}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -71,6 +71,12 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
def getRatingCol: String = get(ratingCol)

/**
* Validates and transforms the input schema.
* @param schema input schema
* @param paramMap extra params
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
assert(schema(map(userCol)).dataType == IntegerType)
Expand All @@ -85,6 +91,9 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
}
}

/**
* Model fitted by ALS.
*/
class ALSModel private[ml] (
override val parent: ALS,
override val fittingParamMap: ParamMap,
Expand Down Expand Up @@ -127,10 +136,13 @@ class ALSModel private[ml] (
}

private object ALSModel {

/** Case class to convert factors to SchemaRDDs */
private case class Factor(id: Int, features: Seq[Float])
}

/**
* Alternating least squares (ALS).
*/
class ALS extends Estimator[ALSModel] with ALSParams {

import ALS.Rating
Expand All @@ -154,6 +166,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
this
}

setMaxIter(20)
setRegParam(1.0)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
Expand All @@ -176,10 +191,12 @@ class ALS extends Estimator[ALSModel] with ALSParams {
}
}

object ALS {
private object ALS extends Logging {

/** Rating class for better code readability. */
private case class Rating(user: Int, product: Int, rating: Float)

/** Cholesky solver for least square problems. */
private class CholeskySolver(val k: Int) {

val upper = "U"
Expand Down Expand Up @@ -207,6 +224,7 @@ object ALS {
}
}

/** Representing a normal equation (ALS' subproblem). */
private class NormalEquation(val k: Int) extends Serializable {

val triK = k * (k + 1) / 2
Expand Down Expand Up @@ -256,6 +274,9 @@ object ALS {
}
}

/**
* Implementation of the ALS algorithm.
*/
private def train(
ratings: RDD[Rating],
rank: Int = 10,
Expand Down Expand Up @@ -424,6 +445,9 @@ object ALS {
}
}

/**
* Blockifies raw ratings.
*/
private def blockifyRatings(
ratings: RDD[Rating],
srcPart: Partitioner,
Expand Down Expand Up @@ -457,6 +481,10 @@ object ALS {
}.setName("blockRatings")
}

/**
* Builder for blocks of (srcId, dstEncodedIndex, rating) tuples.
* @param encoder
*/
private class UncompressedBlockBuilder(encoder: LocalIndexEncoder) {

val srcIds = mutable.ArrayBuilder.make[Int]
Expand Down Expand Up @@ -486,6 +514,9 @@ object ALS {
}
}

/**
* A block of (srcId, dstEncodedIndex, rating) tuples.
*/
private class UncompressedBlock(
val srcIds: Array[Int],
val dstEncodedIndices: Array[Int],
Expand Down Expand Up @@ -531,17 +562,13 @@ object ALS {
InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
}

private def timSort(): Unit = {
val sorter = new Sorter(new UncompressedBlockSort)
sorter.sort(this, 0, size, Ordering[IntWrapper])
}

private def sort(): Unit = {
val sz = size
println("size: " + sz)
logDebug(s"Sorting uncompressed block of size $sz.")
val start = System.nanoTime()
timSort()
println("sort uncompressed time: " + (System.nanoTime() - start) / 1e9)
val sorter = new Sorter(new UncompressedBlockSort)
sorter.sort(this, 0, size, Ordering[IntWrapper])
logDebug("Sorting took " + (System.nanoTime() - start) / 1e9 + " seconds.")
}
}

Expand Down Expand Up @@ -643,7 +670,8 @@ object ALS {
dstIdToLocalIndex.update(sortedDstIds(i), i)
i += 1
}
println("convert to local indices time: " + (System.nanoTime() - start) / 1e9)
logDebug(
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + "seconds.")
val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
Expand Down

0 comments on commit 3f2d81a

Please sign in to comment.