Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSHD-968 #112

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ public abstract class AbstractSession extends SessionHelper {
protected final Object encodeLock = new Object();
protected final Object decodeLock = new Object();
protected final Object requestLock = new Object();
// There are some cases where we need to know the sequence number
protected final AtomicLong requestSeqo = new AtomicLong(-1);

/*
* Rekeying
Expand Down Expand Up @@ -980,6 +982,7 @@ public Buffer request(String request, Buffer buffer, long timeout, TimeUnit unit
}

result = requestResult.getAndSet(null);
requestSeqo.set(-1);
}
} catch (InterruptedException e) {
throw (InterruptedIOException) new InterruptedIOException(
Expand Down Expand Up @@ -1069,6 +1072,11 @@ protected Buffer encode(Buffer buffer) throws IOException {
// Check that the packet has some free space for the header
int curPos = buffer.rpos();
int cmd = buffer.rawByte(curPos) & 0xFF; // usually the 1st byte is an SSH opcode

if (cmd == SshConstants.SSH_MSG_GLOBAL_REQUEST) {
requestSeqo.set(seqo);
}

if (curPos < SshConstants.SSH_PACKET_HEADER_LEN) {
log.warn("encode({}) command={}[{}] performance cost: available buffer packet header length ({}) below min. required ({})",
this, cmd, SshConstants.getCommandMessageName(cmd),
Expand Down Expand Up @@ -1792,13 +1800,40 @@ protected void requestSuccess(Buffer buffer) throws Exception {
* @throws Exception If failed to handle the message
*/
protected void requestFailure(Buffer buffer) throws Exception {
signalRequestFailure();
}

/**
* Marks the current pending global request result as failed
*
* @throws Exception If failed to signal
*/
protected void signalRequestFailure() throws Exception {
synchronized (requestResult) {
requestResult.set(GenericUtils.NULL);
resetIdleTimeout();
requestResult.notifyAll();
}
}

@Override
protected void doHandleUnimplemented(Buffer buffer) throws Exception {
// Some servers do respond to requests with the SSH_MSG_UNIMPLEMENTED
// message instead of SSH_MSG_REQUEST_FAILURE, so deal with it
if (requestSeqo.get() >= 0) {
synchronized (requestResult) {
int rpos = buffer.rpos();
long seq = buffer.getUInt();
if (requestSeqo.get() == seq) {
signalRequestFailure();
return;
}
buffer.rpos(rpos);
}
}
super.doHandleUnimplemented(buffer);
}

@Override
public void addSessionListener(SessionListener listener) {
SessionListener.validateListener(listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ protected void handleUnimplemented(Buffer buffer) throws Exception {
}
resetIdleTimeout();

doHandleUnimplemented(buffer);
}

protected void doHandleUnimplemented(Buffer buffer) throws Exception {
ReservedSessionMessagesHandler handler = resolveReservedSessionMessagesHandler();
handler.handleUnimplementedMessage(this, SshConstants.SSH_MSG_UNIMPLEMENTED, buffer);
}
Expand Down
33 changes: 33 additions & 0 deletions sshd-core/src/test/java/org/apache/sshd/KeepAliveTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.ByteArrayOutputStream;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -32,6 +33,9 @@
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.PropertyResolverUtils;
import org.apache.sshd.common.channel.Channel;
import org.apache.sshd.common.session.ConnectionService;
import org.apache.sshd.common.session.helpers.AbstractConnectionServiceRequestHandler;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.channel.ChannelSession;
import org.apache.sshd.server.command.Command;
Expand Down Expand Up @@ -107,6 +111,35 @@ public void tearDown() {
client, ClientFactoryManager.HEARTBEAT_INTERVAL, ClientFactoryManager.DEFAULT_HEARTBEAT_INTERVAL);
}

@Test
public void testSshd968() throws Exception {
sshd.setGlobalRequestHandlers(Collections.singletonList(new AbstractConnectionServiceRequestHandler() {
@Override
public Result process(ConnectionService connectionService, String request, boolean wantReply, Buffer buffer) throws Exception {
connectionService.process(255, buffer);
return Result.Replied;
}
}));

PropertyResolverUtils.updateProperty(client, ClientFactoryManager.HEARTBEAT_INTERVAL, HEARTBEAT);
PropertyResolverUtils.updateProperty(client, ClientFactoryManager.HEARTBEAT_REPLY_WAIT, 5000L);
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
.verify(7L, TimeUnit.SECONDS)
.getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(5L, TimeUnit.SECONDS);

try (ClientChannel channel = session.createChannel(Channel.CHANNEL_SHELL)) {
long waitStart = System.currentTimeMillis();
Collection<ClientChannelEvent> result =
channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), WAIT);
long waitEnd = System.currentTimeMillis();
assertTrue("Wrong channel state after wait of " + (waitEnd - waitStart) + " ms: " + result,
result.contains(ClientChannelEvent.TIMEOUT));
}
}
}

@Test
public void testIdleClient() throws Exception {
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
Expand Down