Skip to content

Commit

Permalink
Fixes client mode version parsing (#153)
Browse files Browse the repository at this point in the history
Motivation:
Server may send additional lines of data before version, but since we
use version as part of key exchange, we need to filter out those lines,
otherwise it will fail key exchange.

Modifications:
 - Strip out lines of data before version in client role
 - Update tests
  • Loading branch information
artemredkin committed Jul 21, 2023
1 parent d7279ea commit ded5e5c
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ extension SSHConnectionStateMachine {
self.serializer = state.serializer
self.protectionSchemes = state.protectionSchemes

self.parser = SSHPacketParser(allocator: allocator)
self.parser = SSHPacketParser(isServer: self.role.isServer, allocator: allocator)
self.allocator = allocator
}

Expand Down
33 changes: 24 additions & 9 deletions Sources/NIOSSH/SSHPacketParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct SSHPacketParser {
case encryptedWaitingForBytes(UInt32, NIOSSHTransportProtection)
}

private let isServer: Bool
private var buffer: ByteBuffer
private var state: State
private(set) var sequenceNumber: UInt32
Expand All @@ -32,7 +33,8 @@ struct SSHPacketParser {
self.buffer.readerIndex
}

init(allocator: ByteBufferAllocator) {
init(isServer: Bool, allocator: ByteBufferAllocator) {
self.isServer = isServer
self.buffer = allocator.buffer(capacity: 0)
self.state = .initialized
self.sequenceNumber = 0
Expand Down Expand Up @@ -121,18 +123,31 @@ struct SSHPacketParser {
let carriageReturn = UInt8(ascii: "\r")
let lineFeed = UInt8(ascii: "\n")

// Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line.
var slice = self.buffer.readableBytesView
while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex {
if slice.starts(with: "SSH-".utf8) {
// Return all data upto the last LF we found, excluding the last [CR]LF.
slice = self.buffer.readableBytesView
// Per RFC 4253 §4.2:
// The server MAY send other lines of data before sending the version string.
// This means that server does not expect any lines before version so we will return all data before first line feed
if self.isServer {
// Looking for a string ending with \r\n
let slice = self.buffer.readableBytesView
if let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex {
let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex
let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self)
self.buffer.moveReaderIndex(forwardBy: slice.startIndex.distance(to: lfIndex).advanced(by: 1))
return version
} else {
slice = slice[slice.index(after: lfIndex)...]
}
} else {
// Search for version line, which starts with "SSH-". Lines without this prefix may come before the version line.
var slice = self.buffer.readableBytesView
let startIndex = slice.startIndex
while let lfIndex = slice.firstIndex(of: lineFeed), lfIndex < slice.endIndex {
if slice.starts(with: "SSH-".utf8) {
let versionEndIndex = slice[lfIndex.advanced(by: -1)] == carriageReturn ? lfIndex.advanced(by: -1) : lfIndex
let version = String(decoding: slice[slice.startIndex ..< versionEndIndex], as: UTF8.self)
self.buffer.moveReaderIndex(forwardBy: startIndex.distance(to: lfIndex).advanced(by: 1))
return version
} else {
slice = slice[slice.index(after: lfIndex)...]
}
}
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion Sources/NIOSSHClient/ExecHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ final class ExampleExecHandler: ChannelDuplexHandler {
DispatchQueue(label: "pipe bootstrap").async {
bootstrap.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true).channelInitializer { channel in
channel.pipeline.addHandler(theirs)
}.withPipes(inputDescriptor: 0, outputDescriptor: 1).whenComplete { result in
}.takingOwnershipOfDescriptors(input: 0, output: 1).whenComplete { result in
switch result {
case .success:
// We need to exec a thing.
Expand Down
2 changes: 1 addition & 1 deletion Sources/NIOSSHServer/ExecHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ final class ExampleExecHandler: ChannelDuplexHandler {
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
.channelInitializer { pipeChannel in
pipeChannel.pipeline.addHandler(theirs)
}.withPipes(inputDescriptor: dup(outPipe.fileHandleForReading.fileDescriptor), outputDescriptor: dup(inPipe.fileHandleForWriting.fileDescriptor)).wait()
}.takingOwnershipOfDescriptors(input: dup(outPipe.fileHandleForReading.fileDescriptor), output: dup(inPipe.fileHandleForWriting.fileDescriptor)).wait()

// Ok, great, we've sorted stdout and stdin. For stderr we need a different strategy: we just park a thread for this.
DispatchQueue(label: "stderrorwhatever").async {
Expand Down
2 changes: 1 addition & 1 deletion Tests/NIOSSHTests/SSHEncryptedTrafficTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class SSHEncryptedTrafficTests: XCTestCase {

override func setUp() {
self.serializer = SSHPacketSerializer()
self.parser = SSHPacketParser(allocator: .init())
self.parser = SSHPacketParser(isServer: false, allocator: .init())

self.assertPacketRoundTrips(.version("SSH-2.0-SwiftSSH_1.0"))
}
Expand Down
16 changes: 8 additions & 8 deletions Tests/NIOSSHTests/SSHPackerSerializerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.disconnect(.init(reason: 42, description: "description", tag: "tag"))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand All @@ -74,7 +74,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.serviceRequest(.init(service: "ssh-userauth"))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand All @@ -97,7 +97,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.serviceAccept(.init(service: "ssh-userauth"))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand Down Expand Up @@ -133,7 +133,7 @@ final class SSHPacketSerializerTests: XCTestCase {
))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand Down Expand Up @@ -165,7 +165,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.keyExchangeInit(.init(publicKey: ByteBuffer.of(bytes: [42])))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand Down Expand Up @@ -193,7 +193,7 @@ final class SSHPacketSerializerTests: XCTestCase {
))
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand Down Expand Up @@ -226,7 +226,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.newKeys
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand All @@ -247,7 +247,7 @@ final class SSHPacketSerializerTests: XCTestCase {
let message = SSHMessage.newKeys
let allocator = ByteBufferAllocator()
var serializer = SSHPacketSerializer()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)

self.runVersionHandshake(serializer: &serializer, parser: &parser)

Expand Down
64 changes: 51 additions & 13 deletions Tests/NIOSSHTests/SSHPacketParserTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ final class SSHPacketParserTests: XCTestCase {
}

func testReadVersion() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "SSH-2.0-")
parser.append(bytes: &part1)
Expand All @@ -57,8 +57,8 @@ final class SSHPacketParserTests: XCTestCase {
}
}

func testReadVersionWithExtraLines() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
func testReadVersionWithExtraLinesOnClient() throws {
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xxxx\r\nyyyy\r\nSSH-2.0-")
parser.append(bytes: &part1)
Expand All @@ -70,14 +70,33 @@ final class SSHPacketParserTests: XCTestCase {

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9")
XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.9")
default:
XCTFail("Expecting .version")
}
}

func testReadVersionWithExtraLinesOnServer() throws {
var parser = SSHPacketParser(isServer: true, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xx")
parser.append(bytes: &part1)

XCTAssertNil(try parser.nextPacket())

var part2 = ByteBuffer.of(string: "xx\r\nyyyy\r\nSSH-2.0-OpenSSH_7.9\r\n")
parser.append(bytes: &part2)

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx")
default:
XCTFail("Expecting .version")
}
}

func testReadVersionWithoutCarriageReturn() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "SSH-2.0-")
parser.append(bytes: &part1)
Expand All @@ -95,8 +114,8 @@ final class SSHPacketParserTests: XCTestCase {
}
}

func testReadVersionWithExtraLinesWithoutCarriageReturn() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
func testReadVersionWithExtraLinesWithoutCarriageReturnOnClient() throws {
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xxxx\nyyyy\nSSH-2.0-")
parser.append(bytes: &part1)
Expand All @@ -108,14 +127,33 @@ final class SSHPacketParserTests: XCTestCase {

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx\nyyyy\nSSH-2.0-OpenSSH_7.4")
XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.4")
default:
XCTFail("Expecting .version")
}
}

func testReadVersionWithExtraLinesWithoutCarriageReturnOnServer() throws {
var parser = SSHPacketParser(isServer: true, allocator: ByteBufferAllocator())

var part1 = ByteBuffer.of(string: "xx")
parser.append(bytes: &part1)

XCTAssertNil(try parser.nextPacket())

var part2 = ByteBuffer.of(string: "xx\nyyyy\nSSH-2.0-OpenSSH_7.4\n")
parser.append(bytes: &part2)

switch try parser.nextPacket() {
case .version(let string):
XCTAssertEqual(string, "xxxx")
default:
XCTFail("Expecting .version")
}
}

func testBinaryInParts() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())
self.feedVersion(to: &parser)

var part1 = ByteBuffer.of(bytes: [0, 0, 0])
Expand Down Expand Up @@ -148,7 +186,7 @@ final class SSHPacketParserTests: XCTestCase {
}

func testBinaryFull() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())
self.feedVersion(to: &parser)

var part1 = ByteBuffer.of(bytes: [0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207])
Expand All @@ -164,7 +202,7 @@ final class SSHPacketParserTests: XCTestCase {
}

func testBinaryTwoMessages() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())
self.feedVersion(to: &parser)

var part = ByteBuffer.of(bytes: [0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207, 0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207])
Expand All @@ -187,7 +225,7 @@ final class SSHPacketParserTests: XCTestCase {
}

func testWeReclaimStorage() throws {
var parser = SSHPacketParser(allocator: ByteBufferAllocator())
var parser = SSHPacketParser(isServer: false, allocator: ByteBufferAllocator())
self.feedVersion(to: &parser)
XCTAssertNoThrow(try parser.nextPacket())

Expand All @@ -213,7 +251,7 @@ final class SSHPacketParserTests: XCTestCase {

func testSequencePreservedBetweenPlainAndCypher() throws {
let allocator = ByteBufferAllocator()
var parser = SSHPacketParser(allocator: allocator)
var parser = SSHPacketParser(isServer: false, allocator: allocator)
self.feedVersion(to: &parser)

var part = ByteBuffer(bytes: [0, 0, 0, 12, 10, 21, 41, 114, 125, 250, 3, 79, 3, 217, 166, 136])
Expand Down

0 comments on commit ded5e5c

Please sign in to comment.