Skip to content

Commit

Permalink
Add class script wrapper (#2033)
Browse files Browse the repository at this point in the history
* Add class code wrapper for scala 3 scripts

* Move wrapping scripts to right before building

* Add warning when @main is used in scripts wrapped in class

* NIT Move mainClassObject def to CodeWrapper object

* Add new types to enforce wrapping scripts before building project
  • Loading branch information
MaciejG604 committed May 8, 2023
1 parent 9dee77c commit 77237f0
Show file tree
Hide file tree
Showing 27 changed files with 651 additions and 159 deletions.
10 changes: 6 additions & 4 deletions modules/build/src/main/scala/scala/build/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.build.errors.*
import scala.build.input.VirtualScript.VirtualScriptNameRegex
import scala.build.input.*
import scala.build.internal.resource.ResourceMapper
import scala.build.internal.{Constants, CustomCodeWrapper, MainClass, Util}
import scala.build.internal.{Constants, MainClass, Util}
import scala.build.options.ScalaVersionUtil.asVersion
import scala.build.options.*
import scala.build.options.validation.ValidationException
Expand Down Expand Up @@ -227,7 +227,6 @@ object Build {
CrossSources.forInputs(
inputs,
Sources.defaultPreprocessors(
options.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
options.archiveCache,
options.internal.javaClassNameVersionOpt,
() => options.javaHome().value.javaCommand
Expand Down Expand Up @@ -266,8 +265,11 @@ object Build {
overrideOptions: BuildOptions
): Either[BuildException, NonCrossBuilds] = either {

val baseOptions = overrideOptions.orElse(sharedOptions)
val scopedSources = value(crossSources.scopedSources(baseOptions))
val baseOptions = overrideOptions.orElse(sharedOptions)

val wrappedScriptsSources = crossSources.withWrappedScripts(baseOptions)

val scopedSources = value(wrappedScriptsSources.scopedSources(baseOptions))

val mainSources = scopedSources.sources(Scope.Main, baseOptions)
val mainOptions = mainSources.buildOptions
Expand Down
112 changes: 103 additions & 9 deletions modules/build/src/main/scala/scala/build/CrossSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,94 @@ import scala.build.testrunner.DynamicTestRunner.globPattern
import scala.util.Try
import scala.util.chaining.*

final case class CrossSources(
/** CrossSources with unwrapped scripts, use [[withWrappedScripts]] to wrap them and obtain an
* instance of CrossSources
*
* See [[CrossSources]] for more information
*
* @param paths
* paths and realtive paths to sources on disk, wrapped in their build requirements
* @param inMemory
* in memory sources (e.g. snippets) wrapped in their build requirements
* @param defaultMainClass
* @param resourceDirs
* @param buildOptions
* build options from sources
* @param unwrappedScripts
* in memory script sources, their code must be wrapped before compiling
*/
sealed class UnwrappedCrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
defaultMainClass: Option[String],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
buildOptions: Seq[WithBuildRequirements[BuildOptions]],
unwrappedScripts: Seq[WithBuildRequirements[Sources.UnwrappedScript]]
) {

/** For all unwrapped script sources contained in this object wrap them according to provided
* BuildOptions
*
* @param buildOptions
* options used to choose the script wrapper
* @return
* CrossSources with all the scripts wrapped
*/
def withWrappedScripts(buildOptions: BuildOptions): CrossSources = {
val codeWrapper = ScriptPreprocessor.getScriptWrapper(buildOptions)

val wrappedScripts = unwrappedScripts.map { unwrapppedWithRequirements =>
unwrapppedWithRequirements.map(_.wrap(codeWrapper))
}

CrossSources(
paths,
inMemory ++ wrappedScripts,
defaultMainClass,
resourceDirs,
this.buildOptions
)
}

def sharedOptions(baseOptions: BuildOptions): BuildOptions =
buildOptions
.filter(_.requirements.isEmpty)
.map(_.value)
.foldLeft(baseOptions)(_ orElse _)

private def needsScalaVersion =
protected def needsScalaVersion =
paths.exists(_.needsScalaVersion) ||
inMemory.exists(_.needsScalaVersion) ||
resourceDirs.exists(_.needsScalaVersion) ||
buildOptions.exists(_.needsScalaVersion)
}

/** Information gathered from preprocessing command inputs - sources and build options from using
* directives
*
* @param paths
* paths and realtive paths to sources on disk, wrapped in their build requirements
* @param inMemory
* in memory sources (e.g. snippets and wrapped scripts) wrapped in their build requirements
* @param defaultMainClass
* @param resourceDirs
* @param buildOptions
* build options from sources
*/
final case class CrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
defaultMainClass: Option[String],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
) extends UnwrappedCrossSources(
paths,
inMemory,
defaultMainClass,
resourceDirs,
buildOptions,
Nil
) {
def scopedSources(baseOptions: BuildOptions): Either[BuildException, ScopedSources] = either {

val sharedOptions0 = sharedOptions(baseOptions)
Expand Down Expand Up @@ -114,7 +182,6 @@ final case class CrossSources(
crossSources0.buildOptions.map(_.scopedValue(defaultScope))
)
}

}

object CrossSources {
Expand All @@ -141,7 +208,7 @@ object CrossSources {
suppressWarningOptions: SuppressWarningOptions,
exclude: Seq[Positioned[String]] = Nil,
maybeRecoverOnError: BuildException => Option[BuildException] = e => Some(e)
)(using ScalaCliInvokeData): Either[BuildException, (CrossSources, Inputs)] = either {
)(using ScalaCliInvokeData): Either[BuildException, (UnwrappedCrossSources, Inputs)] = either {

def preprocessSources(elems: Seq[SingleElement])
: Either[BuildException, Seq[PreprocessedSource]] =
Expand Down Expand Up @@ -262,6 +329,16 @@ object CrossSources {
Sources.InMemory(m.originalPath, m.relPath, m.code, m.ignoreLen)
) -> m.directivesPositions
}
val unwrappedScriptsWithDirectivePositions
: Seq[(WithBuildRequirements[Sources.UnwrappedScript], Option[DirectivesPositions])] =
preprocessedSources.collect {
case m: PreprocessedSource.UnwrappedScript =>
val baseReqs0 = baseReqs(m.scopePath)
WithBuildRequirements(
m.requirements.fold(baseReqs0)(_ orElse baseReqs0),
Sources.UnwrappedScript(m.originalPath, m.relPath, m.wrapScriptFun)
) -> m.directivesPositions
}

val resourceDirs: Seq[WithBuildRequirements[os.Path]] = allInputs.elements.collect {
case r: ResourceDirectory =>
Expand All @@ -271,14 +348,20 @@ object CrossSources {
)

lazy val allPathsWithDirectivesByScope: Map[Scope, Seq[(os.Path, DirectivesPositions)]] =
(pathsWithDirectivePositions ++ inMemoryWithDirectivePositions)
(pathsWithDirectivePositions ++
inMemoryWithDirectivePositions ++
unwrappedScriptsWithDirectivePositions)
.flatMap { (withBuildRequirements, directivesPositions) =>
val scope = withBuildRequirements.scopedValue(Scope.Main).scope
val path: os.Path = withBuildRequirements.value match
case im: Sources.InMemory =>
im.originalPath match
case Right((_, p: os.Path)) => p
case _ => inputs.workspace / im.generatedRelPath
case us: Sources.UnwrappedScript =>
us.originalPath match
case Right((_, p: os.Path)) => p
case _ => inputs.workspace / us.generatedRelPath
case (p: os.Path, _) => p
directivesPositions.map((path, scope, _))
}
Expand Down Expand Up @@ -306,9 +389,20 @@ object CrossSources {
}
}

val paths = pathsWithDirectivePositions.map(_._1)
val inMemory = inMemoryWithDirectivePositions.map(_._1)
(CrossSources(paths, inMemory, defaultMainClassOpt, resourceDirs, buildOptions), allInputs)
val paths = pathsWithDirectivePositions.map(_._1)
val inMemory = inMemoryWithDirectivePositions.map(_._1)
val unwrappedScripts = unwrappedScriptsWithDirectivePositions.map(_._1)
(
UnwrappedCrossSources(
paths,
inMemory,
defaultMainClassOpt,
resourceDirs,
buildOptions,
unwrappedScripts
),
allInputs
)
}

private def resolveInputsFromSources(sources: Seq[Positioned[os.Path]], enableMarkdown: Boolean) =
Expand Down
14 changes: 12 additions & 2 deletions modules/build/src/main/scala/scala/build/Sources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ object Sources {
topWrapperLen: Int
)

final case class UnwrappedScript(
originalPath: Either[String, (os.SubPath, os.Path)],
generatedRelPath: os.RelPath,
wrapScriptFun: CodeWrapper => (String, Int)
) {
def wrap(wrapper: CodeWrapper): InMemory = {
val (content, topWrapperLen) = wrapScriptFun(wrapper)
InMemory(originalPath, generatedRelPath, content, topWrapperLen)
}
}

/** The default preprocessor list.
*
* @param codeWrapper
Expand All @@ -86,13 +97,12 @@ object Sources {
* @return
*/
def defaultPreprocessors(
codeWrapper: CodeWrapper,
archiveCache: ArchiveCache[Task],
javaClassNameVersionOpt: Option[String],
javaCommand: () => String
): Seq[Preprocessor] =
Seq(
ScriptPreprocessor(codeWrapper),
ScriptPreprocessor,
MarkdownPreprocessor,
JavaPreprocessor(archiveCache, javaClassNameVersionOpt, javaCommand),
ScalaPreprocessor,
Expand Down
13 changes: 12 additions & 1 deletion modules/build/src/main/scala/scala/build/bsp/BspClient.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala.build.bsp

import ch.epfl.scala.{bsp4j => b}
import ch.epfl.scala.bsp4j as b

import java.lang.Boolean as JBoolean
import java.net.URI
Expand All @@ -10,6 +10,7 @@ import java.util.concurrent.{ConcurrentHashMap, ExecutorService}
import scala.build.Position.File
import scala.build.bsp.protocol.TextEdit
import scala.build.errors.{BuildException, CompositeBuildException, Diagnostic, Severity}
import scala.build.internal.util.WarningMessages
import scala.build.postprocessing.LineConversion
import scala.build.{BloopBuildClient, GeneratedSource, Logger}
import scala.jdk.CollectionConverters.*
Expand Down Expand Up @@ -48,6 +49,16 @@ class BspClient(
val diag0 = diag.duplicate()
diag0.getRange.getStart.setLine(startLine)
diag0.getRange.getEnd.setLine(endLine)

if (
diag0.getMessage.contains(
"cannot be a main method since it cannot be accessed statically"
)
)
diag0.setMessage(
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
)

diag0
}
updatedDiagOpt.getOrElse(diag)
Expand Down
16 changes: 10 additions & 6 deletions modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.build.errors.{
ParsingInputsException
}
import scala.build.input.{Inputs, ScalaCliInvokeData}
import scala.build.internal.{Constants, CustomCodeWrapper}
import scala.build.internal.Constants
import scala.build.options.{BuildOptions, Scope}
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.DurationInt
Expand Down Expand Up @@ -101,7 +101,6 @@ final class BspImpl(
CrossSources.forInputs(
inputs = inputs,
preprocessors = Sources.defaultPreprocessors(
buildOptions.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
buildOptions.archiveCache,
buildOptions.internal.javaClassNameVersionOpt,
() => buildOptions.javaHome().value.javaCommand
Expand All @@ -113,16 +112,21 @@ final class BspImpl(
).left.map((_, Scope.Main))
}

val wrappedScriptsSources = crossSources.withWrappedScripts(buildOptions)

if (verbosity >= 3)
pprint.err.log(crossSources)
pprint.err.log(wrappedScriptsSources)

val scopedSources = value(crossSources.scopedSources(buildOptions).left.map((_, Scope.Main)))
val scopedSources =
value(wrappedScriptsSources.scopedSources(buildOptions).left.map((_, Scope.Main)))

if (verbosity >= 3)
pprint.err.log(scopedSources)

val sourcesMain = scopedSources.sources(Scope.Main, crossSources.sharedOptions(buildOptions))
val sourcesTest = scopedSources.sources(Scope.Test, crossSources.sharedOptions(buildOptions))
val sourcesMain =
scopedSources.sources(Scope.Main, wrappedScriptsSources.sharedOptions(buildOptions))
val sourcesTest =
scopedSources.sources(Scope.Test, wrappedScriptsSources.sharedOptions(buildOptions))

if (verbosity >= 3)
pprint.err.log(sourcesMain)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package scala.build.internal

/** Script code wrapper that solves problem of deadlocks when using threads. The code is placed in a
* class instance constructor, the created object is kept in 'mainObjectCode'.script to support
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {
private val userCodeNestingLevel = 1
def apply(
code: String,
pkgName: Seq[Name],
indexedWrapperName: Name,
extraCode: String,
scriptPath: String
) = {
val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
| private var args$$opt0 = Option.empty[Array[String]]
| def args$$set(args: Array[String]): Unit = {
| args$$opt0 = Some(args)
| }
| def args$$opt: Option[Array[String]] = args$$opt0
| def args$$: Array[String] = args$$opt.getOrElse {
| sys.error("No arguments passed to this script")
| }
|
| lazy val script = new $wrapperClassName
|
| def main(args: Array[String]): Unit = {
| args$$set(args)
| script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|
|export $name.script as ${indexedWrapperName.backticked}
|""".stripMargin)

val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"

// indentation is important in the generated code, so we don't want scalafmt to touch that
// format: off
val top = AmmUtil.normalizeNewlines(s"""
$packageDirective


final class $wrapperClassName {
def args = $name.args$$
def scriptPath = \"\"\"$scriptPath\"\"\"
""")
val bottom = AmmUtil.normalizeNewlines(s"""
$extraCode
}

$mainObjectCode
""")
// format: on

(top, bottom, userCodeNestingLevel)
}
}

0 comments on commit 77237f0

Please sign in to comment.