Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20393][Webu UI] Strengthen Spark to prevent XSS vulnerabilities #17686

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

def render(request: HttpServletRequest): Seq[Node] = {
// stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean

val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete)
val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val id = Option(request.getParameter("id"))
// stripXSS is called first to remove suspicious characters used in XSS attacks
val killFlag =
Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
Expand Down
30 changes: 18 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024

// stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand All @@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}

import org.apache.commons.lang3.StringEscapeUtils

import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph

Expand All @@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"

private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r

// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat =
Expand Down Expand Up @@ -527,4 +531,21 @@ private[spark] object UIUtils extends Logging {
origHref
}
}

/**
* Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
*
* For more information about XSS testing:
* https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
* https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
*/
def stripXSS(requestParameter: String): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you would, add a couple brief tests of this to UIUtilsSuite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work on some.

if (requestParameter == null) {
null
} else {
// Remove new lines and single quotes, followed by escaping HTML version 4.0
StringEscapeUtils.escapeHtml4(
NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage

private val sc = parent.sc

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val executorId = Option(request.getParameter("executorId")).map { executorId =>
val executorId =
Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
jobTag: String,
jobs: Seq[JobUIData],
killEnabled: Boolean): Seq[Node] = {
val allParameters = request.getParameterMap.asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))

val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"

val parameterJobPage = request.getParameter(jobTag + ".page")
val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))

val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener

listener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val jobId = parameterId.toInt
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
Expand All @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val jobId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
jobId.foreach { id =>
if (jobProgresslistener.activeJobs.contains(id)) {
sc.foreach(_.cancelJob(id))
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {

def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = Option(request.getParameter("poolname")).map { poolname =>
// stripXSS is called first to remove suspicious characters used in XSS attacks
val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {

def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterAttempt = request.getParameter("attempt")
val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")

val parameterTaskPage = request.getParameter("task.page")
val parameterTaskSortColumn = request.getParameter("task.sort")
val parameterTaskSortDesc = request.getParameter("task.desc")
val parameterTaskPageSize = request.getParameter("task.pageSize")
val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize")
val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))

val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
val allParameters = request.getParameterMap().asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))

val parameterStagePage = request.getParameter(stageTag + ".page")
val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize")
val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
val parameterStagePrevPageSize =
UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))

val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
Expand Down Expand Up @@ -512,4 +514,3 @@ private[ui] class StageDataSource(
}
}
}

5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
Expand All @@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val stageId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
stageId.foreach { id =>
if (progressListener.activeStages.contains(id)) {
sc.foreach(_.cancelStage(id, "killed via the Web UI"))
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener

def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterBlockPage = request.getParameter("block.page")
val parameterBlockSortColumn = request.getParameter("block.sort")
val parameterBlockSortDesc = request.getParameter("block.desc")
val parameterBlockPageSize = request.getParameter("block.pageSize")
val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize")
val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))

val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
Expand Down
39 changes: 39 additions & 0 deletions core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded2 === decodeURLParameter(decoded2))
}

test("SPARK-20393: Prevent newline characters in parameters.") {
val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b"
val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b"

assert(stripEncoding === stripXSS(encoding))
}

test("SPARK-20393: Prevent script from parameters running on page.") {
val scriptAlert = """>"'><script>alert(401)<%2Fscript>"""
val stripScriptAlert = "&gt;&quot;&gt;&lt;script&gt;alert(401)&lt;%2Fscript&gt;"

assert(stripScriptAlert === stripXSS(scriptAlert))
}

test("SPARK-20393: Prevent javascript from parameters running on page.") {
val javascriptAlert =
"""app-20161208133404-0002<iframe+src%3Djavascript%3Aalert(1705)>"""
val stripJavascriptAlert =
"app-20161208133404-0002&lt;iframe+src%3Djavascript%3Aalert(1705)&gt;"

assert(stripJavascriptAlert === stripXSS(javascriptAlert))
}

test("SPARK-20393: Prevent links from parameters on page.") {
val link =
"""stdout'"><iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html>"""
val stripLink =
"stdout&quot;&gt;&lt;iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html&gt;"

assert(stripLink === stripXSS(link))
}

test("SPARK-20393: Prevent popups from parameters on page.") {
val popup = """stdout'%2Balert(60)%2B'"""
val stripPopup = "stdout%2Balert(60)%2B"

assert(stripPopup === stripXSS(popup))
}

private def verify(
desc: String,
expected: Node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {

override def render(request: HttpServletRequest): Seq[Node] = {
val driverId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")

val state = parent.scheduler.getDriverState(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val listener = parent.listener

override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized {
val parameterExecutionId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)

/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val content =
Expand Down Expand Up @@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true)
}
}