From 2110941757398dcfaab4cef5ac018159c9a7c101 Mon Sep 17 00:00:00 2001 From: Alex Archambault Date: Mon, 1 May 2023 18:34:02 +0200 Subject: [PATCH] Send completion metadata to frontends --- build.sc | 3 +- .../scala/almond/echo/EchoInterpreter.scala | 1 + .../integration/KernelTestsDefinitions.scala | 7 + .../ScalaInterpreterCompletions.scala | 324 +++++++++++++++++ .../ScalaInterpreterCompletions.scala | 191 ++++++++++ .../main/scala/almond/ScalaInterpreter.scala | 77 +++- .../scala/almond/ScalaInterpreterTests.scala | 14 +- .../main/scala/almond/integration/Tests.scala | 333 ++++++++++++++++++ .../scala/almond/interpreter/Completion.scala | 10 +- .../almond/interpreter/TestInterpreter.scala | 2 +- .../scala/almond/testkit/ClientStreams.scala | 12 +- .../src/main/scala/almond/testkit/Dsl.scala | 44 +++ 12 files changed, 1005 insertions(+), 13 deletions(-) create mode 100644 modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterCompletions.scala create mode 100644 modules/scala/scala-interpreter/src/main/scala-3/almond/internals/ScalaInterpreterCompletions.scala diff --git a/build.sc b/build.sc index 1268cb714..895d7b90f 100644 --- a/build.sc +++ b/build.sc @@ -565,7 +565,8 @@ class TestDefinitions(val crossScalaVersion: String) extends CrossSbtModule with shared.`test-kit`() ) def ivyDeps = Agg( - Deps.coursierApi + Deps.coursierApi, + Deps.upickle ) } diff --git a/modules/echo/src/main/scala/almond/echo/EchoInterpreter.scala b/modules/echo/src/main/scala/almond/echo/EchoInterpreter.scala index a29faab8a..60bf559fc 100644 --- a/modules/echo/src/main/scala/almond/echo/EchoInterpreter.scala +++ b/modules/echo/src/main/scala/almond/echo/EchoInterpreter.scala @@ -72,6 +72,7 @@ final class EchoInterpreter extends Interpreter { pos, pos, Seq("sent"), + None, RawJson(code.drop("meta:".length).getBytes(StandardCharsets.UTF_8)) ) else diff --git a/modules/scala/integration/src/main/scala/almond/integration/KernelTestsDefinitions.scala b/modules/scala/integration/src/main/scala/almond/integration/KernelTestsDefinitions.scala index 288659e92..d1b9b403c 100644 --- a/modules/scala/integration/src/main/scala/almond/integration/KernelTestsDefinitions.scala +++ b/modules/scala/integration/src/main/scala/almond/integration/KernelTestsDefinitions.scala @@ -138,4 +138,11 @@ abstract class KernelTestsDefinitions extends AlmondFunSuite { } } + test("completion") { + kernelLauncher.withKernel { implicit runner => + implicit val sessionId: SessionId = SessionId() + almond.integration.Tests.completion(kernelLauncher.defaultScalaVersion) + } + } + } diff --git a/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterCompletions.scala b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterCompletions.scala new file mode 100644 index 000000000..ffc09515a --- /dev/null +++ b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterCompletions.scala @@ -0,0 +1,324 @@ +package almond.internals + +import almond.logger.LoggerContext +import ammonite.util.Name +import ammonite.util.Util.newLine + +import java.io.{OutputStream, PrintWriter} + +import scala.reflect.internal.util.{BatchSourceFile, OffsetPosition, Position} +import scala.tools.nsc +import scala.tools.nsc.interactive.Response +import scala.util.Try + +object ScalaInterpreterCompletions { + + // Based on https://github.com/com-lihaoyi/Ammonite/blob/7699698ecedd4e01912df373f250b6a4e20c9a33/amm/compiler/src/main/scala-2/ammonite/compiler/Pressy.scala#L312-L313 + // and around. + // Most changes revolve around keeping "kind" info ("method", "class", "value", etc.) + + def complete( + compilerManager: ammonite.compiler.iface.CompilerLifecycleManager, + dependencyCompleteOpt: Option[String => (Int, Seq[String])], + snippetIndex: Int, + previousImports: String, + snippet: String, + logCtx: LoggerContext + ): (Int, Seq[String], Seq[String], Seq[(String, String)]) = { + + val log = logCtx(getClass) + + val prefix = previousImports + newLine + "object AutocompleteWrapper{" + newLine + val suffix = newLine + "}" + val allCode = prefix + snippet + suffix + val index = snippetIndex + prefix.length + + val compilerManager0 = compilerManager match { + case m: ammonite.compiler.CompilerLifecycleManager => m + case _ => ??? + } + val pressy = compilerManager0.pressy.compiler + val currentFile = new BatchSourceFile( + ammonite.compiler.Compiler.makeFile(allCode.getBytes, name = "Current.sc"), + allCode + ) + + val r = new Response[Unit] + pressy.askReload(List(currentFile), r) + r.get.fold(x => x, e => throw e) + + val run = Try(new Run(pressy, currentFile, dependencyCompleteOpt, allCode, index, logCtx)) + + val (i, all): (Int, Seq[(String, Option[String], String)]) = + run.map(_.prefixed).toEither match { + case Right((i0, all0)) => i0 -> all0 + case Left(ex) => + log.info("Ignoring exception during completion", ex) + (0, Seq.empty) + } + + val allNames = all.collect { case (name, None, _) => name }.sorted.distinct + + val allNameWithTypes = all.collect { case (name, None, tpe) => (name, tpe) }.sorted.distinct + + val signatures = all.collect { case (name, Some(defn), _) => defn }.sorted.distinct + + log.info(s"signatures=$signatures") + + val isPartialIvyImport = allNames.isEmpty && + snippet.split("\\s+|`").startsWith(Seq("import", "$ivy.")) && + snippet.count(_ == '`') == 1 + + if (isPartialIvyImport) + complete( + compilerManager, + dependencyCompleteOpt, + snippetIndex, + previousImports, + snippet + "`", + logCtx + ) + else (i - prefix.length, allNames, signatures, allNameWithTypes) + } + + class Run( + val pressy: nsc.interactive.Global, + currentFile: BatchSourceFile, + dependencyCompleteOpt: Option[String => (Int, Seq[String])], + allCode: String, + index: Int, + logCtx: LoggerContext + ) { + + val log = logCtx(getClass) + + val blacklistedPackages = Set("shaded") + + /** Dumb things that turn up in the autocomplete that nobody needs or wants + */ + def blacklisted(s: pressy.Symbol) = { + val blacklist = Set( + "scala.Predef.any2stringadd.+", + "scala.Any.##", + "java.lang.Object.##", + "scala.", + "scala.", + "scala.", + "scala.", + "scala.Predef.StringFormat.formatted", + "scala.Predef.Ensuring.ensuring", + "scala.Predef.ArrowAssoc.->", + "scala.Predef.ArrowAssoc.→", + "java.lang.Object.synchronized", + "java.lang.Object.ne", + "java.lang.Object.eq", + "java.lang.Object.wait", + "java.lang.Object.notifyAll", + "java.lang.Object.notify" + ) + + blacklist(s.fullNameAsName('.').decoded) || + s.isImplicit || + // Cache objects, which you should probably never need to + // access directly, and apart from that have annoyingly long names + "cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decoded).isDefined || + s.isDeprecated || + s.decodedName == "" || + s.decodedName.contains('$') + } + + private val memberToString = { + // Some hackery here to get at the protected CodePrinter.printedName which was the only + // mildly reusable prior art I could locate. Other related bits: + // - When constructing Trees, Scala captures the back-quoted nature into the AST as + // Ident.isBackquoted. Code-completion is inherently "pre-Tree", at least for the + // symbols being considered for use as completions, so not clear how this would be + // leveragable without big changes inside nsc. + // - There's a public-but-incomplete implementation of rule-based backquoting in + // Printers.quotedName + val nullOutputStream = new OutputStream() { def write(b: Int): Unit = {} } + val backQuoter = new pressy.CodePrinter( + new PrintWriter(nullOutputStream), + printRootPkg = false + ) { + def apply(decodedName: pressy.Name): String = printedName(decodedName, decoded = true) + } + + (member: pressy.Member) => { + import pressy._ + // nsc returns certain members w/ a suffix (LOCAL_SUFFIX_STRING, " "). + // See usages of symNameDropLocal in nsc's PresentationCompilerCompleter. + // Several people have asked that Scala mask this implementation detail: + // https://github.com/scala/bug/issues/5736 + val decodedName = member.sym.name.dropLocal.decodedName + backQuoter(decodedName) + } + } + + val r = new Response[pressy.Tree] + pressy.askTypeAt(new OffsetPosition(currentFile, index), r) + val tree = r.get.fold(x => x, e => throw e) + + /** Search for terms to autocomplete not just from the local scope, but from any packages and + * package objects accessible from the local scope + */ + def deepCompletion(name: String): List[(String, String)] = { + def rec(t: pressy.Symbol): Seq[pressy.Symbol] = + if (blacklistedPackages(t.nameString)) + Nil + else { + val children = + if (t.hasPackageFlag || t.isPackageObject) + pressy.ask(() => t.typeSignature.members.filter(_ != t).flatMap(rec)) + else Nil + + t +: children.toSeq + } + + for { + member <- pressy.RootClass.typeSignature.members.toList + sym <- rec(member) + // sketchy name munging because I don't know how to do this properly + // Note lack of back-quoting support. + strippedName = sym.nameString.stripPrefix("package$").stripSuffix("$") + if strippedName.startsWith(name) + (pref, _) = sym.fullNameString.splitAt(sym.fullNameString.lastIndexOf('.') + 1) + out = pref + strippedName + if out != "" + } yield (out, sym.kindString) + } + def handleTypeCompletion( + position: Int, + decoded: String, + offset: Int + ): (Int, List[(String, Option[String], String)]) = { + + val r = ask(position, pressy.askTypeCompletion) + val prefix = if (decoded == "") "" else decoded + (position + offset, handleCompletion(r, prefix)) + } + + def handleCompletion( + r: List[pressy.Member], + prefix: String + ): List[(String, Option[String], String)] = pressy.ask { () => + log.info( + s"response=$newLine${r.map(m => " " + Try(m.toString).getOrElse("???") + newLine).mkString}" + ) + r.filter(_.sym.name.decoded.startsWith(prefix)) + .filter(m => !blacklisted(m.sym)) + .map { x => + ( + memberToString(x), + if (x.sym.name.decoded != prefix) None + else Some(x.sym.defString), + x.sym.kindString + ) + } + } + + def prefixed: (Int, Seq[(String, Option[String], String)]) = tree match { + case t @ pressy.Select(qualifier, name) => + val dotOffset = if (qualifier.pos.point == t.pos.point) 0 else 1 + + // In scala 2.10.x if we call pos.end on a scala.reflect.internal.util.Position + // that is not a range, a java.lang.UnsupportedOperationException is thrown. + // We check here if Position is a range before calling .end on it. + // This is not needed for scala 2.11.x. + if (qualifier.pos.isRange) + handleTypeCompletion(qualifier.pos.end, name.decoded, dotOffset) + else + // not prefixed + (0, Seq.empty) + + case t @ pressy.Import(expr, selectors) => + // If the selectors haven't been defined yet... + if (selectors.head.name.toString == "") + if (expr.tpe.toString == "") + // If the expr is badly typed, try to scope complete it + if (expr.isInstanceOf[pressy.Ident]) { + val exprName = expr.asInstanceOf[pressy.Ident].name.decoded + val pos = + // Without the first case, things like `import ` are + // returned a wrong position. + if (exprName == "") expr.pos.point - 1 + else expr.pos.point + pos -> handleCompletion( + ask(expr.pos.point, pressy.askScopeCompletion), + // if it doesn't have a name at all, accept anything + if (exprName == "") "" else exprName + ) + } + else (expr.pos.point, Seq.empty) + else + // If the expr is well typed, type complete + // the next thing + handleTypeCompletion(expr.pos.end, "", 1) + else { + val isImportIvy = expr.isInstanceOf[pressy.Ident] && + expr.asInstanceOf[pressy.Ident].name.decoded == "$ivy" + val selector = selectors + .filter(s => Math.max(s.namePos, s.renamePos) <= index) + .lastOption + .getOrElse(selectors.last) + + if (isImportIvy) { + def forceOpenedBacktick(s: String): String = { + val res = Name(s).backticked + if (res.startsWith("`")) res.stripSuffix("`") + else "`" + res + } + dependencyCompleteOpt match { + case None => (0, Seq.empty[(String, Option[String], String)]) + case Some(complete) => + val input = selector.name.decoded + val (pos, completions) = complete(input) + val input0 = input.take(pos) + ( + selector.namePos, + completions.map(s => (forceOpenedBacktick(input0 + s), None, "dependency")) + ) + } + } + else + // just use typeCompletion + handleTypeCompletion(selector.namePos, selector.name.decoded, 0) + } + case t @ pressy.Ident(name) => + lazy val shallow = handleCompletion( + ask(index, pressy.askScopeCompletion), + name.decoded + ) + lazy val deep = deepCompletion(name.decoded).distinct.map(t => (t._1, None, t._2)) + + val res = + if (shallow.length > 0) shallow + else if (deep.length == 1) deep + else deep :+ (("", None, "")) + + (t.pos.start, res) + + case t => + val comps = ask(index, pressy.askScopeCompletion) + + index -> pressy.ask(() => + comps.filter(m => !blacklisted(m.sym)) + .map(s => (memberToString(s), None, s.sym.kindString)) + ) + } + def ask(index: Int, query: (Position, Response[List[pressy.Member]]) => Unit) = { + val position = new OffsetPosition(currentFile, index) + // if a match can't be found awaitResponse throws an Exception. + val result = Try { + ammonite.compiler.Compiler.awaitResponse[List[pressy.Member]]( + query(position, _) + ) + } + result.toEither match { + case Right(scopes) => scopes.filter(_.accessible) + case Left(error) => List.empty[pressy.Member] + } + } + } + +} diff --git a/modules/scala/scala-interpreter/src/main/scala-3/almond/internals/ScalaInterpreterCompletions.scala b/modules/scala/scala-interpreter/src/main/scala-3/almond/internals/ScalaInterpreterCompletions.scala new file mode 100644 index 000000000..7a5abcc5a --- /dev/null +++ b/modules/scala/scala-interpreter/src/main/scala-3/almond/internals/ScalaInterpreterCompletions.scala @@ -0,0 +1,191 @@ +package almond.internals + +import almond.logger.LoggerContext +import ammonite.compiler.Compiler +import dotty.tools.dotc.{CompilationUnit, Compiler => DottyCompiler, Run, ScalacCommand} +import dotty.tools.dotc.ast.{tpd, untpd} +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Symbols.{defn, Symbol} +import dotty.tools.dotc.core.{Flags, MacroClassLoader, Mode} +import dotty.tools.dotc.interactive.Completion +import dotty.tools.dotc.util.Spans.Span +import dotty.tools.dotc.util.{Property, SourceFile, SourcePosition} + +import java.nio.charset.StandardCharsets + +object ScalaInterpreterCompletions { + + extension (completion: Completion) { + def finalLabel = completion.label.replace(".package$.", ".") + } + + private def newLine = System.lineSeparator() + + def complete( + compilerManager: ammonite.compiler.iface.CompilerLifecycleManager, + dependencyCompleteOpt: Option[String => (Int, Seq[String])], + snippetIndex: Int, + previousImports: String, + snippet: String, + logCtx: LoggerContext + ): (Int, Seq[String], Seq[String], Seq[(String, String)]) = { + + val compilerManager0 = compilerManager match { + case m: ammonite.compiler.CompilerLifecycleManager => m + case _ => ??? + } + + val compiler = compilerManager0.compiler.compiler + val initialCtx = compilerManager0.compiler.initialCtx + + val offset = snippetIndex + + val prefix = previousImports + newLine + + "object AutocompleteWrapper{ val expr: _root_.scala.Unit = {" + newLine + val suffix = newLine + "()}}" + val allCode = prefix + snippet + suffix + val index = offset + prefix.length + + // Originally based on + // https://github.com/lampepfl/dotty/blob/3.0.0-M1/ + // compiler/src/dotty/tools/repl/ReplDriver.scala/#L179-L191 + + val (tree, ctx0) = + tryTypeCheck(compiler, initialCtx, allCode.getBytes("UTF-8"), "") + val ctx = ctx0.fresh + val file = SourceFile.virtual("", allCode, maybeIncomplete = true) + val unit = CompilationUnit(file)(using ctx) + unit.tpdTree = { + given Context = ctx + import tpd._ + tree match { + case PackageDef(_, p) => + p.collectFirst { + case TypeDef(_, tmpl: Template) => + tmpl.body + .collectFirst { case dd: ValDef if dd.name.show == "expr" => dd } + .getOrElse(???) + }.getOrElse(???) + case _ => ??? + } + } + val ctx1 = ctx.fresh.setCompilationUnit(unit) + val srcPos = SourcePosition(file, Span(index)) + val (start, completions) = dotty.ammonite.compiler.AmmCompletion.completions( + srcPos, + dependencyCompleteOpt = dependencyCompleteOpt, + enableDeep = false + )(using ctx1) + + val blacklistedPackages = Set("shaded") + + def deepCompletion(name: String): List[String] = { + given Context = ctx1 + def rec(t: Symbol): Seq[Symbol] = + if (blacklistedPackages(t.name.toString)) + Nil + else { + val children = + if (t.is(Flags.Package) || t.is(Flags.PackageVal) || t.is(Flags.PackageClass)) + t.denot.info.allMembers.map(_.symbol).filter(_ != t).flatMap(rec) + else Nil + + t +: children.toSeq + } + + for { + member <- defn.RootClass.denot.info.allMembers.map(_.symbol).toList + sym <- rec(member) + // Scala 2 comment: sketchy name munging because I don't know how to do this properly + // Note lack of back-quoting support. + strippedName = sym.name.toString.stripPrefix("package$").stripSuffix("$") + if strippedName.startsWith(name) + (pref, _) = sym.fullName.toString.splitAt(sym.fullName.toString.lastIndexOf('.') + 1) + out = pref + strippedName + if out != "" + } yield out + } + + def blacklisted(s: Symbol) = { + given Context = ctx1 + val blacklist = Set( + "scala.Predef.any2stringadd.+", + "scala.Any.##", + "java.lang.Object.##", + "scala.", + "scala.", + "scala.", + "scala.", + "scala.Predef.StringFormat.formatted", + "scala.Predef.Ensuring.ensuring", + "scala.Predef.ArrowAssoc.->", + "scala.Predef.ArrowAssoc.→", + "java.lang.Object.synchronized", + "java.lang.Object.ne", + "java.lang.Object.eq", + "java.lang.Object.wait", + "java.lang.Object.notifyAll", + "java.lang.Object.notify", + "java.lang.Object.clone", + "java.lang.Object.finalize" + ) + + blacklist(s.showFullName) || + s.isOneOf(Flags.GivenOrImplicit) || + // Cache objects, which you should probably never need to + // access directly, and apart from that have annoyingly long names + "cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decode.toString).isDefined || + // s.isDeprecated || + s.name.decode.toString == "" || + s.name.decode.toString.contains('$') + } + + val filteredCompletions = completions.filter { c => + c.symbols.isEmpty || c.symbols.exists(!blacklisted(_)) + } + val signatures = { + given Context = ctx1 + for { + c <- filteredCompletions + s <- c.symbols + isMethod = s.denot.is(Flags.Method) + if isMethod + } yield s"def ${s.name}${s.denot.info.widenTermRefExpr.show}" + } + val completionsWithTypes = { + given Context = ctx1 + for { + c <- filteredCompletions + s <- c.symbols + } yield (c.finalLabel, s.denot.kindString) + } + (start - prefix.length, filteredCompletions.map(_.finalLabel), signatures, completionsWithTypes) + } + + private def tryTypeCheck( + compiler: DottyCompiler, + initialCtx: Context, + src: Array[Byte], + fileName: String + ) = + val sourceFile = SourceFile.virtual(fileName, new String(src, StandardCharsets.UTF_8)) + + val reporter0 = Compiler.newStoreReporter() + val run = new Run( + compiler, + initialCtx.fresh + .addMode(Mode.ReadPositions | Mode.Interactive) + .setReporter(reporter0) + .setSetting(initialCtx.settings.YstopAfter, List("typer")) + ) + implicit val ctx: Context = run.runContext.withSource(sourceFile) + + val unit = + new CompilationUnit(ctx.source): + override def isSuspendable: Boolean = false + ctx + .run + .compileUnits(unit :: Nil, ctx) + + (unit.tpdTree, ctx) +} diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala index 21c1982f6..38aa7b508 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala @@ -7,7 +7,7 @@ import almond.interpreter.api.{CommHandler, OutputHandler} import almond.interpreter.input.InputManager import almond.interpreter.util.AsyncInterpreterOps import almond.logger.LoggerContext -import almond.protocol.KernelInfo +import almond.protocol.{KernelInfo, RawJson} import almond.toree.{CellMagicHook, LineMagicHook} import ammonite.compiler.Parsers import ammonite.repl.{ReplApiImpl => _, _} @@ -16,6 +16,8 @@ import ammonite.util.{Frame => _, _} import coursier.cache.shaded.dirs.{GetWinDirs, ProjectDirectories} import fastparse.Parsed +import java.nio.charset.StandardCharsets + import scala.util.control.NonFatal /** Holds bits of state for the interpreter, and implements [[almond.interpreter.Interpreter]]. */ @@ -196,20 +198,87 @@ final class ScalaInterpreter( override def complete(code: String, pos: Int): Completion = { - val (newPos, completions0, _) = ammInterp.compilerManager.complete( + val (newPos, completions0, _, completionsWithTypes) = ScalaInterpreterCompletions.complete( + ammInterp.compilerManager, + Some(ammInterp.dependencyComplete), pos, (ammInterp.predefImports ++ frames0().head.imports).toString(), - code + code, + logCtx ) val completions = completions0 .filter(!_.contains("$")) .filter(_.nonEmpty) + val metadata = + if (java.lang.Boolean.getBoolean("almond.completion.demo") && code.startsWith("// Demo")) + // Types from https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#completionItemKind + RawJson( + """{ + | "_jupyter_types_experimental": [ + | { + | "text": "AField", + | "type": "Field" + | }, + | { + | "text": "AMethod", + | "type": "Method" + | }, + | { + | "text": "AConstructor", + | "type": "Constructor" + | }, + | { + | "text": "AVariable", + | "type": "Variable" + | }, + | { + | "text": "AClass", + | "type": "Class" + | }, + | { + | "text": "AnInterface", + | "type": "Interface" + | }, + | { + | "text": "AModule", + | "type": "Module" + | }, + | { + | "text": "AProperty", + | "type": "Property" + | } + | ] + |} + |""".stripMargin.getBytes(StandardCharsets.UTF_8) + ) + else { + val elems = completionsWithTypes.map { + case (compl, tpe) => + val tpe0 = tpe match { + case "value" => "Field" + case _ => tpe.capitalize + } + ujson.Obj( + "text" -> ujson.Str(compl), + "type" -> ujson.Str(tpe0) + ) + } + if (elems.isEmpty) + RawJson.emptyObj + else { + val json = ujson.Obj("_jupyter_types_experimental" -> ujson.Arr(elems: _*)).render() + RawJson(json.getBytes(StandardCharsets.UTF_8)) + } + } + Completion( if (completions.isEmpty) pos else newPos, pos, - completions.map(_.trim).distinct + completions.map(_.trim).distinct, + None, + metadata = metadata ) } diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaInterpreterTests.scala b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaInterpreterTests.scala index 25d49983e..a74cfdacc 100644 --- a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaInterpreterTests.scala +++ b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaInterpreterTests.scala @@ -4,6 +4,7 @@ import java.nio.file.{Path, Paths} import almond.interpreter.api.DisplayData import almond.interpreter.{Completion, ExecuteResult, Interpreter} +import almond.protocol.RawJson import almond.testkit.TestLogging.logCtx import almond.TestUtil._ import almond.amm.AmmInterpreter @@ -150,6 +151,13 @@ object ScalaInterpreterTests extends TestSuite { } } + private implicit class TestCompletionOps(private val compl: Completion) extends AnyVal { + def clearMetadata: Completion = + compl.copy( + metadata = RawJson.emptyObj + ) + } + val tests = Tests { test("execute") { @@ -200,7 +208,7 @@ object ScalaInterpreterTests extends TestSuite { test { val code = "repl.la" val expectedRes = Completion(5, 7, Seq("lastException")) - val res = interpreter.complete(code, code.length) + val res = interpreter.complete(code, code.length).clearMetadata assert(res == expectedRes) } @@ -208,7 +216,7 @@ object ScalaInterpreterTests extends TestSuite { val code = "Lis" val expectedRes = Completion(0, 3, Seq("List")) val alternativeExpectedRes = Completion(0, 3, Seq("scala.List")) - val res0 = interpreter.complete(code, code.length) + val res0 = interpreter.complete(code, code.length).clearMetadata val res = res0.copy( completions = res0.completions.filter(expectedRes.completions.toSet) ) @@ -241,7 +249,7 @@ object ScalaInterpreterTests extends TestSuite { "scala.collection.mutable.HashMap" ) ++ extraCompletions ) - val res0 = interpreter.complete(code, code.length) + val res0 = interpreter.complete(code, code.length).clearMetadata val res = res0.copy( completions = res0.completions.filter(expectedRes.completions.toSet) ) diff --git a/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala b/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala index 2b07d4aff..04bdab2cf 100644 --- a/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala +++ b/modules/scala/test-definitions/src/main/scala/almond/integration/Tests.scala @@ -825,4 +825,337 @@ object Tests { ) } + def completion(scalaVersion: String)(implicit + sessionId: SessionId, + runner: Runner + ): Unit = + runner.withSession() { implicit session => + execute( + "val l = 1 :: 2 :: Nil", + "l: List[Int] = List(1, 2)" + ) + val res = complete( + "l.#" + ) + + val expectedMatches = Seq( + "!=", + "++", + "++:", + "+:", + ":+", + "::", + ":::", + "==", + "WithFilter", + "addString", + "aggregate", + "andThen", + "apply", + "applyOrElse", + "asInstanceOf", + "canEqual", + "collect", + "collectFirst", + "combinations", + "companion", + "compose", + "contains", + "containsSlice", + "copyToArray", + "copyToBuffer", + "corresponds", + "count", + "diff", + "distinct", + "drop", + "dropRight", + "dropWhile", + "endsWith", + "equals", + "exists", + "filter", + "filterNot", + "find", + "flatMap", + "flatten", + "fold", + "foldLeft", + "foldRight", + "forall", + "foreach", + "genericBuilder", + "getClass", + "groupBy", + "grouped", + "hasDefiniteSize", + "hashCode", + "head", + "headOption", + "indexOf", + "indexOfSlice", + "indexWhere", + "indices", + "init", + "inits", + "intersect", + "isDefinedAt", + "isEmpty", + "isInstanceOf", + "isTraversableAgain", + "iterator", + "last", + "lastIndexOf", + "lastIndexOfSlice", + "lastIndexWhere", + "lastOption", + "length", + "lengthCompare", + "lift", + "map", + "mapConserve", + "max", + "maxBy", + "min", + "minBy", + "mkString", + "nonEmpty", + "orElse", + "padTo", + "par", + "partition", + "patch", + "permutations", + "prefixLength", + "product", + "productArity", + "productElement", + "productIterator", + "productPrefix", + "reduce", + "reduceLeft", + "reduceLeftOption", + "reduceOption", + "reduceRight", + "reduceRightOption", + "repr", + "reverse", + "reverseIterator", + "reverseMap", + "reverse_:::", + "runWith", + "sameElements", + "scan", + "scanLeft", + "scanRight", + "segmentLength", + "seq", + "size", + "slice", + "sliding", + "sortBy", + "sortWith", + "sorted", + "span", + "splitAt", + "startsWith", + "stringPrefix", + "sum", + "tail", + "tails", + "take", + "takeRight", + "takeWhile", + "to", + "toArray", + "toBuffer", + "toIndexedSeq", + "toIterable", + "toIterator", + "toList", + "toMap", + "toParArray", + "toSeq", + "toSet", + "toStream", + "toString", + "toTraversable", + "toVector", + "transpose", + "union", + "unzip", + "unzip3", + "updated", + "view", + "withFilter", + "zip", + "zipAll", + "zipWithIndex" + ) + val matches = res.flatMap(_.matches) + expect(matches == expectedMatches) + + val expectedTypes = Seq( + ("!=", "Method"), + ("++", "Method"), + ("++:", "Method"), + ("+:", "Method"), + (":+", "Method"), + ("::", "Method"), + (":::", "Method"), + ("==", "Method"), + ("WithFilter", "Class"), + ("addString", "Method"), + ("aggregate", "Method"), + ("andThen", "Method"), + ("apply", "Method"), + ("applyOrElse", "Method"), + ("asInstanceOf", "Method"), + ("canEqual", "Method"), + ("collect", "Method"), + ("collectFirst", "Method"), + ("combinations", "Method"), + ("companion", "Method"), + ("compose", "Method"), + ("contains", "Method"), + ("containsSlice", "Method"), + ("copyToArray", "Method"), + ("copyToBuffer", "Method"), + ("corresponds", "Method"), + ("count", "Method"), + ("diff", "Method"), + ("distinct", "Method"), + ("drop", "Method"), + ("dropRight", "Method"), + ("dropWhile", "Method"), + ("endsWith", "Method"), + ("equals", "Method"), + ("exists", "Method"), + ("filter", "Method"), + ("filterNot", "Method"), + ("find", "Method"), + ("flatMap", "Method"), + ("flatten", "Method"), + ("fold", "Method"), + ("foldLeft", "Method"), + ("foldRight", "Method"), + ("forall", "Method"), + ("foreach", "Method"), + ("genericBuilder", "Method"), + ("getClass", "Method"), + ("groupBy", "Method"), + ("grouped", "Method"), + ("hasDefiniteSize", "Method"), + ("hashCode", "Method"), + ("head", "Method"), + ("headOption", "Method"), + ("indexOf", "Method"), + ("indexOfSlice", "Method"), + ("indexWhere", "Method"), + ("indices", "Method"), + ("init", "Method"), + ("inits", "Method"), + ("intersect", "Method"), + ("isDefinedAt", "Method"), + ("isEmpty", "Method"), + ("isInstanceOf", "Method"), + ("isTraversableAgain", "Method"), + ("iterator", "Method"), + ("last", "Method"), + ("lastIndexOf", "Method"), + ("lastIndexOfSlice", "Method"), + ("lastIndexWhere", "Method"), + ("lastOption", "Method"), + ("length", "Method"), + ("lengthCompare", "Method"), + ("lift", "Method"), + ("map", "Method"), + ("mapConserve", "Method"), + ("max", "Method"), + ("maxBy", "Method"), + ("min", "Method"), + ("minBy", "Method"), + ("mkString", "Method"), + ("nonEmpty", "Method"), + ("orElse", "Method"), + ("padTo", "Method"), + ("par", "Method"), + ("partition", "Method"), + ("patch", "Method"), + ("permutations", "Method"), + ("prefixLength", "Method"), + ("product", "Method"), + ("productArity", "Method"), + ("productElement", "Method"), + ("productIterator", "Method"), + ("productPrefix", "Method"), + ("reduce", "Method"), + ("reduceLeft", "Method"), + ("reduceLeftOption", "Method"), + ("reduceOption", "Method"), + ("reduceRight", "Method"), + ("reduceRightOption", "Method"), + ("repr", "Method"), + ("reverse", "Method"), + ("reverseIterator", "Method"), + ("reverseMap", "Method"), + ("reverse_:::", "Method"), + ("runWith", "Method"), + ("sameElements", "Method"), + ("scan", "Method"), + ("scanLeft", "Method"), + ("scanRight", "Method"), + ("segmentLength", "Method"), + ("seq", "Method"), + ("size", "Method"), + ("slice", "Method"), + ("sliding", "Method"), + ("sortBy", "Method"), + ("sortWith", "Method"), + ("sorted", "Method"), + ("span", "Method"), + ("splitAt", "Method"), + ("startsWith", "Method"), + ("stringPrefix", "Method"), + ("sum", "Method"), + ("tail", "Method"), + ("tails", "Method"), + ("take", "Method"), + ("takeRight", "Method"), + ("takeWhile", "Method"), + ("to", "Method"), + ("toArray", "Method"), + ("toBuffer", "Method"), + ("toIndexedSeq", "Method"), + ("toIterable", "Method"), + ("toIterator", "Method"), + ("toList", "Method"), + ("toMap", "Method"), + ("toParArray", "Method"), + ("toSeq", "Method"), + ("toSet", "Method"), + ("toStream", "Method"), + ("toString", "Method"), + ("toTraversable", "Method"), + ("toVector", "Method"), + ("transpose", "Method"), + ("union", "Method"), + ("unzip", "Method"), + ("unzip3", "Method"), + ("updated", "Method"), + ("view", "Method"), + ("withFilter", "Method"), + ("zip", "Method"), + ("zipAll", "Method"), + ("zipWithIndex", "Method") + ) + + val metadata = ujson.read(res.head.metadata.value) + val types = metadata.obj("_jupyter_types_experimental") + val types0 = types.arr.map { entry => + val entry0 = entry.obj + (entry0("text").str, entry0("type").str) + } + expect(types0 == expectedTypes) + } } diff --git a/modules/shared/interpreter/src/main/scala/almond/interpreter/Completion.scala b/modules/shared/interpreter/src/main/scala/almond/interpreter/Completion.scala index 826d52e11..510b93d66 100644 --- a/modules/shared/interpreter/src/main/scala/almond/interpreter/Completion.scala +++ b/modules/shared/interpreter/src/main/scala/almond/interpreter/Completion.scala @@ -16,12 +16,16 @@ final case class Completion( from: Int, until: Int, completions: Seq[String], + completionWithTypes: Option[Seq[(String, String)]], metadata: RawJson -) +) { + def withCompletionWithTypes(completionWithTypes: Seq[(String, String)]): Completion = + copy(completionWithTypes = Some(completionWithTypes)) +} object Completion { def apply(from: Int, until: Int, completions: Seq[String]): Completion = - Completion(from, until, completions, RawJson.emptyObj) + Completion(from, until, completions, None, RawJson.emptyObj) def empty(pos: Int): Completion = - Completion(pos, pos, Nil, RawJson.emptyObj) + Completion(pos, pos, Nil, None, RawJson.emptyObj) } diff --git a/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala b/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala index fb2f63d48..dc975f304 100644 --- a/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala +++ b/modules/shared/interpreter/src/test/scala/almond/interpreter/TestInterpreter.scala @@ -84,7 +84,7 @@ final class TestInterpreter extends Interpreter { ) } else if (code.startsWith("meta:")) { - val c = Completion(pos, pos, Seq("sent"), RawJson(code.drop("meta:".length).bytes)) + val c = Completion(pos, pos, Seq("sent"), None, RawJson(code.drop("meta:".length).bytes)) CancellableFuture(Future.successful(c), () => sys.error("should not happen")) } else diff --git a/modules/shared/test-kit/src/main/scala/almond/testkit/ClientStreams.scala b/modules/shared/test-kit/src/main/scala/almond/testkit/ClientStreams.scala index 382c1bf02..c4f314d4f 100644 --- a/modules/shared/test-kit/src/main/scala/almond/testkit/ClientStreams.scala +++ b/modules/shared/test-kit/src/main/scala/almond/testkit/ClientStreams.scala @@ -5,7 +5,7 @@ import almond.interpreter.Message import almond.interpreter.messagehandlers.MessageHandler import almond.protocol.Codecs.stringCodec import almond.protocol.Execute.DisplayData -import almond.protocol.{Execute, Inspect, MessageType, RawJson} +import almond.protocol.{Complete, Execute, Inspect, MessageType, RawJson} import cats.effect.IO import cats.effect.std.Queue import cats.effect.unsafe.IORuntime @@ -265,6 +265,16 @@ final case class ClientStreams( .flatten .toVector + def completeReplies: Seq[Complete.Reply] = + generatedMessages + .iterator + .collect { + case Left((Channel.Requests, m)) if m.header.msg_type == Complete.replyType.messageType => + m.decodeAs[Complete.Reply].toOption.map(_.content).toSeq + } + .flatten + .toVector + } object ClientStreams { diff --git a/modules/shared/test-kit/src/main/scala/almond/testkit/Dsl.scala b/modules/shared/test-kit/src/main/scala/almond/testkit/Dsl.scala index d88c9066b..cd75f7eaa 100644 --- a/modules/shared/test-kit/src/main/scala/almond/testkit/Dsl.scala +++ b/modules/shared/test-kit/src/main/scala/almond/testkit/Dsl.scala @@ -361,4 +361,48 @@ object Dsl { Inspect.Request(code, pos, if (detailed) 1 else 0) ).on(Channel.Requests) + def complete( + code: String, + pos: Int = -1 + )(implicit + sessionId: SessionId, + session: Session + ): Seq[Complete.Reply] = { + + val (code0, pos0) = + if (pos >= 0) (code, pos) + else { + val cursor = "#" + val idx = code.indexOf(cursor) + assert(idx >= 0, "Expected a # character in code to complete, at the cursor position") + (code.take(idx) + code.drop(idx + cursor.length), idx) + } + + val input = Stream( + completeMessage(code0, pos0) + ) + + val streams = ClientStreams.create(input, stopWhen(Complete.replyType.messageType)) + + session.run(streams) + + streams.completeReplies + } + + private def completeMessage( + code: String, + pos: Int, + msgId: String = UUID.randomUUID().toString + )(implicit sessionId: SessionId) = + Message( + Header( + msgId, + "test", + sessionId.sessionId, + Complete.requestType.messageType, + Some(Protocol.versionStr) + ), + Complete.Request(code, pos) + ).on(Channel.Requests) + }