Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 62 additions & 50 deletions repl/src/main/resources/fake_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io
import json
import logging
import signal
import sys
import traceback
import base64
Expand All @@ -49,6 +50,14 @@

TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in <module>')

class StatementCancellationException(Exception):
pass

def signal_handler(signal, frame):
raise StatementCancellationException

signal.signal(signal.SIGUSR1, signal_handler)

def execute_reply(status, content):
return {
'msg_type': 'execute_reply',
Expand Down Expand Up @@ -648,58 +657,61 @@ def main():
sys_stdout.flush()

while True:
line = sys_stdin.readline()

if line == '':
break
elif line == '\n':
continue

try:
msg = json.loads(line)
except ValueError:
LOG.error('failed to parse message', exc_info=True)
continue

try:
msg_type = msg['msg_type']
except KeyError:
LOG.error('missing message type', exc_info=True)
continue

try:
content = msg['content']
except KeyError:
LOG.error('missing content', exc_info=True)
continue

if not isinstance(content, dict):
LOG.error('content is not a dictionary')
continue

try:
handler = msg_type_router[msg_type]
except KeyError:
LOG.error('unknown message type: %s', msg_type)
line = sys_stdin.readline()

if line == '':
break
elif line == '\n':
continue

try:
msg = json.loads(line)
except ValueError:
LOG.error('failed to parse message', exc_info=True)
continue

try:
msg_type = msg['msg_type']
except KeyError:
LOG.error('missing message type', exc_info=True)
continue

try:
content = msg['content']
except KeyError:
LOG.error('missing content', exc_info=True)
continue

if not isinstance(content, dict):
LOG.error('content is not a dictionary')
continue

try:
handler = msg_type_router[msg_type]
except KeyError:
LOG.error('unknown message type: %s', msg_type)
continue

response = handler(content)

try:
response = json.dumps(response)
except ValueError:
response = json.dumps({
'msg_type': 'inspect_reply',
'content': {
'status': 'error',
'ename': 'ValueError',
'evalue': 'cannot json-ify %s' % response,
'traceback': [],
}
})

print(response, file=sys_stdout)
sys_stdout.flush()
except StatementCancellationException:
continue

response = handler(content)

try:
response = json.dumps(response)
except ValueError:
response = json.dumps({
'msg_type': 'inspect_reply',
'content': {
'status': 'error',
'ename': 'ValueError',
'evalue': 'cannot json-ify %s' % response,
'traceback': [],
}
})

print(response, file=sys_stdout)
sys_stdout.flush()
finally:
if os.environ.get("LIVY_TEST") != "true" and 'sc' in global_dict:
gateway.shutdown_callback_server()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging {

protected def conf: SparkConf

private var currentThread: Thread = _

protected def postStart(): Unit = {
entries = new SparkEntries(conf)

Expand Down Expand Up @@ -106,10 +108,14 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging {
override protected[repl] def execute(code: String): Interpreter.ExecuteResponse =
restoreContextClassLoader {
require(isStarted())

executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject(
(TEXT_PLAIN, JString(""))
)))
try {
currentThread = Thread.currentThread()
executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject(
(TEXT_PLAIN, JString(""))
)))
} finally {
currentThread = null
}
}

override protected[repl] def complete(code: String, cursor: Int): Array[String] = {
Expand Down Expand Up @@ -349,4 +355,10 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging {

output
}

override def cancel(): Unit = {
if (currentThread != null) {
currentThread.interrupt()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interrupt cannot terminate the running task. This can resolve the problem of Thread.sleep(xxx), but it won't help for long running task like while(true): xxx.

}
}
}
3 changes: 3 additions & 0 deletions repl/src/main/scala/org/apache/livy/repl/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ trait Interpreter {

/** Shut down the interpreter. */
def close(): Unit

/** Cancel the executions */
def cancel(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ abstract class ProcessInterpreter(process: Process)
}
}

override def cancel(): Unit = {}

protected def sendExecuteRequest(request: String): Interpreter.ExecuteResponse

protected def sendShutdownRequest(): Unit = {}
Expand Down
14 changes: 14 additions & 0 deletions repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,20 @@ private class PythonInterpreter(
}
}

override def cancel(): Unit = {
val pythonPid = getPythonPid()
info("Sending SIGUSR1 to " + pythonPid)
Runtime.getRuntime().exec("kill -SIGUSR1 " + pythonPid)
}

def getPythonPid(): Int = {
// This implementation is specific to Unix type systems
val field = process.getClass().getDeclaredField("pid")
field.setAccessible(true)
val pid = field.get(process).asInstanceOf[Int]
pid
}

private def sendRequest(request: Map[String, Any]): Option[JValue] = {
stdin.println(write(request))
stdin.flush()
Expand Down
2 changes: 2 additions & 0 deletions repl/src/main/scala/org/apache/livy/repl/SQLInterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,6 @@ class SQLInterpreter(
}

override def close(): Unit = { }

override def cancel(): Unit = {}
}
7 changes: 6 additions & 1 deletion repl/src/main/scala/org/apache/livy/repl/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class Session(
}

val statementId = newStatementId.getAndIncrement()
val statement = new Statement(statementId, code, StatementState.Waiting, null)
val statement = new Statement(statementId, code, StatementState.Waiting, null, tpe.name)
_statements.synchronized { _statements(statementId) = statement }

Future {
Expand Down Expand Up @@ -213,6 +213,11 @@ class Session(
statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled)
} else {
sc.cancelJobGroup(statementId.toString)
val intpOpt = interpreter(Kind(statement.kind))
if (!intpOpt.isEmpty) {
val intp = intpOpt.get
intp.cancel()
}
if (statement.state.get() == StatementState.Cancelling) {
Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class SparkSessionSpec extends BaseSessionSpec(Spark) {
eventually(timeout(30 seconds), interval(100 millis)) {
assert(session.statements(stmtId).state.get() == StatementState.Cancelled)
session.statements(stmtId).output should include (
"Job 0 cancelled part of cancelled job group 0")
"java.lang.InterruptedException")
}
}

Expand All @@ -232,7 +232,7 @@ class SparkSessionSpec extends BaseSessionSpec(Spark) {
eventually(timeout(30 seconds), interval(100 millis)) {
assert(session.statements(stmtId1).state.get() == StatementState.Cancelled)
session.statements(stmtId1).output should include (
"Job 0 cancelled part of cancelled job group 0")
"java.lang.InterruptedException")
}
}

Expand Down
6 changes: 6 additions & 0 deletions rsc/src/main/java/org/apache/livy/rsc/driver/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class Statement {
public double progress;
public long started = 0;
public long completed = 0;
public transient String kind;

public Statement(Integer id, String code, StatementState state, String output) {
this.id = id;
Expand All @@ -39,6 +40,11 @@ public Statement(Integer id, String code, StatementState state, String output) {
this.progress = 0.0;
}

public Statement(Integer id, String code, StatementState state, String output, String kind) {
this(id, code, state, output);
this.kind = kind;
}

public Statement() {
this(null, null, null, null);
}
Expand Down