From cdbcd10185fffba4fd9f7eb0adef47811808d347 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 28 Jan 2019 15:24:01 +0800 Subject: [PATCH] [FLINK-11421] Providing more compilation options for code-generated operators (changes for stream jobs) --- .../flink/api/common/CompilationOption.java | 41 +++ .../flink/api/common/ExecutionConfig.java | 19 ++ .../flink/configuration/CoreOptions.java | 6 + .../apache/flink/table/codegen/Compiler.scala | 13 +- .../runtime/CRowOutputProcessRunner.scala | 11 +- .../table/runtime/CRowProcessRunner.scala | 11 +- .../flink/table/runtime/FlatMapRunner.scala | 8 +- .../table/runtime/stream/sql/SortITCase.scala | 39 ++- .../api/functions/CodeGenFunction.java | 37 +++ .../streaming/api/graph/StreamGraph.java | 8 + .../streaming/runtime/tasks/StreamTask.java | 24 ++ .../flink/streaming/util/JCACompiler.java | 272 ++++++++++++++++++ 12 files changed, 474 insertions(+), 15 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/CompilationOption.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/CodeGenFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/util/JCACompiler.java diff --git a/flink-core/src/main/java/org/apache/flink/api/common/CompilationOption.java b/flink-core/src/main/java/org/apache/flink/api/common/CompilationOption.java new file mode 100644 index 0000000000000..2bd897a2da1b6 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/CompilationOption.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common; + +/** + * The option for the compiler used for compiling generated Java code. + */ +public enum CompilationOption { + + /** + * Compiling Java code by Janino. + * The compilation is fast, but the generated binary code is of low quality. + */ + FAST, + + /** + * Compiling Java code by Java Compiler API (JCA) + * The compilation is slow, but the generated binary code is of high quality. + */ + SLOW; + + public static boolean inTest = false; + + public static CompilationOption currentOption = FAST; +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java index d288810c094d2..5cedd15454ce7 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java @@ -162,6 +162,9 @@ public class ExecutionConfig implements Serializable, Archiveable CODEGEN_COMPILATION_OPTION = ConfigOptions + .key("codegen.compilation.option") + .defaultValue("fast") + .withDescription("A string indicating the option used for compiling generated code. 'fast' means compiling with Janino, " + + "whereas 'slow' means compiling with Java Compiler API (JCA)."); + public static String[] getParentFirstLoaderPatterns(Configuration config) { String base = config.getString(ALWAYS_PARENT_FIRST_LOADER_PATTERNS); String append = config.getString(ALWAYS_PARENT_FIRST_LOADER_PATTERNS_ADDITIONAL); diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala index 4fcfab0e5073d..cba48bec7fd84 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/Compiler.scala @@ -18,7 +18,9 @@ package org.apache.flink.table.codegen -import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.common.{CompilationOption, InvalidProgramException} +import org.apache.flink.api.common.functions.RuntimeContext +import org.apache.flink.streaming.util.JCACompiler import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.SimpleCompiler @@ -38,4 +40,13 @@ trait Compiler[T] { } compiler.getClassLoader.loadClass(name).asInstanceOf[Class[T]] } + + @throws(classOf[CompileException]) + def compile(ctx: RuntimeContext, name: String, code: String): Class[T] = { + if (ctx.getExecutionConfig.getCompileOption == CompilationOption.FAST) { + compile(ctx.getUserCodeClassLoader, name, code) + } else { + JCACompiler.getInstance.getCodeClass(name, code).asInstanceOf[Class[T]] + } + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala index ebef577ff8358..e43e98dcd2c79 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowOutputProcessRunner.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.functions.{CodeGenFunction, ProcessFunction} import org.apache.flink.streaming.api.operators.TimestampedCollector import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow @@ -40,14 +40,15 @@ class CRowOutputProcessRunner( extends ProcessFunction[Any, CRow] with ResultTypeQueryable[CRow] with Compiler[ProcessFunction[Any, Row]] - with Logging { + with Logging + with CodeGenFunction{ private var function: ProcessFunction[Any, Row] = _ private var cRowWrapper: CRowWrappingCollector = _ override def open(parameters: Configuration): Unit = { LOG.debug(s"Compiling ProcessFunction: $name \n\n Code:\n$code") - val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) + val clazz = compile(getRuntimeContext, name, code) LOG.debug("Instantiating ProcessFunction.") function = clazz.newInstance() FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext) @@ -75,4 +76,8 @@ class CRowOutputProcessRunner( override def close(): Unit = { FunctionUtils.closeFunction(function) } + + override def getName: String = name + + override def getCode: String = code } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala index a7f3d7287baff..d2bb4cbac3f2c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CRowProcessRunner.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.functions.{CodeGenFunction, ProcessFunction} import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.util.Logging @@ -39,14 +39,15 @@ class CRowProcessRunner( extends ProcessFunction[CRow, CRow] with ResultTypeQueryable[CRow] with Compiler[ProcessFunction[Row, Row]] - with Logging { + with Logging + with CodeGenFunction { private var function: ProcessFunction[Row, Row] = _ private var cRowWrapper: CRowWrappingCollector = _ override def open(parameters: Configuration): Unit = { LOG.debug(s"Compiling ProcessFunction: $name \n\n Code:\n$code") - val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) + val clazz = compile(getRuntimeContext, name, code) LOG.debug("Instantiating ProcessFunction.") function = clazz.newInstance() FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext) @@ -74,6 +75,10 @@ class CRowProcessRunner( override def close(): Unit = { FunctionUtils.closeFunction(function) } + + override def getName: String = name + + override def getCode: String = code } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala index 6c1f80489851e..c62ece1640743 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala @@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFuncti import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.CodeGenFunction import org.apache.flink.table.codegen.Compiler import org.apache.flink.table.util.Logging import org.apache.flink.types.Row @@ -35,7 +36,8 @@ class FlatMapRunner( extends RichFlatMapFunction[Row, Row] with ResultTypeQueryable[Row] with Compiler[FlatMapFunction[Row, Row]] - with Logging { + with Logging + with CodeGenFunction { private var function: FlatMapFunction[Row, Row] = _ @@ -56,4 +58,8 @@ class FlatMapRunner( override def close(): Unit = { FunctionUtils.closeFunction(function) } + + override def getName: String = name + + override def getCode: String = code } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SortITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SortITCase.scala index 3b08b6442d654..6a979c6354c4b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SortITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SortITCase.scala @@ -18,8 +18,10 @@ package org.apache.flink.table.runtime.stream.sql +import org.apache.flink.api.common.CompilationOption import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ +import org.apache.flink.configuration.{ConfigOptions, CoreOptions} import org.apache.flink.streaming.api.TimeCharacteristic import org.apache.flink.streaming.api.functions.sink.RichSinkFunction import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment @@ -32,10 +34,25 @@ import org.apache.flink.table.utils.MemoryTableSourceSinkUtil import org.apache.flink.types.Row import org.junit.Assert._ import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized import scala.collection.mutable -class SortITCase extends StreamingWithStateTestBase { +@RunWith(classOf[Parameterized]) +class SortITCase (compilationOption: CompilationOption) extends StreamingWithStateTestBase { + + @Before + def before(): Unit = { + CompilationOption.inTest = true + CompilationOption.currentOption = compilationOption + } + + @After + def after(): Unit = { + CompilationOption.inTest = false + CompilationOption.currentOption = CompilationOption.FAST + } @Test def testEventTimeOrderBy(): Unit = { @@ -72,7 +89,7 @@ class SortITCase extends StreamingWithStateTestBase { Right(14000L), Left((15000L, (8L, 8, "Hello World"))), Right(17000L), - Left((20000L, (20L, 20, "Hello World"))), + Left((20000L, (20L, 20, "Hello World"))), Right(19000L)) val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -80,19 +97,20 @@ class SortITCase extends StreamingWithStateTestBase { env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear + SortITCase.clear val t1 = env.addSource(new EventTimeSourceFunction[(Long, Int, String)](data)) .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) - + tEnv.registerTable("T1", t1) val sqlQuery = "SELECT b FROM T1 ORDER BY rowtime, b ASC " - - + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] result.addSink(new StringRowSelectorSink(0)).setParallelism(1) env.execute() - + val expected = mutable.MutableList( "1", "15", "16", "1", "2", "2", "3", @@ -101,7 +119,7 @@ class SortITCase extends StreamingWithStateTestBase { "5", "-1", "6", "6", "65", "67", "18", "7", "9", - "7", "17", "77", + "7", "17", "77", "18", "8", "20") @@ -150,4 +168,11 @@ object SortITCase { } var testResults: mutable.MutableList[String] = mutable.MutableList.empty[String] + + def clear = testResults.clear() + + @Parameterized.Parameters(name = "compilationOption = {0}") + def compilationConfig(): Array[CompilationOption] = { + Array(CompilationOption.FAST, CompilationOption.SLOW) + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/CodeGenFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/CodeGenFunction.java new file mode 100644 index 0000000000000..a47cc8940d44e --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/CodeGenFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions; + +/** + * A function based on generated Java code. + */ +public interface CodeGenFunction { + + /** + * Gets the class name of generated code. + * @return + */ + String getName(); + + /** + * Gets the generated code. + * @return + */ + String getCode(); +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java index f4950ecea7d86..3bd1e304b086a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java @@ -18,6 +18,7 @@ package org.apache.flink.streaming.api.graph; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.CompilationOption; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.io.InputFormat; @@ -28,6 +29,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.api.java.typeutils.MissingTypeInfo; +import org.apache.flink.configuration.CoreOptions; +import org.apache.flink.configuration.GlobalConfiguration; import org.apache.flink.optimizer.plan.StreamingPlan; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; @@ -103,6 +106,11 @@ public StreamGraph(StreamExecutionEnvironment environment) { this.executionConfig = environment.getConfig(); this.checkpointConfig = environment.getCheckpointConfig(); + // get compilation option + CompilationOption compilationOption = CompilationOption.inTest ? CompilationOption.currentOption : + CompilationOption.valueOf(GlobalConfiguration.loadConfiguration().getString(CoreOptions.CODEGEN_COMPILATION_OPTION).toUpperCase()); + this.executionConfig.setCompilationOption(compilationOption); + // create an empty new stream graph. clear(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index a38886e0e5b1a..a20400b9d87f1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -19,6 +19,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.CompilationOption; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.fs.CloseableRegistry; @@ -42,8 +43,10 @@ import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.functions.CodeGenFunction; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.OperatorSnapshotFinalizer; import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -54,6 +57,7 @@ import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer; +import org.apache.flink.streaming.util.JCACompiler; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; @@ -64,6 +68,7 @@ import java.io.Closeable; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -71,6 +76,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; /** * Base class for all streaming tasks. A task is the unit of local processing that is deployed @@ -412,6 +418,20 @@ public final boolean isCanceled() { return canceled; } + /** + * Compile all generated code in batch, to save time. + */ + private void compileInBatch() { + List codeGenFunctions = Arrays.stream(operatorChain.getAllOperators()).filter(operator -> operator instanceof AbstractUdfStreamOperator). + filter(operator -> ((AbstractUdfStreamOperator) operator).getUserFunction() instanceof CodeGenFunction). + map(operator -> (CodeGenFunction) ((AbstractUdfStreamOperator) operator).getUserFunction()).collect(Collectors.toList()); + + List names = codeGenFunctions.stream().map(function -> function.getName()).collect(Collectors.toList()); + List sources = codeGenFunctions.stream().map(function -> function.getCode()).collect(Collectors.toList()); + + JCACompiler.getInstance().compileSourceInBatch(names, sources); + } + /** * Execute {@link StreamOperator#open()} of each operator in the chain of this * {@link StreamTask}. Opening happens from tail to head operator in the chain, contrary @@ -419,6 +439,10 @@ public final boolean isCanceled() { * (see {@link #closeAllOperators()}. */ private void openAllOperators() throws Exception { + if (this.getExecutionConfig().getCompileOption() == CompilationOption.SLOW) { + compileInBatch(); + } + for (StreamOperator operator : operatorChain.getAllOperators()) { if (operator != null) { operator.open(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/JCACompiler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/JCACompiler.java new file mode 100644 index 0000000000000..67d4f4932b3ce --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/util/JCACompiler.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.util; + +import org.apache.commons.lang3.tuple.ImmutablePair; + +import javax.tools.JavaCompiler; +import javax.tools.JavaFileObject; +import javax.tools.StandardJavaFileManager; +import javax.tools.ToolProvider; + +import java.io.File; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Compiler based on Java Compiler API (JCA). + */ +public class JCACompiler { + + private static JCACompiler instance = new JCACompiler(); + + /** + * The root path of the source/binary code. + */ + private static final String codeRoot = "codegen"; + + /** + * A cache containing all compiled classes. + * The structure is: class name -> list of (code, class) + */ + private Map>>> clazzCache = new HashMap<>(); + + private JCACompiler() { + } + + public static JCACompiler getInstance() { + return instance; + } + + /** + * Retrieve the class file path, given the source file path. + * + * @param srcPath + * @return + */ + private String getClassFilePath(String srcPath) { + File srcFile = new File(srcPath); + String srcName = srcFile.getName(); + int idx = srcName.lastIndexOf('.'); + if (idx == -1) { + throw new IllegalArgumentException(srcPath + " is not a valid java source file path"); + } + String className = srcName.substring(0, idx) + ".class"; + return new File(srcFile.getParentFile(), className).getAbsolutePath(); + } + + /** + * Get the class name, given either the source file, or the class file path. + * + * @param filePath + * @return + */ + private String getClassName(String filePath) { + File file = new File(filePath); + String name = file.getName(); + int idx = name.lastIndexOf("."); + if (idx == -1) { + throw new IllegalArgumentException(filePath + " is not a valid java source/class file path"); + } + return name.substring(0, idx); + } + + /** + * Load the class, given its class file path. + * + * @param classPath + * @return + */ + private Class loadClass(String classPath) { + try { + File parentDir = new File(classPath).getParentFile(); + URL[] urls = new URL[]{ parentDir.toURI().toURL() }; + ClassLoader cl = new URLClassLoader(urls); + return cl.loadClass(getClassName(classPath)); + } catch (Exception e) { + String msg = "Failed to load class " + getClassName(classPath) + " from path: " + new File(classPath).getAbsolutePath(); + throw new RuntimeException(msg, e); + } + } + + /** + * Compile source files, given their paths. + * + * @param paths + */ + public void compile(String ...paths) { + List pathList = Arrays.stream(paths).map(p -> new File(p).getAbsolutePath()).collect(Collectors.toList()); + String pathStr = String.join(", ", pathList); + try { + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + StandardJavaFileManager fileManager = compiler.getStandardFileManager(null, null, null); + Iterable compilationUnits = fileManager.getJavaFileObjectsFromStrings(pathList); + compiler.getTask(null, fileManager, null, null, null, compilationUnits).call(); + fileManager.close(); + } catch (Exception e) { + throw new RuntimeException("Failed to compile generated code from: " + pathStr + ", due to " + e.getMessage(), e); + } + } + + /** + * Write the source file to disk, given its class name and source code. + * If the source file already exists, it will be overwritten. + * + * @param name + * @param code + * @return the path of the source file on disk. + */ + private String writeSource(String name, String code) { + File srcFile = new File(codeRoot, name + ".java"); + try { + org.apache.commons.io.FileUtils.writeStringToFile(srcFile, code); + } catch (Throwable e) { + throw new RuntimeException("Failed to write source file for " + name + ", reason: " + e.getMessage(), e); + } + return srcFile.getAbsolutePath(); + } + + /** + * Try to find a class in class cache, given its name and code. + * + * @param name + * @param code + * @return + */ + private Class findClassInCache(String name, String code) { + List>> clazzList = clazzCache.get(name); + if (clazzList == null) { + return null; + } + for (ImmutablePair> pair : clazzList) { + if (pair.left.equals(code)) { + return pair.right; + } + } + return null; + } + + /** + * Check if a batch of source files have all been compiled, and inserted into the cache. + * @param names + * @param sources + * @return + */ + private boolean allCompiled(List names, List sources) { + if (names.size() != sources.size()) { + throw new IllegalArgumentException("Source file names and code are not of equal size."); + } + + for (int i = 0; i < names.size(); i++) { + if (findClassInCache(names.get(i), sources.get(i)) == null) { + return false; + } + } + + return true; + } + + /** + * Compile a number of source files in batch. + * This will be much faster than compiling the files individually. + * @param names + * @param sources + */ + public void compileSourceInBatch(List names, List sources) { + if (allCompiled(names, sources)) { + return; + } + + synchronized (clazzCache) { + if (!allCompiled(names, sources)) { + // write source files to disk + String[] sourcePaths = new String[names.size()]; + for (int i = 0; i < names.size(); i++) { + sourcePaths[i] = writeSource(names.get(i), sources.get(i)); + } + + // compile sources + compile(sourcePaths); + + // load classes and insert them into cache + for (int i = 0; i < names.size(); i++) { + Class clazz = loadClass(getClassFilePath(sourcePaths[i])); + List>> clazzList = clazzCache.get(names.get(i)); + if (clazzList == null) { + clazzList = new ArrayList<>(); + clazzCache.put(names.get(i), clazzList); + } + clazzList.add(new ImmutablePair<>(sources.get(i), clazz)); + } + } + } + } + + /** + * Given the class name and the code, get the class. + * + * @param name + * @param code + * @return + */ + public Class getCodeClass(String name, String code) { + Class ret = findClassInCache(name, code); + if (ret != null) { + // The class is in the cache. + return ret; + } + + // get the list of all classes with the same name. + List>> clazzList = clazzCache.get(name); + if (clazzList == null) { + synchronized (clazzCache) { + clazzList = clazzCache.get(name); + if (clazzList == null) { + clazzList = new ArrayList<>(); + clazzCache.put(name, clazzList); + } + } + } + + synchronized (clazzList) { + ret = findClassInCache(name, code); + if (ret == null) { + // write source file + String srcPath = writeSource(name, code); + + // compile source + compile(srcPath); + + // load class + ret = loadClass(getClassFilePath(srcPath)); + + // insert class to cache + clazzList.add(new ImmutablePair<>(code, ret)); + } + } + + return ret; + } +}