diff --git a/modules/scala/almond-spark/src/main/scala/org/apache/spark/sql/almondinternals/SendLog.scala b/modules/scala/almond-spark/src/main/scala/org/apache/spark/sql/almondinternals/SendLog.scala index 6bb194e09..72784b024 100644 --- a/modules/scala/almond-spark/src/main/scala/org/apache/spark/sql/almondinternals/SendLog.scala +++ b/modules/scala/almond-spark/src/main/scala/org/apache/spark/sql/almondinternals/SendLog.scala @@ -81,12 +81,13 @@ final class SendLog( try { withExponentialBackOff { commHandler.commOpen( - commTarget, - commId, - Json.obj( + targetName = commTarget, + id = commId, + data = Json.obj( "file_name" -> Json.jString(fileName0), "prefix" -> prefix.fold(Json.jNull)(Json.jString) - ).nospaces + ).nospaces, + metadata = "{}" ) } @@ -122,7 +123,7 @@ final class SendLog( lines.clear() withExponentialBackOff { - commHandler.commMessage(commId, res) + commHandler.commMessage(commId, res, "{}") } } } @@ -133,7 +134,7 @@ final class SendLog( if (r != null) r.close() // no re-attempt hereā€¦ - commHandler.commClose(commId, "{}") + commHandler.commClose(commId, "{}", "{}") } } } diff --git a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/CommHandler.scala b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/CommHandler.scala index 4efad4550..87688fb5a 100644 --- a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/CommHandler.scala +++ b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/CommHandler.scala @@ -25,10 +25,15 @@ abstract class CommHandler extends OutputHandler.UpdateHelpers { def registerCommTarget(name: String, target: CommTarget): Unit def unregisterCommTarget(name: String): Unit + def registerCommId(id: String, target: CommTarget): Unit + def unregisterCommId(id: String): Unit - def commOpen(targetName: String, id: String, data: String): Unit - def commMessage(id: String, data: String): Unit - def commClose(id: String, data: String): Unit + @throws(classOf[IllegalArgumentException]) + def commOpen(targetName: String, id: String, data: String, metadata: String): Unit + @throws(classOf[IllegalArgumentException]) + def commMessage(id: String, data: String, metadata: String): Unit + @throws(classOf[IllegalArgumentException]) + def commClose(id: String, data: String, metadata: String): Unit final def receiver( @@ -50,14 +55,23 @@ abstract class CommHandler extends OutputHandler.UpdateHelpers { final def sender( targetName: String, id: String = UUID.randomUUID().toString, - data: String = "{}" + data: String = "{}", + metadata: String = "{}", + onMessage: (String, String) => Unit = (_, _) => (), + onClose: (String, String) => Unit = (_, _) => () ): Comm = { - commOpen(targetName, id, data) + commOpen(targetName, id, data, metadata) + val t = CommTarget( + onMessage = (id, data) => onMessage(id, data), + onOpen = (_, _) => (), // ignore since we open the comm from the kernel + onClose = (id, data) => onClose(id, data) + ) + registerCommId(id, t) new Comm { - def message(data: String) = - commMessage(id, data) - def close(data: String) = - commClose(id, data) + def message(data: String, metadata: String = "{}") = + commMessage(id, data, metadata) + def close(data: String, metadata: String = "{}") = + commClose(id, data, metadata) } } } @@ -65,8 +79,8 @@ abstract class CommHandler extends OutputHandler.UpdateHelpers { object CommHandler { abstract class Comm { - def message(data: String): Unit - def close(data: String): Unit + def message(data: String, metadata: String): Unit + def close(data: String, metadata: String): Unit } } diff --git a/modules/shared/interpreter/src/main/scala/almond/interpreter/comm/DefaultCommHandler.scala b/modules/shared/interpreter/src/main/scala/almond/interpreter/comm/DefaultCommHandler.scala index b9113a9bf..bf8368f63 100644 --- a/modules/shared/interpreter/src/main/scala/almond/interpreter/comm/DefaultCommHandler.scala +++ b/modules/shared/interpreter/src/main/scala/almond/interpreter/comm/DefaultCommHandler.scala @@ -5,7 +5,7 @@ import almond.interpreter.api.{CommHandler, CommTarget, DisplayData} import almond.interpreter.util.DisplayDataOps._ import almond.interpreter.Message import almond.protocol._ -import argonaut.{EncodeJson, JsonObject} +import argonaut.{EncodeJson, Json, JsonObject} import argonaut.Parse.{parse => parseJson} import cats.effect.IO import fs2.concurrent.Queue @@ -21,8 +21,8 @@ final class DefaultCommHandler( private val message: Message[_] = Message( - Header("", "username", "", "", Some(Protocol.versionStr)), // FIXME Hardcoded user / session id - () + header = Header("", "username", "", "", Some(Protocol.versionStr)), // FIXME Hardcoded user / session id + content = () ) @@ -34,32 +34,31 @@ final class DefaultCommHandler( def registerCommTarget(name: String, target: IOCommTarget): Unit = commTargetManager.addTarget(name, target) + def registerCommId(id: String, target: CommTarget): Unit = + commTargetManager.addId(IOCommTarget.fromCommTarget(target, commEc), id) + def unregisterCommId(id: String): Unit = + commTargetManager.removeId(id) - private def publish[T: EncodeJson](messageType: MessageType[T], content: T): Unit = + + private def publish[T: EncodeJson](messageType: MessageType[T], content: T, metadata: Map[String, Json]): Unit = message - .publish(messageType, content) + .publish(messageType, content, metadata) .enqueueOn(Channel.Publish, queue) .unsafeRunSync() - private def parseJsonObj(s: String): Option[JsonObject] = + private def parseJsonObj(s: String): JsonObject = parseJson(s) - .right - .toOption - .flatMap(_.obj) - - // TODO Throw an exception if bad data is passed + .right.flatMap(_.obj.toRight("Not a JSON object")) + .fold(left => throw new IllegalArgumentException(left), identity) - def commOpen(targetName: String, id: String, data: String): Unit = - for (obj <- parseJsonObj(data)) - publish(Comm.openType, Comm.Open(id, targetName, obj)) + def commOpen(targetName: String, id: String, data: String, metadata: String): Unit = + publish(Comm.openType, Comm.Open(id, targetName, parseJsonObj(data)), parseJsonObj(metadata).toMap) - def commMessage(id: String, data: String): Unit = - for (obj <- parseJsonObj(data)) - publish(Comm.messageType, Comm.Message(id, obj)) + def commMessage(id: String, data: String, metadata: String): Unit = + publish(Comm.messageType, Comm.Message(id, parseJsonObj(data)), parseJsonObj(metadata).toMap) - def commClose(id: String, data: String): Unit = - for (obj <- parseJsonObj(data)) - publish(Comm.closeType, Comm.Close(id, obj)) + def commClose(id: String, data: String, metadata: String): Unit = + publish(Comm.closeType, Comm.Close(id, parseJsonObj(data)), parseJsonObj(metadata).toMap) def updateDisplay(data: DisplayData): Unit = { @@ -72,6 +71,6 @@ final class DefaultCommHandler( Execute.DisplayData.Transient(data.idOpt) ) - publish(Execute.updateDisplayDataType, content) + publish(Execute.updateDisplayDataType, content, Map.empty) } } diff --git a/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala b/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala index 48ab6345b..01909874c 100644 --- a/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala +++ b/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala @@ -34,7 +34,7 @@ final class TestInterpreter extends Interpreter { ExecuteResult.Error("comm not available") case Some(h) => val target = code.stripPrefix("comm-open:") - h.commOpen(target, target, "{}") + h.commOpen(target, target, "{}", "{}") count += 1 ExecuteResult.Success() } @@ -44,7 +44,7 @@ final class TestInterpreter extends Interpreter { ExecuteResult.Error("comm not available") case Some(h) => val target = code.stripPrefix("comm-message:") - h.commMessage(target, """{"a": "b"}""") + h.commMessage(target, """{"a": "b"}""", "{}") count += 1 ExecuteResult.Success() } @@ -54,7 +54,7 @@ final class TestInterpreter extends Interpreter { ExecuteResult.Error("comm not available") case Some(h) => val target = code.stripPrefix("comm-close:") - h.commClose(target, "{}") + h.commClose(target, "{}", "{}") count += 1 ExecuteResult.Success() } diff --git a/project/Mima.scala b/project/Mima.scala index 11c82bbb0..78f467b5d 100644 --- a/project/Mima.scala +++ b/project/Mima.scala @@ -9,14 +9,14 @@ object Mima { .replace("-RC", "-") .forall(c => c == '.' || c == '-' || c.isDigit) - def binaryCompatibilityVersions(contains: String): Set[String] = - Seq("git", "tag", "--merged", "HEAD^", "--contains", contains) + def binaryCompatibilityVersions(): Set[String] = + Seq("git", "tag", "--merged", "HEAD^", "--contains", "v0.7.0") .!! .linesIterator .map(_.trim) .filter(_.startsWith("v")) .map(_.stripPrefix("v")) - .filter(_ != "0.3.1") // Mima enabled right after it + .filter(_ != "0.7.0") // Preserving compatibility right after it .filter(stable) .toSet diff --git a/project/Settings.scala b/project/Settings.scala index 61851a736..cfec5e0e5 100644 --- a/project/Settings.scala +++ b/project/Settings.scala @@ -212,11 +212,8 @@ object Settings { lazy val mima = Seq( MimaPlugin.autoImport.mimaPreviousArtifacts := { val sv = scalaVersion.value - val contains = - if (sv.startsWith("2.13.")) "4e9441b9" - else "v0.3.1" - Mima.binaryCompatibilityVersions(contains).map { ver => + Mima.binaryCompatibilityVersions().map { ver => (organization.value % moduleName.value % ver).cross(crossVersion.value) } }