diff --git a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java new file mode 100644 index 0000000000000..4d254a0c4c9fe --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.InputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.Arrays; +import java.util.List; + +/** + * An object input stream that only allows classes used by the launcher protocol to be in the + * serialized stream. See SPARK-20922. + */ +class FilteredObjectInputStream extends ObjectInputStream { + + private static final List ALLOWED_PACKAGES = Arrays.asList( + "org.apache.spark.launcher.", + "java.lang."); + + FilteredObjectInputStream(InputStream is) throws IOException { + super(is); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws IOException, ClassNotFoundException { + + boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p)); + if (!isValid) { + throw new IllegalArgumentException( + String.format("Unexpected class in stream: %s", desc.getName())); + } + return super.resolveClass(desc); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index eec264909bbb6..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -20,7 +20,6 @@ import java.io.Closeable; import java.io.EOFException; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.Socket; import java.util.logging.Level; @@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable { @Override public void run() { try { - ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); + FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); while (!closed) { Message msg = (Message) in.readObject(); handle(msg); diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 12f1a0ce2d1b4..03c2934e2692e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -19,8 +19,11 @@ import java.io.Closeable; import java.io.IOException; +import java.io.ObjectInputStream; import java.net.InetAddress; import java.net.Socket; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; @@ -120,31 +123,7 @@ public void testTimeout() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), LauncherServer.getServerInstance().getPort()); client = new TestClient(s); - - // Try a few times since the client-side socket may not reflect the server-side close - // immediately. - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { - try { - if (!helloSent) { - client.send(new Hello(handle.getSecret(), "1.4.0")); - helloSent = true; - } else { - client.send(new SetAppId("appId")); - } - fail("Expected exception caused by connection timeout."); - } catch (IllegalStateException | IOException e) { - // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } - } - } + waitForError(client, handle.getSecret()); } finally { SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); kill(handle); @@ -183,6 +162,25 @@ public void infoChanged(SparkAppHandle handle) { } } + @Test + public void testStreamFiltering() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + + client = new TestClient(s); + client.send(new EvilPayload()); + waitForError(client, handle.getSecret()); + assertEquals(0, EvilPayload.EVIL_BIT); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + private void kill(SparkAppHandle handle) { if (handle != null) { handle.kill(); @@ -199,6 +197,35 @@ private void close(Closeable c) { } } + /** + * Try a few times to get a client-side error, since the client-side socket may not reflect the + * server-side close immediately. + */ + private void waitForError(TestClient client, String secret) throws Exception { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { + try { + if (!helloSent) { + client.send(new Hello(secret, "1.4.0")); + helloSent = true; + } else { + client.send(new SetAppId("appId")); + } + fail("Expected error but message went through."); + } catch (IllegalStateException | IOException e) { + // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } + } + } + } + private static class TestClient extends LauncherConnection { final BlockingQueue inbound; @@ -220,4 +247,19 @@ protected void handle(Message msg) throws IOException { } + private static class EvilPayload extends LauncherProtocol.Message { + + static int EVIL_BIT = 0; + + // This field should cause the launcher server to throw an error and not deserialize the + // message. + private List notAllowedField = Arrays.asList("disallowed"); + + private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { + stream.defaultReadObject(); + EVIL_BIT = 1; + } + + } + }