Skip to content

Commit

Permalink
Adds support for custom transport protection algorithms
Browse files Browse the repository at this point in the history
Motivation:
In some cases AES GCM might not be supported by a server/client, and
clients might require custom algorithms to be implemented in their
projects.

Modifications:
 - Makes NIOSSHTransportProtection, NIOSSHSessionKeys and ExpectedKeySizes public
 - Makes transport algorithms list configurable
 - Adds sequence of incoming and outgoing messages to parser and
   serializer
 - Fixes connection state machine state update if message can not be read yet
 - Tests
  • Loading branch information
artemredkin committed Feb 7, 2022
1 parent 3da5629 commit 48ae0bb
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Sources/NIOSSH/ByteBuffer+SSH.swift
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ extension ByteBuffer {

/// Writes a given number of SSH-acceptable padding bytes to this buffer.
@discardableResult
mutating func writeSSHPaddingBytes(count: Int) -> Int {
public mutating func writeSSHPaddingBytes(count: Int) -> Int {
// Annoyingly, the system random number generator can only give bytes to us 8 bytes at a time.
precondition(count >= 0, "Cannot write negative number of padding bytes: \(count)")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@ struct SSHConnectionStateMachine {
/// The state of this state machine.
private var state: State

private static let defaultTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [
AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self,
]

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Self.defaultTransportProtectionSchemes) {
init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Constants.defaultTransportProtectionSchemes) {
self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes))
}

Expand Down
6 changes: 5 additions & 1 deletion Sources/NIOSSH/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
//
//===----------------------------------------------------------------------===//

enum Constants {
public enum Constants {
static let version = "SSH-2.0-SwiftNIOSSH_1.0"

public static let defaultTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [
AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self,
]
}
28 changes: 17 additions & 11 deletions Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ extension KeyExchangeResult: Equatable {}
/// Of these types, the encryption keys and the MAC keys are intended to be secret, and so
/// we store them in the `SymmetricKey` types. The IVs do not need to be secret, and so are
/// stored in regular heap buffers.
struct NIOSSHSessionKeys {
var initialInboundIV: [UInt8]
public struct NIOSSHSessionKeys {
public var initialInboundIV: [UInt8]

var initialOutboundIV: [UInt8]
public var initialOutboundIV: [UInt8]

var inboundEncryptionKey: SymmetricKey
public var inboundEncryptionKey: SymmetricKey

var outboundEncryptionKey: SymmetricKey
public var outboundEncryptionKey: SymmetricKey

var inboundMACKey: SymmetricKey
public var inboundMACKey: SymmetricKey

var outboundMACKey: SymmetricKey
public var outboundMACKey: SymmetricKey
}

extension NIOSSHSessionKeys: Equatable {}
Expand All @@ -68,10 +68,16 @@ extension NIOSSHSessionKeys: Equatable {}
/// hash function invocations. The output of these hash functions is truncated to an appropriate
/// length as needed, which means we need to ensure the code doing the calculation knows how
/// to truncate appropriately.
struct ExpectedKeySizes {
var ivSize: Int
public struct ExpectedKeySizes {
public var ivSize: Int

var encryptionKeySize: Int
public var encryptionKeySize: Int

var macKeySize: Int
public var macKeySize: Int

public init(ivSize: Int, encryptionKeySize: Int, macKeySize: Int) {
self.ivSize = ivSize
self.encryptionKeySize = encryptionKeySize
self.macKeySize = macKeySize
}
}
6 changes: 4 additions & 2 deletions Sources/NIOSSH/NIOSSHHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ public final class NIOSSHHandler {

private var pendingGlobalRequestResponses: CircularBuffer<PendingGlobalRequestResponse?>

public init(role: SSHConnectionRole, allocator: ByteBufferAllocator, inboundChildChannelInitializer: ((Channel, SSHChannelType) -> EventLoopFuture<Void>)?) {
self.stateMachine = SSHConnectionStateMachine(role: role)
public init(role: SSHConnectionRole, allocator: ByteBufferAllocator,
inboundChildChannelInitializer: ((Channel, SSHChannelType) -> EventLoopFuture<Void>)?,
protectionSchemes: [NIOSSHTransportProtection.Type] = Constants.defaultTransportProtectionSchemes) {
self.stateMachine = SSHConnectionStateMachine(role: role, protectionSchemes: protectionSchemes)
self.pendingWrite = false
self.outboundFrameBuffer = allocator.buffer(capacity: 1024)
self.pendingChannelInitializations = CircularBuffer(initialCapacity: 4)
Expand Down
16 changes: 15 additions & 1 deletion Sources/NIOSSH/SSHPacketParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct SSHPacketParser {

private var buffer: ByteBuffer
private var state: State
private var sequence: UInt32

/// Testing only: the number of bytes we can discard from this buffer.
internal var _discardableBytes: Int {
Expand All @@ -34,6 +35,15 @@ struct SSHPacketParser {
init(allocator: ByteBufferAllocator) {
self.buffer = allocator.buffer(capacity: 0)
self.state = .initialized
self.sequence = 0
}

private mutating func increment() {
if self.sequence == UInt32.max {
self.sequence = 0
} else {
self.sequence += 1
}
}

mutating func append(bytes: inout ByteBuffer) {
Expand Down Expand Up @@ -73,6 +83,7 @@ struct SSHPacketParser {
if let length = self.buffer.getInteger(at: self.buffer.readerIndex, as: UInt32.self) {
if let message = try self.parsePlaintext(length: length) {
self.state = .cleartextWaitingForLength
self.increment()
return message
}
self.state = .cleartextWaitingForBytes(length)
Expand All @@ -82,6 +93,7 @@ struct SSHPacketParser {
case .cleartextWaitingForBytes(let length):
if let message = try self.parsePlaintext(length: length) {
self.state = .cleartextWaitingForLength
self.increment()
return message
}
return nil
Expand All @@ -92,13 +104,15 @@ struct SSHPacketParser {

if let message = try self.parseCiphertext(length: length, protection: protection) {
self.state = .encryptedWaitingForLength(protection)
self.increment()
return message
}
self.state = .encryptedWaitingForBytes(length, protection)
return nil
case .encryptedWaitingForBytes(let length, let protection):
if let message = try self.parseCiphertext(length: length, protection: protection) {
self.state = .encryptedWaitingForLength(protection)
self.increment()
return message
}
return nil
Expand Down Expand Up @@ -160,7 +174,7 @@ struct SSHPacketParser {
return nil
}

var content = try protection.decryptAndVerifyRemainingPacket(&buffer)
var content = try protection.decryptAndVerifyRemainingPacket(&buffer, sequence: self.sequence)
guard let message = try content.readSSHMessage(), content.readableBytes == 0, buffer.readableBytes == 0 else {
// Throw this error if the content wasn't exactly the right length for the message.
throw NIOSSHError.invalidPacketFormat
Expand Down
13 changes: 12 additions & 1 deletion Sources/NIOSSH/SSHPacketSerializer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ struct SSHPacketSerializer {
}

private var state: State = .initialized
private var sequence: UInt32 = 0

private mutating func increment() {
if self.sequence == UInt32.max {
self.sequence = 0
} else {
self.sequence += 1
}
}

/// Encryption schemes can be added to a packet serializer whenever encryption is negotiated.
mutating func addEncryption(_ protection: NIOSSHTransportProtection) {
Expand Down Expand Up @@ -75,9 +84,11 @@ struct SSHPacketSerializer {
buffer.setInteger(UInt8(paddingLength), at: index + 4)
/// random padding
buffer.writeSSHPaddingBytes(count: paddingLength)
self.increment()
case .encrypted(let protection):
let payload = NIOSSHEncryptablePayload(message: message)
try protection.encryptPacket(payload, to: &buffer)
try protection.encryptPacket(payload, sequence: sequence, to: &buffer)
self.increment()
}
}
}
4 changes: 2 additions & 2 deletions Sources/NIOSSH/TransportProtection/AESGCM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection {
// unencrypted!
}

func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer {
func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequence: UInt32) throws -> ByteBuffer {
var plaintext: Data

// Establish a nested scope here to avoid the byte buffer views causing an accidental CoW.
Expand Down Expand Up @@ -117,7 +117,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection {
return source.readSlice(length: plaintext.count)!
}

func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws {
func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequence: UInt32, to outboundBuffer: inout ByteBuffer) throws {
// Keep track of where the length is going to be written.
let packetLengthIndex = outboundBuffer.writerIndex
let packetLengthLength = MemoryLayout<UInt32>.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import NIO
/// Implementers of this protocol **must not** expose unauthenticated plaintext, except for the length field. This
/// is required by the SSH protocol, and swift-nio-ssh does its best to treat the length field as fundamentally
/// untrusted information.
protocol NIOSSHTransportProtection: AnyObject {
public protocol NIOSSHTransportProtection: AnyObject {
/// The name of the cipher portion of this transport protection scheme as negotiated on the wire.
static var cipherName: String { get }

Expand Down Expand Up @@ -87,10 +87,10 @@ protocol NIOSSHTransportProtection: AnyObject {
/// length, the padding, or the MAC), and update source to indicate the consumed bytes.
/// It must also perform any integrity checking that
/// is required and throw if the integrity check fails.
func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer
func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequence: UInt32) throws -> ByteBuffer

/// Encrypt an entire outbound packet
func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws
func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequence: UInt32, to outboundBuffer: inout ByteBuffer) throws
}

extension NIOSSHTransportProtection {
Expand Down
20 changes: 10 additions & 10 deletions Tests/NIOSSHTests/AESGCMTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ final class AESGCMTests: XCTestCase {
let initialKeys = self.generateKeys(keySize: .bits128)

let aes128Encryptor = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: initialKeys))
XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer))
XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequence: 0, to: &self.buffer))

// The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need
// 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total
Expand All @@ -59,7 +59,7 @@ final class AESGCMTests: XCTestCase {
XCTAssertEqual(bufferCopy, self.buffer)

/// After decryption the plaintext should be a newKeys message.
var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy))
var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequence: 0))
XCTAssertEqual(bufferCopy.readableBytes, 0)
XCTAssertNotEqual(plaintext, self.buffer)
XCTAssertEqual(plaintext.readableBytes, 1)
Expand All @@ -77,7 +77,7 @@ final class AESGCMTests: XCTestCase {
let initialKeys = self.generateKeys(keySize: .bits256)

let aes256Encryptor = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: initialKeys))
XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer))
XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequence: 0, to: &self.buffer))

// The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need
// 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total
Expand All @@ -94,7 +94,7 @@ final class AESGCMTests: XCTestCase {
XCTAssertEqual(bufferCopy, self.buffer)

/// After decryption the plaintext should be a newKeys message.
var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy))
var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequence: 0))
XCTAssertEqual(bufferCopy.readableBytes, 0)
XCTAssertNotEqual(plaintext, self.buffer)
XCTAssertEqual(plaintext.readableBytes, 1)
Expand Down Expand Up @@ -300,7 +300,7 @@ final class AESGCMTests: XCTestCase {
buffer.clear()
buffer.writeRepeatingByte(42, count: ciphertextSize)

XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength)
}
}
Expand All @@ -320,7 +320,7 @@ final class AESGCMTests: XCTestCase {
buffer.clear()
buffer.writeRepeatingByte(42, count: ciphertextSize)

XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength)
}
}
Expand Down Expand Up @@ -350,7 +350,7 @@ final class AESGCMTests: XCTestCase {

// We can now attempt to decrypt this packet.
let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys))
XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding)
}
}
Expand Down Expand Up @@ -379,7 +379,7 @@ final class AESGCMTests: XCTestCase {

// We can now attempt to decrypt this packet.
let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys))
XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding)
}
}
Expand Down Expand Up @@ -408,7 +408,7 @@ final class AESGCMTests: XCTestCase {

// We can now attempt to decrypt this packet.
let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys))
XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding)
}
}
Expand Down Expand Up @@ -437,7 +437,7 @@ final class AESGCMTests: XCTestCase {

// We can now attempt to decrypt this packet.
let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys))
XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in
XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)) { error in
XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding)
}
}
Expand Down
8 changes: 4 additions & 4 deletions Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase {
var buffer = ByteBufferAllocator().buffer(capacity: 1024)

do {
try client.encryptPacket(.init(message: message), to: &buffer)
try client.encryptPacket(.init(message: message), sequence: 0, to: &buffer)
try server.decryptFirstBlock(&buffer)
var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer)
var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)
let decrypted = try messageBuffer.readSSHMessage()
XCTAssertEqual(message, decrypted)
XCTAssertEqual(0, buffer.readableBytes)
Expand All @@ -157,9 +157,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase {
buffer.clear()

do {
try server.encryptPacket(.init(message: message), to: &buffer)
try server.encryptPacket(.init(message: message), sequence: 0, to: &buffer)
try client.decryptFirstBlock(&buffer)
var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer)
var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)
let decrypted = try messageBuffer.readSSHMessage()
XCTAssertEqual(message, decrypted)
XCTAssertEqual(0, buffer.readableBytes)
Expand Down
4 changes: 2 additions & 2 deletions Tests/NIOSSHTests/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class TestTransportProtection: NIOSSHTransportProtection {
source.setBytes(plaintext.readableBytesView, at: index)
}

func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer {
func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequence: UInt32) throws -> ByteBuffer {
defer {
self.lastFirstBlock = nil
}
Expand Down Expand Up @@ -211,7 +211,7 @@ class TestTransportProtection: NIOSSHTransportProtection {
return plaintext.readSlice(length: plaintext.readableBytes - Int(paddingLength))!
}

func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws {
func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequence: UInt32, to outboundBuffer: inout ByteBuffer) throws {
let packetLengthIndex = outboundBuffer.writerIndex
let packetLengthLength = MemoryLayout<UInt32>.size
let packetPaddingIndex = outboundBuffer.writerIndex + packetLengthLength
Expand Down
4 changes: 2 additions & 2 deletions Tests/NIOSSHTests/UtilitiesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ final class UtilitiesTests: XCTestCase {
let message = SSHMessage.channelRequest(.init(recipientChannel: 1, type: .exec("uname"), wantReply: false))
let allocator = ByteBufferAllocator()
var buffer = allocator.buffer(capacity: 1024)
XCTAssertNoThrow(try client.encryptPacket(.init(message: message), to: &buffer))
XCTAssertNoThrow(try client.encryptPacket(.init(message: message), sequence: 0, to: &buffer))
XCTAssertNoThrow(try server.decryptFirstBlock(&buffer))
var decoded = try server.decryptAndVerifyRemainingPacket(&buffer)
var decoded = try server.decryptAndVerifyRemainingPacket(&buffer, sequence: 0)
XCTAssertEqual(message, try decoded.readSSHMessage())
}
}

0 comments on commit 48ae0bb

Please sign in to comment.