Skip to content

Commit

Permalink
[SPARK-19220][UI] Make redirection to HTTPS apply to all URIs. (branc…
Browse files Browse the repository at this point in the history
…h-2.0)

The redirect handler was installed only for the root of the server;
any other context ended up being served directly through the HTTP
port. Since every sub page (e.g. application UIs in the history
server) is a separate servlet context, this meant that everything
but the root was accessible via HTTP still.

The change adds separate names to each connector, and binds contexts
to specific connectors so that content is only served through the
HTTPS connector when it's enabled. In that case, the only thing that
binds to the HTTP connector is the redirect handler.

Tested with new unit tests and by checking a live history server.

(cherry picked from commit 59502bb)
  • Loading branch information
Marcelo Vanzin committed Jan 27, 2017
1 parent 48a8dc8 commit ddd7727
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 53 deletions.
38 changes: 36 additions & 2 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,22 @@
package org.apache.spark

import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.net.{HttpURLConnection, URI, URL}
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util.Arrays
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.jar.{JarEntry, JarOutputStream}
import javax.net.ssl._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
Expand Down Expand Up @@ -182,6 +185,37 @@ private[spark] object TestUtils {
assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}

/**
* Returns the response code from an HTTP(S) URL.
*/
def httpResponseCode(url: URL, method: String = "GET"): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)

// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers(): Array[X509Certificate] = null
override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
}
sslCtx.init(null, Array(trustManager), new SecureRandom())
connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
}

try {
connection.connect()
connection.getResponseCode()
} finally {
connection.disconnect()
}
}

}


Expand Down
87 changes: 61 additions & 26 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ import org.apache.spark.util.Utils
*/
private[spark] object JettyUtils extends Logging {

val SPARK_CONNECTOR_NAME = "Spark"
val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"

// Base type for a function that returns something based on an HTTP request. Allows for
// implicit conversion from many types of functions to jetty Handlers.
type Responder[T] = HttpServletRequest => T
Expand Down Expand Up @@ -231,25 +234,28 @@ private[spark] object JettyUtils extends Logging {
conf: SparkConf,
serverName: String = ""): ServerInfo = {

val collection = new ContextHandlerCollection
addFilters(handlers, conf)

val gzipHandlers = handlers.map { h =>
h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))

val gzipHandler = new GzipHandler
gzipHandler.setHandler(h)
gzipHandler
}

// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): (Server, Int) = {
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

val server = new Server(pool)
val connectors = new ArrayBuffer[ServerConnector]
val connectors = new ArrayBuffer[ServerConnector]()
val collection = new ContextHandlerCollection

// Create a connector on port currentPort to listen for HTTP requests
val httpConnector = new ServerConnector(
server,
Expand All @@ -263,26 +269,33 @@ private[spark] object JettyUtils extends Logging {
httpConnector.setPort(currentPort)
connectors += httpConnector

sslOptions.createJettySslContextFactory().foreach { factory =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)

connectors += connector

// redirect the HTTP requests to HTTPS port
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
val httpsConnector = sslOptions.createJettySslContextFactory() match {
case Some(factory) =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)
connector.setName(SPARK_CONNECTOR_NAME)
connectors += connector

// redirect the HTTP requests to HTTPS port
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
Some(connector)

case None =>
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
httpConnector.setName(SPARK_CONNECTOR_NAME)
None
}

gzipHandlers.foreach(collection.addHandler)
// As each acceptor and each selector will use one thread, the number of threads should at
// least be the number of acceptors and selectors plus 1. (See SPARK-13776)
var minThreads = 1
Expand All @@ -294,17 +307,20 @@ private[spark] object JettyUtils extends Logging {
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2
}
server.setConnectors(connectors.toArray)
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

gzipHandlers.foreach(collection.addHandler)
server.setHandler(collection)

server.setConnectors(connectors.toArray)
try {
server.start()
(server, httpConnector.getLocalPort)
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
} catch {
case e: Exception =>
server.stop()
Expand All @@ -313,13 +329,16 @@ private[spark] object JettyUtils extends Logging {
}
}

val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
serverName)
ServerInfo(server, boundPort, securePort,
server.getHandler().asInstanceOf[ContextHandlerCollection])
}

private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
val redirectHandler: ContextHandler = new ContextHandler
redirectHandler.setContextPath("/")
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
redirectHandler.setHandler(new AbstractHandler {
override def handle(
target: String,
Expand Down Expand Up @@ -357,7 +376,23 @@ private[spark] object JettyUtils extends Logging {
private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
rootHandler: ContextHandlerCollection) {
securePort: Option[Int],
private val rootHandler: ContextHandlerCollection) {

def addHandler(handler: ContextHandler): Unit = {
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
}
}

def removeHandler(handler: ContextHandler): Unit = {
rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}

def stop(): Unit = {
server.stop()
Expand Down
14 changes: 2 additions & 12 deletions core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,13 @@ private[spark] abstract class WebUI(
/** Attach a handler to this UI. */
def attachHandler(handler: ServletContextHandler) {
handlers += handler
serverInfo.foreach { info =>
info.rootHandler.addHandler(handler)
if (!handler.isStarted) {
handler.start()
}
}
serverInfo.foreach(_.addHandler(handler))
}

/** Detach a handler from this UI. */
def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
if (handler.isStarted) {
handler.stop()
}
}
serverInfo.foreach(_.removeHandler(handler))
}

/**
Expand Down
13 changes: 2 additions & 11 deletions core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -452,23 +452,14 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}

test("kill stage POST/GET response is correct") {
def getResponseCode(url: URL, method: String): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
connection.connect()
val code = connection.getResponseCode()
connection.disconnect()
code
}

withSpark(newSparkContext(killEnabled = true)) { sc =>
sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync()
eventually(timeout(5 seconds), interval(50 milliseconds)) {
val url = new URL(
sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true")
// SPARK-6846: should be POST only but YARN AM doesn't proxy POST
getResponseCode(url, "GET") should be (200)
getResponseCode(url, "POST") should be (200)
TestUtils.httpResponseCode(url, "GET") should be (200)
TestUtils.httpResponseCode(url, "POST") should be (200)
}
}
}
Expand Down
55 changes: 53 additions & 2 deletions core/src/test/scala/org/apache/spark/ui/UISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
package org.apache.spark.ui

import java.net.{BindException, ServerSocket}
import java.net.URL
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.io.Source

import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.mockito.Mockito.{mock, when}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

Expand Down Expand Up @@ -163,6 +166,7 @@ class UISuite extends SparkFunSuite {
val boundPort = serverInfo.boundPort
assert(server.getState === "STARTED")
assert(boundPort != 0)
assert(serverInfo.securePort.isDefined)
intercept[BindException] {
socket = new ServerSocket(boundPort)
}
Expand Down Expand Up @@ -190,8 +194,55 @@ class UISuite extends SparkFunSuite {
}
}

test("http -> https redirect applies to all URIs") {
var serverInfo: ServerInfo = null
try {
val servlet = new HttpServlet() {
override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = {
res.sendError(HttpServletResponse.SC_OK)
}
}

def newContext(path: String): ServletContextHandler = {
val ctx = new ServletContextHandler()
ctx.setContextPath(path)
ctx.addServlet(new ServletHolder(servlet), "/root")
ctx
}

val (conf, sslOptions) = sslEnabledConf()
serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
Seq[ServletContextHandler](newContext("/"), newContext("/test1")),
conf)
assert(serverInfo.server.getState === "STARTED")

val testContext = newContext("/test2")
serverInfo.addHandler(testContext)
testContext.start()

val httpPort = serverInfo.boundPort

val tests = Seq(
("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND),
("https", serverInfo.securePort.get, HttpServletResponse.SC_OK))

tests.foreach { case (scheme, port, expected) =>
val urls = Seq(
s"$scheme://localhost:$port/root",
s"$scheme://localhost:$port/test1/root",
s"$scheme://localhost:$port/test2/root")
urls.foreach { url =>
val rc = TestUtils.httpResponseCode(new URL(url))
assert(rc === expected, s"Unexpected status $rc for $url")
}
}
} finally {
stopServer(serverInfo)
}
}

def stopServer(info: ServerInfo): Unit = {
if (info != null && info.server != null) info.server.stop
if (info != null) info.stop()
}

def closeSocket(socket: ServerSocket): Unit = {
Expand Down

0 comments on commit ddd7727

Please sign in to comment.