diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 216e886180ca..93e839933c7a 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -643,7 +643,7 @@ unittest_ubuntu_cpu_scala() { unittest_ubuntu_gpu_scala() { set -ex - make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 + make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 SCALA_ON_GPU=1 make scalatest USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 SCALA_TEST_ON_GPU=1 USE_DIST_KVSTORE=1 } diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 361bfab5d611..3b1b051f60b1 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -65,6 +65,26 @@ org.scalastyle scalastyle-maven-plugin + + org.scalastyle + scalastyle-maven-plugin + + + net.alchim31.maven + scala-maven-plugin + 3.3.2 + + + + + package + attach-javadocs + + doc-jar + + + + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 49f4d35136f8..c2de6ea43f2c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -28,10 +28,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.ref.WeakReference /** - * NDArray API of mxnet - */ + * NDArray Object extends from NDArrayBase for abstract function signatures + * Main code will be generated during compile time through Macros + */ @AddNDArrayFunctions(false) -object NDArray { +object NDArray extends NDArrayBase { implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0) private val logger = LoggerFactory.getLogger(classOf[NDArray]) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index a17fe57dde65..194d3681523f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -822,8 +822,12 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD } } +/** + * Symbol Object extends from SymbolBase for abstract function signatures + * Main code will be generated during compile time through Macros + */ @AddSymbolFunctions(false) -object Symbol { +object Symbol extends SymbolBase { private type SymbolCreateNamedFunc = Map[String, Any] => Symbol private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml index 13d3cc1387e0..208d19ee9ce8 100644 --- a/scala-package/infer/pom.xml +++ b/scala-package/infer/pom.xml @@ -65,6 +65,22 @@ org.scalastyle scalastyle-maven-plugin + + net.alchim31.maven + scala-maven-plugin + 3.3.2 + + + + + package + attach-javadocs + + doc-jar + + + + diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 3bbc7fd6a90b..9a8ec645f272 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -19,8 +19,11 @@ package org.apache.mxnet import org.apache.mxnet.init.Base._ import org.apache.mxnet.utils.CToScalaUtils +import java.io._ +import java.security.MessageDigest import scala.collection.mutable.ListBuffer +import scala.io.Source /** * This object will generate the Scala documentation of the new Scala API @@ -35,15 +38,25 @@ private[mxnet] object APIDocGenerator{ def main(args: Array[String]) : Unit = { val FILE_PATH = args(0) - absClassGen(FILE_PATH, true) - absClassGen(FILE_PATH, false) + val hashCollector = ListBuffer[String]() + hashCollector += absClassGen(FILE_PATH, true) + hashCollector += absClassGen(FILE_PATH, false) + hashCollector += nonTypeSafeClassGen(FILE_PATH, true) + hashCollector += nonTypeSafeClassGen(FILE_PATH, false) + val finalHash = hashCollector.mkString("\n") } - def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = { + def MD5Generator(input : String) : String = { + val md = MessageDigest.getInstance("MD5") + md.update(input.getBytes("UTF-8")) + val digest = md.digest() + org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) + } + + def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { // scalastyle:off val absClassFunctions = getSymbolNDArrayMethods(isSymbol) - // TODO: Add Filter to the same location in case of refactor - val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => { + val absFuncs = absClassFunctions.map(absClassFunction => { val scalaDoc = generateAPIDocFromBackend(absClassFunction) val defBody = generateAPISignature(absClassFunction, isSymbol) s"$scalaDoc\n$defBody" @@ -55,16 +68,44 @@ private[mxnet] object APIDocGenerator{ val imports = "import org.apache.mxnet.annotation.Experimental" val absClassDef = s"abstract class $packageName" val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" + val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) + pw.write(finalStr) + pw.close() + MD5Generator(finalStr) + } + + def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { + // scalastyle:off + val absClassFunctions = getSymbolNDArrayMethods(isSymbol) + val absFuncs = absClassFunctions.map(absClassFunction => { + val scalaDoc = generateAPIDocFromBackend(absClassFunction, false) + if (isSymbol) { + val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol" + s"$scalaDoc\n$defBody" + } else { + val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" + val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" + s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody" + } + }) + val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase" + val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" + val scalaStyle = "// scalastyle:off" + val packageDef = "package org.apache.mxnet" + val imports = "import org.apache.mxnet.annotation.Experimental" + val absClassDef = s"abstract class $packageName" + val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" import java.io._ val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) pw.write(finalStr) pw.close() + MD5Generator(finalStr) } // Generate ScalaDoc type - def generateAPIDocFromBackend(func : absClassFunction) : String = { + def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = { val desc = func.desc.split("\n").map({ currStr => - s" * $currStr" + s" * $currStr
" }) val params = func.listOfArgs.map({ absClassArg => val currArgName = absClassArg.argName match { @@ -75,7 +116,11 @@ private[mxnet] object APIDocGenerator{ s" * @param $currArgName\t\t${absClassArg.argDesc}" }) val returnType = s" * @return ${func.returnType}" - s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" + if (withParam) { + s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" + } else { + s" /**\n${desc.mkString("\n")}\n$returnType\n */" + } } def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { @@ -112,11 +157,12 @@ private[mxnet] object APIDocGenerator{ val returnType = if (isSymbol) "Symbol" else "NDArray" _LIB.mxListAllOpNames(opNames) // TODO: Add '_linalg_', '_sparse_', '_image_' support + // TODO: Add Filter to the same location in case of refactor opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType) - }).toList + }).toList.filterNot(_.name.startsWith("_")) } // Create an atomic symbol function by handle and function name. @@ -136,7 +182,7 @@ private[mxnet] object APIDocGenerator{ val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => - val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType) + val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType) new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) } new absClassFunction(aliasName, desc.value, argList.toList, returnType) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 082c64a609c3..644bc5c4489d 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -225,7 +225,8 @@ private[mxnet] object NDArrayMacro { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray") + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray") new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) } new NDArrayFunction(aliasName, argList.toList) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 81430c2ab263..3e790ef4126b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -230,7 +230,8 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:on println val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.Symbol") + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol") new SymbolArg(argName, typeAndOption._1, typeAndOption._2) } new SymbolFunction(aliasName, argList.toList) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index ca50a741012b..b07e6f97eee5 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -21,7 +21,8 @@ private[mxnet] object CToScalaUtils { // Convert C++ Types to Scala Types - def typeConversion(in : String, argType : String = "", returnType : String) : String = { + def typeConversion(in : String, argType : String = "", + argName : String, returnType : String) : String = { in match { case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType @@ -35,7 +36,7 @@ private[mxnet] object CToScalaUtils { case "boolean" | "booleanorNone" => "Boolean" case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( - s"Invalid type for args: $default, $argType") + s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") } } @@ -47,10 +48,12 @@ private[mxnet] object CToScalaUtils { * The three field shown above do not usually come at the same time * This function used the above format to determine if the argument is * optional, what is it Scala type and possibly pass in a default value + * @param argName The name of the argument * @param argType Raw arguement Type description * @return (Scala_Type, isOptional) */ - def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = { + def argumentCleaner(argName: String, + argType : String, returnType : String) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} @@ -66,9 +69,9 @@ private[mxnet] object CToScalaUtils { // arg: Type, optional, default = Null require(commaRemoved(1).equals("optional")) require(commaRemoved(2).startsWith("default=")) - (typeConversion(commaRemoved(0), argType, returnType), true) + (typeConversion(commaRemoved(0), argType, argName, returnType), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType, returnType) + val tempType = typeConversion(commaRemoved(0), argType, argName, returnType) val tempOptional = tempType.equals("org.apache.mxnet.Symbol") (tempType, tempOptional) } else { diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index 5883a00c3315..c3a7c58c1afc 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -43,7 +43,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) for (idx <- input.indices) { - val result = CToScalaUtils.argumentCleaner(input(idx), "org.apache.mxnet.Symbol") + val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol") assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) } } diff --git a/scala-package/pom.xml b/scala-package/pom.xml index cd5dba85dfd5..043aaae5e9e3 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -340,6 +340,22 @@ + + net.alchim31.maven + scala-maven-plugin + 3.3.2 + + + + + package + attach-javadocs + + doc-jar + + + +