Skip to content

Commit

Permalink
[SPARK-36956][MLLIB] model prediction in .mllib avoid conversion to b…
Browse files Browse the repository at this point in the history
…reeze vector

### What changes were proposed in this pull request?
model prediction in .mllib avoid conversion to breeze vector

### Why are the changes needed?
avoid unnecessary conversion

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
existing suites

Closes #34221 from zhengruifeng/mllib_model_avoid_breeze_conversion.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Huaxin Gao <huaxin_gao@apple.com>
  • Loading branch information
zhengruifeng authored and huaxingao committed Oct 25, 2021
1 parent f44f9b4 commit d5563f3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.classification
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
Expand Down Expand Up @@ -72,7 +72,7 @@ class SVMModel @Since("1.1.0") (
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept
val margin = BLAS.dot(weightMatrix, dataMatrix) + intercept
threshold match {
case Some(t) => if (margin > t) 1.0 else 0.0
case None => margin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
Expand All @@ -43,7 +43,7 @@ class LassoModel @Since("1.1.0") (
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept
BLAS.dot(weightMatrix, dataMatrix) + intercept
}

@Since("1.3.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
Expand All @@ -43,7 +43,7 @@ class LinearRegressionModel @Since("1.1.0") (
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept
BLAS.dot(weightMatrix, dataMatrix) + intercept
}

@Since("1.3.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
Expand All @@ -43,7 +43,7 @@ class RidgeRegressionModel @Since("1.1.0") (
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept
BLAS.dot(weightMatrix, dataMatrix) + intercept
}

@Since("1.3.0")
Expand Down

0 comments on commit d5563f3

Please sign in to comment.