Skip to content

Commit

Permalink
Updated Jim Studt's PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Joannis committed Nov 19, 2022
1 parent c905128 commit f5e22a8
Show file tree
Hide file tree
Showing 19 changed files with 83 additions and 6 deletions.
12 changes: 11 additions & 1 deletion Sources/NIOSSH/Child Channels/ChildChannelOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public struct SSHChildChannelOptions {

/// See: ``SSHChildChannelOptions/Types/PeerMaximumMessageLengthOption``.
public static let peerMaximumMessageLength: SSHChildChannelOptions.Types.PeerMaximumMessageLengthOption = .init()

/// - seealso: `UsernameOption`.
public static let username: SSHChildChannelOptions.Types.UsernameOption = .init()
}

extension SSHChildChannelOptions {
Expand Down Expand Up @@ -61,7 +64,14 @@ extension SSHChildChannelOptions.Types {
/// ``SSHChildChannelOptions/Types/PeerMaximumMessageLengthOption`` allows users to query the maximum packet size value reported by the remote peer for a given channel.
public struct PeerMaximumMessageLengthOption: ChannelOption, Sendable {
public typealias Value = UInt32


public init() {}
}

/// `UsernameOption` allows users to query the authenticated username of the channel.
public struct UsernameOption: ChannelOption {
public typealias Value = String?

public init() {}
}
}
5 changes: 5 additions & 0 deletions Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ extension SSHChannelMultiplexer {
self.erroredChannels.append(channelID)
}
}

// The username which the server accepted in authorization
var username: String? { delegate?.username }
}

// MARK: Calls from SSH handlers.
Expand Down Expand Up @@ -218,6 +221,8 @@ extension SSHChannelMultiplexer {
protocol SSHMultiplexerDelegate {
var channel: Channel? { get }

var username: String? { get }

func writeFromChildChannel(_: SSHMessage, _: EventLoopPromise<Void>?)

func flushFromChildChannel()
Expand Down
2 changes: 2 additions & 0 deletions Sources/NIOSSH/Child Channels/SSHChildChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ extension SSHChildChannel: Channel, ChannelCore {
return self.type! as! Option.Value
case _ as SSHChildChannelOptions.Types.PeerMaximumMessageLengthOption:
return self.peerMaxMessageSize as! Option.Value
case _ as SSHChildChannelOptions.Types.UsernameOption:
return multiplexer.username as! Option.Value
case _ as ChannelOptions.Types.AutoReadOption:
return self.autoRead as! Option.Value
case _ as ChannelOptions.Types.AllowRemoteHalfClosureOption:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
protocol AcceptsUserAuthMessages {
var userAuthStateMachine: UserAuthenticationStateMachine { get set }

var connectionAttributes: SSHConnectionStateMachine.Attributes? { get }

var role: SSHConnectionRole { get }
}

Expand Down Expand Up @@ -60,14 +62,15 @@ extension AcceptsUserAuthMessages {

mutating func receiveUserAuthRequest(_ message: SSHMessage.UserAuthRequestMessage) throws -> SSHConnectionStateMachine.StateMachineInboundProcessResult {
let result = try self.userAuthStateMachine.receiveUserAuthRequest(message)

if let future = result {
var banner: SSHServerConfiguration.UserAuthBanner?
if case .server(let config) = role {
banner = config.banner
}

return .possibleFutureMessage(future.map { Self.transform($0, banner: banner) })
let connectionAttributes = self.connectionAttributes
return .possibleFutureMessage(future.map { Self.transform($0, connectionAttributes: connectionAttributes, username: message.username, banner: banner) })
} else {
return .noMessage
}
Expand Down Expand Up @@ -96,9 +99,11 @@ extension AcceptsUserAuthMessages {
return .event(NIOUserAuthBannerEvent(message: message.message, languageTag: message.languageTag))
}

private static func transform(_ result: NIOSSHUserAuthenticationResponseMessage, banner: SSHServerConfiguration.UserAuthBanner? = nil) -> SSHMultiMessage {
private static func transform(_ result: NIOSSHUserAuthenticationResponseMessage, connectionAttributes: SSHConnectionStateMachine.Attributes?, username: String, banner: SSHServerConfiguration.UserAuthBanner? = nil) -> SSHMultiMessage {
switch result {
case .success:
connectionAttributes?.username = username

if let banner = banner {
// Send banner bundled with auth success to avoid leaking any information to unauthenticated clients.
// Note that this is by no means the only option according to RFC 4252
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,21 @@ struct SSHConnectionStateMachine {
case sentDisconnect(SSHConnectionRole)
}

class Attributes {
var username: String? = nil
}

/// The state of this state machine.
private var state: State

/// Attributes of the connection which can be changed by messages handlers
private let attributes: Attributes

var username: String? { attributes.username }

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Constants.bundledTransportProtectionSchemes) {
self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes))
self.attributes = .init()
self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes, attributes: attributes))
}

func start() -> SSHMultiMessage? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ extension SSHConnectionStateMachine {

internal var sessionIdentifier: ByteBuffer

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: UserAuthenticationState) {
self.role = previous.role
self.serializer = previous.serializer
self.parser = previous.parser
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}

init(_ previous: RekeyingReceivedNewKeysState) {
Expand All @@ -47,6 +50,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}

init(_ previous: RekeyingSentNewKeysState) {
Expand All @@ -56,6 +60,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ extension SSHConnectionStateMachine {

internal var protectionSchemes: [NIOSSHTransportProtection.Type]

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type]) {
internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type], attributes: SSHConnectionStateMachine.Attributes) {
self.role = role
self.serializer = SSHPacketSerializer()
self.protectionSchemes = protectionSchemes
self.connectionAttributes = attributes
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(sentVersionState state: SentVersionState, allocator: ByteBufferAllocator, loop: EventLoop, remoteVersion: String) {
self.role = state.role
self.parser = state.parser
self.serializer = state.serializer
self.remoteVersion = remoteVersion
self.protectionSchemes = state.protectionSchemes
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: state.role, remoteVersion: remoteVersion, protectionSchemes: state.protectionSchemes, previousSessionIdentifier: nil)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extension SSHConnectionStateMachine {

internal var sessionIdentifier: ByteBuffer

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: ActiveState, allocator: ByteBufferAllocator, loop: EventLoop) {
self.role = previous.role
self.serializer = previous.serializer
Expand All @@ -42,6 +44,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: previous.role, remoteVersion: previous.remoteVersion, protectionSchemes: previous.protectionSchemes, previousSessionIdentifier: self.sessionIdentifier)
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ extension SSHConnectionStateMachine {
/// The user auth state machine that drives user authentication.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(keyExchangeState state: KeyExchangeState,
loop: EventLoop) {
self.role = state.role
Expand All @@ -53,6 +55,7 @@ extension SSHConnectionStateMachine {
self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role,
loop: loop,
sessionID: self.sessionIdentifier)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: RekeyingState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -44,6 +46,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: RekeyingState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -44,6 +46,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: ReceivedKexInitWhenActiveState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -43,6 +45,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}

init(_ previousState: SentKexInitWhenActiveState) {
Expand All @@ -53,6 +56,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentitifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extension SSHConnectionStateMachine {

internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: ActiveState, allocator: ByteBufferAllocator, loop: EventLoop) {
self.role = previous.role
self.serializer = previous.serializer
Expand All @@ -42,6 +44,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentitifier = previous.sessionIdentifier
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: self.role, remoteVersion: self.remoteVersion, protectionSchemes: self.protectionSchemes, previousSessionIdentifier: previous.sessionIdentifier)
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ extension SSHConnectionStateMachine {
/// The user auth state machine that drives user authentication.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(keyExchangeState state: KeyExchangeState,
loop: EventLoop) {
self.role = state.role
Expand All @@ -53,6 +55,7 @@ extension SSHConnectionStateMachine {
self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role,
loop: loop,
sessionID: self.sessionIdentifier)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ extension SSHConnectionStateMachine {

var protectionSchemes: [NIOSSHTransportProtection.Type]

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

private let allocator: ByteBufferAllocator

init(idleState state: IdleState, allocator: ByteBufferAllocator) {
Expand All @@ -37,6 +39,7 @@ extension SSHConnectionStateMachine {

self.parser = SSHPacketParser(allocator: allocator)
self.allocator = allocator
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(sentNewKeysState state: SentNewKeysState) {
self.role = state.role
self.parser = state.parser
Expand All @@ -43,6 +45,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = state.remoteVersion
self.protectionSchemes = state.protectionSchemes
self.sessionIdentifier = state.sessionIdentifier
self.connectionAttributes = state.connectionAttributes
}

init(receivedNewKeysState state: ReceivedNewKeysState) {
Expand All @@ -53,6 +56,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = state.remoteVersion
self.protectionSchemes = state.protectionSchemes
self.sessionIdentifier = state.sessionIdentifier
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions Sources/NIOSSH/NIOSSHHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public final class NIOSSHHandler {

private var pendingGlobalRequestResponses: CircularBuffer<PendingGlobalRequestResponse?>

// The authenticated username, if there was one.
var username: String? { stateMachine.username }

/// Construct a new ``NIOSSHHandler``.
///
/// - parameters:
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import XCTest
/// This reduces the testing surface area somewhat, which greatly helps us to test the
/// implementation of the multiplexer and child channels.
final class DummyDelegate: SSHMultiplexerDelegate {
var username : String? = "dummy"

var _channel: EmbeddedChannel = EmbeddedChannel()

var writes: MarkedCircularBuffer<(SSHMessage, EventLoopPromise<Void>?)> = MarkedCircularBuffer(initialCapacity: 8)
Expand Down

0 comments on commit f5e22a8

Please sign in to comment.