Skip to content

Commit

Permalink
Get python codegen to work
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jul 5, 2019
1 parent 90089fa commit 987c7c4
Show file tree
Hide file tree
Showing 44 changed files with 426 additions and 317 deletions.
54 changes: 52 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import scala.sys.process.Process

name := "mmlspark"
organization := "com.microsoft.ml.spark"
version := "0.17.1"
scalaVersion := "2.11.12"

Expand All @@ -22,10 +25,57 @@ libraryDependencies ++= Seq(

lazy val IntegrationTest2 = config("it").extend(Test)

val settings = Seq((scalastyleConfig in Test) := baseDirectory.value / "scalastyle-test-config.xml") ++
inConfig(IntegrationTest2)(Defaults.testSettings)
lazy val CodeGen = config("codegen").extend(Test)

val settings = Seq(
(scalastyleConfig in Test) := baseDirectory.value / "scalastyle-test-config.xml",
buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion, baseDirectory),
buildInfoPackage := "com.microsoft.ml.spark.build") ++
inConfig(IntegrationTest2)(Defaults.testSettings) ++
inConfig(CodeGen)(Defaults.testSettings)

lazy val mmlspark = (project in file("."))
.configs(IntegrationTest2)
.configs(CodeGen)
.enablePlugins(BuildInfoPlugin)
.enablePlugins(ScalaUnidocPlugin)
.settings(settings: _*)

def join(folders: String*): File = {
folders.tail.foldLeft(new File(folders.head)) { case (f, s) => new File(f, s) }
}

val packagePythonTask = TaskKey[Unit]("packagePython", "Package python sdk")
val genDir = join("target", "scala-2.11", "generated")
val pythonSrcDir = join(genDir.toString, "src", "python")
val pythonPackageDir = join(genDir.toString, "package", "python")
val pythonTestDir = join(genDir.toString, "test", "python")

packagePythonTask := {
val s: TaskStreams = streams.value
(run in CodeGen).toTask("").value
Process(
s"python setup.py bdist_wheel --universal -d ${pythonPackageDir.absolutePath}",
pythonSrcDir,
"MML_VERSION" -> version.value) ! s.log
}

val installPipPackageTask = TaskKey[Unit]("installPipPackage", "test python sdk")

installPipPackageTask := {
val s: TaskStreams = streams.value
packagePythonTask.value
Process(
Seq("python", "-m","wheel","install", s"mmlspark-${version.value}-py2.py3-none-any.whl", "--force"),
pythonPackageDir) ! s.log
}

val testPythonTask = TaskKey[Unit]("testPython", "test python sdk")

testPythonTask := {
val s: TaskStreams = streams.value
installPipPackageTask.value
Process(
Seq("python", "-m","unittest","discover"),
join(pythonTestDir.toString, "mmlspark")) ! s.log
}
File renamed without changes.
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.2")
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.9.0")
82 changes: 82 additions & 0 deletions src/codegen/scala/com/microsoft/ml/spark/codegen/CodeGen.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.codegen

import java.io.File

import com.microsoft.ml.spark.codegen.Config._
import com.microsoft.ml.spark.codegen.DocGen._
import com.microsoft.ml.spark.codegen.WrapperClassDoc._
import com.microsoft.ml.spark.core.env.FileUtilities._
import org.apache.commons.io.FileUtils
import org.apache.commons.io.FilenameUtils._

object CodeGen {

def generateArtifacts(): Unit = {
println(
s"""|Running code generation with config:
| topDir: $topDir
| packageDir: $packageDir
| pySrcDir: $pySrcDir
| pyTestDir: $pyTestDir
| pyDocDir: $pyDocDir
| rsrcDir: $rSrcDir
| tmpDocDir: $tmpDocDir """.stripMargin)

println("Creating temp folders")
if (generatedDir.exists()) FileUtils.forceDelete(generatedDir)

println("Generating python APIs")
PySparkWrapperGenerator()
println("Generating R APIs")
SparklyRWrapperGenerator(version)
println("Generating .rst files for the Python APIs documentation")
genRstFiles()

def toDir(f: File): File = new File(f, File.separator)

//writeFile(new File(pySrcDir, "__init__.py"), packageHelp(""))
FileUtils.copyDirectoryToDirectory(toDir(pySrcOverrideDir), toDir(pySrcDir))
FileUtils.copyDirectoryToDirectory(toDir(pyTestOverrideDir), toDir(pyTestDir))
makeInitFiles()

// build init file
// package python+r zip files
// zipFolder(pyDir, pyZipFile)
rPackageDir.mkdirs()
zipFolder(rSrcDir, new File(rPackageDir, s"mmlspark-$version.zip"))

//FileUtils.forceDelete(rDir)
// leave the python source files, so they will be included in the super-jar
// FileUtils.forceDelete(pyDir)
// delete the text files with the Python Class descriptions - truly temporary
// FileUtils.forceDelete(tmpDocDir)
}

private def allTopLevelFiles(dir: File, pred: (File => Boolean) = null): Array[File] = {
def loop(dir: File): Array[File] = {
val (dirs, files) = dir.listFiles.sorted.partition(_.isDirectory)
if (pred == null) files else files.filter(pred)
}
loop(dir)
}

private def makeInitFiles(packageFolder: String = ""): Unit = {
val dir = new File(new File(pySrcDir,"mmlspark"), packageFolder)
val packageString = if (packageFolder != "") packageFolder.replace("/",".") else ""
val importStrings =
allTopLevelFiles(dir, f => "^[a-zA-Z]\\w*[.]py$".r.findFirstIn(f.getName).isDefined)
.map(f => s"from mmlspark$packageString.${getBaseName(f.getName)} import *\n").mkString("")
writeFile(new File(dir, "__init__.py"), packageHelp(importStrings))
dir.listFiles().filter(_.isDirectory).foreach(f =>
makeInitFiles(packageFolder +"/" + f.getName)
)
}

def main(args: Array[String]): Unit = {
generateArtifacts()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.codegen

import com.microsoft.ml.spark.core.env.FileUtilities._
import com.microsoft.ml.spark.build.BuildInfo

object Config {
val debugMode = sys.env.getOrElse("DEBUGMODE", "").trim.toLowerCase == "true"

val topDir = BuildInfo.baseDirectory
val version = BuildInfo.version
val targetDir = new File(topDir, "target/scala-2.11")
val scalaSrcDir = "src/main/scala"

val generatedDir = new File(targetDir, "generated")
val packageDir = new File(generatedDir, "package")
val srcDir = new File(generatedDir, "src")
val testDir = new File(generatedDir, "test")
val docDir = new File(generatedDir, "doc")

//Python Codegen Constant
val pySrcDir = new File(srcDir, "python")
val pyPackageDir = new File(packageDir, "python")
val pyTestDir = new File(testDir, "python")
val pyDocDir = new File(docDir, "python")
val pySrcOverrideDir = new File(topDir, "src/main/python")
val pyTestOverrideDir = new File(topDir, "src/test/python")
val tmpDocDir = new File(pyDocDir, "tmpDoc")

//R Codegen Constants
val rSrcDir = new File(srcDir, "R")
val sparklyRNamespacePath = new File(rSrcDir, "NAMESPACE")
val rPackageDir = new File(packageDir, "R")
val rTestDir = new File(testDir, "R")
val rSrcOverrideDir = new File(topDir, "src/main/R")
//val rPackageFile = new File(rPackageDir, s"mmlspark-$mmlVer.zip")

val internalPrefix = "_"
val scopeDepth = " " * 4

val copyrightLines =
s"""|# Copyright (C) Microsoft Corporation. All rights reserved.
|# Licensed under the MIT License. See LICENSE in project root for information.
|""".stripMargin

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ object DocGen {
// Generate a modules.rst file that lists all the .py files to be included in API documentation
// Find the files to use: Must start with upper case letter, end in .py
val pattern = "^[A-Z]\\w*[.]py$".r
val moduleString = allFiles(pyDir, f => pattern.findFirstIn(f.getName).isDefined)
val moduleString = allFiles(pySrcDir, f => pattern.findFirstIn(f.getName).isDefined)
.map(f => s" ${getBaseName(f.getName)}\n").mkString("")
pyDocDir.mkdirs()
writeFile(new File(pyDocDir, "modules.rst"), rstFileLines(moduleString))

// Generate .rst file for each PySpark wrapper - for documentation generation
allFiles(pyDir, f => pattern.findFirstIn(f.getName).isDefined)
allFiles(pySrcDir, f => pattern.findFirstIn(f.getName).isDefined)
.foreach{x => writeFile(new File(pyDocDir, getBaseName(x.getName) + ".rst"),
contentsString(getBaseName(x.getName)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package com.microsoft.ml.spark.codegen

import java.io.File

import scala.collection.mutable.ListBuffer
import org.apache.commons.lang3.StringUtils
import org.apache.spark.ml.{Estimator, Transformer}
Expand All @@ -23,8 +25,8 @@ abstract class PySparkParamsWrapper(entryPoint: Params,

private val additionalImports = Map(
("complexTypes",
s"from ${pyDir.getName}.TypeConversionUtils import generateTypeConverter, complexTypeConverter"),
("utils", s"from ${pyDir.getName}.Utils import *")
s"from mmlspark.core.schema.TypeConversionUtils import generateTypeConverter, complexTypeConverter"),
("utils", s"from mmlspark.core.schema.Utils import *")
)

val importClassString = ""
Expand Down Expand Up @@ -385,7 +387,13 @@ abstract class PySparkParamsWrapper(entryPoint: Params,
}

def writeWrapperToFile(dir: File): Unit = {
writeFile(new File(dir, entryPointName + ".py"), pysparkWrapperBuilder())
val packageDir = entryPointQualifiedName
.replace("com.microsoft.ml.spark","")
.split(".".toCharArray.head).dropRight(1)
.foldLeft(dir){ case (base, folder) => new File(base, folder)}
packageDir.mkdirs()
new File(packageDir, "__init__.py").createNewFile()
writeFile(new File(packageDir, entryPointName + ".py"), pysparkWrapperBuilder())
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package com.microsoft.ml.spark.codegen

import java.io.File

import org.apache.commons.lang3.StringUtils
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.PipelineStage
Expand All @@ -22,8 +24,17 @@ abstract class PySparkWrapperParamsTest(entryPoint: Params,
// general classes are imported from the mmlspark directy;
// internal classes have to be imported from their packages
private def importClass(entryPointName:String):String = {
if (entryPointName startsWith internalPrefix) s"from mmlspark.$entryPointName import $entryPointName"
else s"from mmlspark import $entryPointName"
val packageString = if (subPackages.isEmpty) {
"mmlspark"
}else{
"mmlspark." + subPackages.mkString(".")
}

if (entryPointName startsWith internalPrefix) {
s"from $packageString.$entryPointName import $entryPointName"
} else {
s"from $packageString import $entryPointName"
}
}

protected def classTemplate(classParams: String, paramGettersAndSetters: String) =
Expand All @@ -32,15 +43,22 @@ abstract class PySparkWrapperParamsTest(entryPoint: Params,
|import numpy as np
|import pyspark.ml, pyspark.ml.feature
|from pyspark import SparkContext
|from pyspark.sql import SQLContext
|from pyspark.sql import SQLContext, SparkSession
|from pyspark.ml.classification import LogisticRegression
|from pyspark.ml.regression import LinearRegression
|${importClass(entryPointName)}
|from pyspark.ml.feature import Tokenizer
|from mmlspark import TrainClassifier
|from mmlspark import ValueIndexer
|from mmlspark.train import TrainClassifier
|from mmlspark.featurize import ValueIndexer
|
|print("HEREEEEEEEE")
|spark = SparkSession.builder \\
| .master("local[*]") \\
| .appName("$entryPointName") \\
| .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark:$version") \\
| .getOrCreate()
|
|sc = SparkContext()
|sc = spark.sparkContext
|
|class ${entryPointName}Test(unittest.TestCase):
| def test_placeholder(self):
Expand Down Expand Up @@ -151,7 +169,8 @@ abstract class PySparkWrapperParamsTest(entryPoint: Params,
| self.assertNotEqual(bestModel, None)
|""".stripMargin

// These com.microsoft.ml.spark.core.serialize.params are need custom handling. For now, just skip them so we have tests that pass.
// These com.microsoft.ml.spark.core.serialize.params are need custom handling.
// For now, just skip them so we have tests that pass.
private lazy val skippedParams = Set[String]("models", "model", "cntkModel", "stage")
protected def isSkippedParam(paramName: String): Boolean = skippedParams.contains(paramName)
protected def isModel(paramName: String): Boolean = paramName.toLowerCase() == "model"
Expand Down Expand Up @@ -250,8 +269,15 @@ abstract class PySparkWrapperParamsTest(entryPoint: Params,
copyrightLines + getPysparkWrapperTestBase
}

private val subPackages = entryPointQualifiedName
.replace("com.microsoft.ml.spark.","")
.split(".".toCharArray.head).dropRight(1)

def writeWrapperToFile(dir: File): Unit = {
writeFile(new File(dir, entryPointName + "_tests.py"), pysparkWrapperTestBuilder())
val packageDir = subPackages.foldLeft(dir){ case (base, folder) => new File(base, folder)}
packageDir.mkdirs()
new File(packageDir, "__init__.py").createNewFile()
writeFile(new File(packageDir,"test_" + entryPointName + ".py"), pysparkWrapperTestBuilder())
}

}
Expand Down
Loading

0 comments on commit 987c7c4

Please sign in to comment.