Skip to content

Commit

Permalink
build: make python test loop easier:
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jul 2, 2020
1 parent 65a13bc commit 0319650
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
31 changes: 17 additions & 14 deletions build.sbt
Expand Up @@ -60,8 +60,12 @@ cleanCondaEnvTask := {
new File(".")) ! s.log
}

def isWindows: Boolean = {
sys.props("os.name").toLowerCase.contains("windows")
}

def osPrefix: Seq[String] = {
if (sys.props("os.name").toLowerCase.contains("windows")) {
if (isWindows) {
Seq("cmd", "/C")
} else {
Seq()
Expand Down Expand Up @@ -98,6 +102,15 @@ generatePythonDoc := {

}

val pythonizedVersion = settingKey[String]("Pythonized version")
pythonizedVersion := {
if (version.value.contains("-")){
version.value.split("-".head).head + ".dev1"
}else{
version.value
}
}

def uploadToBlob(source: String, dest: String,
container: String, log: ManagedLogger,
accountName: String="mmlspark"): Int = {
Expand Down Expand Up @@ -161,14 +174,6 @@ publishR := {
singleUploadToBlob(rPackage.toString,rPackage.getName, "rrr", s.log)
}

def pythonizeVersion(v: String): String = {
if (v.contains("-")){
v.split("-".head).head + ".dev1"
}else{
v
}
}

packagePythonTask := {
val s = streams.value
(run in IntegrationTest2).toTask("").value
Expand All @@ -180,8 +185,7 @@ packagePythonTask := {
Process(
activateCondaEnv ++
Seq(s"python", "setup.py", "bdist_wheel", "--universal", "-d", s"${pythonPackageDir.absolutePath}"),
pythonSrcDir,
"MML_PY_VERSION" -> pythonizeVersion(version.value)) ! s.log
pythonSrcDir) ! s.log
}

val installPipPackageTask = TaskKey[Unit]("installPipPackage", "install python sdk")
Expand All @@ -192,7 +196,7 @@ installPipPackageTask := {
packagePythonTask.value
Process(
activateCondaEnv ++ Seq("pip", "install",
s"mmlspark-${pythonizeVersion(version.value)}-py2.py3-none-any.whl"),
s"mmlspark-${pythonizedVersion.value}-py2.py3-none-any.whl"),
pythonPackageDir) ! s.log
}

Expand All @@ -211,7 +215,6 @@ testPythonTask := {
"mmlsparktest"
),
new File("target/scala-2.11/generated/test/python/"),
"MML_VERSION" -> version.value
) ! s.log
}

Expand Down Expand Up @@ -319,7 +322,7 @@ val settings = Seq(
logBuffered in Test := false,
buildInfoKeys := Seq[BuildInfoKey](
name, version, scalaVersion, sbtVersion,
baseDirectory, datasetDir),
baseDirectory, datasetDir, pythonizedVersion),
parallelExecution in Test := false,
test in assembly := {},
assemblyMergeStrategy in assembly := {
Expand Down
Expand Up @@ -59,6 +59,9 @@ object Config {
|CNTK library, images, and text.
|"\""
|
|__version__ = "${BuildInfo.pythonizedVersion}"
|__spark_package_version__ = "${BuildInfo.version}"
|
|$importString
|""".stripMargin
}
Expand Down
18 changes: 17 additions & 1 deletion src/main/python/setup.py
Expand Up @@ -3,10 +3,26 @@

import os
from setuptools import setup, find_packages
import codecs
import os.path

def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read()

def get_version(rel_path):
for line in read(rel_path).splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")


setup(
name="mmlspark",
version=os.environ["MML_PY_VERSION"],
version=get_version("mmlspark/__init__.py"),
description="Microsoft ML for Spark",
long_description="Microsoft ML for Apache Spark contains Microsoft's open source " +
"contributions to the Apache Spark ecosystem",
Expand Down
3 changes: 2 additions & 1 deletion src/test/python/mmlsparktest/spark.py
Expand Up @@ -3,11 +3,12 @@

from pyspark.sql import SparkSession, SQLContext
import os
import mmlspark

spark = SparkSession.builder \
.master("local[*]") \
.appName("PysparkTests") \
.config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:" + os.environ["MML_VERSION"]) \
.config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:" + mmlspark.__spark_package_version__) \
.config("spark.executor.heartbeatInterval", "60s") \
.config("spark.sql.shuffle.partitions", 10) \
.config("spark.sql.crossJoin.enabled", "true") \
Expand Down

0 comments on commit 0319650

Please sign in to comment.