diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java index 4468276c7..0553e279f 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java @@ -175,6 +175,8 @@ public abstract class AbstractSession extends SessionHelper { protected final Object encodeLock = new Object(); protected final Object decodeLock = new Object(); protected final Object requestLock = new Object(); + // There are some cases where we need to know the sequence number + protected final AtomicLong requestSeqo = new AtomicLong(-1); /* * Rekeying @@ -980,6 +982,7 @@ public Buffer request(String request, Buffer buffer, long timeout, TimeUnit unit } result = requestResult.getAndSet(null); + requestSeqo.set(-1); } } catch (InterruptedException e) { throw (InterruptedIOException) new InterruptedIOException( @@ -1069,6 +1072,11 @@ protected Buffer encode(Buffer buffer) throws IOException { // Check that the packet has some free space for the header int curPos = buffer.rpos(); int cmd = buffer.rawByte(curPos) & 0xFF; // usually the 1st byte is an SSH opcode + + if (cmd == SshConstants.SSH_MSG_GLOBAL_REQUEST) { + requestSeqo.set(seqo); + } + if (curPos < SshConstants.SSH_PACKET_HEADER_LEN) { log.warn("encode({}) command={}[{}] performance cost: available buffer packet header length ({}) below min. required ({})", this, cmd, SshConstants.getCommandMessageName(cmd), @@ -1792,6 +1800,15 @@ protected void requestSuccess(Buffer buffer) throws Exception { * @throws Exception If failed to handle the message */ protected void requestFailure(Buffer buffer) throws Exception { + signalRequestFailure(); + } + + /** + * Marks the current pending global request result as failed + * + * @throws Exception If failed to signal + */ + protected void signalRequestFailure() throws Exception { synchronized (requestResult) { requestResult.set(GenericUtils.NULL); resetIdleTimeout(); @@ -1799,6 +1816,24 @@ protected void requestFailure(Buffer buffer) throws Exception { } } + @Override + protected void doHandleUnimplemented(Buffer buffer) throws Exception { + // Some servers do respond to requests with the SSH_MSG_UNIMPLEMENTED + // message instead of SSH_MSG_REQUEST_FAILURE, so deal with it + if (requestSeqo.get() >= 0) { + synchronized (requestResult) { + int rpos = buffer.rpos(); + long seq = buffer.getUInt(); + if (requestSeqo.get() == seq) { + signalRequestFailure(); + return; + } + buffer.rpos(rpos); + } + } + super.doHandleUnimplemented(buffer); + } + @Override public void addSessionListener(SessionListener listener) { SessionListener.validateListener(listener); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java index 6003372c6..388da96a0 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java @@ -431,6 +431,10 @@ protected void handleUnimplemented(Buffer buffer) throws Exception { } resetIdleTimeout(); + doHandleUnimplemented(buffer); + } + + protected void doHandleUnimplemented(Buffer buffer) throws Exception { ReservedSessionMessagesHandler handler = resolveReservedSessionMessagesHandler(); handler.handleUnimplementedMessage(this, SshConstants.SSH_MSG_UNIMPLEMENTED, buffer); } diff --git a/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java b/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java index efba1aed3..457f3790d 100644 --- a/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java @@ -20,6 +20,7 @@ import java.io.ByteArrayOutputStream; import java.util.Collection; +import java.util.Collections; import java.util.EnumSet; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -32,6 +33,9 @@ import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.PropertyResolverUtils; import org.apache.sshd.common.channel.Channel; +import org.apache.sshd.common.session.ConnectionService; +import org.apache.sshd.common.session.helpers.AbstractConnectionServiceRequestHandler; +import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.server.SshServer; import org.apache.sshd.server.channel.ChannelSession; import org.apache.sshd.server.command.Command; @@ -107,6 +111,35 @@ public void tearDown() { client, ClientFactoryManager.HEARTBEAT_INTERVAL, ClientFactoryManager.DEFAULT_HEARTBEAT_INTERVAL); } + @Test + public void testSshd968() throws Exception { + sshd.setGlobalRequestHandlers(Collections.singletonList(new AbstractConnectionServiceRequestHandler() { + @Override + public Result process(ConnectionService connectionService, String request, boolean wantReply, Buffer buffer) throws Exception { + connectionService.process(255, buffer); + return Result.Replied; + } + })); + + PropertyResolverUtils.updateProperty(client, ClientFactoryManager.HEARTBEAT_INTERVAL, HEARTBEAT); + PropertyResolverUtils.updateProperty(client, ClientFactoryManager.HEARTBEAT_REPLY_WAIT, 5000L); + try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port) + .verify(7L, TimeUnit.SECONDS) + .getSession()) { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(5L, TimeUnit.SECONDS); + + try (ClientChannel channel = session.createChannel(Channel.CHANNEL_SHELL)) { + long waitStart = System.currentTimeMillis(); + Collection result = + channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), WAIT); + long waitEnd = System.currentTimeMillis(); + assertTrue("Wrong channel state after wait of " + (waitEnd - waitStart) + " ms: " + result, + result.contains(ClientChannelEvent.TIMEOUT)); + } + } + } + @Test public void testIdleClient() throws Exception { try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)