diff --git a/LICENSE-binary b/LICENSE-binary index b6971798e5577..6a8dc8ca3dd96 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -267,6 +267,7 @@ io.fabric8:kubernetes-model-scheduling io.fabric8:kubernetes-model-storageclass io.fabric8:zjsonpatch io.github.java-diff-utils:java-diff-utils +io.jsonwebtoken:jjwt-api io.netty:netty-all io.netty:netty-buffer io.netty:netty-codec diff --git a/core/pom.xml b/core/pom.xml index adb1b3034b427..ff7fa04c8dbdd 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -118,6 +118,23 @@ org.apache.zookeeper zookeeper + + io.jsonwebtoken + jjwt-api + 0.12.6 + + + io.jsonwebtoken + jjwt-impl + 0.12.6 + test + + + io.jsonwebtoken + jjwt-jackson + 0.12.6 + test + diff --git a/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala b/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala new file mode 100644 index 0000000000000..e942bce366aa1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/JWSFilter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.crypto.SecretKey + +import io.jsonwebtoken.{JwtException, Jwts} +import io.jsonwebtoken.io.Decoders +import io.jsonwebtoken.security.Keys +import jakarta.servlet.{Filter, FilterChain, FilterConfig, ServletRequest, ServletResponse} +import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse} + +/** + * A servlet filter that requires JWS, a cryptographically signed JSON Web Token, in the header. + * + * Like the other UI filters, the following configurations are required to use this filter. + * {{{ + * - spark.ui.filters=org.apache.spark.ui.JWSFilter + * - spark.org.apache.spark.ui.JWSFilter.param.key=BASE64URL-ENCODED-YOUR-PROVIDED-KEY + * }}} + * The HTTP request should have {@code Authorization: Bearer } header. + * {{{ + * - is a string with three fields, '
..'. + * -
is supposed to be a base64url-encoded string of '{"alg":"HS256","typ":"JWT"}'. + * - is a base64url-encoded string of fully-user-defined content. + * - is a signature based on '
.' and a user-provided key parameter. + * }}} + */ +private class JWSFilter extends Filter { + private val AUTHORIZATION = "Authorization" + + private var key: SecretKey = null + + /** + * Load and validate the configurtions: + * - IllegalArgumentException will happen if the user didn't provide this argument + * - WeakKeyException will happen if the user-provided value is insufficient + */ + override def init(config: FilterConfig): Unit = { + key = Keys.hmacShaKeyFor(Decoders.BASE64URL.decode(config.getInitParameter("key"))); + } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val hres = res.asInstanceOf[HttpServletResponse] + hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + + try { + val header = hreq.getHeader(AUTHORIZATION) + header match { + case null => + hres.sendError(HttpServletResponse.SC_FORBIDDEN, s"${AUTHORIZATION} header is missing.") + case s"Bearer $token" => + val claims = Jwts.parser().verifyWith(key).build().parseSignedClaims(token) + chain.doFilter(req, res) + case _ => + hres.sendError(HttpServletResponse.SC_FORBIDDEN, s"Malformed ${AUTHORIZATION} header.") + } + } catch { + case e: JwtException => + // We intentionally don't expose the detail of JwtException here + hres.sendError(HttpServletResponse.SC_FORBIDDEN, "JWT Validate Fail") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala b/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala new file mode 100644 index 0000000000000..a094a5e48e992 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/JWSFilterSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Base64, HashMap => JHashMap} + +import scala.jdk.CollectionConverters._ + +import jakarta.servlet.{FilterChain, FilterConfig, ServletContext} +import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse} +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} + +import org.apache.spark._ + +class JWSFilterSuite extends SparkFunSuite { + // {"alg":"HS256","typ":"JWT"} => eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9, {} => e30 + private val TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.4EKWlOkobpaAPR0J4BE0cPQ-ZD1tRQKLZp1vtE7upPw" + + private val TEST_KEY = Base64.getUrlEncoder.encodeToString( + "Visit https://spark.apache.org to download Apache Spark.".getBytes()) + + test("Should fail when a parameter is missing") { + val filter = new JWSFilter() + val params = new JHashMap[String, String] + val m = intercept[IllegalArgumentException] { + filter.init(new DummyFilterConfig(params)) + }.getMessage() + assert(m.contains("Decode argument cannot be null")) + } + + test("Succeed to initialize") { + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("key", TEST_KEY) + filter.init(new DummyFilterConfig(params)) + } + + test("Should response with SC_FORBIDDEN when it cannot verify JWS") { + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("key", TEST_KEY) + val conf = new DummyFilterConfig(params) + filter.init(conf) + + // 'Authorization' header is missing + filter.doFilter(req, res, chain) + verify(res).sendError(meq(HttpServletResponse.SC_FORBIDDEN), + meq("Authorization header is missing.")) + verify(chain, times(0)).doFilter(any(), any()) + + // The value of Authorization field is not 'Bearer ' style. + when(req.getHeader("Authorization")).thenReturn("Invalid") + filter.doFilter(req, res, chain) + verify(res).sendError(meq(HttpServletResponse.SC_FORBIDDEN), + meq("Malformed Authorization header.")) + verify(chain, times(0)).doFilter(any(), any()) + } + + test("Should succeed on valid JWS") { + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + + val filter = new JWSFilter() + val params = new JHashMap[String, String] + params.put("key", TEST_KEY) + val conf = new DummyFilterConfig(params) + filter.init(conf) + + when(req.getHeader("Authorization")).thenReturn(s"Bearer $TOKEN") + filter.doFilter(req, res, chain) + verify(chain, times(1)).doFilter(any(), any()) + } + + private def mockRequest(params: Map[String, Array[String]] = Map()): HttpServletRequest = { + val req = mock(classOf[HttpServletRequest]) + when(req.getParameterMap()).thenReturn(params.asJava) + req + } + + class DummyFilterConfig (val map: java.util.Map[String, String]) extends FilterConfig { + override def getFilterName: String = "dummy" + + override def getInitParameter(arg0: String): String = map.get(arg0) + + override def getInitParameterNames: java.util.Enumeration[String] = + java.util.Collections.enumeration(map.keySet) + + override def getServletContext: ServletContext = null + } +} diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 53e2086a62a4d..492733b6b9950 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -138,6 +138,7 @@ jersey-server/3.0.12//jersey-server-3.0.12.jar jettison/1.5.4//jettison-1.5.4.jar jetty-util-ajax/11.0.21//jetty-util-ajax-11.0.21.jar jetty-util/11.0.21//jetty-util-11.0.21.jar +jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar jline/3.25.1//jline-3.25.1.jar jna/5.14.0//jna-5.14.0.jar