diff --git a/python/pom.xml b/python/pom.xml index c14d4b1da84..289755df803 100644 --- a/python/pom.xml +++ b/python/pom.xml @@ -40,8 +40,6 @@ **/PythonInterpreterPandasSqlTest.java, **/PythonInterpreterMatplotlibTest.java - https://pypi.python.org/packages - /64/5c/01e13b68e8caafece40d549f232c9b5677ad1016071a48d04cc3895acaa3 1.4.0 2.4.1 @@ -137,35 +135,12 @@ - - org.codehaus.mojo - wagon-maven-plugin - 1.0 - - - package - download-single - - ${pypi.repo.url}${python.py4j.repo.folder} - py4j-${python.py4j.version}.zip - ${project.build.directory}/../../interpreter/python/py4j-${python.py4j.version}.zip - - - - - maven-antrun-plugin 1.7 package - - - - - run diff --git a/python/src/main/java/org/apache/zeppelin/python/IPythonInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/IPythonInterpreter.java index 4fe50ee31d6..2daa986a840 100644 --- a/python/src/main/java/org/apache/zeppelin/python/IPythonInterpreter.java +++ b/python/src/main/java/org/apache/zeppelin/python/IPythonInterpreter.java @@ -243,16 +243,21 @@ private void setupJVMGateway(int jvmGatewayPort) throws IOException { private void launchIPythonKernel(int ipythonPort) throws IOException, URISyntaxException { // copy the python scripts to a temp directory, then launch ipython kernel in that folder - File tmpPythonScriptFolder = Files.createTempDirectory("zeppelin_ipython").toFile(); + File pythonWorkDir = Files.createTempDirectory("zeppelin_ipython").toFile(); String[] ipythonScripts = {"ipython_server.py", "ipython_pb2.py", "ipython_pb2_grpc.py"}; for (String ipythonScript : ipythonScripts) { URL url = getClass().getClassLoader().getResource("grpc/python" + "/" + ipythonScript); - FileUtils.copyURLToFile(url, new File(tmpPythonScriptFolder, ipythonScript)); + FileUtils.copyURLToFile(url, new File(pythonWorkDir, ipythonScript)); } + //TODO(zjffdu) don't do hard code on py4j here + File py4jDestFile = new File(pythonWorkDir, "py4j-src-0.9.2.zip"); + FileUtils.copyURLToFile(getClass().getClassLoader().getResource( + "python/py4j-src-0.9.2.zip"), py4jDestFile); + CommandLine cmd = CommandLine.parse(pythonExecutable); - cmd.addArgument(tmpPythonScriptFolder.getAbsolutePath() + "/ipython_server.py"); + cmd.addArgument(pythonWorkDir.getAbsolutePath() + "/ipython_server.py"); cmd.addArgument(ipythonPort + ""); DefaultExecutor executor = new DefaultExecutor(); ProcessLogOutputStream processOutput = new ProcessLogOutputStream(LOGGER); @@ -261,20 +266,12 @@ private void launchIPythonKernel(int ipythonPort) executor.setWatchdog(watchDog); if (useBuiltinPy4j) { - String py4jLibPath = null; - if (System.getenv("ZEPPELIN_HOME") != null) { - py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator - + PythonInterpreter.ZEPPELIN_PY4JPATH; - } else { - Path workingPath = Paths.get("..").toAbsolutePath(); - py4jLibPath = workingPath + File.separator + PythonInterpreter.ZEPPELIN_PY4JPATH; - } if (additionalPythonPath != null) { // put the py4j at the end, because additionalPythonPath may already contain py4j. // e.g. PySparkInterpreter - additionalPythonPath = additionalPythonPath + ":" + py4jLibPath; + additionalPythonPath = additionalPythonPath + ":" + py4jDestFile.getAbsolutePath(); } else { - additionalPythonPath = py4jLibPath; + additionalPythonPath = py4jDestFile.getAbsolutePath(); } } @@ -326,7 +323,7 @@ protected Map setupIPythonEnv() throws IOException { @Override public void close() throws InterpreterException { if (watchDog != null) { - LOGGER.debug("Kill IPython Process"); + LOGGER.info("Kill IPython Process"); ipythonClient.stop(StopRequest.newBuilder().build()); watchDog.destroyProcess(); gatewayServer.shutdown(); diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonCondaInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/PythonCondaInterpreter.java index 887beb8ce7a..8d3e972bae8 100644 --- a/python/src/main/java/org/apache/zeppelin/python/PythonCondaInterpreter.java +++ b/python/src/main/java/org/apache/zeppelin/python/PythonCondaInterpreter.java @@ -31,9 +31,10 @@ /** * Conda support + * TODO(zjffdu) Add removing conda env */ public class PythonCondaInterpreter extends Interpreter { - Logger logger = LoggerFactory.getLogger(PythonCondaInterpreter.class); + private static Logger logger = LoggerFactory.getLogger(PythonCondaInterpreter.class); public static final String ZEPPELIN_PYTHON = "zeppelin.python"; public static final String CONDA_PYTHON_PATH = "/bin/python"; public static final String DEFAULT_ZEPPELIN_PYTHON = "python"; @@ -145,33 +146,22 @@ private void changePythonEnvironment(String envName) } } setCurrentCondaEnvName(envName); - python.setPythonCommand(binPath); + python.setPythonExec(binPath); } private void restartPythonProcess() throws InterpreterException { - PythonInterpreter python = getPythonInterpreter(); + logger.debug("Restarting PythonInterpreter"); + Interpreter python = + getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName()); python.close(); python.open(); } protected PythonInterpreter getPythonInterpreter() throws InterpreterException { - LazyOpenInterpreter lazy = null; PythonInterpreter python = null; Interpreter p = getInterpreterInTheSameSessionByClassName(PythonInterpreter.class.getName()); - - while (p instanceof WrappedInterpreter) { - if (p instanceof LazyOpenInterpreter) { - lazy = (LazyOpenInterpreter) p; - } - p = ((WrappedInterpreter) p).getInnerInterpreter(); - } - python = (PythonInterpreter) p; - - if (lazy != null) { - lazy.open(); - } - return python; + return (PythonInterpreter) ((LazyOpenInterpreter)p).getInnerInterpreter(); } public static String runCondaCommandForTextOutput(String title, List commands) @@ -392,27 +382,50 @@ public Scheduler getScheduler() { public static String runCommand(List commands) throws IOException, InterruptedException { + logger.info("Starting shell commands: " + StringUtils.join(commands, " ")); + Process process = Runtime.getRuntime().exec(commands.toArray(new String[0])); + StreamGobbler errorGobbler = new StreamGobbler(process.getErrorStream()); + StreamGobbler outputGobbler = new StreamGobbler(process.getInputStream()); + errorGobbler.start(); + outputGobbler.start(); + if (process.waitFor() != 0) { + throw new IOException("Fail to run shell commands: " + StringUtils.join(commands, " ")); + } + logger.info("Complete shell commands: " + StringUtils.join(commands, " ")); + return outputGobbler.getOutput(); + } - StringBuilder sb = new StringBuilder(); + private static class StreamGobbler extends Thread { + InputStream is; + StringBuilder output = new StringBuilder(); - ProcessBuilder builder = new ProcessBuilder(commands); - builder.redirectErrorStream(true); - Process process = builder.start(); - InputStream stdout = process.getInputStream(); - BufferedReader br = new BufferedReader(new InputStreamReader(stdout)); - String line; - while ((line = br.readLine()) != null) { - sb.append(line); - sb.append("\n"); + // reads everything from is until empty. + StreamGobbler(InputStream is) { + this.is = is; } - int r = process.waitFor(); // Let the process finish. - if (r != 0) { - throw new RuntimeException("Failed to execute `" + - StringUtils.join(commands, " ") + "` exited with " + r); + public void run() { + try { + InputStreamReader isr = new InputStreamReader(is); + BufferedReader br = new BufferedReader(isr); + String line = null; + long startTime = System.currentTimeMillis(); + while ( (line = br.readLine()) != null) { + output.append(line + "\n"); + // logging per 5 seconds + if ((System.currentTimeMillis() - startTime) > 5000) { + logger.info(line); + startTime = System.currentTimeMillis(); + } + } + } catch (IOException ioe) { + ioe.printStackTrace(); + } } - return sb.toString(); + public String getOutput() { + return output.toString(); + } } public static String runCommand(String ... command) diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java index 22f6c2ee994..b528efa7af7 100644 --- a/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java +++ b/python/src/main/java/org/apache/zeppelin/python/PythonDockerInterpreter.java @@ -58,7 +58,7 @@ public void close() { @Override public InterpreterResult interpret(String st, InterpreterContext context) throws InterpreterException { - File pythonScript = new File(getPythonInterpreter().getScriptPath()); + File pythonWorkDir = getPythonInterpreter().getPythonWorkDir(); InterpreterOutput out = context.out; Matcher activateMatcher = activatePattern.matcher(st); @@ -73,26 +73,23 @@ public InterpreterResult interpret(String st, InterpreterContext context) pull(out, image); // mount pythonscript dir - String mountPythonScript = "-v " + - pythonScript.getParentFile().getAbsolutePath() + - ":/_zeppelin_tmp "; + String mountPythonScript = "-v " + pythonWorkDir.getAbsolutePath() + + ":/_python_workdir "; // mount zeppelin dir - String mountPy4j = "-v " + - zeppelinHome.getAbsolutePath() + + String mountPy4j = "-v " + zeppelinHome.getAbsolutePath() + ":/_zeppelin "; // set PYTHONPATH - String pythonPath = ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PY4JPATH + ":" + - ":/_zeppelin/" + PythonInterpreter.ZEPPELIN_PYTHON_LIBS; + String pythonPath = ".:/_python_workdir/py4j-src-0.9.2.zip:/_python_workdir"; setPythonCommand("docker run -i --rm " + mountPythonScript + mountPy4j + "-e PYTHONPATH=\"" + pythonPath + "\" " + image + " " + - getPythonInterpreter().getPythonBindPath() + " " + - "/_zeppelin_tmp/" + pythonScript.getName()); + getPythonInterpreter().getPythonExec() + " " + + "/_python_workdir/zeppelin_python.py"); restartPythonProcess(); out.clear(); return new InterpreterResult(InterpreterResult.Code.SUCCESS, "\"" + image + "\" activated"); @@ -108,7 +105,7 @@ public InterpreterResult interpret(String st, InterpreterContext context) public void setPythonCommand(String cmd) throws InterpreterException { PythonInterpreter python = getPythonInterpreter(); - python.setPythonCommand(cmd); + python.setPythonExec(cmd); } private void printUsage(InterpreterOutput out) { diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java index 178f79a25ec..95cfc824053 100644 --- a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java +++ b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreter.java @@ -1,41 +1,24 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.zeppelin.python; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; -import java.net.*; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.regex.Pattern; - +import com.google.common.io.Files; +import com.google.gson.Gson; import org.apache.commons.exec.CommandLine; import org.apache.commons.exec.DefaultExecutor; import org.apache.commons.exec.ExecuteException; @@ -45,239 +28,233 @@ import org.apache.commons.exec.environment.EnvironmentUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; -import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.BaseZeppelinContext; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.InvalidHookException; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; -import org.apache.zeppelin.scheduler.Job; -import org.apache.zeppelin.scheduler.Scheduler; -import org.apache.zeppelin.scheduler.SchedulerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import py4j.GatewayServer; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; + /** - * Python interpreter for Zeppelin. + * Interpreter for Python, it is the first implementation of interpreter for Python, so with less + * features compared to IPythonInterpreter, but requires less prerequisites than + * IPythonInterpreter, only python installation is required. */ public class PythonInterpreter extends Interpreter implements ExecuteResultHandler { - private static final Logger LOG = LoggerFactory.getLogger(PythonInterpreter.class); - public static final String ZEPPELIN_PYTHON = "python/zeppelin_python.py"; - public static final String ZEPPELIN_CONTEXT = "python/zeppelin_context.py"; - public static final String ZEPPELIN_PY4JPATH = "interpreter/python/py4j-0.9.2/src"; - public static final String ZEPPELIN_PYTHON_LIBS = "interpreter/lib/python"; - public static final String DEFAULT_ZEPPELIN_PYTHON = "python"; - public static final String MAX_RESULT = "zeppelin.python.maxResult"; - - private PythonZeppelinContext zeppelinContext; - private InterpreterContext context; - private Pattern errorInLastLine = Pattern.compile(".*(Error|Exception): .*$"); - private String pythonPath; - private int maxResult; - private String py4jLibPath; - private String pythonLibPath; - - private String pythonCommand; + private static final Logger LOGGER = LoggerFactory.getLogger(PythonInterpreter.class); + private static final int MAX_TIMEOUT_SEC = 10; private GatewayServer gatewayServer; private DefaultExecutor executor; - private int port; - private InterpreterOutputStream outputStream; - private BufferedWriter ins; - private PipedInputStream in; - private ByteArrayOutputStream input; - private String scriptPath; - boolean pythonscriptRunning = false; - private static final int MAX_TIMEOUT_SEC = 10; + private File pythonWorkDir; + protected boolean useBuiltinPy4j = true; - private long pythonPid = 0; + // used to forward output from python process to InterpreterOutput + private InterpreterOutputStream outputStream; + private AtomicBoolean pythonScriptRunning = new AtomicBoolean(false); + private AtomicBoolean pythonScriptInitialized = new AtomicBoolean(false); + private long pythonPid = -1; private IPythonInterpreter iPythonInterpreter; - - Integer statementSetNotifier = new Integer(0); + private BaseZeppelinContext zeppelinContext; + private String condaPythonExec; // set by PythonCondaInterpreter public PythonInterpreter(Properties property) { super(property); - try { - File scriptFile = File.createTempFile("zeppelin_python-", ".py", new File("/tmp")); - scriptPath = scriptFile.getAbsolutePath(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private String workingDir() { - URL myURL = getClass().getProtectionDomain().getCodeSource().getLocation(); - java.net.URI myURI = null; - try { - myURI = myURL.toURI(); - } catch (URISyntaxException e1) - {} - String path = java.nio.file.Paths.get(myURI).toFile().toString(); - return path; } - private void createPythonScript() throws InterpreterException { - File out = new File(scriptPath); - - if (out.exists() && out.isDirectory()) { - throw new InterpreterException("Can't create python script " + out.getAbsolutePath()); + @Override + public void open() throws InterpreterException { + // try IPythonInterpreter first + iPythonInterpreter = getIPythonInterpreter(); + if (getProperty("zeppelin.python.useIPython", "true").equals("true") && + StringUtils.isEmpty( + iPythonInterpreter.checkIPythonPrerequisite(getPythonExec()))) { + try { + iPythonInterpreter.open(); + LOGGER.info("IPython is available, Use IPythonInterpreter to replace PythonInterpreter"); + return; + } catch (Exception e) { + iPythonInterpreter = null; + LOGGER.warn("Fail to open IPythonInterpreter", e); + } } - copyFile(out, ZEPPELIN_PYTHON); - // copy zeppelin_context.py as well - File zOut = new File(out.getParent() + "/zeppelin_context.py"); - copyFile(zOut, ZEPPELIN_CONTEXT); - - logger.info("File {} , {} created", scriptPath, zOut.getAbsolutePath()); - } - - public String getScriptPath() { - return scriptPath; - } + // reset iPythonInterpreter to null as it is not available + iPythonInterpreter = null; + LOGGER.info("IPython is not available, use the native PythonInterpreter"); + // Add matplotlib display hook + InterpreterGroup intpGroup = getInterpreterGroup(); + if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { + try { + // just for unit test I believe (zjffdu) + registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()"); + } catch (InvalidHookException e) { + throw new InterpreterException(e); + } + } - private void copyFile(File out, String sourceFile) throws InterpreterException { - ClassLoader classLoader = getClass().getClassLoader(); try { - FileOutputStream outStream = new FileOutputStream(out); - IOUtils.copy( - classLoader.getResourceAsStream(sourceFile), - outStream); - outStream.close(); + createGatewayServerAndStartScript(); } catch (IOException e) { - throw new InterpreterException(e); + LOGGER.error("Fail to open PythonInterpreter", e); + throw new InterpreterException("Fail to open PythonInterpreter", e); } } - private void createGatewayServerAndStartScript() - throws UnknownHostException, InterpreterException { - createPythonScript(); - if (System.getenv("ZEPPELIN_HOME") != null) { - py4jLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PY4JPATH; - pythonLibPath = System.getenv("ZEPPELIN_HOME") + File.separator + ZEPPELIN_PYTHON_LIBS; - } else { - Path workingPath = Paths.get("..").toAbsolutePath(); - py4jLibPath = workingPath + File.separator + ZEPPELIN_PY4JPATH; - pythonLibPath = workingPath + File.separator + ZEPPELIN_PYTHON_LIBS; - } - - port = findRandomOpenPortOnAllLocalInterfaces(); + // start gateway sever and start python process + private void createGatewayServerAndStartScript() throws IOException { + // start gateway server in JVM side + int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces(); + // use the FQDN as the server address instead of 127.0.0.1 so that python process in docker + // container can also connect to this gateway server. + String serverAddress = getLocalIP(); gatewayServer = new GatewayServer(this, port, GatewayServer.DEFAULT_PYTHON_PORT, - InetAddress.getByName("0.0.0.0"), - InetAddress.getByName("0.0.0.0"), + InetAddress.getByName(serverAddress), + InetAddress.getByName(serverAddress), GatewayServer.DEFAULT_CONNECT_TIMEOUT, GatewayServer.DEFAULT_READ_TIMEOUT, - (List) null); - + (List) null);; gatewayServer.start(); + LOGGER.info("Starting GatewayServer at " + serverAddress + ":" + port); - // Run python shell - String pythonCmd = getPythonCommand(); - CommandLine cmd = CommandLine.parse(pythonCmd); - - if (!pythonCmd.endsWith(".py")) { - // PythonDockerInterpreter set pythoncmd with script - cmd.addArgument(getScriptPath(), false); + // launch python process to connect to the gateway server in JVM side + createPythonScript(); + String pythonExec = getPythonExec(); + CommandLine cmd = CommandLine.parse(pythonExec); + if (!pythonExec.endsWith(".py")) { + // PythonDockerInterpreter set pythonExec with script + cmd.addArgument(pythonWorkDir + "/zeppelin_python.py", false); } + cmd.addArgument(serverAddress, false); cmd.addArgument(Integer.toString(port), false); - cmd.addArgument(getLocalIp(), false); executor = new DefaultExecutor(); - outputStream = new InterpreterOutputStream(LOG); - PipedOutputStream ps = new PipedOutputStream(); - in = null; - try { - in = new PipedInputStream(ps); - } catch (IOException e1) { - throw new InterpreterException(e1); - } - ins = new BufferedWriter(new OutputStreamWriter(ps)); - input = new ByteArrayOutputStream(); - - PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in); + outputStream = new InterpreterOutputStream(LOGGER); + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream); executor.setStreamHandler(streamHandler); executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); - try { - Map env = EnvironmentUtils.getProcEnvironment(); - if (!env.containsKey("PYTHONPATH")) { - env.put("PYTHONPATH", py4jLibPath + File.pathSeparator + pythonLibPath); - } else { - env.put("PYTHONPATH", env.get("PYTHONPATH") + File.pathSeparator + - py4jLibPath + File.pathSeparator + pythonLibPath); - } + Map env = setupPythonEnv(); + LOGGER.info("Launching Python Process Command: " + cmd.getExecutable() + + " " + StringUtils.join(cmd.getArguments(), " ")); + executor.execute(cmd, env, this); + pythonScriptRunning.set(true); + } - logger.info("cmd = {}", cmd.toString()); - executor.execute(cmd, env, this); - pythonscriptRunning = true; - } catch (IOException e) { - throw new InterpreterException(e); + private void createPythonScript() throws IOException { + // set java.io.tmpdir to /tmp on MacOS, because docker can not share the /var folder which will + // cause PythonDockerInterpreter fails. + // https://stackoverflow.com/questions/45122459/docker-mounts-denied-the-paths-are-not-shared- + // from-os-x-and-are-not-known + if (System.getProperty("os.name", "").contains("Mac")) { + System.setProperty("java.io.tmpdir", "/tmp"); + } + this.pythonWorkDir = Files.createTempDir(); + this.pythonWorkDir.deleteOnExit(); + LOGGER.info("Create Python working dir: " + pythonWorkDir.getAbsolutePath()); + copyResourceToPythonWorkDir("python/zeppelin_python.py", "zeppelin_python.py"); + copyResourceToPythonWorkDir("python/zeppelin_context.py", "zeppelin_context.py"); + copyResourceToPythonWorkDir("python/backend_zinline.py", "backend_zinline.py"); + copyResourceToPythonWorkDir("python/mpl_config.py", "mpl_config.py"); + copyResourceToPythonWorkDir("python/py4j-src-0.9.2.zip", "py4j-src-0.9.2.zip"); + } + + protected boolean useIPython() { + return this.iPythonInterpreter != null; + } + + private String getLocalIP() { + // zeppelin.python.gatewayserver_address is only for unit test on travis. + // Because the FQDN would fail unit test on travis ci. + String gatewayserver_address = + properties.getProperty("zeppelin.python.gatewayserver_address"); + if (gatewayserver_address != null) { + return gatewayserver_address; } try { - input.write("import sys, getopt\n".getBytes()); - ins.flush(); - } catch (IOException e) { - throw new InterpreterException(e); + return Inet4Address.getLocalHost().getHostAddress(); + } catch (UnknownHostException e) { + LOGGER.warn("can't get local IP", e); } + // fall back to loopback addreess + return "127.0.0.1"; } - @Override - public void open() throws InterpreterException { - // try IPythonInterpreter first. If it is not available, we will fallback to the original - // python interpreter implementation. - iPythonInterpreter = getIPythonInterpreter(); - this.zeppelinContext = new PythonZeppelinContext( - getInterpreterGroup().getInterpreterHookRegistry(), - Integer.parseInt(getProperty("zeppelin.python.maxResult", "1000"))); - if (getProperty("zeppelin.python.useIPython", "true").equals("true") && - StringUtils.isEmpty(iPythonInterpreter.checkIPythonPrerequisite(getPythonBindPath()))) { - try { - iPythonInterpreter.open(); - LOG.info("IPython is available, Use IPythonInterpreter to replace PythonInterpreter"); - return; - } catch (Exception e) { - iPythonInterpreter = null; - LOG.warn("Fail to open IPythonInterpreter", e); + private void copyResourceToPythonWorkDir(String srcResourceName, + String dstFileName) throws IOException { + FileOutputStream out = null; + try { + out = new FileOutputStream(pythonWorkDir.getAbsoluteFile() + "/" + dstFileName); + IOUtils.copy( + getClass().getClassLoader().getResourceAsStream(srcResourceName), + out); + } finally { + if (out != null) { + out.close(); } } + } - // reset iPythonInterpreter to null as it is not available - iPythonInterpreter = null; - LOG.info("IPython is not available, use the native PythonInterpreter"); - // Add matplotlib display hook - InterpreterGroup intpGroup = getInterpreterGroup(); - if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { - try { - registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()"); - } catch (InvalidHookException e) { - throw new InterpreterException(e); - } - } - // Add matplotlib display hook - try { - createGatewayServerAndStartScript(); - } catch (UnknownHostException e) { - throw new InterpreterException(e); + protected Map setupPythonEnv() throws IOException { + Map env = EnvironmentUtils.getProcEnvironment(); + appendToPythonPath(env, pythonWorkDir.getAbsolutePath()); + if (useBuiltinPy4j) { + appendToPythonPath(env, pythonWorkDir.getAbsolutePath() + "/py4j-src-0.9.2.zip"); } + LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); + return env; } - private IPythonInterpreter getIPythonInterpreter() { - LazyOpenInterpreter lazy = null; - IPythonInterpreter ipython = null; - Interpreter p = getInterpreterInTheSameSessionByClassName(IPythonInterpreter.class.getName()); + private void appendToPythonPath(Map env, String path) { + if (!env.containsKey("PYTHONPATH")) { + env.put("PYTHONPATH", path); + } else { + env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + path); + } + } - while (p instanceof WrappedInterpreter) { - if (p instanceof LazyOpenInterpreter) { - lazy = (LazyOpenInterpreter) p; - } - p = ((WrappedInterpreter) p).getInnerInterpreter(); + // Run python script + // Choose python in the order of + // condaPythonExec > zeppelin.python + protected String getPythonExec() { + if (condaPythonExec != null) { + return condaPythonExec; + } else { + return getProperty("zeppelin.python", "python"); } - ipython = (IPythonInterpreter) p; - return ipython; + } + + public File getPythonWorkDir() { + return pythonWorkDir; } @Override @@ -286,54 +263,58 @@ public void close() throws InterpreterException { iPythonInterpreter.close(); return; } - pythonscriptRunning = false; - pythonScriptInitialized = false; - - try { - ins.flush(); - ins.close(); - input.flush(); - input.close(); - } catch (IOException e) { - e.printStackTrace(); - } + pythonScriptRunning.set(false); + pythonScriptInitialized.set(false); executor.getWatchdog().destroyProcess(); - new File(scriptPath).delete(); gatewayServer.shutdown(); - // wait until getStatements stop - synchronized (statementSetNotifier) { - try { - statementSetNotifier.wait(1500); - } catch (InterruptedException e) { - } - statementSetNotifier.notify(); - } + // reset these 2 monitors otherwise when you restart PythonInterpreter it would fails to execute + // python code as these 2 objects are in incorrect state. + statementSetNotifier = new Integer(0); + statementFinishedNotifier = new Integer(0); + } + + private PythonInterpretRequest pythonInterpretRequest = null; + private Integer statementSetNotifier = new Integer(0); + private Integer statementFinishedNotifier = new Integer(0); + private String statementOutput = null; + private boolean statementError = false; + + public void setPythonExec(String pythonExec) { + LOGGER.info("Set Python Command : {}", pythonExec); + this.condaPythonExec = pythonExec; } - PythonInterpretRequest pythonInterpretRequest = null; /** - * Result class of python interpreter + * Request send to Python Daemon */ public class PythonInterpretRequest { public String statements; + public boolean isForCompletion; - public PythonInterpretRequest(String statements) { + public PythonInterpretRequest(String statements, boolean isForCompletion) { this.statements = statements; + this.isForCompletion = isForCompletion; } public String statements() { return statements; } + + public boolean isForCompletion() { + return isForCompletion; + } } + // called by Python Process public PythonInterpretRequest getStatements() { synchronized (statementSetNotifier) { - while (pythonInterpretRequest == null && pythonscriptRunning && pythonScriptInitialized) { + while (pythonInterpretRequest == null) { try { statementSetNotifier.wait(1000); } catch (InterruptedException e) { + e.printStackTrace(); } } PythonInterpretRequest req = pythonInterpretRequest; @@ -342,65 +323,78 @@ public PythonInterpretRequest getStatements() { } } - String statementOutput = null; - boolean statementError = false; - Integer statementFinishedNotifier = new Integer(0); - + // called by Python Process public void setStatementsFinished(String out, boolean error) { synchronized (statementFinishedNotifier) { + LOGGER.debug("Setting python statement output: " + out + ", error: " + error); statementOutput = out; statementError = error; statementFinishedNotifier.notify(); } } - boolean pythonScriptInitialized = false; - Integer pythonScriptInitializeNotifier = new Integer(0); - + // called by Python Process public void onPythonScriptInitialized(long pid) { pythonPid = pid; - synchronized (pythonScriptInitializeNotifier) { - pythonScriptInitialized = true; - pythonScriptInitializeNotifier.notifyAll(); + synchronized (pythonScriptInitialized) { + LOGGER.debug("onPythonScriptInitialized is called"); + pythonScriptInitialized.set(true); + pythonScriptInitialized.notifyAll(); } } + // called by Python Process public void appendOutput(String message) throws IOException { + LOGGER.debug("Output from python process: " + message); outputStream.getInterpreterOutput().write(message); } - @Override - public InterpreterResult interpret(String cmd, InterpreterContext contextInterpreter) - throws InterpreterException { - if (iPythonInterpreter != null) { - return iPythonInterpreter.interpret(cmd, contextInterpreter); - } + // used by subclass such as PySparkInterpreter to set JobGroup before executing spark code + protected void preCallPython(InterpreterContext context) { - if (cmd == null || cmd.isEmpty()) { - return new InterpreterResult(Code.SUCCESS, ""); + } + + // blocking call. Send python code to python process and get response + protected void callPython(PythonInterpretRequest request) { + synchronized (statementSetNotifier) { + this.pythonInterpretRequest = request; + statementOutput = null; + statementSetNotifier.notify(); } - this.context = contextInterpreter; + synchronized (statementFinishedNotifier) { + while (statementOutput == null) { + try { + statementFinishedNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + } + } - zeppelinContext.setGui(context.getGui()); - zeppelinContext.setNoteGui(context.getNoteGui()); - zeppelinContext.setInterpreterContext(context); + @Override + public InterpreterResult interpret(String st, InterpreterContext context) + throws InterpreterException { + if (iPythonInterpreter != null) { + return iPythonInterpreter.interpret(st, context); + } - if (!pythonscriptRunning) { - return new InterpreterResult(Code.ERROR, "python process not running" - + outputStream.toString()); + if (!pythonScriptRunning.get()) { + return new InterpreterResult(Code.ERROR, "python process not running " + + outputStream.toString()); } outputStream.setInterpreterOutput(context.out); - synchronized (pythonScriptInitializeNotifier) { + synchronized (pythonScriptInitialized) { long startTime = System.currentTimeMillis(); - while (pythonScriptInitialized == false - && pythonscriptRunning - && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { + while (!pythonScriptInitialized.get() + && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { try { - pythonScriptInitializeNotifier.wait(1000); + LOGGER.info("Wait for PythonScript initialized"); + pythonScriptInitialized.wait(100); } catch (InterruptedException e) { + e.printStackTrace(); } } } @@ -413,59 +407,40 @@ public InterpreterResult interpret(String cmd, InterpreterContext contextInterpr throw new InterpreterException(e); } - if (pythonscriptRunning == false) { - // python script failed to initialize and terminated - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "failed to start python")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - if (pythonScriptInitialized == false) { + if (!pythonScriptInitialized.get()) { // timeout. didn't get initialized message errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "python is not responding")); + InterpreterResult.Type.TEXT, "Failed to initialize Python")); return new InterpreterResult(Code.ERROR, errorMessage); } - pythonInterpretRequest = new PythonInterpretRequest(cmd); - statementOutput = null; - - synchronized (statementSetNotifier) { - statementSetNotifier.notify(); - } + BaseZeppelinContext z = getZeppelinContext(); + z.setInterpreterContext(context); + z.setGui(context.getGui()); + z.setNoteGui(context.getNoteGui()); + InterpreterContext.set(context); - synchronized (statementFinishedNotifier) { - while (statementOutput == null) { - try { - statementFinishedNotifier.wait(1000); - } catch (InterruptedException e) { - } - } - } + preCallPython(context); + callPython(new PythonInterpretRequest(st, false)); if (statementError) { return new InterpreterResult(Code.ERROR, statementOutput); } else { - try { context.out.flush(); } catch (IOException e) { throw new InterpreterException(e); } - return new InterpreterResult(Code.SUCCESS); } } - public InterpreterContext getCurrentInterpreterContext() { - return context; - } - public void interrupt() throws IOException, InterpreterException { if (pythonPid > -1) { - logger.info("Sending SIGINT signal to PID : " + pythonPid); + LOGGER.info("Sending SIGINT signal to PID : " + pythonPid); Runtime.getRuntime().exec("kill -SIGINT " + pythonPid); } else { - logger.warn("Non UNIX/Linux system, close the interpreter"); + LOGGER.warn("Non UNIX/Linux system, close the interpreter"); close(); } } @@ -474,11 +449,12 @@ public void interrupt() throws IOException, InterpreterException { public void cancel(InterpreterContext context) throws InterpreterException { if (iPythonInterpreter != null) { iPythonInterpreter.cancel(context); + return; } try { interrupt(); } catch (IOException e) { - e.printStackTrace(); + LOGGER.error("Error", e); } } @@ -495,114 +471,162 @@ public int getProgress(InterpreterContext context) throws InterpreterException { return 0; } - @Override - public Scheduler getScheduler() { - if (iPythonInterpreter != null) { - return iPythonInterpreter.getScheduler(); - } - return SchedulerFactory.singleton().createOrGetFIFOScheduler( - PythonInterpreter.class.getName() + this.hashCode()); - } @Override public List completion(String buf, int cursor, - InterpreterContext interpreterContext) { + InterpreterContext interpreterContext) + throws InterpreterException { if (iPythonInterpreter != null) { return iPythonInterpreter.completion(buf, cursor, interpreterContext); } - return null; - } + if (buf.length() < cursor) { + cursor = buf.length(); + } + String completionString = getCompletionTargetString(buf, cursor); + String completionCommand = "__zeppelin_completion__.getCompletion('" + completionString + "')"; + LOGGER.debug("completionCommand: " + completionCommand); - public void setPythonCommand(String cmd) { - logger.info("Set Python Command : {}", cmd); - pythonCommand = cmd; - } + pythonInterpretRequest = new PythonInterpretRequest(completionCommand, true); + statementOutput = null; - private String getPythonCommand() { - if (pythonCommand == null) { - return getPythonBindPath(); - } else { - return pythonCommand; + synchronized (statementSetNotifier) { + statementSetNotifier.notify(); } - } - public String getPythonBindPath() { - String path = getProperty("zeppelin.python"); - if (path == null) { - return DEFAULT_ZEPPELIN_PYTHON; - } else { - return path; + String[] completionList = null; + synchronized (statementFinishedNotifier) { + long startTime = System.currentTimeMillis(); + while (statementOutput == null + && pythonScriptRunning.get()) { + try { + if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) { + LOGGER.error("Python completion didn't have response for {}sec.", MAX_TIMEOUT_SEC); + break; + } + statementFinishedNotifier.wait(1000); + } catch (InterruptedException e) { + // not working + LOGGER.info("wait drop"); + return new LinkedList<>(); + } + } + if (statementError) { + return new LinkedList<>(); + } + Gson gson = new Gson(); + completionList = gson.fromJson(statementOutput, String[].class); + } + //end code for completion + if (completionList == null) { + return new LinkedList<>(); } - } - private Job getRunningJob(String paragraphId) { - Job foundJob = null; - Collection jobsRunning = getScheduler().getJobsRunning(); - for (Job job : jobsRunning) { - if (job.getId().equals(paragraphId)) { - foundJob = job; - break; - } + List results = new LinkedList<>(); + for (String name: completionList) { + results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY)); } - return foundJob; + return results; } - void bootStrapInterpreter(String file) throws IOException { - BufferedReader bootstrapReader = new BufferedReader( - new InputStreamReader( - PythonInterpreter.class.getResourceAsStream(file))); - String line = null; - String bootstrapCode = ""; + private String getCompletionTargetString(String text, int cursor) { + String[] completionSeqCharaters = {" ", "\n", "\t"}; + int completionEndPosition = cursor; + int completionStartPosition = cursor; + int indexOfReverseSeqPostion = cursor; - while ((line = bootstrapReader.readLine()) != null) { - bootstrapCode += line + "\n"; + String resultCompletionText = ""; + String completionScriptText = ""; + try { + completionScriptText = text.substring(0, cursor); } + catch (Exception e) { + LOGGER.error(e.toString()); + return null; + } + completionEndPosition = completionScriptText.length(); + + String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString(); + + for (String seqCharacter : completionSeqCharaters) { + indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter); + + if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) { + completionStartPosition = indexOfReverseSeqPostion; + } - try { - interpret(bootstrapCode, context); - } catch (InterpreterException e) { - throw new IOException(e); } - } - public PythonZeppelinContext getZeppelinContext() { - return zeppelinContext; + if (completionStartPosition == completionEndPosition) { + completionStartPosition = 0; + } + else + { + completionStartPosition = completionEndPosition - completionStartPosition; + } + resultCompletionText = completionScriptText.substring( + completionStartPosition , completionEndPosition); + + return resultCompletionText; } - String getLocalIp() { - try { - return Inet4Address.getLocalHost().getHostAddress(); - } catch (UnknownHostException e) { - logger.error("can't get local IP", e); + protected IPythonInterpreter getIPythonInterpreter() { + LazyOpenInterpreter lazy = null; + IPythonInterpreter iPython = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(IPythonInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); } - // fall back to loopback addreess - return "127.0.0.1"; + iPython = (IPythonInterpreter) p; + return iPython; } - private int findRandomOpenPortOnAllLocalInterfaces() { - Integer port = -1; - try (ServerSocket socket = new ServerSocket(0);) { - port = socket.getLocalPort(); - socket.close(); - } catch (IOException e) { - LOG.error("Can't find an open port", e); + protected BaseZeppelinContext createZeppelinContext() { + return new PythonZeppelinContext( + getInterpreterGroup().getInterpreterHookRegistry(), + Integer.parseInt(getProperty("zeppelin.python.maxResult", "1000"))); + } + + public BaseZeppelinContext getZeppelinContext() { + if (zeppelinContext == null) { + zeppelinContext = createZeppelinContext(); } - return port; + return zeppelinContext; } - public int getMaxResult() { - return maxResult; + protected void bootstrapInterpreter(String resourceName) throws IOException { + LOGGER.info("Bootstrap interpreter via " + resourceName); + String bootstrapCode = + IOUtils.toString(getClass().getClassLoader().getResourceAsStream(resourceName)); + try { + InterpreterResult result = interpret(bootstrapCode, InterpreterContext.get()); + if (result.code() != Code.SUCCESS) { + throw new IOException("Fail to run bootstrap script: " + resourceName); + } + } catch (InterpreterException e) { + throw new IOException(e); + } } @Override public void onProcessComplete(int exitValue) { - pythonscriptRunning = false; - logger.info("python process terminated. exit code " + exitValue); + LOGGER.info("python process terminated. exit code " + exitValue); + pythonScriptRunning.set(false); + pythonScriptInitialized.set(false); } @Override public void onProcessFailed(ExecuteException e) { - pythonscriptRunning = false; - logger.error("python process failed", e); + LOGGER.error("python process failed", e); + pythonScriptRunning.set(false); + pythonScriptInitialized.set(false); + } + + // Called by Python Process, used for debugging purpose + public void logPythonOutput(String message) { + LOGGER.debug("Python Process Output: " + message); } } diff --git a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreterPandasSql.java b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreterPandasSql.java index 54984c3ce7f..db65960805d 100644 --- a/python/src/main/java/org/apache/zeppelin/python/PythonInterpreterPandasSql.java +++ b/python/src/main/java/org/apache/zeppelin/python/PythonInterpreterPandasSql.java @@ -70,7 +70,7 @@ public void open() throws InterpreterException { LOG.info("Bootstrap {} interpreter with {}", this.toString(), SQL_BOOTSTRAP_FILE_PY); PythonInterpreter python = getPythonInterpreter(); - python.bootStrapInterpreter(SQL_BOOTSTRAP_FILE_PY); + python.bootstrapInterpreter(SQL_BOOTSTRAP_FILE_PY); } catch (IOException e) { LOG.error("Can't execute " + SQL_BOOTSTRAP_FILE_PY + " to import SQL dependencies", e); } diff --git a/interpreter/lib/python/backend_zinline.py b/python/src/main/resources/python/backend_zinline.py similarity index 100% rename from interpreter/lib/python/backend_zinline.py rename to python/src/main/resources/python/backend_zinline.py diff --git a/interpreter/lib/python/mpl_config.py b/python/src/main/resources/python/mpl_config.py similarity index 100% rename from interpreter/lib/python/mpl_config.py rename to python/src/main/resources/python/mpl_config.py diff --git a/python/src/main/resources/python/py4j-src-0.9.2.zip b/python/src/main/resources/python/py4j-src-0.9.2.zip new file mode 100644 index 00000000000..8ceb15c5264 Binary files /dev/null and b/python/src/main/resources/python/py4j-src-0.9.2.zip differ diff --git a/python/src/main/resources/python/zeppelin_python.py b/python/src/main/resources/python/zeppelin_python.py index 0b2d5338918..19fa2201a40 100644 --- a/python/src/main/resources/python/zeppelin_python.py +++ b/python/src/main/resources/python/zeppelin_python.py @@ -15,24 +15,12 @@ # limitations under the License. # -import os, sys, getopt, traceback, json, re +import os, sys, traceback, json, re from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.protocol import Py4JJavaError, Py4JNetworkError -import warnings -import ast -import traceback -import warnings -import signal -import base64 - -from io import BytesIO -try: - from StringIO import StringIO -except ImportError: - from io import StringIO +from py4j.protocol import Py4JJavaError -# for back compatibility +import ast class Logger(object): def __init__(self): @@ -47,46 +35,79 @@ def reset(self): def flush(self): pass -def handler_stop_signals(sig, frame): - sys.exit("Got signal : " + str(sig)) +class PythonCompletion: + def __init__(self, interpreter, userNameSpace): + self.interpreter = interpreter + self.userNameSpace = userNameSpace -signal.signal(signal.SIGINT, handler_stop_signals) + def getObjectCompletion(self, text_value): + completions = [completion for completion in list(self.userNameSpace.keys()) if completion.startswith(text_value)] + builtinCompletions = [completion for completion in dir(__builtins__) if completion.startswith(text_value)] + return completions + builtinCompletions -host = "127.0.0.1" -if len(sys.argv) >= 3: - host = sys.argv[2] + def getMethodCompletion(self, objName, methodName): + execResult = locals() + try: + exec("{} = dir({})".format("objectDefList", objName), _zcUserQueryNameSpace, execResult) + except: + self.interpreter.logPythonOutput("Fail to run dir on " + objName) + self.interpreter.logPythonOutput(traceback.format_exc()) + return None + else: + objectDefList = execResult['objectDefList'] + return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] + + def getCompletion(self, text_value): + if text_value == None: + return None + + dotPos = text_value.find(".") + if dotPos == -1: + objName = text_value + completionList = self.getObjectCompletion(objName) + else: + objName = text_value[:dotPos] + methodName = text_value[dotPos + 1:] + completionList = self.getMethodCompletion(objName, methodName) + + if completionList is None or len(completionList) <= 0: + self.interpreter.setStatementsFinished("", False) + else: + result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) + self.interpreter.setStatementsFinished(result, False) + +host = sys.argv[1] +port = int(sys.argv[2]) + +client = GatewayClient(address=host, port=port) +gateway = JavaGateway(client, auto_convert = True) +intp = gateway.entry_point +# redirect stdout/stderr to java side so that PythonInterpreter can capture the python execution result +output = Logger() +sys.stdout = output +sys.stderr = output _zcUserQueryNameSpace = {} -client = GatewayClient(address=host, port=int(sys.argv[1])) - -gateway = JavaGateway(client) - -intp = gateway.entry_point -intp.onPythonScriptInitialized(os.getpid()) -java_import(gateway.jvm, "org.apache.zeppelin.display.Input") +completion = PythonCompletion(intp, _zcUserQueryNameSpace) +_zcUserQueryNameSpace["__zeppelin_completion__"] = completion +_zcUserQueryNameSpace["gateway"] = gateway from zeppelin_context import PyZeppelinContext +if intp.getZeppelinContext(): + z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway) + __zeppelin__._setup_matplotlib() + _zcUserQueryNameSpace["z"] = z + _zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ -z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway) -__zeppelin__._setup_matplotlib() - -_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ -_zcUserQueryNameSpace["z"] = z - -output = Logger() -sys.stdout = output -#sys.stderr = output +intp.onPythonScriptInitialized(os.getpid()) while True : req = intp.getStatements() - if req == None: - break - try: stmts = req.statements().split("\n") - final_code = [] + isForCompletion = req.isForCompletion() # Get post-execute hooks try: @@ -98,35 +119,23 @@ def handler_stop_signals(sig, frame): user_hook = __zeppelin__.getHook('post_exec') except: user_hook = None - - nhooks = 0 - for hook in (global_hook, user_hook): - if hook: - nhooks += 1 - for s in stmts: - if s == None: - continue - - # skip comment - s_stripped = s.strip() - if len(s_stripped) == 0 or s_stripped.startswith("#"): - continue - - final_code.append(s) + nhooks = 0 + if not isForCompletion: + for hook in (global_hook, user_hook): + if hook: + nhooks += 1 - if final_code: + if stmts: # use exec mode to compile the statements except the last statement, # so that the last statement's evaluation will be printed to stdout - code = compile('\n'.join(final_code), '', 'exec', ast.PyCF_ONLY_AST, 1) - + code = compile('\n'.join(stmts), '', 'exec', ast.PyCF_ONLY_AST, 1) to_run_hooks = [] if (nhooks > 0): to_run_hooks = code.body[-nhooks:] to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], [code.body[-(nhooks + 1)]]) - try: for node in to_run_exec: mod = ast.Module([node]) @@ -142,19 +151,37 @@ def handler_stop_signals(sig, frame): mod = ast.Module([node]) code = compile(mod, '', 'exec') exec(code, _zcUserQueryNameSpace) + + if not isForCompletion: + # only call it when it is not for code completion. code completion will call it in + # PythonCompletion.getCompletion + intp.setStatementsFinished("", False) + except Py4JJavaError: + # raise it to outside try except + raise except: - raise Exception(traceback.format_exc()) + if not isForCompletion: + # extract which line incur error from error message. e.g. + # Traceback (most recent call last): + # File "", line 1, in + # ZeroDivisionError: integer division or modulo by zero + exception = traceback.format_exc() + m = re.search("File \"\", line (\d+).*", exception) + if m: + line_no = int(m.group(1)) + intp.setStatementsFinished( + "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) + else: + intp.setStatementsFinished(exception, True) + else: + intp.setStatementsFinished("", False) - intp.setStatementsFinished("", False) except Py4JJavaError: excInnerError = traceback.format_exc() # format_tb() does not return the inner exception innerErrorStart = excInnerError.find("Py4JJavaError:") if innerErrorStart > -1: - excInnerError = excInnerError[innerErrorStart:] + excInnerError = excInnerError[innerErrorStart:] intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) - except Py4JNetworkError: - # lost connection from gateway server. exit - sys.exit(1) except: intp.setStatementsFinished(traceback.format_exc(), True) diff --git a/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java new file mode 100644 index 00000000000..9bedd53f726 --- /dev/null +++ b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.python; + +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.display.ui.CheckBox; +import org.apache.zeppelin.display.ui.Select; +import org.apache.zeppelin.display.ui.TextBox; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.remote.RemoteEventClient; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +public abstract class BasePythonInterpreterTest { + + protected InterpreterGroup intpGroup; + protected Interpreter interpreter; + + @Before + public abstract void setUp() throws InterpreterException; + + @After + public abstract void tearDown() throws InterpreterException; + + + @Test + public void testPythonBasics() throws InterpreterException, InterruptedException, IOException { + + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("import sys\nprint(sys.version[0])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + Thread.sleep(100); + List interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + + // single output without print + context = getInterpreterContext(); + result = interpreter.interpret("'hello world'", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("'hello world'", interpreterResultMessages.get(0).getData().trim()); + + // unicode + context = getInterpreterContext(); + result = interpreter.interpret("print(u'你好')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("你好\n", interpreterResultMessages.get(0).getData()); + + // only the last statement is printed + context = getInterpreterContext(); + result = interpreter.interpret("'hello world'\n'hello world2'", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("'hello world2'", interpreterResultMessages.get(0).getData().trim()); + + // single output + context = getInterpreterContext(); + result = interpreter.interpret("print('hello world')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("hello world\n", interpreterResultMessages.get(0).getData()); + + // multiple output + context = getInterpreterContext(); + result = interpreter.interpret("print('hello world')\nprint('hello world2')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData()); + + // assignment + context = getInterpreterContext(); + result = interpreter.interpret("abc=1",context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(0, interpreterResultMessages.size()); + + // if block + context = getInterpreterContext(); + result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("True\n", interpreterResultMessages.get(0).getData()); + + // for loop + context = getInterpreterContext(); + result = interpreter.interpret("for i in range(3):\n\tprint(i)", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData()); + + // syntax error + context = getInterpreterContext(); + result = interpreter.interpret("print(unknown)", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + if (interpreter instanceof IPythonInterpreter) { + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined")); + } else if (interpreter instanceof PythonInterpreter) { + assertTrue(result.message().get(0).getData().contains("name 'unknown' is not defined")); + } + + // raise runtime exception + context = getInterpreterContext(); + result = interpreter.interpret("1/0", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + if (interpreter instanceof IPythonInterpreter) { + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError")); + } else if (interpreter instanceof PythonInterpreter) { + assertTrue(result.message().get(0).getData().contains("ZeroDivisionError")); + } + + // ZEPPELIN-1133 + context = getInterpreterContext(); + result = interpreter.interpret( + "from __future__ import print_function\n" + + "def greet(name):\n" + + " print('Hello', name)\n" + + "greet('Jack')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData()); + + // ZEPPELIN-1114 + context = getInterpreterContext(); + result = interpreter.interpret("print('there is no Error: ok')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData()); + } + + @Test + public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException { + // there's no completion for 'a.' because it is not recognized by compiler for now. + InterpreterContext context = getInterpreterContext(); + String st = "a='hello'\na."; + List completions = interpreter.completion(st, st.length(), context); + assertEquals(0, completions.size()); + + // define `a` first + context = getInterpreterContext(); + st = "a='hello'"; + InterpreterResult result = interpreter.interpret(st, context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // now we can get the completion for `a.` + context = getInterpreterContext(); + st = "a."; + completions = interpreter.completion(st, st.length(), context); + // it is different for python2 and python3 and may even different for different minor version + // so only verify it is larger than 20 + assertTrue(completions.size() > 20); + + context = getInterpreterContext(); + st = "a.co"; + completions = interpreter.completion(st, st.length(), context); + assertEquals(1, completions.size()); + assertEquals("count", completions.get(0).getValue()); + + // cursor is in the middle of code + context = getInterpreterContext(); + st = "a.co\b='hello"; + completions = interpreter.completion(st, 4, context); + assertEquals(1, completions.size()); + assertEquals("count", completions.get(0).getValue()); + } + + @Test + public void testZeppelinContext() throws InterpreterException, InterruptedException, IOException { + // TextBox + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + List interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'")); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox); + TextBox textbox = (TextBox) context.getGui().getForms().get("text_1"); + assertEquals("text_1", textbox.getName()); + assertEquals("value_1", textbox.getDefaultValue()); + + // Select + context = getInterpreterContext(); + result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("select_1") instanceof Select); + Select select = (Select) context.getGui().getForms().get("select_1"); + assertEquals("select_1", select.getName()); + assertEquals(2, select.getOptions().length); + assertEquals("name_1", select.getOptions()[0].getDisplayName()); + assertEquals("value_1", select.getOptions()[0].getValue()); + + // CheckBox + context = getInterpreterContext(); + result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox); + CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1"); + assertEquals("checkbox_1", checkbox.getName()); + assertEquals(2, checkbox.getOptions().length); + assertEquals("name_1", checkbox.getOptions()[0].getDisplayName()); + assertEquals("value_1", checkbox.getOptions()[0].getValue()); + + // Pandas DataFrame + context = getInterpreterContext(); + result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); + assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + + context = getInterpreterContext(); + result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(2, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); + assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType()); + assertEquals("Results are limited by 3.\n", interpreterResultMessages.get(1).getData()); + + // z.show(matplotlib) + context = getInterpreterContext(); + result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType()); + + // clear output + context = getInterpreterContext(); + result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); + assertEquals("%text world\n", context.out.getCurrentOutput().toString()); + } + + @Test + public void testRedefinitionZeppelinContext() throws InterpreterException { + String redefinitionCode = "z = 1\n"; + String restoreCode = "z = __zeppelin__\n"; + String validCode = "z.input(\"test\")\n"; + + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(redefinitionCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.ERROR, interpreter.interpret(validCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(restoreCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code()); + } + + protected InterpreterContext getInterpreterContext() { + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap(), + new GUI(), + new GUI(), + null, + null, + null, + new InterpreterOutput(null)); + } + + protected InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) { + InterpreterContext context = getInterpreterContext(); + context.setClient(mockRemoteEventClient); + return context; + } +} diff --git a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java index f016f091374..9e01d062e06 100644 --- a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java @@ -17,288 +17,64 @@ package org.apache.zeppelin.python; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.display.ui.CheckBox; -import org.apache.zeppelin.display.ui.Select; -import org.apache.zeppelin.display.ui.TextBox; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterGroup; -import org.apache.zeppelin.interpreter.InterpreterOutput; -import org.apache.zeppelin.interpreter.InterpreterOutputListener; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResultMessage; -import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; -import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Properties; -import java.util.concurrent.CopyOnWriteArrayList; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.mockito.Mockito.mock; -public class IPythonInterpreterTest { +public class IPythonInterpreterTest extends BasePythonInterpreterTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IPythonInterpreterTest.class); - private IPythonInterpreter interpreter; - public void startInterpreter(Properties properties) throws InterpreterException { - interpreter = new IPythonInterpreter(properties); - InterpreterGroup mockInterpreterGroup = mock(InterpreterGroup.class); - interpreter.setInterpreterGroup(mockInterpreterGroup); - interpreter.open(); - } - - @After - public void close() throws InterpreterException { - interpreter.close(); - } - - - @Test - public void testIPython() throws IOException, InterruptedException, InterpreterException { + protected Properties initIntpProperties() { Properties properties = new Properties(); properties.setProperty("zeppelin.python.maxResult", "3"); - startInterpreter(properties); - testInterpreter(interpreter); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + return properties; } - @Test - public void testGrpcFrameSize() throws InterpreterException, IOException { - Properties properties = new Properties(); - properties.setProperty("zeppelin.ipython.grpc.message_size", "200"); - startInterpreter(properties); - - // to make this test can run under both python2 and python3 - InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - InterpreterContext context = getInterpreterContext(); - result = interpreter.interpret("print('1'*300)", context); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - List interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("Frame size 304 exceeds maximum: 200")); - - // next call continue work - result = interpreter.interpret("print(1)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + protected void startInterpreter(Properties properties) throws InterpreterException { + interpreter = new LazyOpenInterpreter(new IPythonInterpreter(properties)); + intpGroup = new InterpreterGroup(); + intpGroup.put("session_1", new ArrayList()); + intpGroup.get("session_1").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - close(); + interpreter.open(); + } - // increase framesize to make it work - properties.setProperty("zeppelin.ipython.grpc.message_size", "500"); + @Override + public void setUp() throws InterpreterException { + Properties properties = initIntpProperties(); startInterpreter(properties); - // to make this test can run under both python2 and python3 - result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - context = getInterpreterContext(); - result = interpreter.interpret("print('1'*300)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); } - public static void testInterpreter(final Interpreter interpreter) throws IOException, InterruptedException, InterpreterException { - // to make this test can run under both python2 and python3 - InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - - InterpreterContext context = getInterpreterContext(); - result = interpreter.interpret("import sys\nprint(sys.version[0])", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - Thread.sleep(100); - List interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - boolean isPython2 = interpreterResultMessages.get(0).getData().equals("2\n"); - - // single output without print - context = getInterpreterContext(); - result = interpreter.interpret("'hello world'", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("'hello world'", interpreterResultMessages.get(0).getData()); - - // unicode - context = getInterpreterContext(); - result = interpreter.interpret("print(u'你好')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("你好\n", interpreterResultMessages.get(0).getData()); - - // only the last statement is printed - context = getInterpreterContext(); - result = interpreter.interpret("'hello world'\n'hello world2'", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("'hello world2'", interpreterResultMessages.get(0).getData()); - - // single output - context = getInterpreterContext(); - result = interpreter.interpret("print('hello world')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("hello world\n", interpreterResultMessages.get(0).getData()); - - // multiple output - context = getInterpreterContext(); - result = interpreter.interpret("print('hello world')\nprint('hello world2')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData()); - - // assignment - context = getInterpreterContext(); - result = interpreter.interpret("abc=1",context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(0, interpreterResultMessages.size()); - - // if block - context = getInterpreterContext(); - result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("True\n", interpreterResultMessages.get(0).getData()); - - // for loop - context = getInterpreterContext(); - result = interpreter.interpret("for i in range(3):\n\tprint(i)", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData()); - - // syntax error - context = getInterpreterContext(); - result = interpreter.interpret("print(unknown)", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined")); - - // raise runtime exception - context = getInterpreterContext(); - result = interpreter.interpret("1/0", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError")); - - // ZEPPELIN-1133 - context = getInterpreterContext(); - result = interpreter.interpret("def greet(name):\n" + - " print('Hello', name)\n" + - "greet('Jack')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData()); - - // ZEPPELIN-1114 - context = getInterpreterContext(); - result = interpreter.interpret("print('there is no Error: ok')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData()); - - // completion - context = getInterpreterContext(); - List completions = interpreter.completion("ab", 2, context); - assertEquals(2, completions.size()); - assertEquals("abc", completions.get(0).getValue()); - assertEquals("abs", completions.get(1).getValue()); - - context = getInterpreterContext(); - interpreter.interpret("import sys", context); - completions = interpreter.completion("sys.", 4, context); - assertFalse(completions.isEmpty()); - - context = getInterpreterContext(); - completions = interpreter.completion("sys.std", 7, context); - for (InterpreterCompletion completion : completions) { - System.out.println(completion.getValue()); - } - assertEquals(3, completions.size()); - assertEquals("stderr", completions.get(0).getValue()); - assertEquals("stdin", completions.get(1).getValue()); - assertEquals("stdout", completions.get(2).getValue()); - - // there's no completion for 'a.' because it is not recognized by compiler for now. - context = getInterpreterContext(); - String st = "a='hello'\na."; - completions = interpreter.completion(st, st.length(), context); - assertEquals(0, completions.size()); - - // define `a` first - context = getInterpreterContext(); - st = "a='hello'"; - result = interpreter.interpret(st, context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(0, interpreterResultMessages.size()); - - // now we can get the completion for `a.` - context = getInterpreterContext(); - st = "a."; - completions = interpreter.completion(st, st.length(), context); - // it is different for python2 and python3 and may even different for different minor version - // so only verify it is larger than 20 - assertTrue(completions.size() > 20); - - context = getInterpreterContext(); - st = "a.co"; - completions = interpreter.completion(st, st.length(), context); - assertEquals(1, completions.size()); - assertEquals("count", completions.get(0).getValue()); - - // cursor is in the middle of code - context = getInterpreterContext(); - st = "a.co\b='hello"; - completions = interpreter.completion(st, 4, context); - assertEquals(1, completions.size()); - assertEquals("count", completions.get(0).getValue()); + @Override + public void tearDown() throws InterpreterException { + intpGroup.close(); + } + @Test + public void testIPythonAdvancedFeatures() throws InterpreterException, InterruptedException, IOException { // ipython help - context = getInterpreterContext(); - result = interpreter.interpret("range?", context); + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("range?", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + List interpreterResultMessages = context.out.toInterpreterResultMessage(); assertTrue(interpreterResultMessages.get(0).getData().contains("range(stop)")); // timeit @@ -331,13 +107,16 @@ public void run() { assertEquals(InterpreterResult.Code.ERROR, result.code()); interpreterResultMessages = context2.out.toInterpreterResultMessage(); assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt")); + } + @Test + public void testIPythonPlotting() throws InterpreterException, InterruptedException, IOException { // matplotlib - context = getInterpreterContext(); - result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context); + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + List interpreterResultMessages = context.out.toInterpreterResultMessage(); // the order of IMAGE and TEXT is not determined // check there must be one IMAGE output boolean hasImageOutput = false; @@ -411,94 +190,44 @@ public void run() { } } assertTrue("No Image Output", hasImageOutput); + } - // ZeppelinContext + @Test + public void testGrpcFrameSize() throws InterpreterException, IOException { + tearDown(); - // TextBox - context = getInterpreterContext(); - result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'")); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox); - TextBox textbox = (TextBox) context.getGui().getForms().get("text_1"); - assertEquals("text_1", textbox.getName()); - assertEquals("value_1", textbox.getDefaultValue()); + Properties properties = initIntpProperties(); + properties.setProperty("zeppelin.ipython.grpc.message_size", "3000"); - // Select - context = getInterpreterContext(); - result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("select_1") instanceof Select); - Select select = (Select) context.getGui().getForms().get("select_1"); - assertEquals("select_1", select.getName()); - assertEquals(2, select.getOptions().length); - assertEquals("name_1", select.getOptions()[0].getDisplayName()); - assertEquals("value_1", select.getOptions()[0].getValue()); + startInterpreter(properties); - // CheckBox - context = getInterpreterContext(); - result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + // to make this test can run under both python2 and python3 + InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox); - CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1"); - assertEquals("checkbox_1", checkbox.getName()); - assertEquals(2, checkbox.getOptions().length); - assertEquals("name_1", checkbox.getOptions()[0].getDisplayName()); - assertEquals("value_1", checkbox.getOptions()[0].getValue()); - // Pandas DataFrame - context = getInterpreterContext(); - result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + InterpreterContext context = getInterpreterContext(); + result = interpreter.interpret("print('1'*3000)", context); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + List interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals(1, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); - assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + assertTrue(interpreterResultMessages.get(0).getData().contains("exceeds maximum: 3000")); - context = getInterpreterContext(); - result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context); + // next call continue work + result = interpreter.interpret("print(1)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(2, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); - assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); - assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType()); - assertEquals("Results are limited by 3.\n", interpreterResultMessages.get(1).getData()); - // z.show(matplotlib) - context = getInterpreterContext(); - result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context); + tearDown(); + + // increase framesize to make it work + properties.setProperty("zeppelin.ipython.grpc.message_size", "5000"); + startInterpreter(properties); + // to make this test can run under both python2 and python3 + result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(2, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType()); - assertEquals(InterpreterResult.Type.IMG, interpreterResultMessages.get(1).getType()); - // clear output context = getInterpreterContext(); - result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); - assertEquals("%text world\n", context.out.getCurrentOutput().toString()); + result = interpreter.interpret("print('1'*3000)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); } - private static InterpreterContext getInterpreterContext() { - return new InterpreterContext( - "noteId", - "paragraphId", - "replName", - "paragraphTitle", - "paragraphText", - new AuthenticationInfo(), - new HashMap(), - new GUI(), - new GUI(), - null, - null, - null, - new InterpreterOutput(null)); - } } diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java index c750352a81c..f1be1b94a63 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java @@ -39,7 +39,9 @@ public class PythonCondaInterpreterTest { @Before public void setUp() throws InterpreterException { conda = spy(new PythonCondaInterpreter(new Properties())); + when(conda.getClassName()).thenReturn(PythonCondaInterpreter.class.getName()); python = mock(PythonInterpreter.class); + when(python.getClassName()).thenReturn(PythonInterpreter.class.getName()); InterpreterGroup group = new InterpreterGroup(); group.put("note", Arrays.asList(python, conda)); @@ -79,7 +81,7 @@ public void testActivateEnv() throws IOException, InterruptedException, Interpre conda.interpret("activate " + envname, context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand("/path1/bin/python"); + verify(python).setPythonExec("/path1/bin/python"); assertTrue(envname.equals(conda.getCurrentCondaEnvName())); } @@ -89,7 +91,7 @@ public void testDeactivate() throws InterpreterException { conda.interpret("deactivate", context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand("python"); + verify(python).setPythonExec("python"); assertTrue(conda.getCurrentCondaEnvName().isEmpty()); } diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java index 56346304530..17f6cc1c9a8 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java @@ -17,24 +17,27 @@ package org.apache.zeppelin.python; import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; import org.apache.zeppelin.user.AuthenticationInfo; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; -import java.io.IOException; -import java.net.Inet4Address; -import java.net.UnknownHostException; +import java.io.File; import java.util.Arrays; import java.util.HashMap; import java.util.Properties; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; public class PythonDockerInterpreterTest { private PythonDockerInterpreter docker; @@ -52,7 +55,7 @@ public void setUp() throws InterpreterException { doReturn(true).when(docker).pull(any(InterpreterOutput.class), anyString()); doReturn(python).when(docker).getPythonInterpreter(); - doReturn("/scriptpath/zeppelin_python.py").when(python).getScriptPath(); + doReturn(new File("/scriptpath")).when(python).getPythonWorkDir(); docker.open(); } @@ -64,7 +67,7 @@ public void testActivateEnv() throws InterpreterException { verify(python, times(1)).open(); verify(python, times(1)).close(); verify(docker, times(1)).pull(any(InterpreterOutput.class), anyString()); - verify(python).setPythonCommand(Mockito.matches("docker run -i --rm -v.*")); + verify(python).setPythonExec(Mockito.matches("docker run -i --rm -v.*")); } @Test @@ -73,7 +76,7 @@ public void testDeactivate() throws InterpreterException { docker.interpret("deactivate", context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand(null); + verify(python).setPythonExec(null); } private InterpreterContext getInterpreterContext() { diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java index c0beccbd9da..e750ddefe73 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java @@ -17,130 +17,91 @@ package org.apache.zeppelin.python; -import static org.apache.zeppelin.python.PythonInterpreter.DEFAULT_ZEPPELIN_PYTHON; -import static org.apache.zeppelin.python.PythonInterpreter.MAX_RESULT; -import static org.apache.zeppelin.python.PythonInterpreter.ZEPPELIN_PYTHON; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.io.File; -import java.io.IOException; -import java.net.URISyntaxException; -import java.net.URL; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import java.util.Properties; - -import org.apache.commons.exec.environment.EnvironmentUtils; -import org.apache.zeppelin.display.AngularObjectRegistry; -import org.apache.zeppelin.display.GUI; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; -import org.apache.zeppelin.interpreter.InterpreterContextRunner; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterGroup; -import org.apache.zeppelin.interpreter.InterpreterOutput; -import org.apache.zeppelin.interpreter.InterpreterOutputListener; import org.apache.zeppelin.interpreter.InterpreterResult; -import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; -import org.apache.zeppelin.resource.LocalResourcePool; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.junit.Test; -public class PythonInterpreterTest implements InterpreterOutputListener { - PythonInterpreter pythonInterpreter = null; - String cmdHistory; - private InterpreterContext context; - InterpreterOutput out; - - public static Properties getPythonTestProperties() { - Properties p = new Properties(); - p.setProperty(ZEPPELIN_PYTHON, DEFAULT_ZEPPELIN_PYTHON); - p.setProperty(MAX_RESULT, "1000"); - p.setProperty("zeppelin.python.useIPython", "false"); - return p; - } - - @Before - public void beforeTest() throws IOException, InterpreterException { - cmdHistory = ""; - - // python interpreter - pythonInterpreter = new PythonInterpreter(getPythonTestProperties()); - - // create interpreter group - InterpreterGroup group = new InterpreterGroup(); - group.put("note", new LinkedList()); - group.get("note").add(pythonInterpreter); - pythonInterpreter.setInterpreterGroup(group); - - out = new InterpreterOutput(this); +import java.io.IOException; +import java.util.LinkedList; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; - context = new InterpreterContext("note", "id", null, "title", "text", - new AuthenticationInfo(), - new HashMap(), - new GUI(), - new GUI(), - new AngularObjectRegistry(group.getId(), null), - new LocalResourcePool("id"), - new LinkedList(), - out); - InterpreterContext.set(context); - pythonInterpreter.open(); - } +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; - @After - public void afterTest() throws IOException, InterpreterException { - pythonInterpreter.close(); - } +public class PythonInterpreterTest extends BasePythonInterpreterTest { - @Test - public void testInterpret() throws InterruptedException, IOException, InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("print (\"hi\")", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - } + @Override + public void setUp() throws InterpreterException { - @Test - public void testInterpretInvalidSyntax() throws IOException, InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("for x in range(0,3): print (\"hi\")\n", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertTrue(new String(out.getOutputAt(0).toByteArray()).contains("hi\nhi\nhi")); - } + intpGroup = new InterpreterGroup(); - @Test - public void testRedefinitionZeppelinContext() throws InterpreterException { - String pyRedefinitionCode = "z = 1\n"; - String pyRestoreCode = "z = __zeppelin__\n"; - String pyValidCode = "z.input(\"test\")\n"; + Properties properties = new Properties(); + properties.setProperty("zeppelin.python.maxResult", "3"); + properties.setProperty("zeppelin.python.useIPython", "false"); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRedefinitionCode, context).code()); - assertEquals(InterpreterResult.Code.ERROR, pythonInterpreter.interpret(pyValidCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRestoreCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code()); - } + interpreter = new LazyOpenInterpreter(new PythonInterpreter(properties)); + intpGroup.put("note", new LinkedList()); + intpGroup.get("note").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - @Test - public void testOutputClear() throws InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("print(\"Hello\")\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); - assertEquals("%text world\n", out.getCurrentOutput().toString()); + InterpreterContext.set(getInterpreterContext()); + interpreter.open(); } @Override - public void onUpdateAll(InterpreterOutput out) { - + public void tearDown() throws InterpreterException { + intpGroup.close(); } @Override - public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) { - + public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException { + super.testCodeCompletion(); + + //TODO(zjffdu) PythonInterpreter doesn't support this kind of code completion for now. + // completion + // InterpreterContext context = getInterpreterContext(); + // List completions = interpreter.completion("ab", 2, context); + // assertEquals(2, completions.size()); + // assertEquals("abc", completions.get(0).getValue()); + // assertEquals("abs", completions.get(1).getValue()); } - @Override - public void onUpdate(int index, InterpreterResultMessageOutput out) { + private class infinityPythonJob implements Runnable { + @Override + public void run() { + String code = "import time\nwhile True:\n time.sleep(1)" ; + InterpreterResult ret = null; + try { + ret = interpreter.interpret(code, getInterpreterContext()); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertNotNull(ret); + Pattern expectedMessage = Pattern.compile("KeyboardInterrupt"); + Matcher m = expectedMessage.matcher(ret.message().toString()); + assertTrue(m.find()); + } + } + @Test + public void testCancelIntp() throws InterruptedException, InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, + interpreter.interpret("a = 1\n", getInterpreterContext()).code()); + Thread t = new Thread(new infinityPythonJob()); + t.start(); + Thread.sleep(5000); + interpreter.cancel(getInterpreterContext()); + assertTrue(t.isAlive()); + t.join(2000); + assertFalse(t.isAlive()); } } diff --git a/python/src/test/resources/log4j.properties b/python/src/test/resources/log4j.properties index 035c7a3a6fd..8993ff2854d 100644 --- a/python/src/test/resources/log4j.properties +++ b/python/src/test/resources/log4j.properties @@ -15,18 +15,13 @@ # limitations under the License. # +# Root logger option +log4j.rootLogger=INFO, stdout + # Direct log messages to stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.Target=System.out log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c:%L - %m%n -#log4j.appender.stdout.layout.ConversionPattern= -#%5p [%t] (%F:%L) - %m%n -#%-4r [%t] %-5p %c %x - %m%n -# +log4j.appender.stdout.layout.ConversionPattern=%5p [%d] ({%t} %F[%M]:%L) - %m%n -# Root logger option -log4j.rootLogger=INFO, stdout -log4j.logger.org.apache.zeppelin.python.IPythonInterpreter=DEBUG -log4j.logger.org.apache.zeppelin.python.IPythonClient=DEBUG -log4j.logger.org.apache.zeppelin.python=DEBUG \ No newline at end of file + +log4j.logger.org.apache.zeppelin.python=DEBUG diff --git a/spark/interpreter/pom.xml b/spark/interpreter/pom.xml index c89cfa6ecbe..5330b1cad1c 100644 --- a/spark/interpreter/pom.xml +++ b/spark/interpreter/pom.xml @@ -441,14 +441,14 @@ 1 false - -Xmx1024m -XX:MaxPermSize=256m + -Xmx1536m -XX:MaxPermSize=256m **/SparkRInterpreterTest.java ${pyspark.test.exclude} ${tests.to.exclude} - ${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/lib/python/:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip:. + ${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip ${basedir}/../../ diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java index 3691156e3bf..3896cba53b0 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java @@ -27,6 +27,7 @@ import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.WrappedInterpreter; import org.apache.zeppelin.python.IPythonInterpreter; +import org.apache.zeppelin.python.PythonInterpreter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,8 +50,8 @@ public IPySparkInterpreter(Properties property) { @Override public void open() throws InterpreterException { - setProperty("zeppelin.python", - PySparkInterpreter.getPythonExec(getProperties())); + PySparkInterpreter pySparkInterpreter = getPySparkInterpreter(); + setProperty("zeppelin.python", pySparkInterpreter.getPythonExec()); sparkInterpreter = getSparkInterpreter(); SparkConf conf = sparkInterpreter.getSparkContext().getConf(); // only set PYTHONPATH in embedded, local or yarn-client mode. @@ -94,6 +95,16 @@ private SparkInterpreter getSparkInterpreter() throws InterpreterException { return spark; } + private PySparkInterpreter getPySparkInterpreter() throws InterpreterException { + PySparkInterpreter pySpark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(PySparkInterpreter.class.getName()); + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + pySpark = (PySparkInterpreter) p; + return pySpark; + } + @Override public BaseZeppelinContext buildZeppelinContext() { return sparkInterpreter.getZeppelinContext(); @@ -117,6 +128,7 @@ public void cancel(InterpreterContext context) throws InterpreterException { @Override public void close() throws InterpreterException { + LOGGER.info("Close IPySparkInterpreter"); super.close(); if (sparkInterpreter != null) { sparkInterpreter.close(); diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java index c8efa7a7d9f..9b629f9e85c 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java @@ -56,7 +56,7 @@ */ public class NewSparkInterpreter extends AbstractSparkInterpreter { - private static final Logger LOGGER = LoggerFactory.getLogger(SparkInterpreter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(NewSparkInterpreter.class); private BaseSparkScalaInterpreter innerInterpreter; private Map innerInterpreterClassMap = new HashMap<>(); @@ -177,7 +177,10 @@ private String mergeProperty(String originalValue, String appendedValue) { @Override public void close() { LOGGER.info("Close SparkInterpreter"); - innerInterpreter.close(); + if (innerInterpreter != null) { + innerInterpreter.close(); + innerInterpreter = null; + } } @Override diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index 809e8832de0..beebd425d80 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.BaseZeppelinContext; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; @@ -44,6 +45,8 @@ import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; +import org.apache.zeppelin.python.IPythonInterpreter; +import org.apache.zeppelin.python.PythonInterpreter; import org.apache.zeppelin.spark.dep.SparkDependencyContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,56 +71,23 @@ * features compared to IPySparkInterpreter, but requires less prerequisites than * IPySparkInterpreter, only python is required. */ -public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); - private static final int MAX_TIMEOUT_SEC = 10; +public class PySparkInterpreter extends PythonInterpreter { + + private static Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); - private GatewayServer gatewayServer; - private DefaultExecutor executor; - // used to forward output from python process to InterpreterOutput - private InterpreterOutputStream outputStream; - private String scriptPath; - private boolean pythonscriptRunning = false; - private long pythonPid = -1; - private IPySparkInterpreter iPySparkInterpreter; private SparkInterpreter sparkInterpreter; public PySparkInterpreter(Properties property) { super(property); + this.useBuiltinPy4j = false; } @Override public void open() throws InterpreterException { - // try IPySparkInterpreter first - iPySparkInterpreter = getIPySparkInterpreter(); - if (getProperty("zeppelin.pyspark.useIPython", "true").equals("true") && - StringUtils.isEmpty( - iPySparkInterpreter.checkIPythonPrerequisite(getPythonExec(getProperties())))) { - try { - iPySparkInterpreter.open(); - LOGGER.info("IPython is available, Use IPySparkInterpreter to replace PySparkInterpreter"); - return; - } catch (Exception e) { - iPySparkInterpreter = null; - LOGGER.warn("Fail to open IPySparkInterpreter", e); - } - } + setProperty("zeppelin.python.useIPython", getProperty("zeppelin.pyspark.useIPython", "true")); - // reset iPySparkInterpreter to null as it is not available - iPySparkInterpreter = null; - LOGGER.info("IPython is not available, use the native PySparkInterpreter\n"); - // Add matplotlib display hook - InterpreterGroup intpGroup = getInterpreterGroup(); - if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { - try { - // just for unit test I believe (zjffdu) - registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()"); - } catch (InvalidHookException e) { - throw new InterpreterException(e); - } - } + // create SparkInterpreter in JVM side TODO(zjffdu) move to SparkInterpreter DepInterpreter depInterpreter = getDepInterpreter(); - // load libraries from Dependency Interpreter URL [] urls = new URL[0]; List urlList = new LinkedList<>(); @@ -159,474 +129,81 @@ public void open() throws InterpreterException { ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); try { URLClassLoader newCl = new URLClassLoader(urls, oldCl); - LOGGER.info("urls:" + urls); - for (URL url : urls) { - LOGGER.info("url:" + url); - } Thread.currentThread().setContextClassLoader(newCl); + // create Python Process and JVM gateway + super.open(); // must create spark interpreter after ClassLoader is set, otherwise the additional jars // can not be loaded by spark repl. this.sparkInterpreter = getSparkInterpreter(); - createGatewayServerAndStartScript(); - } catch (IOException e) { - LOGGER.error("Fail to open PySparkInterpreter", e); - throw new InterpreterException("Fail to open PySparkInterpreter", e); } finally { Thread.currentThread().setContextClassLoader(oldCl); } - } - - private void createGatewayServerAndStartScript() throws IOException { - // start gateway server in JVM side - int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces(); - gatewayServer = new GatewayServer(this, port); - gatewayServer.start(); - - // launch python process to connect to the gateway server in JVM side - createPythonScript(); - String pythonExec = getPythonExec(getProperties()); - LOGGER.info("PythonExec: " + pythonExec); - CommandLine cmd = CommandLine.parse(pythonExec); - cmd.addArgument(scriptPath, false); - cmd.addArgument(Integer.toString(port), false); - cmd.addArgument(Integer.toString(sparkInterpreter.getSparkVersion().toNumber()), false); - executor = new DefaultExecutor(); - outputStream = new InterpreterOutputStream(LOGGER); - PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream); - executor.setStreamHandler(streamHandler); - executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); - - Map env = setupPySparkEnv(); - executor.execute(cmd, env, this); - pythonscriptRunning = true; - } - - private void createPythonScript() throws IOException { - FileOutputStream pysparkScriptOutput = null; - FileOutputStream zeppelinContextOutput = null; - try { - // copy zeppelin_pyspark.py - File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py"); - this.scriptPath = scriptFile.getAbsolutePath(); - pysparkScriptOutput = new FileOutputStream(scriptFile); - IOUtils.copy( - getClass().getClassLoader().getResourceAsStream("python/zeppelin_pyspark.py"), - pysparkScriptOutput); - - // copy zeppelin_context.py to the same folder of zeppelin_pyspark.py - zeppelinContextOutput = new FileOutputStream(scriptFile.getParent() + "/zeppelin_context.py"); - IOUtils.copy( - getClass().getClassLoader().getResourceAsStream("python/zeppelin_context.py"), - zeppelinContextOutput); - LOGGER.info("PySpark script {} {} is created", - scriptPath, scriptFile.getParent() + "/zeppelin_context.py"); - } finally { - if (pysparkScriptOutput != null) { - try { - pysparkScriptOutput.close(); - } catch (IOException e) { - // ignore - } - } - if (zeppelinContextOutput != null) { - try { - zeppelinContextOutput.close(); - } catch (IOException e) { - // ignore - } - } - } - } - - private Map setupPySparkEnv() throws IOException { - Map env = EnvironmentUtils.getProcEnvironment(); - // only set PYTHONPATH in local or yarn-client mode. - // yarn-cluster will setup PYTHONPATH automatically. - SparkConf conf = null; - try { - conf = getSparkConf(); - } catch (InterpreterException e) { - throw new IOException(e); - } - if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) { - if (!env.containsKey("PYTHONPATH")) { - env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); - } else { - env.put("PYTHONPATH", PythonUtils.sparkPythonPath() + ":" + env.get("PYTHONPATH")); - } - } - // get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT - // also, add all packages to PYTHONPATH since there might be transitive dependencies - if (SparkInterpreter.useSparkSubmit() && - !sparkInterpreter.isYarnMode()) { - String sparkSubmitJars = conf.get("spark.jars").replace(",", ":"); - if (!StringUtils.isEmpty(sparkSubmitJars)) { - env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + sparkSubmitJars); + if (!useIPython()) { + // Initialize Spark in Python Process + try { + bootstrapInterpreter("python/zeppelin_pyspark.py"); + } catch (IOException e) { + throw new InterpreterException("Fail to bootstrap pyspark", e); } } - - // set PYSPARK_PYTHON - if (conf.contains("spark.pyspark.python")) { - env.put("PYSPARK_PYTHON", conf.get("spark.pyspark.python")); - } - LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); - return env; - } - - // Run python shell - // Choose python in the order of - // PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python - public static String getPythonExec(Properties properties) { - String pythonExec = properties.getProperty("zeppelin.pyspark.python", "python"); - if (System.getenv("PYSPARK_PYTHON") != null) { - pythonExec = System.getenv("PYSPARK_PYTHON"); - } - if (System.getenv("PYSPARK_DRIVER_PYTHON") != null) { - pythonExec = System.getenv("PYSPARK_DRIVER_PYTHON"); - } - return pythonExec; } @Override public void close() throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.close(); - return; + super.close(); + if (sparkInterpreter != null) { + sparkInterpreter.close(); } - executor.getWatchdog().destroyProcess(); - gatewayServer.shutdown(); } - private PythonInterpretRequest pythonInterpretRequest = null; - private Integer statementSetNotifier = new Integer(0); - private String statementOutput = null; - private boolean statementError = false; - private Integer statementFinishedNotifier = new Integer(0); - - /** - * Request send to Python Daemon - */ - public class PythonInterpretRequest { - public String statements; - public String jobGroup; - public String jobDescription; - public boolean isForCompletion; - - public PythonInterpretRequest(String statements, String jobGroup, - String jobDescription, boolean isForCompletion) { - this.statements = statements; - this.jobGroup = jobGroup; - this.jobDescription = jobDescription; - this.isForCompletion = isForCompletion; - } - - public String statements() { - return statements; - } - - public String jobGroup() { - return jobGroup; - } - - public String jobDescription() { - return jobDescription; - } - - public boolean isForCompletion() { - return isForCompletion; - } - } - - // called by Python Process - public PythonInterpretRequest getStatements() { - synchronized (statementSetNotifier) { - while (pythonInterpretRequest == null) { - try { - statementSetNotifier.wait(1000); - } catch (InterruptedException e) { - } - } - PythonInterpretRequest req = pythonInterpretRequest; - pythonInterpretRequest = null; - return req; - } - } - - // called by Python Process - public void setStatementsFinished(String out, boolean error) { - synchronized (statementFinishedNotifier) { - LOGGER.debug("Setting python statement output: " + out + ", error: " + error); - statementOutput = out; - statementError = error; - statementFinishedNotifier.notify(); - } - } - - private boolean pythonScriptInitialized = false; - private Integer pythonScriptInitializeNotifier = new Integer(0); - - // called by Python Process - public void onPythonScriptInitialized(long pid) { - pythonPid = pid; - synchronized (pythonScriptInitializeNotifier) { - LOGGER.debug("onPythonScriptInitialized is called"); - pythonScriptInitialized = true; - pythonScriptInitializeNotifier.notifyAll(); - } - } - - // called by Python Process - public void appendOutput(String message) throws IOException { - LOGGER.debug("Output from python process: " + message); - outputStream.getInterpreterOutput().write(message); + @Override + protected BaseZeppelinContext createZeppelinContext() { + return sparkInterpreter.getZeppelinContext(); } @Override public InterpreterResult interpret(String st, InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.interpret(st, context); - } - - if (sparkInterpreter.isUnsupportedSparkVersion()) { - return new InterpreterResult(Code.ERROR, "Spark " - + sparkInterpreter.getSparkVersion().toString() + " is not supported"); - } sparkInterpreter.populateSparkWebUrl(context); + return super.interpret(st, context); + } - if (!pythonscriptRunning) { - return new InterpreterResult(Code.ERROR, "python process not running " - + outputStream.toString()); - } - - outputStream.setInterpreterOutput(context.out); - - synchronized (pythonScriptInitializeNotifier) { - long startTime = System.currentTimeMillis(); - while (pythonScriptInitialized == false - && pythonscriptRunning - && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { - try { - LOGGER.info("Wait for PythonScript running"); - pythonScriptInitializeNotifier.wait(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } - - List errorMessage; - try { - context.out.flush(); - errorMessage = context.out.toInterpreterResultMessage(); - } catch (IOException e) { - throw new InterpreterException(e); - } - - - if (pythonscriptRunning == false) { - // python script failed to initialize and terminated - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "Failed to start PySpark")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - if (pythonScriptInitialized == false) { - // timeout. didn't get initialized message - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "Failed to initialize PySpark")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - - //TODO(zjffdu) remove this as PySpark is supported starting from spark 1.2s - if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) { - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, - "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - + @Override + protected void preCallPython(InterpreterContext context) { String jobGroup = Utils.buildJobGroupId(context); String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo()); - - SparkZeppelinContext z = sparkInterpreter.getZeppelinContext(); - z.setInterpreterContext(context); - z.setGui(context.getGui()); - z.setNoteGui(context.getNoteGui()); - InterpreterContext.set(context); - - pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc, false); - statementOutput = null; - - synchronized (statementSetNotifier) { - statementSetNotifier.notify(); - } - - synchronized (statementFinishedNotifier) { - while (statementOutput == null) { - try { - statementFinishedNotifier.wait(1000); - } catch (InterruptedException e) { - } - } - } - - if (statementError) { - return new InterpreterResult(Code.ERROR, statementOutput); - } else { - try { - context.out.flush(); - } catch (IOException e) { - throw new InterpreterException(e); - } - return new InterpreterResult(Code.SUCCESS); - } - } - - public void interrupt() throws IOException, InterpreterException { - if (pythonPid > -1) { - LOGGER.info("Sending SIGINT signal to PID : " + pythonPid); - Runtime.getRuntime().exec("kill -SIGINT " + pythonPid); - } else { - LOGGER.warn("Non UNIX/Linux system, close the interpreter"); - close(); - } + callPython(new PythonInterpretRequest( + String.format("if 'sc' in locals():\n\tsc.setJobGroup('%s', '%s')", jobGroup, jobDesc), + false)); } + // Run python shell + // Choose python in the order of + // PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python @Override - public void cancel(InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.cancel(context); - return; - } - SparkInterpreter sparkInterpreter = getSparkInterpreter(); - sparkInterpreter.cancel(context); - try { - interrupt(); - } catch (IOException e) { - LOGGER.error("Error", e); + protected String getPythonExec() { + String pythonExec = getProperty("zeppelin.pyspark.python", "python"); + if (System.getenv("PYSPARK_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_PYTHON"); } - } - - @Override - public FormType getFormType() { - return FormType.NATIVE; - } - - @Override - public int getProgress(InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.getProgress(context); + if (System.getenv("PYSPARK_DRIVER_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_DRIVER_PYTHON"); } - SparkInterpreter sparkInterpreter = getSparkInterpreter(); - return sparkInterpreter.getProgress(context); + return pythonExec; } - @Override - public List completion(String buf, int cursor, - InterpreterContext interpreterContext) - throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.completion(buf, cursor, interpreterContext); - } - if (buf.length() < cursor) { - cursor = buf.length(); - } - String completionString = getCompletionTargetString(buf, cursor); - String completionCommand = "completion.getCompletion('" + completionString + "')"; - LOGGER.debug("completionCommand: " + completionCommand); - - //start code for completion - if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) { - return new LinkedList<>(); - } - - pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "", true); - statementOutput = null; - - synchronized (statementSetNotifier) { - statementSetNotifier.notify(); - } - - String[] completionList = null; - synchronized (statementFinishedNotifier) { - long startTime = System.currentTimeMillis(); - while (statementOutput == null - && pythonscriptRunning) { - try { - if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) { - LOGGER.error("pyspark completion didn't have response for {}sec.", MAX_TIMEOUT_SEC); - break; - } - statementFinishedNotifier.wait(1000); - } catch (InterruptedException e) { - // not working - LOGGER.info("wait drop"); - return new LinkedList<>(); - } - } - if (statementError) { - return new LinkedList<>(); - } - Gson gson = new Gson(); - completionList = gson.fromJson(statementOutput, String[].class); - } - //end code for completion - if (completionList == null) { - return new LinkedList<>(); - } - - List results = new LinkedList<>(); - for (String name: completionList) { - results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY)); - LOGGER.debug("completion: " + name); - } - return results; - } - - private String getCompletionTargetString(String text, int cursor) { - String[] completionSeqCharaters = {" ", "\n", "\t"}; - int completionEndPosition = cursor; - int completionStartPosition = cursor; - int indexOfReverseSeqPostion = cursor; - - String resultCompletionText = ""; - String completionScriptText = ""; - try { - completionScriptText = text.substring(0, cursor); - } - catch (Exception e) { - LOGGER.error(e.toString()); - return null; - } - completionEndPosition = completionScriptText.length(); - - String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString(); - - for (String seqCharacter : completionSeqCharaters) { - indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter); - - if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) { - completionStartPosition = indexOfReverseSeqPostion; - } - - } - - if (completionStartPosition == completionEndPosition) { - completionStartPosition = 0; - } - else - { - completionStartPosition = completionEndPosition - completionStartPosition; + protected IPythonInterpreter getIPythonInterpreter() { + IPySparkInterpreter iPython = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName()); + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); } - resultCompletionText = completionScriptText.substring( - completionStartPosition , completionEndPosition); - - return resultCompletionText; + iPython = (IPySparkInterpreter) p; + return iPython; } - private SparkInterpreter getSparkInterpreter() throws InterpreterException { LazyOpenInterpreter lazy = null; SparkInterpreter spark = null; @@ -646,63 +223,45 @@ private SparkInterpreter getSparkInterpreter() throws InterpreterException { return spark; } - private IPySparkInterpreter getIPySparkInterpreter() { - LazyOpenInterpreter lazy = null; - IPySparkInterpreter iPySpark = null; - Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName()); - - while (p instanceof WrappedInterpreter) { - if (p instanceof LazyOpenInterpreter) { - lazy = (LazyOpenInterpreter) p; - } - p = ((WrappedInterpreter) p).getInnerInterpreter(); - } - iPySpark = (IPySparkInterpreter) p; - return iPySpark; - } - public SparkZeppelinContext getZeppelinContext() throws InterpreterException { - SparkInterpreter sparkIntp = getSparkInterpreter(); - if (sparkIntp != null) { - return getSparkInterpreter().getZeppelinContext(); + public SparkZeppelinContext getZeppelinContext() { + if (sparkInterpreter != null) { + return sparkInterpreter.getZeppelinContext(); } else { return null; } } - public JavaSparkContext getJavaSparkContext() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public JavaSparkContext getJavaSparkContext() { + if (sparkInterpreter == null) { return null; } else { - return new JavaSparkContext(intp.getSparkContext()); + return new JavaSparkContext(sparkInterpreter.getSparkContext()); } } - public Object getSparkSession() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public Object getSparkSession() { + if (sparkInterpreter == null) { return null; } else { - return intp.getSparkSession(); + return sparkInterpreter.getSparkSession(); } } - public SparkConf getSparkConf() throws InterpreterException { + public SparkConf getSparkConf() { JavaSparkContext sc = getJavaSparkContext(); if (sc == null) { return null; } else { - return getJavaSparkContext().getConf(); + return sc.getConf(); } } - public SQLContext getSQLContext() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public SQLContext getSQLContext() { + if (sparkInterpreter == null) { return null; } else { - return intp.getSQLContext(); + return sparkInterpreter.getSQLContext(); } } @@ -718,21 +277,7 @@ private DepInterpreter getDepInterpreter() { return (DepInterpreter) p; } - - @Override - public void onProcessComplete(int exitValue) { - pythonscriptRunning = false; - LOGGER.info("python process terminated. exit code " + exitValue); - } - - @Override - public void onProcessFailed(ExecuteException e) { - pythonscriptRunning = false; - LOGGER.error("python process failed", e); - } - - // Called by Python Process, used for debugging purpose - public void logPythonOutput(String message) { - LOGGER.debug("Python Process Output: " + message); + public boolean isSpark2() { + return sparkInterpreter.getSparkVersion().newerThanEquals(SparkVersion.SPARK_2_0_0); } } diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py index 1352318cd44..8fcca9b41fa 100644 --- a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py @@ -15,150 +15,43 @@ # limitations under the License. # -import os, sys, getopt, traceback, json, re - -from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.protocol import Py4JJavaError +from py4j.java_gateway import java_import from pyspark.conf import SparkConf from pyspark.context import SparkContext -import ast -import warnings # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row -class Logger(object): - def __init__(self): - pass - - def write(self, message): - intp.appendOutput(message) - - def reset(self): - pass - - def flush(self): - pass - - -class SparkVersion(object): - SPARK_1_4_0 = 10400 - SPARK_1_3_0 = 10300 - SPARK_2_0_0 = 20000 - - def __init__(self, versionNumber): - self.version = versionNumber - - def isAutoConvertEnabled(self): - return self.version >= self.SPARK_1_4_0 - - def isImportAllPackageUnderSparkSql(self): - return self.version >= self.SPARK_1_3_0 - - def isSpark2(self): - return self.version >= self.SPARK_2_0_0 - -class PySparkCompletion: - def __init__(self, interpreterObject): - self.interpreterObject = interpreterObject - - def getGlobalCompletion(self, text_value): - completions = [completion for completion in list(globals().keys()) if completion.startswith(text_value)] - return completions - - def getMethodCompletion(self, objName, methodName): - execResult = locals() - try: - exec("{} = dir({})".format("objectDefList", objName), globals(), execResult) - except: - return None - else: - objectDefList = execResult['objectDefList'] - return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] - - def getCompletion(self, text_value): - if text_value == None: - return None - - dotPos = text_value.find(".") - if dotPos == -1: - objName = text_value - completionList = self.getGlobalCompletion(objName) - else: - objName = text_value[:dotPos] - methodName = text_value[dotPos + 1:] - completionList = self.getMethodCompletion(objName, methodName) - - if len(completionList) <= 0: - self.interpreterObject.setStatementsFinished("", False) - else: - result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) - self.interpreterObject.setStatementsFinished(result, False) - -client = GatewayClient(port=int(sys.argv[1])) -sparkVersion = SparkVersion(int(sys.argv[2])) -if sparkVersion.isSpark2(): +intp = gateway.entry_point +isSpark2 = intp.isSpark2() +if isSpark2: from pyspark.sql import SparkSession -else: - from pyspark.sql import SchemaRDD - -if sparkVersion.isAutoConvertEnabled(): - gateway = JavaGateway(client, auto_convert = True) -else: - gateway = JavaGateway(client) +jsc = intp.getJavaSparkContext() java_import(gateway.jvm, "org.apache.spark.SparkEnv") java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") -intp = gateway.entry_point -output = Logger() -sys.stdout = output -sys.stderr = output - -jsc = intp.getJavaSparkContext() - -if sparkVersion.isImportAllPackageUnderSparkSql(): - java_import(gateway.jvm, "org.apache.spark.sql.*") - java_import(gateway.jvm, "org.apache.spark.sql.hive.*") -else: - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") - +java_import(gateway.jvm, "org.apache.spark.sql.*") +java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") -_zcUserQueryNameSpace = {} - jconf = intp.getSparkConf() conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf) sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf) -_zcUserQueryNameSpace["_zsc_"] = _zsc_ -_zcUserQueryNameSpace["sc"] = sc -if sparkVersion.isSpark2(): + +if isSpark2: spark = __zSpark__ = SparkSession(sc, intp.getSparkSession()) sqlc = __zSqlc__ = __zSpark__._wrapped - _zcUserQueryNameSpace["sqlc"] = sqlc - _zcUserQueryNameSpace["__zSqlc__"] = __zSqlc__ - _zcUserQueryNameSpace["spark"] = spark - _zcUserQueryNameSpace["__zSpark__"] = __zSpark__ + else: sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext()) - _zcUserQueryNameSpace["sqlc"] = sqlc - _zcUserQueryNameSpace["__zSqlc__"] = sqlc sqlContext = __zSqlc__ -_zcUserQueryNameSpace["sqlContext"] = sqlContext - -completion = __zeppelin_completion__ = PySparkCompletion(intp) -_zcUserQueryNameSpace["completion"] = completion -_zcUserQueryNameSpace["__zeppelin_completion__"] = __zeppelin_completion__ - from zeppelin_context import PyZeppelinContext @@ -176,92 +69,4 @@ def show(self, obj): super(PySparkZeppelinContext, self).show(obj) z = __zeppelin__ = PySparkZeppelinContext(intp.getZeppelinContext(), gateway) - __zeppelin__._setup_matplotlib() -_zcUserQueryNameSpace["z"] = z -_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ - -intp.onPythonScriptInitialized(os.getpid()) - -while True : - req = intp.getStatements() - try: - stmts = req.statements().split("\n") - jobGroup = req.jobGroup() - jobDesc = req.jobDescription() - isForCompletion = req.isForCompletion() - - # Get post-execute hooks - try: - global_hook = intp.getHook('post_exec_dev') - except: - global_hook = None - - try: - user_hook = __zeppelin__.getHook('post_exec') - except: - user_hook = None - - nhooks = 0 - if not isForCompletion: - for hook in (global_hook, user_hook): - if hook: - nhooks += 1 - - if stmts: - # use exec mode to compile the statements except the last statement, - # so that the last statement's evaluation will be printed to stdout - sc.setJobGroup(jobGroup, jobDesc) - code = compile('\n'.join(stmts), '', 'exec', ast.PyCF_ONLY_AST, 1) - to_run_hooks = [] - if (nhooks > 0): - to_run_hooks = code.body[-nhooks:] - - to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], - [code.body[-(nhooks + 1)]]) - try: - for node in to_run_exec: - mod = ast.Module([node]) - code = compile(mod, '', 'exec') - exec(code, _zcUserQueryNameSpace) - - for node in to_run_single: - mod = ast.Interactive([node]) - code = compile(mod, '', 'single') - exec(code, _zcUserQueryNameSpace) - - for node in to_run_hooks: - mod = ast.Module([node]) - code = compile(mod, '', 'exec') - exec(code, _zcUserQueryNameSpace) - - if not isForCompletion: - # only call it when it is not for code completion. code completion will call it in - # PySparkCompletion.getCompletion - intp.setStatementsFinished("", False) - except Py4JJavaError: - # raise it to outside try except - raise - except: - if not isForCompletion: - exception = traceback.format_exc() - m = re.search("File \"\", line (\d+).*", exception) - if m: - line_no = int(m.group(1)) - intp.setStatementsFinished( - "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) - else: - intp.setStatementsFinished(exception, True) - else: - intp.setStatementsFinished("", False) - - except Py4JJavaError: - excInnerError = traceback.format_exc() # format_tb() does not return the inner exception - innerErrorStart = excInnerError.find("Py4JJavaError:") - if innerErrorStart > -1: - excInnerError = excInnerError[innerErrorStart:] - intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) - except: - intp.setStatementsFinished(traceback.format_exc(), True) - - output.reset() diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java index 2cc11ace88a..ece52353ff2 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java @@ -27,18 +27,16 @@ import org.apache.zeppelin.interpreter.InterpreterOutput; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.python.IPythonInterpreterTest; import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import java.io.IOException; -import java.net.URL; +import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Properties; @@ -46,65 +44,72 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; -public class IPySparkInterpreterTest { +public class IPySparkInterpreterTest extends IPythonInterpreterTest { - private IPySparkInterpreter iPySparkInterpreter; private InterpreterGroup intpGroup; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - @Before - public void setup() throws InterpreterException { + @Override + protected Properties initIntpProperties() { Properties p = new Properties(); p.setProperty("spark.master", "local[4]"); p.setProperty("master", "local[4]"); p.setProperty("spark.submit.deployMode", "client"); p.setProperty("spark.app.name", "Zeppelin Test"); - p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.useHiveContext", "false"); p.setProperty("zeppelin.spark.maxResult", "3"); p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.spark.useNew", "true"); p.setProperty("zeppelin.pyspark.python", "python"); p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath()); + p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + return p; + } + @Override + protected void startInterpreter(Properties properties) throws InterpreterException { intpGroup = new InterpreterGroup(); - intpGroup.put("session_1", new LinkedList()); + intpGroup.put("session_1", new ArrayList()); - SparkInterpreter sparkInterpreter = new SparkInterpreter(p); + LazyOpenInterpreter sparkInterpreter = new LazyOpenInterpreter( + new SparkInterpreter(properties)); intpGroup.get("session_1").add(sparkInterpreter); sparkInterpreter.setInterpreterGroup(intpGroup); - sparkInterpreter.open(); - sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); - iPySparkInterpreter = new IPySparkInterpreter(p); - intpGroup.get("session_1").add(iPySparkInterpreter); - iPySparkInterpreter.setInterpreterGroup(intpGroup); - iPySparkInterpreter.open(); - sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); + LazyOpenInterpreter pySparkInterpreter = + new LazyOpenInterpreter(new PySparkInterpreter(properties)); + intpGroup.get("session_1").add(pySparkInterpreter); + pySparkInterpreter.setInterpreterGroup(intpGroup); + + interpreter = new LazyOpenInterpreter(new IPySparkInterpreter(properties)); + intpGroup.get("session_1").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); + + interpreter.open(); } - @After + @Override public void tearDown() throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.close(); - } + intpGroup.close(); + interpreter = null; + intpGroup = null; } @Test - public void testBasics() throws InterruptedException, IOException, InterpreterException { - // all the ipython test should pass too. - IPythonInterpreterTest.testInterpreter(iPySparkInterpreter); - testPySpark(iPySparkInterpreter, mockRemoteEventClient); - + public void testIPySpark() throws InterruptedException, InterpreterException, IOException { + testPySpark(interpreter, mockRemoteEventClient); } public static void testPySpark(final Interpreter interpreter, RemoteEventClient mockRemoteEventClient) throws InterpreterException, IOException, InterruptedException { + reset(mockRemoteEventClient); // rdd - InterpreterContext context = getInterpreterContext(mockRemoteEventClient); + InterpreterContext context = createInterpreterContext(mockRemoteEventClient); InterpreterResult result = interpreter.interpret("sc.version", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); @@ -112,17 +117,17 @@ public static void testPySpark(final Interpreter interpreter, RemoteEventClient // spark url is sent verify(mockRemoteEventClient).onMetaInfosReceived(any(Map.class)); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("sc.range(1,10).sum()", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); List interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals("45", interpreterResultMessages.get(0).getData().trim()); // spark job url is sent - verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class)); +// verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class)); // spark sql - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); if (!isSpark2(sparkVersion)) { result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); @@ -135,7 +140,7 @@ public static void testPySpark(final Interpreter interpreter, RemoteEventClient "| 2| b|\n" + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); @@ -155,7 +160,7 @@ public static void testPySpark(final Interpreter interpreter, RemoteEventClient "| 2| b|\n" + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); @@ -166,7 +171,7 @@ public static void testPySpark(final Interpreter interpreter, RemoteEventClient } // cancel if (interpreter instanceof IPySparkInterpreter) { - final InterpreterContext context2 = getInterpreterContext(mockRemoteEventClient); + final InterpreterContext context2 = createInterpreterContext(mockRemoteEventClient); Thread thread = new Thread() { @Override @@ -196,24 +201,24 @@ public void run() { } // completions - List completions = interpreter.completion("sc.ran", 6, getInterpreterContext(mockRemoteEventClient)); + List completions = interpreter.completion("sc.ran", 6, createInterpreterContext(mockRemoteEventClient)); assertEquals(1, completions.size()); assertEquals("range", completions.get(0).getValue()); - completions = interpreter.completion("sc.", 3, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("sc.", 3, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("range", "range", "")); - completions = interpreter.completion("1+1\nsc.", 7, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("1+1\nsc.", 7, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("range", "range", "")); - completions = interpreter.completion("s", 1, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("s", 1, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("sc", "sc", "")); // pyspark streaming - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret( "from pyspark.streaming import StreamingContext\n" + "import time\n" + @@ -239,7 +244,7 @@ private static boolean isSpark2(String sparkVersion) { return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2."); } - private static InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) { + private static InterpreterContext createInterpreterContext(RemoteEventClient mockRemoteEventClient) { InterpreterContext context = new InterpreterContext( "noteId", "paragraphId", diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java index 068ff50c3d8..3a986535c95 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java @@ -127,7 +127,7 @@ public void onMetaInfosReceived(Map infos) { new LocalResourcePool("id"), new LinkedList(), new InterpreterOutput(null)) { - + @Override public RemoteEventClientWrapper getClient() { return remoteEventClientWrapper; @@ -192,7 +192,7 @@ public void testNextLineCompanionObject() throws InterpreterException { public void testEndWithComment() throws InterpreterException { assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("val c=1\n//comment", context).code()); } - + @Test public void testCreateDataFrame() throws InterpreterException { if (getSparkVersionNumber(repl) >= 13) { diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java index e228c7ed778..446f183c4d2 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java @@ -17,154 +17,73 @@ package org.apache.zeppelin.spark; -import org.apache.zeppelin.display.AngularObjectRegistry; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; +import com.google.common.io.Files; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; -import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; -import org.apache.zeppelin.resource.LocalResourcePool; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.*; -import org.junit.rules.TemporaryFolder; -import org.junit.runners.MethodSorters; +import org.apache.zeppelin.python.PythonInterpreterTest; +import org.junit.Test; import java.io.IOException; -import java.util.HashMap; import java.util.LinkedList; -import java.util.List; -import java.util.Map; import java.util.Properties; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import static org.junit.Assert.*; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -@FixMethodOrder(MethodSorters.NAME_ASCENDING) -public class PySparkInterpreterTest { +public class PySparkInterpreterTest extends PythonInterpreterTest { - @ClassRule - public static TemporaryFolder tmpDir = new TemporaryFolder(); - - static SparkInterpreter sparkInterpreter; - static PySparkInterpreter pySparkInterpreter; - static InterpreterGroup intpGroup; - static InterpreterContext context; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - private static Properties getPySparkTestProperties() throws IOException { - Properties p = new Properties(); - p.setProperty("spark.master", "local"); - p.setProperty("spark.app.name", "Zeppelin Test"); - p.setProperty("zeppelin.spark.useHiveContext", "true"); - p.setProperty("zeppelin.spark.maxResult", "1000"); - p.setProperty("zeppelin.spark.importImplicit", "true"); - p.setProperty("zeppelin.pyspark.python", "python"); - p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); - p.setProperty("zeppelin.pyspark.useIPython", "false"); - p.setProperty("zeppelin.spark.test", "true"); - return p; - } - - /** - * Get spark version number as a numerical value. - * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... - */ - public static int getSparkVersionNumber() { - if (sparkInterpreter == null) { - return 0; - } - - String[] split = sparkInterpreter.getSparkContext().version().split("\\."); - int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]); - return version; - } - - @BeforeClass - public static void setUp() throws Exception { + @Override + public void setUp() throws InterpreterException { + Properties properties = new Properties(); + properties.setProperty("spark.master", "local"); + properties.setProperty("spark.app.name", "Zeppelin Test"); + properties.setProperty("zeppelin.spark.useHiveContext", "false"); + properties.setProperty("zeppelin.spark.maxResult", "3"); + properties.setProperty("zeppelin.spark.importImplicit", "true"); + properties.setProperty("zeppelin.pyspark.python", "python"); + properties.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath()); + properties.setProperty("zeppelin.pyspark.useIPython", "false"); + properties.setProperty("zeppelin.spark.useNew", "true"); + properties.setProperty("zeppelin.spark.test", "true"); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + + InterpreterContext.set(getInterpreterContext(mockRemoteEventClient)); + // create interpreter group intpGroup = new InterpreterGroup(); intpGroup.put("note", new LinkedList()); - context = new InterpreterContext("note", "id", null, "title", "text", - new AuthenticationInfo(), - new HashMap(), - new GUI(), - new GUI(), - new AngularObjectRegistry(intpGroup.getId(), null), - new LocalResourcePool("id"), - new LinkedList(), - new InterpreterOutput(null)); - InterpreterContext.set(context); - - sparkInterpreter = new SparkInterpreter(getPySparkTestProperties()); + LazyOpenInterpreter sparkInterpreter = + new LazyOpenInterpreter(new SparkInterpreter(properties)); intpGroup.get("note").add(sparkInterpreter); sparkInterpreter.setInterpreterGroup(intpGroup); - sparkInterpreter.open(); - - pySparkInterpreter = new PySparkInterpreter(getPySparkTestProperties()); - intpGroup.get("note").add(pySparkInterpreter); - pySparkInterpreter.setInterpreterGroup(intpGroup); - pySparkInterpreter.open(); - } - - @AfterClass - public static void tearDown() throws InterpreterException { - pySparkInterpreter.close(); - sparkInterpreter.close(); - } - @Test - public void testBasicIntp() throws InterpreterException, InterruptedException, IOException { - IPySparkInterpreterTest.testPySpark(pySparkInterpreter, mockRemoteEventClient); - } + LazyOpenInterpreter iPySparkInterpreter = + new LazyOpenInterpreter(new IPySparkInterpreter(properties)); + intpGroup.get("note").add(iPySparkInterpreter); + iPySparkInterpreter.setInterpreterGroup(intpGroup); - @Test - public void testRedefinitionZeppelinContext() throws InterpreterException { - if (getSparkVersionNumber() > 11) { - String redefinitionCode = "z = 1\n"; - String restoreCode = "z = __zeppelin__\n"; - String validCode = "z.input(\"test\")\n"; + interpreter = new LazyOpenInterpreter(new PySparkInterpreter(properties)); + intpGroup.get("note").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(redefinitionCode, context).code()); - assertEquals(InterpreterResult.Code.ERROR, pySparkInterpreter.interpret(validCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(restoreCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); - } + interpreter.open(); } - private class infinityPythonJob implements Runnable { - @Override - public void run() { - String code = "import time\nwhile True:\n time.sleep(1)" ; - InterpreterResult ret = null; - try { - ret = pySparkInterpreter.interpret(code, context); - } catch (InterpreterException e) { - e.printStackTrace(); - } - assertNotNull(ret); - Pattern expectedMessage = Pattern.compile("KeyboardInterrupt"); - Matcher m = expectedMessage.matcher(ret.message().toString()); - assertTrue(m.find()); - } + @Override + public void tearDown() throws InterpreterException { + intpGroup.close(); + intpGroup = null; + interpreter = null; } @Test - public void testCancelIntp() throws InterruptedException, InterpreterException { - if (getSparkVersionNumber() > 11) { - assertEquals(InterpreterResult.Code.SUCCESS, - pySparkInterpreter.interpret("a = 1\n", context).code()); - - Thread t = new Thread(new infinityPythonJob()); - t.start(); - Thread.sleep(5000); - pySparkInterpreter.cancel(context); - assertTrue(t.isAlive()); - t.join(2000); - assertFalse(t.isAlive()); - } + public void testPySpark() throws InterruptedException, InterpreterException, IOException { + IPySparkInterpreterTest.testPySpark(interpreter, mockRemoteEventClient); } + } diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java index 53f29c327c7..8eaf1e4b901 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java @@ -26,6 +26,8 @@ import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import java.io.IOException; @@ -47,8 +49,8 @@ public class SparkRInterpreterTest { private SparkInterpreter sparkInterpreter; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - @Test - public void testSparkRInterpreter() throws InterpreterException, InterruptedException { + @Before + public void setUp() throws InterpreterException { Properties properties = new Properties(); properties.setProperty("spark.master", "local"); properties.setProperty("spark.app.name", "test"); @@ -69,6 +71,16 @@ public void testSparkRInterpreter() throws InterpreterException, InterruptedExce sparkRInterpreter.open(); sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); + } + + @After + public void tearDown() throws InterpreterException { + sparkInterpreter.close(); + } + + @Test + public void testSparkRInterpreter() throws InterpreterException, InterruptedException { + InterpreterResult result = sparkRInterpreter.interpret("1+1", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); diff --git a/spark/interpreter/src/test/resources/log4j.properties b/spark/interpreter/src/test/resources/log4j.properties index 0dc7c89701f..edd13e463ba 100644 --- a/spark/interpreter/src/test/resources/log4j.properties +++ b/spark/interpreter/src/test/resources/log4j.properties @@ -43,9 +43,9 @@ log4j.logger.DataNucleus.Datastore=ERROR # Log all JDBC parameters log4j.logger.org.hibernate.type=ALL -log4j.logger.org.apache.zeppelin.interpreter=DEBUG -log4j.logger.org.apache.zeppelin.spark=DEBUG +log4j.logger.org.apache.zeppelin.interpreter=WARN +log4j.logger.org.apache.zeppelin.spark=INFO log4j.logger.org.apache.zeppelin.python=DEBUG -log4j.logger.org.apache.spark.repl.Main=INFO +log4j.logger.org.apache.spark.repl.Main=WARN diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java index 9f889013bba..4cf4b315968 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java @@ -161,4 +161,17 @@ public boolean equals(Object o) { public int hashCode() { return id != null ? id.hashCode() : 0; } + + public void close() { + for (List session : sessions.values()) { + for (Interpreter interpreter : session) { + try { + interpreter.close(); + } catch (InterpreterException e) { + LOGGER.warn("Fail to close interpreter: " + interpreter.getClassName(), e); + } + } + } + sessions.clear(); + } } diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index 6198c7ba2e3..671091570fa 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -31,9 +31,11 @@ import java.io.File; import java.io.IOException; import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.zeppelin.conf.ZeppelinConfiguration; import org.apache.zeppelin.display.AngularObject; @@ -54,8 +56,16 @@ */ @RunWith(value = Parameterized.class) public class ZeppelinSparkClusterTest extends AbstractTestRestApi { + private static final Logger LOGGER = LoggerFactory.getLogger(ZeppelinSparkClusterTest.class); + //This is for only run setupSparkInterpreter one time for each spark version, otherwise + //each test method will run setupSparkInterpreter which will cost a long time and may cause travis + //ci timeout. + //TODO(zjffdu) remove this after we upgrade it to junit 4.13 (ZEPPELIN-3341) + private static Set verifiedSparkVersions = new HashSet<>(); + + private String sparkVersion; private AuthenticationInfo anonymous = new AuthenticationInfo("anonymous"); @@ -63,8 +73,11 @@ public ZeppelinSparkClusterTest(String sparkVersion) throws Exception { this.sparkVersion = sparkVersion; LOGGER.info("Testing SparkVersion: " + sparkVersion); String sparkHome = SparkDownloadUtils.downloadSpark(sparkVersion); - setupSparkInterpreter(sparkHome); - verifySparkVersionNumber(); + if (!verifiedSparkVersions.contains(sparkVersion)) { + verifiedSparkVersions.add(sparkVersion); + setupSparkInterpreter(sparkHome); + verifySparkVersionNumber(); + } } @Parameterized.Parameters diff --git a/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java b/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java index 21de851bab7..04a87fdef37 100644 --- a/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java +++ b/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java @@ -520,7 +520,8 @@ public Properties getJavaProperties() { Map iProperties = (Map) properties; for (Map.Entry entry : iProperties.entrySet()) { if (entry.getValue().getValue() != null) { - jProperties.setProperty(entry.getKey().trim(), entry.getValue().getValue().toString().trim()); + jProperties.setProperty(entry.getKey().trim(), + entry.getValue().getValue().toString().trim()); } }