Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Get e2e tests working
  • Loading branch information
mhamilton723 committed Jul 5, 2019
1 parent 07316a8 commit 7c5e7b6
Show file tree
Hide file tree
Showing 32 changed files with 197 additions and 158 deletions.
91 changes: 64 additions & 27 deletions build.sbt
Expand Up @@ -5,23 +5,9 @@ import org.apache.commons.io.FileUtils

import scala.sys.process.Process

def getVersion(baseVersion: String): String = {
sys.env.get("MMLSPARK_RELEASE").map(_ =>baseVersion)
.orElse(sys.env.get("BUILD_NUMBER").map(bn => baseVersion + s"_$bn"))
.getOrElse(baseVersion + "-SNAPSHOT")
}

def getPythonVersion(baseVersion: String): String = {
sys.env.get("MMLSPARK_RELEASE").map(_ =>baseVersion)
.orElse(sys.env.get("BUILD_NUMBER").map(bn => baseVersion + s".dev$bn"))
.getOrElse(baseVersion + ".dev1")
}

val baseVersion = "0.17.1"
val condaEnvName = "mmlspark"
name := "mmlspark"
organization := "com.microsoft.ml.spark"
version := getVersion(baseVersion)
scalaVersion := "2.11.12"

val sparkVersion = "2.4.0"
Expand Down Expand Up @@ -53,11 +39,11 @@ createCondaEnvTask := {
val s = streams.value
val hasEnv = Process("conda env list").lineStream.toList
.map(_.split("\\s+").head).contains(condaEnvName)
if (!hasEnv){
if (!hasEnv) {
Process(
"conda env create -f environment.yaml",
new File(".")) ! s.log
} else{
} else {
println("Found conda env " + condaEnvName)
}
}
Expand All @@ -70,10 +56,18 @@ cleanCondaEnvTask := {
new File(".")) ! s.log
}

def osPrefix: Seq[String] = {
if (sys.props("os.name").toLowerCase.contains("windows")) {
Seq("cmd", "/C")
} else {
Seq()
}
}

def activateCondaEnv: Seq[String] = {
if(sys.props("os.name").toLowerCase.contains("windows")){
Seq("cmd", "/C", "activate", condaEnvName, "&&")
}else{
if (sys.props("os.name").toLowerCase.contains("windows")) {
osPrefix ++ Seq("activate", condaEnvName, "&&")
} else {
Seq()
//TODO figure out why this doesent work
//Seq("/bin/bash", "-l", "-c", "source activate " + condaEnvName, "&&")
Expand All @@ -86,15 +80,27 @@ val pythonSrcDir = join(genDir.toString, "src", "python")
val pythonPackageDir = join(genDir.toString, "package", "python")
val pythonTestDir = join(genDir.toString, "test", "python")

def pythonizeVersion(v: String): String = {
if (v.contains("+")){
v.split("+".head).head + ".dev1"
}else{
v
}
}

packagePythonTask := {
val s = streams.value
(run in IntegrationTest2).toTask("").value
createCondaEnvTask.value
val destPyDir = join("target", "scala-2.11", "classes", "mmlspark")
if (destPyDir.exists()) FileUtils.forceDelete(destPyDir)
FileUtils.copyDirectory(join(pythonSrcDir.getAbsolutePath, "mmlspark"), destPyDir)

Process(
activateCondaEnv ++
Seq(s"python", "setup.py", "bdist_wheel", "--universal", "-d", s"${pythonPackageDir.absolutePath}"),
pythonSrcDir,
"MML_PY_VERSION" -> getPythonVersion(baseVersion)) ! s.log
"MML_PY_VERSION" -> pythonizeVersion(version.value)) ! s.log
}

val installPipPackageTask = TaskKey[Unit]("installPipPackage", "install python sdk")
Expand All @@ -105,7 +111,7 @@ installPipPackageTask := {
packagePythonTask.value
Process(
activateCondaEnv ++ Seq("pip", "install",
s"mmlspark-${getPythonVersion(baseVersion)}-py2.py3-none-any.whl"),
s"mmlspark-${pythonizeVersion(version.value)}-py2.py3-none-any.whl"),
pythonPackageDir) ! s.log
}

Expand All @@ -117,7 +123,7 @@ testPythonTask := {
Process(
activateCondaEnv ++ Seq("python", "tools2/run_all_tests.py"),
new File("."),
"MML_VERSION" -> getVersion(baseVersion)
"MML_VERSION" -> version.value
) ! s.log
}

Expand Down Expand Up @@ -147,10 +153,36 @@ setupTask := {
getDatasetsTask.value
}

val publishBlob = TaskKey[Unit]("publishBlob", "publish the library to mmlspark blob")
publishBlob := {
val s = streams.value
publishM2.value
val scalaVersionSuffix = scalaVersion.value.split(".".toCharArray.head).dropRight(1).mkString(".")
val nameAndScalaVersion = s"${name.value}_$scalaVersionSuffix"

val localPackageFolder = join(
Seq(new File(new URI(Resolver.mavenLocal.root)).getAbsolutePath)
++ organization.value.split(".".toCharArray.head)
++ Seq(nameAndScalaVersion, version.value): _*).toString

val blobMavenFolder = organization.value.replace(".", "/") +
s"/$nameAndScalaVersion/${version.value}"
val command = Seq("az", "storage", "blob", "upload-batch",
"--source", localPackageFolder,
"--destination", "maven",
"--destination-path", blobMavenFolder,
"--account-name", "mmlspark",
"--account-key", Secrets.storageKey)
println(command.mkString(" "))
Process(osPrefix ++ command) ! s.log
}

val settings = Seq(
(scalastyleConfig in Test) := baseDirectory.value / "scalastyle-test-config.xml",
logBuffered in Test := false,
buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion, baseDirectory, datasetDir),
buildInfoKeys := Seq[BuildInfoKey](
name, version, scalaVersion, sbtVersion,
baseDirectory, datasetDir),
parallelExecution in Test := false,
buildInfoPackage := "com.microsoft.ml.spark.build") ++
inConfig(IntegrationTest2)(Defaults.testSettings)
Expand Down Expand Up @@ -180,20 +212,25 @@ credentials += Credentials("Sonatype Nexus Repository Manager",
pgpPassphrase := Some(Secrets.pgpPassword.toCharArray)
pgpSecretRing := {
val temp = File.createTempFile("secret", ".asc")
new PrintWriter(temp) { write(Secrets.pgpPrivate); close() }
new PrintWriter(temp) {
write(Secrets.pgpPrivate); close()
}
temp
}
pgpPublicRing := {
val temp = File.createTempFile("public", ".asc")
new PrintWriter(temp) { write(Secrets.pgpPublic); close() }
new PrintWriter(temp) {
write(Secrets.pgpPublic); close()
}
temp
}

licenses += ("MIT", url("https://github.com/Azure/mmlspark/blob/master/LICENSE"))
publishMavenStyle := true
publishTo := Some(
if (isSnapshot.value)
if (isSnapshot.value) {
Opts.resolver.sonatypeSnapshots
else
} else {
Opts.resolver.sonatypeStaging
}
)
11 changes: 5 additions & 6 deletions notebooks/samples/AzureSearchIndex - Met Artworks.ipynb
Expand Up @@ -22,12 +22,7 @@
},
"outputs": [],
"source": [
"import numpy as np, pandas as pd, os, sys, time, json, requests\n",
"\n",
"from mmlspark import *\n",
"from pyspark.ml.classification import LogisticRegression\n",
"from pyspark.sql.functions import udf, col\n",
"from pyspark.sql.types import IntegerType, StringType, DoubleType, StructType, StructField, ArrayType\n",
"import os, sys, time, json, requests\n",
"from pyspark.ml import Transformer, Estimator, Pipeline\n",
"from pyspark.ml.feature import SQLTransformer\n",
"from pyspark.sql.functions import lit, udf, col, split"
Expand Down Expand Up @@ -80,6 +75,9 @@
},
"outputs": [],
"source": [
"from mmlspark.cognitive import DescribeImage\n",
"from mmlspark.stages import SelectColumns\n",
"\n",
"#define pipeline\n",
"describeImage = DescribeImage()\\\n",
" .setSubscriptionKey(VISION_API_KEY)\\\n",
Expand Down Expand Up @@ -191,6 +189,7 @@
},
"outputs": [],
"source": [
"from mmlspark.io.http import *\n",
"data_processed.writeToAzureSearch(options)"
]
},
Expand Down
9 changes: 3 additions & 6 deletions notebooks/samples/Classification - Adult Census.ipynb
Expand Up @@ -18,10 +18,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import mmlspark\n",
"\n",
"# help(mmlspark)"
"import pandas as pd"
]
},
{
Expand Down Expand Up @@ -65,7 +62,7 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import TrainClassifier\n",
"from mmlspark.train import TrainClassifier\n",
"from pyspark.ml.classification import LogisticRegression\n",
"model = TrainClassifier(model=LogisticRegression(), labelCol=\" income\", numFeatures=256).fit(train)\n",
"model.write().overwrite().save(\"adultCensusIncomeModel.mml\")"
Expand All @@ -84,7 +81,7 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import ComputeModelStatistics, TrainedClassifierModel\n",
"from mmlspark.train import ComputeModelStatistics, TrainedClassifierModel\n",
"predictionModel = TrainedClassifierModel.load(\"adultCensusIncomeModel.mml\")\n",
"prediction = predictionModel.transform(test)\n",
"metrics = ComputeModelStatistics().transform(prediction)\n",
Expand Down
Expand Up @@ -42,13 +42,13 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"import mmlspark\n",
"from pyspark.sql.types import IntegerType, StringType, StructType, StructField\n",
"import os, urllib\n",
"\n",
"dataFilePath = \"BookReviewsFromAmazon10K.tsv\"\n",
"textSchema = StructType([StructField(\"rating\", IntegerType(), False),\n",
" StructField(\"text\", StringType(), False)])\n",
"import os, urllib\n",
"\n",
"if not os.path.isfile(dataFilePath):\n",
" urllib.request.urlretrieve(\"https://mmlspark.azureedge.net/datasets/\" + dataFilePath, dataFilePath)\n",
"rawData = spark.createDataFrame(pd.read_csv(dataFilePath, sep=\"\\t\", header=None), textSchema)\n",
Expand Down Expand Up @@ -92,7 +92,7 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import UDFTransformer\n",
"from mmlspark.stages import UDFTransformer\n",
"wordLength = \"wordLength\"\n",
"wordCount = \"wordCount\"\n",
"wordLengthTransformer = UDFTransformer(inputCol=\"text\", outputCol=wordLength, udf=wordLengthUDF)\n",
Expand Down Expand Up @@ -208,7 +208,7 @@
"bestModel = models[metrics.index(bestMetric)]\n",
"\n",
"# Save model\n",
"bestModel.write().overwrite().save(\"SparkMLExperiment.mmls\")\n",
"bestModel.write().overwrite().save(\"SparkMLExperiment.mml\")\n",
"# Get AUC on the validation dataset\n",
"scoredVal = bestModel.transform(validation)\n",
"print(evaluator.evaluate(scoredVal))"
Expand Down Expand Up @@ -241,7 +241,8 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import TrainClassifier, FindBestModel, ComputeModelStatistics\n",
"from mmlspark.train import TrainClassifier, ComputeModelStatistics\n",
"from mmlspark.automl import FindBestModel\n",
"\n",
"# Prepare data for learning\n",
"train, test, validation = data.randomSplit([0.60, 0.20, 0.20], seed=123)\n",
Expand All @@ -257,7 +258,7 @@
"bestModel = FindBestModel(evaluationMetric=\"AUC\", models=lrmodels).fit(test)\n",
"\n",
"# Save model\n",
"bestModel.write().overwrite().save(\"MMLSExperiment.mmls\")\n",
"bestModel.write().overwrite().save(\"MMLSExperiment.mml\")\n",
"# Get AUC on the validation dataset\n",
"predictions = bestModel.transform(validation)\n",
"metrics = ComputeModelStatistics().transform(predictions)\n",
Expand Down
Expand Up @@ -22,7 +22,7 @@
},
"outputs": [],
"source": [
"from mmlspark import *\n",
"from mmlspark.cognitive import *\n",
"from pyspark.ml import PipelineModel\n",
"from pyspark.sql.functions import col, udf\n",
"from pyspark.ml.feature import SQLTransformer\n",
Expand Down Expand Up @@ -115,6 +115,8 @@
},
"outputs": [],
"source": [
"from mmlspark.stages import UDFTransformer \n",
"\n",
"recognizeText = RecognizeText()\\\n",
" .setSubscriptionKey(VISION_API_KEY)\\\n",
" .setUrl(\"https://eastus.api.cognitive.microsoft.com/vision/v2.0/recognizeText\")\\\n",
Expand Down Expand Up @@ -175,6 +177,7 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark.stages import SelectColumns\n",
"# Select the final coulmns\n",
"cleanupColumns = SelectColumns().setCols([\"url\", \"firstCeleb\", \"text\", \"sentimentScore\"])\n",
"\n",
Expand Down
Expand Up @@ -31,14 +31,14 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import CNTKModel, ModelDownloader\n",
"from mmlspark.cntk import CNTKModel\n",
"from mmlspark.downloader import ModelDownloader\n",
"from pyspark.sql.functions import udf, col\n",
"from pyspark.sql.types import IntegerType, ArrayType, FloatType, StringType\n",
"from pyspark.sql import Row\n",
"\n",
"from os.path import abspath, join\n",
"import numpy as np\n",
"import pickle\n",
"from nltk.tokenize import sent_tokenize, word_tokenize\n",
"import os, tarfile, pickle\n",
"import urllib.request\n",
Expand Down
Expand Up @@ -13,7 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"from mmlspark import CNTKModel, ModelDownloader\n",
"from mmlspark.cntk import CNTKModel\n",
"from mmlspark.downloader import ModelDownloader\n",
"from pyspark.sql.functions import udf\n",
"from pyspark.sql.types import IntegerType\n",
"from os.path import abspath"
Expand Down Expand Up @@ -104,7 +105,6 @@
"metadata": {},
"outputs": [],
"source": [
"import array\n",
"from pyspark.sql.functions import col\n",
"from pyspark.sql.types import *\n",
"\n",
Expand Down

0 comments on commit 7c5e7b6

Please sign in to comment.