Skip to content

Commit

Permalink
Invalid subcommand arg (#1811)
Browse files Browse the repository at this point in the history
* Add better input error descriptions

* Update documentation

* Make invoke data an implicit parameter for inputs functions

* Add shell checking for whether shebang is valid

* NIT Refactor implicit to given/using, simplify validate args if clauses
  • Loading branch information
MaciejG604 committed Feb 14, 2023
1 parent c200f61 commit 816f91b
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 112 deletions.
164 changes: 101 additions & 63 deletions modules/build/src/main/scala/scala/build/input/Inputs.scala
Expand Up @@ -262,62 +262,96 @@ object Inputs {
download: String => Either[String, Array[Byte]],
stdinOpt: => Option[Array[Byte]],
acceptFds: Boolean,
enableMarkdown: Boolean,
isRunWithShebang: Boolean
): Seq[Either[String, Seq[Element]]] = args.zipWithIndex.map {
case (arg, idx) =>
lazy val path = os.Path(arg, cwd)
lazy val dir = path / os.up
lazy val subPath = path.subRelativeTo(dir)
lazy val stdinOpt0 = stdinOpt
lazy val content = os.read.bytes(path)
if (arg == "-.scala" || arg == "_" || arg == "_.scala") && stdinOpt0.nonEmpty then
Right(Seq(VirtualScalaFile(stdinOpt0.get, "<stdin>-scala-file")))
else if (arg == "-.java" || arg == "_.java") && stdinOpt0.nonEmpty then
Right(Seq(VirtualJavaFile(stdinOpt0.get, "<stdin>-java-file")))
else if (arg == "-" || arg == "-.sc" || arg == "_.sc") && stdinOpt0.nonEmpty then
Right(Seq(VirtualScript(stdinOpt0.get, "stdin", os.sub / "stdin.sc")))
else if (arg == "-.md" || arg == "_.md") && stdinOpt0.nonEmpty then
Right(Seq(VirtualMarkdownFile(stdinOpt0.get, "<stdin>-markdown-file", os.sub / "stdin.md")))
else if arg.endsWith(".zip") && os.exists(path) then
Right(resolveZipArchive(content, enableMarkdown))
else if arg.contains("://") then {
val isGithubGist = githubGistsArchiveRegex.findFirstMatchIn(arg).nonEmpty
val url = if isGithubGist then s"$arg/download" else arg
download(url).map { urlContent =>
if isGithubGist then resolveZipArchive(urlContent, enableMarkdown)
else List(Virtual(url, urlContent))
enableMarkdown: Boolean
)(using programInvokeData: ScalaCliInvokeData): Seq[Either[String, Seq[Element]]] =
args.zipWithIndex.map {
case (arg, idx) =>
lazy val path = os.Path(arg, cwd)
lazy val dir = path / os.up
lazy val subPath = path.subRelativeTo(dir)
lazy val stdinOpt0 = stdinOpt
lazy val content = os.read.bytes(path)
lazy val fullProgramCall = programInvokeData.progName +
s"${
if programInvokeData.subCommand == SubCommand.Default then ""
else s" ${programInvokeData.subCommandName}"
}"
val unrecognizedSourceError =
s"$arg: unrecognized source type (expected .scala or .sc extension, or a directory)"

if (arg == "-.scala" || arg == "_" || arg == "_.scala") && stdinOpt0.nonEmpty then
Right(Seq(VirtualScalaFile(stdinOpt0.get, "<stdin>-scala-file")))
else if (arg == "-.java" || arg == "_.java") && stdinOpt0.nonEmpty then
Right(Seq(VirtualJavaFile(stdinOpt0.get, "<stdin>-java-file")))
else if (arg == "-" || arg == "-.sc" || arg == "_.sc") && stdinOpt0.nonEmpty then
Right(Seq(VirtualScript(stdinOpt0.get, "stdin", os.sub / "stdin.sc")))
else if (arg == "-.md" || arg == "_.md") && stdinOpt0.nonEmpty then
Right(Seq(VirtualMarkdownFile(
stdinOpt0.get,
"<stdin>-markdown-file",
os.sub / "stdin.md"
)))
else if arg.endsWith(".zip") && os.exists(path) then
Right(resolveZipArchive(content, enableMarkdown))
else if arg.contains("://") then {
val isGithubGist = githubGistsArchiveRegex.findFirstMatchIn(arg).nonEmpty
val url = if isGithubGist then s"$arg/download" else arg
download(url).map { urlContent =>
if isGithubGist then resolveZipArchive(urlContent, enableMarkdown)
else List(Virtual(url, urlContent))
}
}
}
else if path.last == Constants.projectFileName then Right(Seq(ProjectScalaFile(dir, subPath)))
else if arg.endsWith(".sc") then Right(Seq(Script(dir, subPath)))
else if arg.endsWith(".scala") then Right(Seq(SourceScalaFile(dir, subPath)))
else if arg.endsWith(".java") then Right(Seq(JavaFile(dir, subPath)))
else if arg.endsWith(".jar") then Right(Seq(JarFile(dir, subPath)))
else if arg.endsWith(".c") || arg.endsWith(".h") then Right(Seq(CFile(dir, subPath)))
else if arg.endsWith(".md") then Right(Seq(MarkdownFile(dir, subPath)))
else if os.isDir(path) then Right(Seq(Directory(path)))
else if acceptFds && arg.startsWith("/dev/fd/") then
Right(Seq(VirtualScript(content, arg, os.sub / s"input-${idx + 1}.sc")))
else if isRunWithShebang && os.exists(path) then
if isShebangScript(String(content)) then Right(Seq(Script(dir, subPath)))
else
Left(s"""$arg does not contain shebang header
|possible fixes:
| Add '#!/usr/bin/env scala-cli shebang' to the top of the file
| Add extension to the file's name e.q. '.sc'
|""".stripMargin)
else {
val msg =
if os.exists(path) then
if isShebangScript(String(content)) then
s"$arg scripts with no file extension should be run with 'scala-cli shebang'"
else if path.last == Constants.projectFileName then
Right(Seq(ProjectScalaFile(dir, subPath)))
else if arg.endsWith(".sc") then Right(Seq(Script(dir, subPath)))
else if arg.endsWith(".scala") then Right(Seq(SourceScalaFile(dir, subPath)))
else if arg.endsWith(".java") then Right(Seq(JavaFile(dir, subPath)))
else if arg.endsWith(".jar") then Right(Seq(JarFile(dir, subPath)))
else if arg.endsWith(".c") || arg.endsWith(".h") then Right(Seq(CFile(dir, subPath)))
else if arg.endsWith(".md") then Right(Seq(MarkdownFile(dir, subPath)))
else if os.isDir(path) then Right(Seq(Directory(path)))
else if acceptFds && arg.startsWith("/dev/fd/") then
Right(Seq(VirtualScript(content, arg, os.sub / s"input-${idx + 1}.sc")))
else if programInvokeData.subCommand == SubCommand.Shebang && os.exists(path) then
if isShebangScript(String(content)) then Right(Seq(Script(dir, subPath)))
else
Left(if programInvokeData.isShebangCapableShell then
s"""$unrecognizedSourceError,
|to use a script with no file extensions add shebang header pointing to
|'$fullProgramCall' to the top of the file
|""".stripMargin
else unrecognizedSourceError)
else {
val msg =
if os.exists(path) then
programInvokeData match {
case ScalaCliInvokeData(progName, _, _, true)
if isShebangScript(String(content)) =>
s"""$arg: scripts with no file extension should be run with
|'$progName shebang'
|""".stripMargin
case ScalaCliInvokeData(progName, _, _, true) =>
s"""$unrecognizedSourceError,
|if it's meant to be a script add a shebang header pointing to
|'$progName shebang' in the top line
|and run the source with '$progName shebang'
|""".stripMargin
case _ => unrecognizedSourceError
}
else if programInvokeData.subCommand == SubCommand.Default && idx == 0 && arg.forall(
_.isLetterOrDigit
)
then
s"""$arg is not a ${programInvokeData.progName} sub-command and it is not a valid path to an input file or directory
|Try '${programInvokeData.progName} --help' to see the list of available sub-commands and options
|""".stripMargin
else
s"$arg: unrecognized source type (expected .scala or .sc extension, or a directory)"
else s"$arg: not found"
Left(msg)
}
}
s"""$arg: file not found
|Try '$fullProgramCall --help' for usage information
|""".stripMargin
Left(msg)
}
}

private def forNonEmptyArgs(
args: Seq[String],
Expand All @@ -333,11 +367,17 @@ object Inputs {
forcedWorkspace: Option[os.Path],
enableMarkdown: Boolean,
allowRestrictedFeatures: Boolean,
extraClasspathWasPassed: Boolean,
isRunWithShebang: Boolean
): Either[BuildException, Inputs] = {
extraClasspathWasPassed: Boolean
)(using ScalaCliInvokeData): Either[BuildException, Inputs] = {
val validatedArgs: Seq[Either[String, Seq[Element]]] =
validateArgs(args, cwd, download, stdinOpt, acceptFds, enableMarkdown, isRunWithShebang)
validateArgs(
args,
cwd,
download,
stdinOpt,
acceptFds,
enableMarkdown
)
val validatedSnippets: Seq[Either[String, Seq[Element]]] =
validateSnippets(scriptSnippetList, scalaSnippetList, javaSnippetList, markdownSnippetList)
val validatedArgsAndSnippets = validatedArgs ++ validatedSnippets
Expand Down Expand Up @@ -378,9 +418,8 @@ object Inputs {
forcedWorkspace: Option[os.Path] = None,
enableMarkdown: Boolean = false,
allowRestrictedFeatures: Boolean,
extraClasspathWasPassed: Boolean,
isRunWithShebang: Boolean
): Either[BuildException, Inputs] =
extraClasspathWasPassed: Boolean
)(using ScalaCliInvokeData): Either[BuildException, Inputs] =
if (
args.isEmpty && scriptSnippetList.isEmpty && scalaSnippetList.isEmpty && javaSnippetList.isEmpty &&
markdownSnippetList.isEmpty && !extraClasspathWasPassed
Expand All @@ -403,8 +442,7 @@ object Inputs {
forcedWorkspace,
enableMarkdown,
allowRestrictedFeatures,
extraClasspathWasPassed,
isRunWithShebang
extraClasspathWasPassed
)

def default(): Option[Inputs] = None
Expand Down
@@ -0,0 +1,25 @@
package scala.build.input

/** Stores information about how the program has been evoked
*
* @param progName
* the actual Scala CLI program name which was run
* @param subCommandName
* the name of the sub-command that was invoked by user
* @param subCommand
* the type of the sub-command that was invoked by user
* @param isShebangCapableShell
* does the host shell support shebang headers
*/

case class ScalaCliInvokeData(
progName: String,
subCommandName: String,
subCommand: SubCommand,
isShebangCapableShell: Boolean
)

enum SubCommand:
case Default extends SubCommand
case Shebang extends SubCommand
case Other extends SubCommand
Expand Up @@ -5,7 +5,7 @@ import scala.build.blooprifle.BloopRifleConfig
import scala.build.{Build, BuildThreads, Directories}
import scala.build.compiler.{BloopCompilerMaker, SimpleScalaCompilerMaker}
import scala.build.errors.BuildException
import scala.build.input.Inputs
import scala.build.input.{Inputs, ScalaCliInvokeData, SubCommand}
import scala.build.internal.Util
import scala.build.options.BuildOptions
import scala.util.control.NonFatal
Expand Down Expand Up @@ -44,9 +44,8 @@ final case class TestInputs(
tmpDir,
forcedWorkspace = forcedWorkspaceOpt.map(_.resolveFrom(tmpDir)),
allowRestrictedFeatures = true,
extraClasspathWasPassed = false,
isRunWithShebang = false
)
extraClasspathWasPassed = false
)(using ScalaCliInvokeData("", "", SubCommand.Other, false))
res match {
case Left(err) => throw new Exception(err)
case Right(inputs) => f(tmpDir, inputs)
Expand Down
12 changes: 12 additions & 0 deletions modules/cli/src/main/scala/scala/cli/commands/ScalaCommand.scala
Expand Up @@ -14,13 +14,15 @@ import scala.annotation.tailrec
import scala.build.EitherCps.{either, value}
import scala.build.compiler.SimpleScalaCompiler
import scala.build.errors.BuildException
import scala.build.input.{ScalaCliInvokeData, SubCommand}
import scala.build.internal.{Constants, Runner}
import scala.build.options.{BuildOptions, ScalacOpt, Scope}
import scala.build.{Artifacts, Logger, Positioned, ReplArtifacts}
import scala.cli.commands.default.LegacyScalaOptions
import scala.cli.commands.shared.{HasLoggingOptions, ScalaCliHelp, ScalacOptions, SharedOptions}
import scala.cli.commands.util.CommandHelpers
import scala.cli.commands.util.ScalacOptionsUtil.*
import scala.cli.internal.ProcUtil
import scala.cli.{CurrentParams, ScalaCli}
import scala.util.{Properties, Try}

Expand Down Expand Up @@ -69,6 +71,16 @@ abstract class ScalaCommand[T <: HasLoggingOptions](implicit myParser: Parser[T]
protected def actualFullCommand: String =
if actualCommandName.nonEmpty then s"$progName $actualCommandName" else progName

protected def invokeData: ScalaCliInvokeData =
ScalaCliInvokeData(
progName,
actualCommandName,
SubCommand.Other,
ProcUtil.isShebangCapableShell
)

given ScalaCliInvokeData = invokeData

override def error(message: Error): Nothing = {
System.err.println(
s"""${message.message}
Expand Down
Expand Up @@ -21,8 +21,7 @@ object Clean extends ScalaCommand[CleanOptions] {
defaultInputs = () => Inputs.default(),
forcedWorkspace = options.workspace.forcedWorkspaceOpt,
allowRestrictedFeatures = ScalaCli.allowRestrictedFeatures,
extraClasspathWasPassed = false,
isRunWithShebang = false
extraClasspathWasPassed = false
) match {
case Left(message) =>
System.err.println(message)
Expand Down
Expand Up @@ -4,6 +4,7 @@ import caseapp.core.help.{Help, HelpCompanion, RuntimeCommandsHelp}
import caseapp.core.{Error, RemainingArgs}

import scala.build.Logger
import scala.build.input.{Inputs, ScalaCliInvokeData, SubCommand}
import scala.build.internal.Constants
import scala.build.options.BuildOptions
import scala.cli.CurrentParams
Expand Down Expand Up @@ -41,6 +42,9 @@ class Default(

private[cli] var rawArgs = Array.empty[String]

override def invokeData: ScalaCliInvokeData =
super.invokeData.copy(subCommand = SubCommand.Default)

override def runCommand(options: DefaultOptions, args: RemainingArgs, logger: Logger): Unit =
// can't fully re-parse and redirect to Version because of --cli-version and --scala-version clashing
if options.version then Version.runCommand(VersionOptions(options.shared.logging), args, logger)
Expand All @@ -55,5 +59,13 @@ class Default(
}.parse(options.legacyScala.filterNonDeprecatedArgs(rawArgs, progName, logger)) match
case Left(e) => error(e)
case Right((replOptions: ReplOptions, _)) => Repl.runCommand(replOptions, args, logger)
case Right((runOptions: RunOptions, _)) => Run.runCommand(runOptions, args, logger)
case Right((runOptions: RunOptions, _)) =>
Run.runCommand(
runOptions,
args.remaining,
args.unparsed,
() => Inputs.default(),
logger,
invokeData
)
}
12 changes: 7 additions & 5 deletions modules/cli/src/main/scala/scala/cli/commands/run/Run.scala
Expand Up @@ -10,7 +10,7 @@ import java.util.concurrent.CompletableFuture
import scala.build.EitherCps.{either, value}
import scala.build.*
import scala.build.errors.BuildException
import scala.build.input.Inputs
import scala.build.input.{Inputs, ScalaCliInvokeData}
import scala.build.internal.{Constants, Runner, ScalaJsLinkerConfig}
import scala.build.options.{BuildOptions, JavaOpt, Platform, ScalacOpt}
import scala.cli.CurrentParams
Expand Down Expand Up @@ -55,7 +55,8 @@ object Run extends ScalaCommand[RunOptions] with BuildCommandHelpers {
args.remaining,
args.unparsed,
() => Inputs.default(),
logger
logger,
invokeData
)

override def buildOptions(options: RunOptions): Some[BuildOptions] = Some {
Expand Down Expand Up @@ -107,14 +108,15 @@ object Run extends ScalaCommand[RunOptions] with BuildCommandHelpers {
programArgs: Seq[String],
defaultInputs: () => Option[Inputs],
logger: Logger,
isRunWithShebang: Boolean = false
invokeData: ScalaCliInvokeData
): Unit = {
val initialBuildOptions = buildOptionsOrExit(options)

val inputs = options.shared.inputs(
inputArgs,
defaultInputs = defaultInputs,
isRunWithShebang
defaultInputs
)(
using invokeData
).orExit(logger)
CurrentParams.workspaceOpt = Some(inputs.workspace)
val threads = BuildThreads.create()
Expand Down

0 comments on commit 816f91b

Please sign in to comment.