Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-319] Javadoc fix #11143

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions 3rdparty/nnvm
Submodule nnvm added at 2bc514
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 8f80df to 620726
20 changes: 20 additions & 0 deletions scala-package/core/pom.xml
Expand Up @@ -65,6 +65,26 @@
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.3.2</version>
<configuration>
</configuration>
<executions>
<execution>
<phase>package</phase>
<id>attach-javadocs</id>
<goals>
<goal>doc-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
Expand Down
Expand Up @@ -31,7 +31,7 @@ import scala.ref.WeakReference
* NDArray API of mxnet
*/
@AddNDArrayFunctions(false)
object NDArray {
object NDArray extends NDArrayBase {
implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0)
private val logger = LoggerFactory.getLogger(classOf[NDArray])

Expand Down Expand Up @@ -64,15 +64,16 @@ object NDArray {
val function = functions(funcName)
val ndArgs = ArrayBuffer.empty[NDArray]
val posArgs = ArrayBuffer.empty[String]
args.foreach {
case arr: NDArray =>
ndArgs.append(arr)
case arrFunRet: NDArrayFuncReturn =>
arrFunRet.arr.foreach(ndArgs.append(_))
case arg =>
posArgs.append(arg.toString)
}

args.foreach {
case arr: NDArray =>
ndArgs.append(arr)
case arrFunRet: NDArrayFuncReturn =>
arrFunRet.arr.foreach(ndArgs.append(_))
case arg =>
posArgs.append(arg.toString)
}

require(posArgs.length <= function.arguments.length,
s"len(posArgs) = ${posArgs.length}, should be less or equal to len(arguments) " +
s"= ${function.arguments.length}")
Expand All @@ -81,6 +82,7 @@ object NDArray {
++ function.arguments.slice(0, posArgs.length).zip(posArgs) - "out"
).map { case (k, v) => k -> v.toString }


val (oriOutputs, outputVars) =
if (kwargs != null && kwargs.contains("out")) {
val output = kwargs("out")
Expand Down Expand Up @@ -537,6 +539,10 @@ object NDArray {
new NDArray(handleRef.value)
}

private def _crop_assign(kwargs: Map[String, Any] = null)(args: Any*) : NDArrayFuncReturn = {
genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
}

// TODO: imdecode
}

Expand Down
Expand Up @@ -823,7 +823,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
}

@AddSymbolFunctions(false)
object Symbol {
object Symbol extends SymbolBase {
private type SymbolCreateNamedFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
Expand Down
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.annotation

import java.lang.annotation.{ElementType, Retention, Target, _}

@Retention(RetentionPolicy.RUNTIME)
@Target(Array(ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE))
class Experimental {}
Expand Up @@ -37,6 +37,8 @@ private[mxnet] object APIDocGenerator{
val FILE_PATH = args(0)
absClassGen(FILE_PATH, true)
absClassGen(FILE_PATH, false)
oldAPIClassGen(FILE_PATH, true)
oldAPIClassGen(FILE_PATH, false)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = {
Expand All @@ -52,6 +54,33 @@ private[mxnet] object APIDocGenerator{
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
}

def oldAPIClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
Expand All @@ -61,7 +90,7 @@ private[mxnet] object APIDocGenerator{
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction) : String = {
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
val desc = func.desc.split("\n").map({ currStr =>
s" * $currStr"
})
Expand All @@ -74,7 +103,11 @@ private[mxnet] object APIDocGenerator{
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
if (withParam) {
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
} else {
s" /**\n${desc.mkString("\n")}\n$returnType\n */"
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
Expand All @@ -97,9 +130,11 @@ private[mxnet] object APIDocGenerator{
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}
s"def ${func.name} (${argDef.mkString(", ")}) : ${returnType}"
val experimentalTag = "@Experimental"
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}


Expand Down
Expand Up @@ -21,7 +21,7 @@ import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}

import scala.annotation.StaticAnnotation
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.experimental.macros
import scala.reflect.macros.blackbox

Expand Down Expand Up @@ -57,14 +57,13 @@ private[mxnet] object NDArrayMacro {

val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
else ndarrayFunctions.filterNot(_.name.startsWith("_"))
}

val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
val funcName = NDArrayfunction.name
val termName = TermName(funcName)
if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) {
Seq(
Seq(
// scalastyle:off
// (yizhi) We are investigating a way to make these functions type-safe
// and waiting to see the new approach is stable enough.
Expand All @@ -75,16 +74,7 @@ private[mxnet] object NDArrayMacro {
q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
} else {
// Default private
Seq(
// scalastyle:off
q"private def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
q"private def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
// scalastyle:on
)
}
}

structGeneration(c)(functionDefs, annottees : _*)
}
Expand All @@ -109,6 +99,7 @@ private[mxnet] object NDArrayMacro {
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
Expand All @@ -123,14 +114,32 @@ private[mxnet] object NDArrayMacro {
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
if (ndarrayarg.isOptional) {
base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
}
impl += base
// NDArray arg implementation
val returnType = "org.apache.mxnet.NDArray"

// TODO: Currently we do not add place holder for NDArray
// Example: an NDArray operator like the following format
// nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional)
// If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2?
// What it should be?
val base =
if (ndarrayarg.argType.equals(returnType)) {
s"args += $currArgName"
} else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
s"args ++= $currArgName"
} else {
"map(\"" + ndarrayarg.argName + "\") = " + currArgName
}
impl.append(
if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
else base
)
})
// add default out parameter
argDef += "out : Option[NDArray] = None"
impl += "if (!out.isEmpty) map(\"out\") = out.get"
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", null, map.toMap)"
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
Expand Down
Expand Up @@ -41,7 +41,7 @@ private[mxnet] object SymbolImplMacros {
impl(c)(annottees: _*)
}
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
newAPIImpl(c)(annottees: _*)
typedAPIImpl(c)(annottees: _*)
}
// scalastyle:on havetype

Expand Down Expand Up @@ -82,7 +82,7 @@ private[mxnet] object SymbolImplMacros {
/**
* Implementation for Dynamic typed API Symbol.api.<functioname>
*/
private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
Expand All @@ -104,6 +104,7 @@ private[mxnet] object SymbolImplMacros {
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "var args = Seq[org.apache.mxnet.Symbol]()"
symbolfunction.listOfArgs.foreach({ symbolarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
Expand All @@ -118,17 +119,28 @@ private[mxnet] object SymbolImplMacros {
else {
argDef += s"${currArgName} : ${symbolarg.argType}"
}
var base = "map(\"" + symbolarg.argName + "\") = " + currArgName
if (symbolarg.isOptional) {
base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
// Symbol arg implementation
val returnType = "org.apache.mxnet.Symbol"
val base =
if (symbolarg.argType.equals(s"Array[$returnType]")) {
if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = $currArgName.get.toSeq"
else s"args = $currArgName.toSeq"
} else {
if (symbolarg.isOptional) {
// scalastyle:off
s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + "\"" + s") = $currArgName.get"
// scalastyle:on
}
else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
}

impl += base
})
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// scalastyle:off
// TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)"
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.Symbol"
Expand Down