From 163acfa960287a583bc26facc74e67d7cc1d2d45 Mon Sep 17 00:00:00 2001 From: Lee moon soo Date: Thu, 10 Sep 2015 19:06:57 -0700 Subject: [PATCH 1/3] Set classloader for gatewayserver with classes from dependency loader --- .../zeppelin/spark/PySparkInterpreter.java | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index 0e58729d885..f2c59c12482 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -25,7 +25,10 @@ import java.io.OutputStreamWriter; import java.io.PipedInputStream; import java.io.PipedOutputStream; +import java.net.MalformedURLException; import java.net.ServerSocket; +import java.net.URL; +import java.net.URLClassLoader; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -51,6 +54,7 @@ import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.spark.dep.DependencyContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -125,6 +129,43 @@ private void createPythonScript() { @Override public void open() { + DepInterpreter depInterpreter = getDepInterpreter(); + + // load libraries from Dependency Interpreter + URL [] urls = new URL[0]; + + if (depInterpreter != null) { + DependencyContext depc = depInterpreter.getDependencyContext(); + if (depc != null) { + List files = depc.getFiles(); + List urlList = new LinkedList(); + if (files != null) { + for (File f : files) { + try { + urlList.add(f.toURI().toURL()); + } catch (MalformedURLException e) { + logger.error("Error", e); + } + } + + urls = urlList.toArray(urls); + } + } + } + + ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); + try { + URLClassLoader newCl = new URLClassLoader(urls, oldCl); + Thread.currentThread().setContextClassLoader(newCl); + createGatewayServerAndStartScript(); + } catch (Exception e) { + throw new InterpreterException(e); + } finally { + Thread.currentThread().setContextClassLoader(oldCl); + } + } + + private void createGatewayServerAndStartScript() { // create python script createPythonScript(); @@ -400,6 +441,23 @@ public SQLContext getSQLContext() { } } + private DepInterpreter getDepInterpreter() { + InterpreterGroup intpGroup = getInterpreterGroup(); + if (intpGroup == null) return null; + synchronized (intpGroup) { + for (Interpreter intp : intpGroup) { + if (intp.getClassName().equals(DepInterpreter.class.getName())) { + Interpreter p = intp; + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (DepInterpreter) p; + } + } + } + return null; + } + @Override public void onProcessComplete(int exitValue) { From 1e8f52ab3e55a2f9b673a8d84c3c722ccd7addd4 Mon Sep 17 00:00:00 2001 From: Lee moon soo Date: Thu, 10 Sep 2015 19:32:54 -0700 Subject: [PATCH 2/3] Add test --- .../rest/ZeppelinSparkClusterTest.java | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index aa2476a5d5d..d5006eee957 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -18,8 +18,12 @@ import static org.junit.Assert.assertEquals; +import java.io.File; import java.io.IOException; +import java.util.List; +import org.apache.commons.io.FileUtils; +import org.apache.zeppelin.interpreter.InterpreterSetting; import org.apache.zeppelin.notebook.Note; import org.apache.zeppelin.notebook.Paragraph; import org.apache.zeppelin.scheduler.Job.Status; @@ -75,7 +79,6 @@ public void basicRDDTransformationAndActionTest() throws IOException { public void pySparkTest() throws IOException { // create new note Note note = ZeppelinServer.notebook.createNote(); - int sparkVersion = getSparkVersionNumber(note); if (isPyspark() && sparkVersion >= 12) { // pyspark supported from 1.2.1 @@ -129,6 +132,46 @@ public void zRunTest() throws IOException { ZeppelinServer.notebook.removeNote(note.id()); } + @Test + public void pySparkDepLoaderTest() throws IOException { + // create new note + Note note = ZeppelinServer.notebook.createNote(); + + if (isPyspark() && getSparkVersionNumber(note) >= 14) { + // restart spark interpreter + List settings = + ZeppelinServer.notebook.getBindedInterpreterSettings(note.id()); + + for (InterpreterSetting setting : settings) { + if (setting.getGroup().equals("spark")) { + ZeppelinServer.notebook.getInterpreterFactory().restart(setting.id()); + break; + } + } + + // load dep + Paragraph p0 = note.addParagraph(); + p0.setText("%dep z.load(\"com.databricks:spark-csv_2.11:1.2.0\")"); + note.run(p0.getId()); + waitForFinish(p0); + + // write test csv file + File tmpFile = File.createTempFile("test", "csv"); + FileUtils.write(tmpFile, "a,b\n1,2"); + + // load data using libraries from dep loader + Paragraph p1 = note.addParagraph(); + p1.setText("%pyspark\n" + + "from pyspark.sql import SQLContext\n" + + "print(sqlContext.read.format('com.databricks.spark.csv')" + + ".load('"+ tmpFile.getAbsolutePath() +"').count())"); + note.run(p1.getId()); + + waitForFinish(p1); + assertEquals("2\n", p1.getResult().message()); + } + } + /** * Get spark version number as a numerical value. * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... From 0de89fec39e5cf583ec89b6f0154ad00b2cf183e Mon Sep 17 00:00:00 2001 From: Lee moon soo Date: Mon, 14 Sep 2015 14:04:40 +0900 Subject: [PATCH 3/3] Add logging --- .../main/java/org/apache/zeppelin/spark/PySparkInterpreter.java | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index f2c59c12482..d0e5fecc2b2 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -159,6 +159,7 @@ public void open() { Thread.currentThread().setContextClassLoader(newCl); createGatewayServerAndStartScript(); } catch (Exception e) { + logger.error("Error", e); throw new InterpreterException(e); } finally { Thread.currentThread().setContextClassLoader(oldCl);