diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 8e534828e7778..7efab73726ef8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy.rest +import java.util.EnumSet + import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException +import jakarta.servlet.DispatcherType import jakarta.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector} -import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} +import org.eclipse.jetty.servlet.{FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -30,6 +33,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ +import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS import org.apache.spark.util.Utils /** @@ -111,12 +115,26 @@ private[spark] abstract class RestSubmissionServer( contextToServlet.foreach { case (prefix, servlet) => mainHandler.addServlet(new ServletHolder(servlet), prefix) } + addFilters(mainHandler) server.setHandler(mainHandler) server.start() val boundPort = connector.getLocalPort (server, boundPort) } + /** + * Add filters, if any, to the given ServletContextHandlers. + */ + private def addFilters(handler: ServletContextHandler): Unit = { + masterConf.get(MASTER_REST_SERVER_FILTERS).foreach { filter => + val params = masterConf.getAllWithPrefix(s"spark.$filter.param.").toMap + val holder = new FilterHolder() + holder.setClassName(filter) + params.foreach { case (k, v) => holder.setInitParameter(k, v) } + handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType])) + } + } + def stop(): Unit = { _server.foreach(_.stop()) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6de500024816f..c40645065bb5d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1964,6 +1964,13 @@ package object config { .intConf .createWithDefault(6066) + private[spark] val MASTER_REST_SERVER_FILTERS = ConfigBuilder("spark.master.rest.filters") + .doc("Comma separated list of filter class names to apply to the Spark Master REST API.") + .version("4.0.0") + .stringConf + .toSequence + .createWithDefault(Nil) + private[spark] val MASTER_UI_PORT = ConfigBuilder("spark.master.ui.port") .version("1.1.0") .intConf diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 15ef2fa7e6d00..1d17ae45f3eea 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.collection.mutable @@ -32,6 +33,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ import org.apache.spark.deploy.master.RecoveryState +import org.apache.spark.internal.config.MASTER_REST_SERVER_FILTERS import org.apache.spark.rpc._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -481,6 +483,55 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(desc.command.javaOpts.exists(_.startsWith("--add-opens"))) } + test("SPARK-49103: `spark.master.rest.filters` loads filters successfully") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + + // Causes exceptions in order to verify new configuration loads filters successfully + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + val m = intercept[IllegalArgumentException] { + server.get.start() + }.getMessage() + assert(m.contains("Decode argument cannot be null")) + } + + private val TEST_KEY = Base64.getUrlEncoder.encodeToString( + "Visit https://spark.apache.org to download Apache Spark.".getBytes()) + + test("SPARK-49103: REST server stars successfully with `spark.master.rest.filters`") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + conf.set("spark.org.apache.spark.ui.JWSFilter.param.key", TEST_KEY) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + server.get.start() + } + + test("SPARK-49103: JWSFilter successfully protects REST API via configurations") { + val conf = new SparkConf() + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + rpcEnv = Some(RpcEnv.create("rest-with-filter", localhost, 0, conf, securityManager)) + val fakeMasterRef = rpcEnv.get.setupEndpoint("fake-master", new DummyMaster(rpcEnv.get)) + conf.set(MASTER_REST_SERVER_FILTERS.key, "org.apache.spark.ui.JWSFilter") + conf.set("spark.org.apache.spark.ui.JWSFilter.param.key", TEST_KEY) + server = Some(new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077")) + val port = server.get.start() + val masterUrl = s"spark://$localhost:$port" + val json = constructSubmitRequest(masterUrl).toJson + val httpUrl = masterUrl.replace("spark://", "http://") + val submitRequestPath = s"$httpUrl/${RestSubmissionServer.PROTOCOL_VERSION}/submissions/create" + val conn = sendHttpRequest(submitRequestPath, "POST", json) + assert(conn.getResponseCode === HttpServletResponse.SC_FORBIDDEN) + } + /* --------------------- * | Helper methods | * --------------------- */