Skip to content

Commit

Permalink
Implement OpenSSH strict key exchange extension (hierynomus#917)
Browse files Browse the repository at this point in the history
(cherry picked from commit a262f51)
  • Loading branch information
hpoettker authored and vladimirlagunov committed Jan 2, 2024
1 parent 830a39d commit c339078
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.transport.kex;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer;
import net.schmizz.sshj.SSHClient;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
class StrictKeyExchangeTest {

@Container
private static final SshdContainer sshd = new SshdContainer();

private final List<Logger> watchedLoggers = new ArrayList<>();
private final ListAppender<ILoggingEvent> logWatcher = new ListAppender<>();

@BeforeEach
void setUpLogWatcher() {
logWatcher.start();
setUpLogger("net.schmizz.sshj.transport.Decoder");
setUpLogger("net.schmizz.sshj.transport.Encoder");
setUpLogger("net.schmizz.sshj.transport.KeyExchanger");
}

@AfterEach
void tearDown() {
watchedLoggers.forEach(Logger::detachAndStopAllAppenders);
}

private void setUpLogger(String className) {
Logger logger = ((Logger) LoggerFactory.getLogger(className));
logger.addAppender(logWatcher);
watchedLoggers.add(logger);
}

@Test
void strictKeyExchange() throws Throwable {
try (SSHClient client = sshd.getConnectedClient()) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated());
}
List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence(
"Initiating key exchange",
"Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT",
"Enabling strict key exchange extension"
);
List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(decoderLogs).containsExactly(
"Received packet #0",
"Received packet #1",
"Received packet #2",
"Received packet #0",
"Received packet #1",
"Received packet #2",
"Received packet #3"
);
List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(encoderLogs).containsExactly(
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
"Encoding packet #3"
);
}

private List<String> getLogs(String className) {
return logWatcher.list.stream()
.filter(event -> event.getLoggerName().endsWith(className))
.map(ILoggingEvent::getFormattedMessage)
.collect(Collectors.toList());
}

}
8 changes: 8 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/Converter.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ long getSequenceNumber() {
return seq;
}

void resetSequenceNumber() {
seq = -1;
}

boolean isSequenceNumberAtMax() {
return seq == 0xffffffffL;
}

void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
this.cipher = cipher;
this.mac = mac;
Expand Down
34 changes: 33 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/KeyExchanger.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ private static enum Expected {

private final AtomicBoolean kexOngoing = new AtomicBoolean();

private final AtomicBoolean initialKex = new AtomicBoolean(true);

private final AtomicBoolean strictKex = new AtomicBoolean();

/** What we are expecting from the next packet */
private Expected expected = Expected.KEXINIT;

Expand Down Expand Up @@ -123,6 +127,14 @@ boolean isKexOngoing() {
return kexOngoing.get();
}

boolean isStrictKex() {
return strictKex.get();
}

boolean isInitialKex() {
return initialKex.get();
}

/**
* Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily
* after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if
Expand Down Expand Up @@ -171,7 +183,7 @@ private void sendKexInit()
throws TransportException {
log.debug("Sending SSH_MSG_KEXINIT");
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs);
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get());
transport.write(clientProposal.getPacket());
kexInitSent.set();
}
Expand All @@ -190,6 +202,9 @@ private void sendNewKeys()
throws TransportException {
log.debug("Sending SSH_MSG_NEWKEYS");
transport.write(new SSHPacket(Message.NEWKEYS));
if (strictKex.get()) {
transport.getEncoder().resetSequenceNumber();
}
}

/**
Expand Down Expand Up @@ -222,6 +237,10 @@ private synchronized void verifyHost(PublicKey key)

private void setKexDone() {
kexOngoing.set(false);
initialKex.set(false);
if (strictKex.get()) {
transport.getDecoder().resetSequenceNumber();
}
kexInitSent.clear();
done.set();
}
Expand All @@ -230,6 +249,7 @@ private void gotKexInit(SSHPacket buf)
throws TransportException {
buf.rpos(buf.rpos() - 1);
final Proposal serverProposal = new Proposal(buf);
gotStrictKexInfo(serverProposal);
negotiatedAlgs = clientProposal.negotiate(serverProposal);
log.debug("Negotiated algorithms: {}", negotiatedAlgs);
for(AlgorithmsVerifier v: algorithmVerifiers) {
Expand All @@ -253,6 +273,18 @@ private void gotKexInit(SSHPacket buf)
}
}

private void gotStrictKexInfo(Proposal serverProposal) throws TransportException {
if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) {
strictKex.set(true);
log.debug("Enabling strict key exchange extension");
if (transport.getDecoder().getSequenceNumber() != 0) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"SSH_MSG_KEXINIT was not first package during strict key exchange"
);
}
}
}

/**
* Private method used while putting new keys into use that will resize the key used to initialize the cipher to the
* needed length.
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/net/schmizz/sshj/transport/Proposal.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ class Proposal {
private final List<String> s2cComp;
private final SSHPacket packet;

public Proposal(Config config, List<String> knownHostAlgs) {
public Proposal(Config config, List<String> knownHostAlgs, boolean initialKex) {
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
if (initialKex) {
kex.add("kex-strict-c-v00@openssh.com");
}
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
Expand Down Expand Up @@ -91,6 +94,10 @@ public List<String> getKeyExchangeAlgorithms() {
return kex;
}

public boolean isStrictKeyExchangeSupportedByServer() {
return kex.contains("kex-strict-s-v00@openssh.com");
}

public List<String> getHostKeyAlgorithms() {
return sig;
}
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ public long write(SSHPacket payload)
assert m != Message.KEXINIT;
kexer.waitForDone();
}
} else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet
} else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet
kexer.startKex(true);

final long seq = encoder.encode(payload);
Expand Down Expand Up @@ -489,9 +489,20 @@ public void handle(Message msg, SSHPacket buf)

log.trace("Received packet {}", msg);

if (kexer.isInitialKex()) {
if (decoder.isSequenceNumberAtMax()) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Sequence number of decoder is about to wrap during initial key exchange");
}
if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Unexpected packet type during initial strict key exchange");
}
}

if (msg.geq(50)) { // not a transport layer packet
service.handle(msg, buf);
} else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet
} else if (isKexerPacket(msg)) {
kexer.handle(msg, buf);
} else {
switch (msg) {
Expand Down Expand Up @@ -523,6 +534,10 @@ public void handle(Message msg, SSHPacket buf)
}
}

private static boolean isKexerPacket(Message msg) {
return msg.in(20, 21) || msg.in(30, 49);
}

private void gotDebug(SSHPacket buf)
throws TransportException {
try {
Expand Down

0 comments on commit c339078

Please sign in to comment.