Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SSHD-1303] AbstractClientChannel: use null stream for redirected stderr #255

Merged
merged 1 commit into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,25 @@
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public interface IoReadFuture extends SshFuture<IoReadFuture>, VerifiableFuture<IoReadFuture> {

/**
* Retrieves the buffer data was read into.
*
* @return the buffer, {@code null} if {@link #isDone()} {@code == false}
*/
Buffer getBuffer();

/**
* Retrieves the number of bytes read.
*
* @return The number of bytes read, or -1 if the source of the read has been exhausted (is at EOF), or zero if the
* read is not done yet ({@link #isDone()} {@code == false})
*/
int getRead();

/**
* Returns the cause of the read failure.
* Returns the cause of the read failure. An {@link java.io.EOFException} indicates that nothing was read because
* the source of the read is exhausted.
*
* @return {@code null} if the read operation is not finished yet, or if the read attempt is successful (use
* {@link #isDone()} to distinguish between the two).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.sshd.client.channel;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down Expand Up @@ -45,8 +46,12 @@
import org.apache.sshd.common.channel.RequestHandler;
import org.apache.sshd.common.channel.Window;
import org.apache.sshd.common.channel.exception.SshChannelOpenException;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.future.DefaultCloseFuture;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.io.IoInputStream;
import org.apache.sshd.common.io.IoOutputStream;
import org.apache.sshd.common.io.IoReadFuture;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.EventNotifier;
import org.apache.sshd.common.util.ExceptionUtils;
Expand All @@ -61,6 +66,15 @@
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public abstract class AbstractClientChannel extends AbstractChannel implements ClientChannel {

private static final InputStream NULL_INPUT_STREAM = new InputStream() {

@Override
public int read() throws IOException {
return -1;
}
};

protected final AtomicBoolean opened = new AtomicBoolean();

protected Streaming streaming;
Expand Down Expand Up @@ -134,6 +148,9 @@ public IoInputStream getAsyncOut() {

@Override
public IoInputStream getAsyncErr() {
if (asyncErr == asyncOut) {
return NullIoInputStream.INSTANCE;
}
return asyncErr;
}

Expand Down Expand Up @@ -167,6 +184,9 @@ public void setOut(OutputStream out) {

@Override
public InputStream getInvertedErr() {
if (invertedErr == invertedOut) {
return NULL_INPUT_STREAM;
}
return invertedErr;
}

Expand Down Expand Up @@ -474,4 +494,48 @@ public Integer getExitStatus() {
public String getExitSignal() {
return exitSignalHolder.get();
}

private enum NullIoInputStream implements IoInputStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend making this a public class in same package as IoInputStream or a member of it. Might come useful in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's delay doing so to that future until a second concrete use case for this comes up. YAGNI. Which basically means "implement features when they're actually needed, not when you think they might be needed."


INSTANCE;

private final CloseFuture closing = new DefaultCloseFuture("", null);

NullIoInputStream() {
closing.setClosed();
}

@Override
public CloseFuture close(boolean immediately) {
return closing;
}

@Override
public void addCloseFutureListener(SshFutureListener<CloseFuture> listener) {
closing.addListener(listener);
}

@Override
public void removeCloseFutureListener(SshFutureListener<CloseFuture> listener) {
closing.removeListener(listener);
}

@Override
public boolean isClosed() {
return true;
}

@Override
public boolean isClosing() {
return true;
}

@Override
public IoReadFuture read(Buffer buffer) {
ChannelAsyncInputStream.IoReadFutureImpl future = new ChannelAsyncInputStream.IoReadFutureImpl("", buffer);
future.setValue(new EOFException("Closed"));
return future;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/
package org.apache.sshd.common.channel;

import java.io.EOFException;
import java.io.IOException;
import java.util.Objects;

import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.future.DefaultVerifiableSshFuture;
import org.apache.sshd.common.io.IoInputStream;
Expand Down Expand Up @@ -69,12 +69,12 @@ public IoReadFuture read(Buffer buf) {
throw new ReadPendingException("Previous pending read not handled");
}
if (buffer.available() > 0) {
Buffer fb = future.getBuffer();
Buffer fb = future.buffer;
int nbRead = fb.putBuffer(buffer, false);
buffer.compact();
future.setValue(nbRead);
} else {
future.setValue(new IOException("Closed"));
future.setValue(new EOFException("Closed"));
}
}
} else {
Expand All @@ -94,7 +94,7 @@ protected void preClose() {
synchronized (buffer) {
if (buffer.available() == 0) {
if (pending != null) {
pending.setValue(new SshException("Closed"));
pending.setValue(new EOFException("Closed"));
}
}
}
Expand Down Expand Up @@ -153,7 +153,8 @@ public String toString() {
}

public static class IoReadFutureImpl extends DefaultVerifiableSshFuture<IoReadFuture> implements IoReadFuture {
private final Buffer buffer;

final Buffer buffer;

public IoReadFutureImpl(Object id, Buffer buffer) {
super(id, null);
Expand All @@ -162,7 +163,7 @@ public IoReadFutureImpl(Object id, Buffer buffer) {

@Override
public Buffer getBuffer() {
return buffer;
return isDone() ? buffer : null;
}

@Override
Expand All @@ -180,14 +181,18 @@ public IoReadFuture verify(long timeoutMillis) throws IOException {
@Override
public int getRead() {
Object v = getValue();
if (v instanceof RuntimeException) {
if (v == null) {
return 0;
} else if (v instanceof Number) {
return ((Number) v).intValue();
} else if (v instanceof EOFException) {
return -1;
} else if (v instanceof RuntimeException) {
throw (RuntimeException) v;
} else if (v instanceof Error) {
throw (Error) v;
} else if (v instanceof Throwable) {
throw new RuntimeSshException("Error reading from channel.", (Throwable) v);
} else if (v instanceof Number) {
return ((Number) v).intValue();
} else {
throw formatExceptionMessage(
IllegalStateException::new,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.rmi.RemoteException;
Expand All @@ -40,6 +41,7 @@
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.SessionListener;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.io.IoUtils;
import org.apache.sshd.core.CoreModuleProperties;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.auth.keyboard.KeyboardInteractiveAuthenticator;
Expand Down Expand Up @@ -369,4 +371,67 @@ private void writeResponse(OutputStream out, String rsp) throws IOException {

assertTrue("Unexpected response remainders: " + values, values.isEmpty());
}

@Test // SSHD-1303
public void testRedirectCommandErrorStreamIsEmpty() throws Exception {
String expectedCommand = getCurrentTestName() + "-CMD";
String expectedStdout = getCurrentTestName() + "-STDOUT";
String expectedStderr = getCurrentTestName() + "-STDERR";
sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
private boolean cmdProcessed;

@Override
protected boolean handleCommandLine(String command) throws Exception {
assertEquals("Mismatched incoming command", expectedCommand, command);
assertFalse("Duplicated command call", cmdProcessed);
writeResponse(getOutputStream(), expectedStdout);
writeResponse(getErrorStream(), expectedStderr);
cmdProcessed = true;
return false;
}

private void writeResponse(OutputStream out, String rsp) throws IOException {
out.write(rsp.getBytes(StandardCharsets.US_ASCII));
out.write((byte) '\n');
out.flush();
}
});

String response;
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT)
.getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(AUTH_TIMEOUT);

try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
// NOTE !!! The LF is only because we are using a buffered reader on the server end to read the command
try (ClientChannel channel = session.createExecChannel(expectedCommand + '\n')) {
channel.setRedirectErrorStream(true);

channel.open().verify(OPEN_TIMEOUT);
try (InputStream stderr = channel.getInvertedErr()) {
assertEquals(-1, stderr.read());
}
try (InputStream stdout = channel.getInvertedOut()) {
IoUtils.copy(stdout, baos, 32); // Use a small buffer on purpose
}
}
byte[] bytes = baos.toByteArray();
response = new String(bytes, StandardCharsets.US_ASCII);
}
}

String[] lines = GenericUtils.split(response, '\n');
assertEquals("Mismatched response lines count", 2, lines.length);

Collection<String> values = new ArrayList<>(Arrays.asList(lines));
// We don't rely on the order the strings were written
for (String expected : new String[] { expectedStdout, expectedStderr }) {
if (!values.remove(expected)) {
fail(expected + " not in response=" + values);
}
}

assertTrue("Unexpected response remainders: " + values, values.isEmpty());
}
}