From 6bfb94502429871345b4550fdb2446d8f2d350fa Mon Sep 17 00:00:00 2001 From: Artem Redkin Date: Wed, 13 Oct 2021 11:52:06 +0100 Subject: [PATCH] Adds support for custom transport protection algorithms Motivation: In some cases AES GCM might not be supported by a server/client, and clients might require custom algorithms to be implemented in their projects. Modifications: - Makes NIOSSHTransportProtection, NIOSSHSessionKeys and ExpectedKeySizes public - Makes transport algorithms list configurable - Adds sequence number of incoming and outgoing messages to parser and serializer - Tests --- Sources/NIOSSH/ByteBuffer+SSH.swift | 2 +- .../SSHConnectionStateMachine.swift | 6 +- .../Key Exchange/SSHKeyExchangeResult.swift | 39 ++++++--- Sources/NIOSSH/NIOSSHHandler.swift | 5 +- Sources/NIOSSH/Role.swift | 14 ++++ Sources/NIOSSH/SSHClientConfiguration.swift | 11 +++ Sources/NIOSSH/SSHPacketParser.swift | 8 +- Sources/NIOSSH/SSHPacketSerializer.swift | 5 +- Sources/NIOSSH/SSHServerConfiguration.swift | 20 ++++- .../NIOSSH/TransportProtection/AESGCM.swift | 4 +- .../SSHTransportProtection.swift | 6 +- Tests/NIOSSHTests/AESGCMTests.swift | 20 ++--- .../SSHKeyExchangeStateMachineTests.swift | 8 +- .../SSHPackerSerializerTests.swift | 47 +++++++++++ Tests/NIOSSHTests/SSHPacketParserTests.swift | 79 +++++++++++++++++++ Tests/NIOSSHTests/Utilities.swift | 4 +- Tests/NIOSSHTests/UtilitiesTests.swift | 4 +- 17 files changed, 234 insertions(+), 48 deletions(-) diff --git a/Sources/NIOSSH/ByteBuffer+SSH.swift b/Sources/NIOSSH/ByteBuffer+SSH.swift index def36c5..0349803 100644 --- a/Sources/NIOSSH/ByteBuffer+SSH.swift +++ b/Sources/NIOSSH/ByteBuffer+SSH.swift @@ -169,7 +169,7 @@ extension ByteBuffer { /// Writes a given number of SSH-acceptable padding bytes to this buffer. @discardableResult - mutating func writeSSHPaddingBytes(count: Int) -> Int { + public mutating func writeSSHPaddingBytes(count: Int) -> Int { // Annoyingly, the system random number generator can only give bytes to us 8 bytes at a time. precondition(count >= 0, "Cannot write negative number of padding bytes: \(count)") diff --git a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift index 56d473b..327cb27 100644 --- a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift +++ b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift @@ -60,11 +60,7 @@ struct SSHConnectionStateMachine { /// The state of this state machine. private var state: State - private static let defaultTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [ - AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self, - ] - - init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Self.defaultTransportProtectionSchemes) { + init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = SSHConnectionRole.bundledTransportProtectionSchemes) { self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes)) } diff --git a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift index 8c9e394..b12f14c 100644 --- a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift +++ b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift @@ -46,18 +46,27 @@ extension KeyExchangeResult: Equatable {} /// Of these types, the encryption keys and the MAC keys are intended to be secret, and so /// we store them in the `SymmetricKey` types. The IVs do not need to be secret, and so are /// stored in regular heap buffers. -struct NIOSSHSessionKeys { - var initialInboundIV: [UInt8] +public struct NIOSSHSessionKeys { + public var initialInboundIV: [UInt8] - var initialOutboundIV: [UInt8] + public var initialOutboundIV: [UInt8] - var inboundEncryptionKey: SymmetricKey + public var inboundEncryptionKey: SymmetricKey - var outboundEncryptionKey: SymmetricKey + public var outboundEncryptionKey: SymmetricKey - var inboundMACKey: SymmetricKey + public var inboundMACKey: SymmetricKey - var outboundMACKey: SymmetricKey + public var outboundMACKey: SymmetricKey + + public init(initialInboundIV: [UInt8], initialOutboundIV: [UInt8], inboundEncryptionKey: SymmetricKey, outboundEncryptionKey: SymmetricKey, inboundMACKey: SymmetricKey, outboundMACKey: SymmetricKey) { + self.initialInboundIV = initialInboundIV + self.initialOutboundIV = initialOutboundIV + self.inboundEncryptionKey = inboundEncryptionKey + self.outboundEncryptionKey = outboundEncryptionKey + self.inboundMACKey = inboundMACKey + self.outboundMACKey = outboundMACKey + } } extension NIOSSHSessionKeys: Equatable {} @@ -68,10 +77,18 @@ extension NIOSSHSessionKeys: Equatable {} /// hash function invocations. The output of these hash functions is truncated to an appropriate /// length as needed, which means we need to ensure the code doing the calculation knows how /// to truncate appropriately. -struct ExpectedKeySizes { - var ivSize: Int +public struct ExpectedKeySizes { + public var ivSize: Int + + public var encryptionKeySize: Int - var encryptionKeySize: Int + public var macKeySize: Int - var macKeySize: Int + public init(ivSize: Int, encryptionKeySize: Int, macKeySize: Int) { + self.ivSize = ivSize + self.encryptionKeySize = encryptionKeySize + self.macKeySize = macKeySize + } } + +extension ExpectedKeySizes: Hashable {} diff --git a/Sources/NIOSSH/NIOSSHHandler.swift b/Sources/NIOSSH/NIOSSHHandler.swift index 0428851..eece566 100644 --- a/Sources/NIOSSH/NIOSSHHandler.swift +++ b/Sources/NIOSSH/NIOSSHHandler.swift @@ -57,8 +57,9 @@ public final class NIOSSHHandler { private var pendingGlobalRequestResponses: CircularBuffer - public init(role: SSHConnectionRole, allocator: ByteBufferAllocator, inboundChildChannelInitializer: ((Channel, SSHChannelType) -> EventLoopFuture)?) { - self.stateMachine = SSHConnectionStateMachine(role: role) + public init(role: SSHConnectionRole, allocator: ByteBufferAllocator, + inboundChildChannelInitializer: ((Channel, SSHChannelType) -> EventLoopFuture)?) { + self.stateMachine = SSHConnectionStateMachine(role: role, protectionSchemes: role.transportProtectionSchemes) self.pendingWrite = false self.outboundFrameBuffer = allocator.buffer(capacity: 1024) self.pendingChannelInitializations = CircularBuffer(initialCapacity: 4) diff --git a/Sources/NIOSSH/Role.swift b/Sources/NIOSSH/Role.swift index 2af2818..1bfdd54 100644 --- a/Sources/NIOSSH/Role.swift +++ b/Sources/NIOSSH/Role.swift @@ -14,6 +14,10 @@ /// The role of a given party in an SSH connection. public enum SSHConnectionRole { + public static let bundledTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [ + AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self, + ] + case client(SSHClientConfiguration) case server(SSHServerConfiguration) @@ -34,4 +38,14 @@ public enum SSHConnectionRole { return true } } + + internal var transportProtectionSchemes: [NIOSSHTransportProtection.Type] { + switch self { + case .client(let configuration): + return configuration.transportProtectionSchemes + case .server(let configuration): + return configuration.transportProtectionSchemes + } + } } + diff --git a/Sources/NIOSSH/SSHClientConfiguration.swift b/Sources/NIOSSH/SSHClientConfiguration.swift index 7ecbbc4..a68d580 100644 --- a/Sources/NIOSSH/SSHClientConfiguration.swift +++ b/Sources/NIOSSH/SSHClientConfiguration.swift @@ -23,11 +23,22 @@ public struct SSHClientConfiguration { /// The global request delegate to be used with this client. public var globalRequestDelegate: GlobalRequestDelegate + /// Supported data encryption algorithms + public var transportProtectionSchemes: [NIOSSHTransportProtection.Type] + public init(userAuthDelegate: NIOSSHClientUserAuthenticationDelegate, serverAuthDelegate: NIOSSHClientServerAuthenticationDelegate, globalRequestDelegate: GlobalRequestDelegate? = nil) { + self.init(userAuthDelegate: userAuthDelegate, serverAuthDelegate: serverAuthDelegate, globalRequestDelegate: globalRequestDelegate, transportProtectionSchemes: SSHConnectionRole.bundledTransportProtectionSchemes) + } + + public init(userAuthDelegate: NIOSSHClientUserAuthenticationDelegate, + serverAuthDelegate: NIOSSHClientServerAuthenticationDelegate, + globalRequestDelegate: GlobalRequestDelegate? = nil, + transportProtectionSchemes: [NIOSSHTransportProtection.Type]) { self.userAuthDelegate = userAuthDelegate self.serverAuthDelegate = serverAuthDelegate self.globalRequestDelegate = globalRequestDelegate ?? DefaultGlobalRequestDelegate() + self.transportProtectionSchemes = transportProtectionSchemes } } diff --git a/Sources/NIOSSH/SSHPacketParser.swift b/Sources/NIOSSH/SSHPacketParser.swift index 1d70c94..7e739ee 100644 --- a/Sources/NIOSSH/SSHPacketParser.swift +++ b/Sources/NIOSSH/SSHPacketParser.swift @@ -25,6 +25,7 @@ struct SSHPacketParser { private var buffer: ByteBuffer private var state: State + private(set) var sequenceNumber: UInt32 /// Testing only: the number of bytes we can discard from this buffer. internal var _discardableBytes: Int { @@ -34,6 +35,7 @@ struct SSHPacketParser { init(allocator: ByteBufferAllocator) { self.buffer = allocator.buffer(capacity: 0) self.state = .initialized + self.sequenceNumber = 0 } mutating func append(bytes: inout ByteBuffer) { @@ -73,6 +75,7 @@ struct SSHPacketParser { if let length = self.buffer.getInteger(at: self.buffer.readerIndex, as: UInt32.self) { if let message = try self.parsePlaintext(length: length) { self.state = .cleartextWaitingForLength + self.sequenceNumber &+= 1 return message } self.state = .cleartextWaitingForBytes(length) @@ -82,6 +85,7 @@ struct SSHPacketParser { case .cleartextWaitingForBytes(let length): if let message = try self.parsePlaintext(length: length) { self.state = .cleartextWaitingForLength + self.sequenceNumber &+= 1 return message } return nil @@ -92,6 +96,7 @@ struct SSHPacketParser { if let message = try self.parseCiphertext(length: length, protection: protection) { self.state = .encryptedWaitingForLength(protection) + self.sequenceNumber &+= 1 return message } self.state = .encryptedWaitingForBytes(length, protection) @@ -99,6 +104,7 @@ struct SSHPacketParser { case .encryptedWaitingForBytes(let length, let protection): if let message = try self.parseCiphertext(length: length, protection: protection) { self.state = .encryptedWaitingForLength(protection) + self.sequenceNumber &+= 1 return message } return nil @@ -169,7 +175,7 @@ struct SSHPacketParser { return nil } - var content = try protection.decryptAndVerifyRemainingPacket(&buffer) + var content = try protection.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: self.sequenceNumber) guard let message = try content.readSSHMessage(), content.readableBytes == 0, buffer.readableBytes == 0 else { // Throw this error if the content wasn't exactly the right length for the message. throw NIOSSHError.invalidPacketFormat diff --git a/Sources/NIOSSH/SSHPacketSerializer.swift b/Sources/NIOSSH/SSHPacketSerializer.swift index 1ab96b7..9cf7351 100644 --- a/Sources/NIOSSH/SSHPacketSerializer.swift +++ b/Sources/NIOSSH/SSHPacketSerializer.swift @@ -22,6 +22,7 @@ struct SSHPacketSerializer { } private var state: State = .initialized + private(set) var sequenceNumber: UInt32 = 0 /// Encryption schemes can be added to a packet serializer whenever encryption is negotiated. mutating func addEncryption(_ protection: NIOSSHTransportProtection) { @@ -75,9 +76,11 @@ struct SSHPacketSerializer { buffer.setInteger(UInt8(paddingLength), at: index + 4) /// random padding buffer.writeSSHPaddingBytes(count: paddingLength) + self.sequenceNumber &+= 1 case .encrypted(let protection): let payload = NIOSSHEncryptablePayload(message: message) - try protection.encryptPacket(payload, to: &buffer) + try protection.encryptPacket(payload, sequenceNumber: self.sequenceNumber, to: &buffer) + self.sequenceNumber &+= 1 } } } diff --git a/Sources/NIOSSH/SSHServerConfiguration.swift b/Sources/NIOSSH/SSHServerConfiguration.swift index fcd8ed0..20a6dbb 100644 --- a/Sources/NIOSSH/SSHServerConfiguration.swift +++ b/Sources/NIOSSH/SSHServerConfiguration.swift @@ -26,16 +26,28 @@ public struct SSHServerConfiguration { /// The ssh banner to display to clients upon authentication public var banner: UserAuthBanner? + /// Supported data encryption algorithms + public var transportProtectionSchemes: [NIOSSHTransportProtection.Type] + public init(hostKeys: [NIOSSHPrivateKey], userAuthDelegate: NIOSSHServerUserAuthenticationDelegate, globalRequestDelegate: GlobalRequestDelegate? = nil, banner: UserAuthBanner? = nil) { - self.hostKeys = hostKeys - self.userAuthDelegate = userAuthDelegate - self.globalRequestDelegate = globalRequestDelegate ?? DefaultGlobalRequestDelegate() - self.banner = banner + self.init(hostKeys: hostKeys, userAuthDelegate: userAuthDelegate, globalRequestDelegate: globalRequestDelegate, banner: banner, transportProtectionSchemes: SSHConnectionRole.bundledTransportProtectionSchemes) } public init(hostKeys: [NIOSSHPrivateKey], userAuthDelegate: NIOSSHServerUserAuthenticationDelegate, globalRequestDelegate: GlobalRequestDelegate? = nil) { self.init(hostKeys: hostKeys, userAuthDelegate: userAuthDelegate, globalRequestDelegate: globalRequestDelegate, banner: nil) } + + public init(hostKeys: [NIOSSHPrivateKey], + userAuthDelegate: NIOSSHServerUserAuthenticationDelegate, + globalRequestDelegate: GlobalRequestDelegate? = nil, + banner: UserAuthBanner? = nil, + transportProtectionSchemes: [NIOSSHTransportProtection.Type]) { + self.hostKeys = hostKeys + self.userAuthDelegate = userAuthDelegate + self.globalRequestDelegate = globalRequestDelegate ?? DefaultGlobalRequestDelegate() + self.banner = banner + self.transportProtectionSchemes = transportProtectionSchemes + } } // MARK: - UserAuthBanner diff --git a/Sources/NIOSSH/TransportProtection/AESGCM.swift b/Sources/NIOSSH/TransportProtection/AESGCM.swift index 3cc1503..b100f9b 100644 --- a/Sources/NIOSSH/TransportProtection/AESGCM.swift +++ b/Sources/NIOSSH/TransportProtection/AESGCM.swift @@ -79,7 +79,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection { // unencrypted! } - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer { + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber _: UInt32) throws -> ByteBuffer { var plaintext: Data // Establish a nested scope here to avoid the byte buffer views causing an accidental CoW. @@ -117,7 +117,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection { return source.readSlice(length: plaintext.count)! } - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws { + func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequenceNumber _: UInt32, to outboundBuffer: inout ByteBuffer) throws { // Keep track of where the length is going to be written. let packetLengthIndex = outboundBuffer.writerIndex let packetLengthLength = MemoryLayout.size diff --git a/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift b/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift index 5b7ba97..9ed6d20 100644 --- a/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift +++ b/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift @@ -44,7 +44,7 @@ import NIOCore /// Implementers of this protocol **must not** expose unauthenticated plaintext, except for the length field. This /// is required by the SSH protocol, and swift-nio-ssh does its best to treat the length field as fundamentally /// untrusted information. -protocol NIOSSHTransportProtection: AnyObject { +public protocol NIOSSHTransportProtection: AnyObject { /// The name of the cipher portion of this transport protection scheme as negotiated on the wire. static var cipherName: String { get } @@ -87,10 +87,10 @@ protocol NIOSSHTransportProtection: AnyObject { /// length, the padding, or the MAC), and update source to indicate the consumed bytes. /// It must also perform any integrity checking that /// is required and throw if the integrity check fails. - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer /// Encrypt an entire outbound packet - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws + func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequenceNumber: UInt32, to outboundBuffer: inout ByteBuffer) throws } extension NIOSSHTransportProtection { diff --git a/Tests/NIOSSHTests/AESGCMTests.swift b/Tests/NIOSSHTests/AESGCMTests.swift index 2b46fbc..7d0178a 100644 --- a/Tests/NIOSSHTests/AESGCMTests.swift +++ b/Tests/NIOSSHTests/AESGCMTests.swift @@ -42,7 +42,7 @@ final class AESGCMTests: XCTestCase { let initialKeys = self.generateKeys(keySize: .bits128) let aes128Encryptor = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: initialKeys)) - XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer)) + XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequenceNumber: 0, to: &self.buffer)) // The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need // 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total @@ -59,7 +59,7 @@ final class AESGCMTests: XCTestCase { XCTAssertEqual(bufferCopy, self.buffer) /// After decryption the plaintext should be a newKeys message. - var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy)) + var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequenceNumber: 0)) XCTAssertEqual(bufferCopy.readableBytes, 0) XCTAssertNotEqual(plaintext, self.buffer) XCTAssertEqual(plaintext.readableBytes, 1) @@ -77,7 +77,7 @@ final class AESGCMTests: XCTestCase { let initialKeys = self.generateKeys(keySize: .bits256) let aes256Encryptor = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: initialKeys)) - XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer)) + XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequenceNumber: 0, to: &self.buffer)) // The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need // 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total @@ -94,7 +94,7 @@ final class AESGCMTests: XCTestCase { XCTAssertEqual(bufferCopy, self.buffer) /// After decryption the plaintext should be a newKeys message. - var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy)) + var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequenceNumber: 0)) XCTAssertEqual(bufferCopy.readableBytes, 0) XCTAssertNotEqual(plaintext, self.buffer) XCTAssertEqual(plaintext.readableBytes, 1) @@ -300,7 +300,7 @@ final class AESGCMTests: XCTestCase { buffer.clear() buffer.writeRepeatingByte(42, count: ciphertextSize) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength) } } @@ -320,7 +320,7 @@ final class AESGCMTests: XCTestCase { buffer.clear() buffer.writeRepeatingByte(42, count: ciphertextSize) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength) } } @@ -350,7 +350,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding) } } @@ -379,7 +379,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding) } } @@ -408,7 +408,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding) } } @@ -437,7 +437,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding) } } diff --git a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift index dbc4ab7..2a1a595 100644 --- a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift @@ -144,9 +144,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { var buffer = ByteBufferAllocator().buffer(capacity: 1024) do { - try client.encryptPacket(.init(message: message), to: &buffer) + try client.encryptPacket(.init(message: message), sequenceNumber: 0, to: &buffer) try server.decryptFirstBlock(&buffer) - var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer) + var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0) let decrypted = try messageBuffer.readSSHMessage() XCTAssertEqual(message, decrypted) XCTAssertEqual(0, buffer.readableBytes) @@ -158,9 +158,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { buffer.clear() do { - try server.encryptPacket(.init(message: message), to: &buffer) + try server.encryptPacket(.init(message: message), sequenceNumber: 0, to: &buffer) try client.decryptFirstBlock(&buffer) - var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer) + var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0) let decrypted = try messageBuffer.readSSHMessage() XCTAssertEqual(message, decrypted) XCTAssertEqual(0, buffer.readableBytes) diff --git a/Tests/NIOSSHTests/SSHPackerSerializerTests.swift b/Tests/NIOSSHTests/SSHPackerSerializerTests.swift index 312f014..6332973 100644 --- a/Tests/NIOSSHTests/SSHPackerSerializerTests.swift +++ b/Tests/NIOSSHTests/SSHPackerSerializerTests.swift @@ -44,6 +44,7 @@ final class SSHPacketSerializerTests: XCTestCase { XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) XCTAssertEqual("SSH-2.0-SwiftSSH_1.0\r\n", buffer.readString(length: buffer.readableBytes)) + XCTAssertEqual(0, serializer.sequenceNumber) } func testDisconnectMessage() throws { @@ -56,6 +57,7 @@ final class SSHPacketSerializerTests: XCTestCase { var buffer = allocator.buffer(capacity: 20) XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -80,6 +82,7 @@ final class SSHPacketSerializerTests: XCTestCase { XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) XCTAssertEqual([0, 0, 0, 28, 10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104], buffer.getBytes(at: 0, length: 22)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -102,6 +105,7 @@ final class SSHPacketSerializerTests: XCTestCase { XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) XCTAssertEqual([0, 0, 0, 28, 10, 6, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97, 117, 116, 104], buffer.getBytes(at: 0, length: 22)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -135,6 +139,7 @@ final class SSHPacketSerializerTests: XCTestCase { var buffer = allocator.buffer(capacity: 20) XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -166,6 +171,7 @@ final class SSHPacketSerializerTests: XCTestCase { var buffer = allocator.buffer(capacity: 20) XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -193,6 +199,7 @@ final class SSHPacketSerializerTests: XCTestCase { var buffer = allocator.buffer(capacity: 20) XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -225,6 +232,7 @@ final class SSHPacketSerializerTests: XCTestCase { var buffer = allocator.buffer(capacity: 5) XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) parser.append(bytes: &buffer) switch try parser.nextPacket() { @@ -234,4 +242,43 @@ final class SSHPacketSerializerTests: XCTestCase { XCTFail("Expecting .newKeys") } } + + func testSequencePreservedBetweenPlainAndCypher() { + let message = SSHMessage.newKeys + let allocator = ByteBufferAllocator() + var serializer = SSHPacketSerializer() + var parser = SSHPacketParser(allocator: allocator) + + self.runVersionHandshake(serializer: &serializer, parser: &parser) + + var buffer = allocator.buffer(capacity: 5) + XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(1, serializer.sequenceNumber) + + buffer = allocator.buffer(capacity: 5) + XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(2, serializer.sequenceNumber) + + let inboundEncryptionKey = SymmetricKey(size: .bits128) + let outboundEncryptionKey = SymmetricKey(size: .bits128) + let inboundMACKey = SymmetricKey(size: .bits128) + let outboundMACKey = SymmetricKey(size: .bits128) + let protection = TestTransportProtection(initialKeys: .init( + initialInboundIV: [], + initialOutboundIV: [], + inboundEncryptionKey: inboundEncryptionKey, + outboundEncryptionKey: outboundEncryptionKey, + inboundMACKey: inboundMACKey, + outboundMACKey: outboundMACKey + )) + + serializer.addEncryption(protection) + buffer = allocator.buffer(capacity: 5) + XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(3, serializer.sequenceNumber) + + buffer = allocator.buffer(capacity: 5) + XCTAssertNoThrow(try serializer.serialize(message: message, to: &buffer)) + XCTAssertEqual(4, serializer.sequenceNumber) + } } diff --git a/Tests/NIOSSHTests/SSHPacketParserTests.swift b/Tests/NIOSSHTests/SSHPacketParserTests.swift index 0a7192c..dd9116e 100644 --- a/Tests/NIOSSHTests/SSHPacketParserTests.swift +++ b/Tests/NIOSSHTests/SSHPacketParserTests.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Crypto import NIOCore @testable import NIOSSH import XCTest @@ -29,6 +30,7 @@ final class SSHPacketParserTests: XCTestCase { switch packet { case .some(.version(let string)): + XCTAssertEqual(0, parser.sequenceNumber) XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.9", file: file, line: line) default: XCTFail("Expecting .version", file: file, line: line) @@ -48,6 +50,7 @@ final class SSHPacketParserTests: XCTestCase { switch try parser.nextPacket() { case .version(let string): + XCTAssertEqual(0, parser.sequenceNumber) XCTAssertEqual(string, "SSH-2.0-OpenSSH_7.9") default: XCTFail("Expecting .version") @@ -119,20 +122,25 @@ final class SSHPacketParserTests: XCTestCase { parser.append(bytes: &part1) XCTAssertNil(try parser.nextPacket()) + XCTAssertEqual(0, parser.sequenceNumber) var part2 = ByteBuffer.of(bytes: [28]) parser.append(bytes: &part2) XCTAssertNil(try parser.nextPacket()) + XCTAssertEqual(0, parser.sequenceNumber) + var part3 = ByteBuffer.of(bytes: [10, 5, 0, 0, 0, 12, 115, 115, 104, 45, 117, 115, 101, 114, 97]) parser.append(bytes: &part3) XCTAssertNil(try parser.nextPacket()) + XCTAssertEqual(0, parser.sequenceNumber) var part4 = ByteBuffer.of(bytes: [117, 116, 104, 42, 111, 216, 12, 226, 248, 144, 175, 157, 207]) parser.append(bytes: &part4) switch try parser.nextPacket() { case .serviceRequest(let message): + XCTAssertEqual(1, parser.sequenceNumber) XCTAssertEqual(message.service, "ssh-userauth") default: XCTFail("Expecting .serviceRequest") @@ -148,6 +156,7 @@ final class SSHPacketParserTests: XCTestCase { switch try parser.nextPacket() { case .serviceRequest(let message): + XCTAssertEqual(1, parser.sequenceNumber) XCTAssertEqual(message.service, "ssh-userauth") default: XCTFail("Expecting .serviceRequest") @@ -163,12 +172,14 @@ final class SSHPacketParserTests: XCTestCase { switch try parser.nextPacket() { case .serviceRequest(let message): + XCTAssertEqual(1, parser.sequenceNumber) XCTAssertEqual(message.service, "ssh-userauth") default: XCTFail("Expecting .serviceRequest") } switch try parser.nextPacket() { case .serviceRequest(let message): + XCTAssertEqual(2, parser.sequenceNumber) XCTAssertEqual(message.service, "ssh-userauth") default: XCTFail("Expecting .serviceRequest") @@ -199,6 +210,74 @@ final class SSHPacketParserTests: XCTestCase { // Now we should have cleared up. XCTAssertEqual(parser._discardableBytes, 0) } + + func testSequencePreservedBetweenPlainAndCypher() throws { + let allocator = ByteBufferAllocator() + var parser = SSHPacketParser(allocator: allocator) + self.feedVersion(to: &parser) + + var part = ByteBuffer(bytes: [0, 0, 0, 12, 10, 21, 41, 114, 125, 250, 3, 79, 3, 217, 166, 136]) + parser.append(bytes: &part) + + switch try parser.nextPacket() { + case .newKeys: + XCTAssertEqual(1, parser.sequenceNumber) + default: + XCTFail("Expecting .newKeys") + } + + part = ByteBuffer(bytes: [0, 0, 0, 12, 10, 21, 41, 114, 125, 250, 3, 79, 3, 217, 166, 136]) + parser.append(bytes: &part) + + switch try parser.nextPacket() { + case .newKeys: + XCTAssertEqual(2, parser.sequenceNumber) + default: + XCTFail("Expecting .newKeys") + } + + let inboundEncryptionKey = SymmetricKey(size: .bits128) + let outboundEncryptionKey = inboundEncryptionKey + let inboundMACKey = SymmetricKey(size: .bits128) + let outboundMACKey = inboundMACKey + let protection = TestTransportProtection(initialKeys: .init( + initialInboundIV: [], + initialOutboundIV: [], + inboundEncryptionKey: inboundEncryptionKey, + outboundEncryptionKey: outboundEncryptionKey, + inboundMACKey: inboundMACKey, + outboundMACKey: outboundMACKey + )) + parser.addEncryption(protection) + + part = allocator.buffer(capacity: 1024) + XCTAssertNoThrow(try protection.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequenceNumber: 2, to: &part)) + var subpart = part.readSlice(length: 2)! + parser.append(bytes: &subpart) + + XCTAssertNil(try parser.nextPacket()) + XCTAssertEqual(2, parser.sequenceNumber) + + parser.append(bytes: &part) + + switch try parser.nextPacket() { + case .newKeys: + XCTAssertEqual(3, parser.sequenceNumber) + default: + XCTFail("Expecting .newKeys") + } + + part = allocator.buffer(capacity: 1024) + XCTAssertNoThrow(try protection.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), sequenceNumber: 2, to: &part)) + parser.append(bytes: &part) + + switch try parser.nextPacket() { + case .newKeys: + XCTAssertEqual(4, parser.sequenceNumber) + default: + XCTFail("Expecting .newKeys") + } + } } extension ByteBuffer { diff --git a/Tests/NIOSSHTests/Utilities.swift b/Tests/NIOSSHTests/Utilities.swift index a2c9299..9d0edea 100644 --- a/Tests/NIOSSHTests/Utilities.swift +++ b/Tests/NIOSSHTests/Utilities.swift @@ -176,7 +176,7 @@ class TestTransportProtection: NIOSSHTransportProtection { source.setBytes(plaintext.readableBytesView, at: index) } - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer { + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer { defer { self.lastFirstBlock = nil } @@ -211,7 +211,7 @@ class TestTransportProtection: NIOSSHTransportProtection { return plaintext.readSlice(length: plaintext.readableBytes - Int(paddingLength))! } - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws { + func encryptPacket(_ packet: NIOSSHEncryptablePayload, sequenceNumber: UInt32, to outboundBuffer: inout ByteBuffer) throws { let packetLengthIndex = outboundBuffer.writerIndex let packetLengthLength = MemoryLayout.size let packetPaddingIndex = outboundBuffer.writerIndex + packetLengthLength diff --git a/Tests/NIOSSHTests/UtilitiesTests.swift b/Tests/NIOSSHTests/UtilitiesTests.swift index 96a4c33..3b085d8 100644 --- a/Tests/NIOSSHTests/UtilitiesTests.swift +++ b/Tests/NIOSSHTests/UtilitiesTests.swift @@ -52,9 +52,9 @@ final class UtilitiesTests: XCTestCase { let message = SSHMessage.channelRequest(.init(recipientChannel: 1, type: .exec("uname"), wantReply: false)) let allocator = ByteBufferAllocator() var buffer = allocator.buffer(capacity: 1024) - XCTAssertNoThrow(try client.encryptPacket(.init(message: message), to: &buffer)) + XCTAssertNoThrow(try client.encryptPacket(.init(message: message), sequenceNumber: 0, to: &buffer)) XCTAssertNoThrow(try server.decryptFirstBlock(&buffer)) - var decoded = try server.decryptAndVerifyRemainingPacket(&buffer) + var decoded = try server.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0) XCTAssertEqual(message, try decoded.readSSHMessage()) } }