Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@

package org.apache.spark.deploy.rest

import java.util.EnumSet

import scala.io.Source

import com.fasterxml.jackson.core.JsonProcessingException
import jakarta.servlet.DispatcherType
import jakarta.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector}
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.eclipse.jetty.servlet.{FilterHolder, ServletContextHandler, ServletHolder}
import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler}
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -111,12 +115,26 @@ private[spark] abstract class RestSubmissionServer(
contextToServlet.foreach { case (prefix, servlet) =>
mainHandler.addServlet(new ServletHolder(servlet), prefix)
}
addFilters(mainHandler)
server.setHandler(mainHandler)
server.start()
val boundPort = connector.getLocalPort
(server, boundPort)
}

/**
* Add filters, if any, to the given ServletContextHandlers.
*/
private def addFilters(handler: ServletContextHandler): Unit = {
masterConf.get(MASTER_REST_SERVER_FILTERS).foreach { filter =>
val params = masterConf.getAllWithPrefix(s"spark.$filter.param.").toMap
val holder = new FilterHolder()
holder.setClassName(filter)
params.foreach { case (k, v) => holder.setInitParameter(k, v) }
handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType]))
}
}

def stop(): Unit = {
_server.foreach(_.stop())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,13 @@ package object config {
.intConf
.createWithDefault(6066)

private[spark] val MASTER_REST_SERVER_FILTERS = ConfigBuilder("spark.master.rest.filters")
.doc("Comma separated list of filter class names to apply to the Spark Master REST API.")
.version("4.0.0")
.stringConf
.toSequence
.createWithDefault(Nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any user-facing documentation for this config?


private[spark] val MASTER_UI_PORT = ConfigBuilder("spark.master.ui.port")
.version("1.1.0")
.intConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.deploy.rest
import java.io.DataOutputStream
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.Base64

import scala.collection.mutable

Expand All @@ -32,6 +33,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState._
import org.apache.spark.deploy.master.RecoveryState
import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS
import org.apache.spark.rpc._
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -481,6 +483,55 @@ class StandaloneRestSubmitSuite extends SparkFunSuite {
assert(desc.command.javaOpts.exists(_.startsWith("--add-opens")))
}

test("SPARK-49103: `spark.master.rest.filters` loads filters successfully") {
val conf = new SparkConf()
val localhost = Utils.localHostName()
val securityManager = new SecurityManager(conf)
rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager))
val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get))

// Causes exceptions in order to verify new configuration loads filters successfully
conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter")
server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077"))
val m = intercept[IllegalArgumentException] {
server.get.start()
}.getMessage()
assert(m.contains("Decode argument cannot be null"))
}

private val TEST_KEY = Base64.getUrlEncoder.encodeToString(
"Visit https://spark.apache.org to download Apache Spark.".getBytes())

test("SPARK-49103: REST server stars successfully with `spark.master.rest.filters`") {
val conf = new SparkConf()
val localhost = Utils.localHostName()
val securityManager = new SecurityManager(conf)
rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager))
val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get))
conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter")
conf.set("spark.org.apache.spark.ui.JWSFilter.param.key", TEST_KEY)
server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077"))
server.get.start()
}

test("SPARK-49103: JWSFilter successfully protects REST API via configurations") {
val conf = new SparkConf()
val localhost = Utils.localHostName()
val securityManager = new SecurityManager(conf)
rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager))
val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get))
conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter")
conf.set("spark.org.apache.spark.ui.JWSFilter.param.key", TEST_KEY)
server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077"))
val port = server.get.start()
val masterUrl = s"spark://$localhost:$port"
val json = constructSubmitRequest(masterUrl).toJson
val httpUrl = masterUrl.replace("spark://", "http://")
val submitRequestPath = s"$httpUrl/${RestSubmissionServer.PROTOCOL_VERSION}/submissions/create"
val conn = sendHttpRequest(submitRequestPath, "POST", json)
assert(conn.getResponseCode === HttpServletResponse.SC_FORBIDDEN)
}

/* --------------------- *
| Helper methods |
* --------------------- */
Expand Down