Skip to content

Commit

Permalink
[MXNET-319] Javadoc fix (apache#11239)
Browse files Browse the repository at this point in the history
* Create Interface for Symbol and NDArray APIs, enable JavaDoc jar building for Scala Package.
  • Loading branch information
lanking520 authored and nswamy committed Jul 3, 2018
1 parent 43e3703 commit b20682f
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ unittest_ubuntu_cpu_scala() {

unittest_ubuntu_gpu_scala() {
set -ex
make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1
make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 SCALA_ON_GPU=1
make scalatest USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 SCALA_TEST_ON_GPU=1 USE_DIST_KVSTORE=1
}

Expand Down
20 changes: 20 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.3.2</version>
<configuration>
</configuration>
<executions>
<execution>
<phase>package</phase>
<id>attach-javadocs</id>
<goals>
<goal>doc-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.ref.WeakReference

/**
* NDArray API of mxnet
*/
* NDArray Object extends from NDArrayBase for abstract function signatures
* Main code will be generated during compile time through Macros
*/
@AddNDArrayFunctions(false)
object NDArray {
object NDArray extends NDArrayBase {
implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0)
private val logger = LoggerFactory.getLogger(classOf[NDArray])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,12 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
}
}

/**
* Symbol Object extends from SymbolBase for abstract function signatures
* Main code will be generated during compile time through Macros
*/
@AddSymbolFunctions(false)
object Symbol {
object Symbol extends SymbolBase {
private type SymbolCreateNamedFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
Expand Down
16 changes: 16 additions & 0 deletions scala-package/infer/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.3.2</version>
<configuration>
</configuration>
<executions>
<execution>
<phase>package</phase>
<id>attach-javadocs</id>
<goals>
<goal>doc-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest

import scala.collection.mutable.ListBuffer
import scala.io.Source

/**
* This object will generate the Scala documentation of the new Scala API
Expand All @@ -35,15 +38,25 @@ private[mxnet] object APIDocGenerator{

def main(args: Array[String]) : Unit = {
val FILE_PATH = args(0)
absClassGen(FILE_PATH, true)
absClassGen(FILE_PATH, false)
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = {
def MD5Generator(input : String) : String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => {
val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
Expand All @@ -55,16 +68,44 @@ private[mxnet] object APIDocGenerator{
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction) : String = {
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
val desc = func.desc.split("\n").map({ currStr =>
s" * $currStr"
s" * $currStr<br>"
})
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
Expand All @@ -75,7 +116,11 @@ private[mxnet] object APIDocGenerator{
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
if (withParam) {
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
} else {
s" /**\n${desc.mkString("\n")}\n$returnType\n */"
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
Expand Down Expand Up @@ -112,11 +157,12 @@ private[mxnet] object APIDocGenerator{
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList
}).toList.filterNot(_.name.startsWith("_"))
}

// Create an atomic symbol function by handle and function name.
Expand All @@ -136,7 +182,7 @@ private[mxnet] object APIDocGenerator{
val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"

val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType)
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ private[mxnet] object NDArrayMacro {
}
// scalastyle:on println
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray")
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ private[mxnet] object SymbolImplMacros {
}
// scalastyle:on println
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.Symbol")
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol")
new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
}
new SymbolFunction(aliasName, argList.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ private[mxnet] object CToScalaUtils {


// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", returnType : String) : String = {
def typeConversion(in : String, argType : String = "",
argName : String, returnType : String) : String = {
in match {
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
Expand All @@ -35,7 +36,7 @@ private[mxnet] object CToScalaUtils {
case "boolean" | "booleanorNone" => "Boolean"
case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default, $argType")
s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
}
}

Expand All @@ -47,10 +48,12 @@ private[mxnet] object CToScalaUtils {
* The three field shown above do not usually come at the same time
* This function used the above format to determine if the argument is
* optional, what is it Scala type and possibly pass in a default value
* @param argName The name of the argument
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = {
def argumentCleaner(argName: String,
argType : String, returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
Expand All @@ -66,9 +69,9 @@ private[mxnet] object CToScalaUtils {
// arg: Type, optional, default = Null
require(commaRemoved(1).equals("optional"))
require(commaRemoved(2).startsWith("default="))
(typeConversion(commaRemoved(0), argType, returnType), true)
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, returnType)
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)

for (idx <- input.indices) {
val result = CToScalaUtils.argumentCleaner(input(idx), "org.apache.mxnet.Symbol")
val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol")
assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
}
}
Expand Down
16 changes: 16 additions & 0 deletions scala-package/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,22 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.3.2</version>
<configuration>
</configuration>
<executions>
<execution>
<phase>package</phase>
<id>attach-javadocs</id>
<goals>
<goal>doc-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
Expand Down

0 comments on commit b20682f

Please sign in to comment.