Skip to content

Commit

Permalink
PoC: Add netlink socket address support
Browse files Browse the repository at this point in the history
  • Loading branch information
Austinpayne committed Aug 30, 2021
1 parent d8348ad commit b043722
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import PackageDescription

var targets: [PackageDescription.Target] = [
.target(name: "NIOCore",
dependencies: ["NIOConcurrencyHelpers", "CNIOLinux"]),
dependencies: ["NIOConcurrencyHelpers", "CNIOLinux", "CNIODarwin"]),
.target(name: "_NIODataStructures"),
.target(name: "NIOEmbedded", dependencies: ["NIOCore", "_NIODataStructures"]),
.target(name: "NIOPosix",
Expand Down
11 changes: 11 additions & 0 deletions Sources/CNIODarwin/include/CNIODarwin.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,16 @@ void *CNIODarwin_CMSG_DATA_MUTABLE(struct cmsghdr *);
size_t CNIODarwin_CMSG_LEN(size_t);
size_t CNIODarwin_CMSG_SPACE(size_t);

// netlink shims
typedef struct {
unsigned short nl_family; /* AF_NETLINK */
unsigned short nl_pad; /* zero */
unsigned int nl_pid; /* port ID */
unsigned int nl_groups; /* multicast groups mask */
} sockaddr_nl;

#define AF_NETLINK 16
#define PF_NETLINK AF_NETLINK

#endif // __APPLE__
#endif // C_NIO_DARWIN_H
1 change: 1 addition & 0 deletions Sources/CNIOLinux/include/CNIOLinux.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <errno.h>
#include <pthread.h>
#include <netinet/ip.h>
#include <linux/netlink.h>
#include "liburing_nio.h"

// Some explanation is required here.
Expand Down
9 changes: 9 additions & 0 deletions Sources/NIOCore/BSDSocketAPI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutable
private let sysInet_pton: @convention(c) (CInt, UnsafePointer<CChar>?, UnsafeMutableRawPointer?) -> CInt = inet_pton
#elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
import Darwin
import CNIODarwin

private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer<CChar>?, socklen_t) -> UnsafePointer<CChar>? = inet_ntop
private let sysInet_pton: @convention(c) (CInt, UnsafePointer<CChar>?, UnsafeMutableRawPointer?) -> CInt = inet_pton
Expand Down Expand Up @@ -170,6 +171,10 @@ extension NIOBSDSocket.AddressFamily {
/// Unix local to host address.
public static let unix: NIOBSDSocket.AddressFamily =
NIOBSDSocket.AddressFamily(rawValue: AF_UNIX)

/// Address for NETLINK protocol.
public static let netlink: NIOBSDSocket.AddressFamily =
NIOBSDSocket.AddressFamily(rawValue: AF_NETLINK)
}

// Protocol Family
Expand All @@ -185,6 +190,10 @@ extension NIOBSDSocket.ProtocolFamily {
/// UNIX local to the host.
public static let unix: NIOBSDSocket.ProtocolFamily =
NIOBSDSocket.ProtocolFamily(rawValue: PF_UNIX)

/// NETLINK protocol.
public static let netlink: NIOBSDSocket.ProtocolFamily =
NIOBSDSocket.ProtocolFamily(rawValue: PF_NETLINK)
}

#if !os(Windows)
Expand Down
66 changes: 65 additions & 1 deletion Sources/NIOCore/SocketAddresses.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import struct WinSDK.in_addr_t
import typealias WinSDK.u_short
#elseif os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
import Darwin
import CNIODarwin
#elseif os(Linux) || os(FreeBSD) || os(Android)
import Glibc
import CNIOLinux
Expand Down Expand Up @@ -99,6 +100,18 @@ public enum SocketAddress: CustomStringConvertible {
}
}

/// A single Netlink socket address for `SocketAddress`.
public struct NetlinkAddress {
private let _storage: Box<sockaddr_nl>

/// The libc socket address for a Netlink Socket.
public var address: sockaddr_nl { return _storage.value }

fileprivate init(address: sockaddr_nl) {
self._storage = Box(address)
}
}

/// An IPv4 `SocketAddress`.
case v4(IPv4Address)

Expand All @@ -108,6 +121,9 @@ public enum SocketAddress: CustomStringConvertible {
/// An UNIX Domain `SocketAddress`.
case unixDomainSocket(UnixSocketAddress)

/// A Netlink `SocketAddress`.
case netlinkSocket(NetlinkAddress)

/// A human-readable description of this `SocketAddress`. Mostly useful for logging.
public var description: String {
let addressString: String
Expand Down Expand Up @@ -135,6 +151,8 @@ public enum SocketAddress: CustomStringConvertible {
host = nil
type = "UDS"
return "[\(type)]\(self.pathname ?? "")"
case .netlinkSocket(let addr):
return "[NLS]\(addr.address.nl_pid):\(addr.address.nl_groups)"
}

return "[\(type)]\(host.map { "\($0)/\(addressString):" } ?? "\(addressString):")\(port)"
Expand All @@ -154,6 +172,8 @@ public enum SocketAddress: CustomStringConvertible {
return .inet6
case .unixDomainSocket:
return .unix
case .netlinkSocket:
return .netlink
}
}

Expand All @@ -170,6 +190,8 @@ public enum SocketAddress: CustomStringConvertible {
return try! descriptionForAddress(family: .inet6, bytes: &mutAddr, length: Int(INET6_ADDRSTRLEN))
case .unixDomainSocket(_):
return nil
case .netlinkSocket:
return nil
}
}

Expand All @@ -188,6 +210,8 @@ public enum SocketAddress: CustomStringConvertible {
return Int(in_port_t(bigEndian: addr.address.sin6_port))
case .unixDomainSocket:
return nil
case .netlinkSocket:
return nil
}
}
set {
Expand All @@ -202,6 +226,8 @@ public enum SocketAddress: CustomStringConvertible {
self = .v6(.init(address: mutAddr, host: addr.host))
case .unixDomainSocket:
precondition(newValue == nil, "attempting to set a non-nil value to a unix socket is not valid")
case .netlinkSocket:
precondition(newValue == nil, "attempting to set a non-nil value to a netlink socket is not valid")
}
}
}
Expand All @@ -222,6 +248,8 @@ public enum SocketAddress: CustomStringConvertible {
return String(cString: charPtr)
}
return pathname
case .netlinkSocket:
return nil
}
}

Expand All @@ -238,6 +266,9 @@ public enum SocketAddress: CustomStringConvertible {
case .unixDomainSocket(let addr):
var address = addr.address
return try address.withSockAddr({ try body($0, $1) })
case .netlinkSocket(let addr):
var address = addr.address
return try address.withSockAddr({ try body($0, $1) })
}
}

Expand Down Expand Up @@ -285,6 +316,14 @@ public enum SocketAddress: CustomStringConvertible {
self = .unixDomainSocket(.init(address: addr))
}

/// Creates a new Netlink Socket `SocketAddress`.
///
/// - parameters:
/// - addr: the `sockaddr_nl` that holds the socket path.
public init(_ addr: sockaddr_nl) {
self = .netlinkSocket(.init(address: addr))
}

/// Creates a new UDS `SocketAddress`.
///
/// - parameters:
Expand Down Expand Up @@ -498,7 +537,14 @@ extension SocketAddress: Equatable {
return strncmp(typedSunpath1, typedSunpath2, bufferSize) == 0
}
}
case (.v4, _), (.v6, _), (.unixDomainSocket, _):
case (.netlinkSocket(let addr1), .netlinkSocket(let addr2)):
guard addr1.address.nl_family == addr2.address.nl_family,
addr1.address.nl_pid == addr2.address.nl_pid,
addr1.address.nl_groups == addr2.address.nl_groups else {
return false
}
return true
case (.v4, _), (.v6, _), (.unixDomainSocket, _), (.netlinkSocket, _):
return false
}
}
Expand Down Expand Up @@ -539,6 +585,11 @@ extension SocketAddress: Hashable {
withUnsafeBytes(of: v6Addr.address.sin6_addr) {
hasher.combine(bytes: $0)
}
case .netlinkSocket(let nls):
hasher.combine(3)
hasher.combine(nls.address.nl_family)
hasher.combine(nls.address.nl_pid)
hasher.combine(nls.address.nl_groups)
}
}
}
Expand All @@ -565,6 +616,10 @@ extension SocketAddress {
// so we can just ask for equality on the top byte.
var v6WireAddress = v6Addr.address.sin6_addr
return withUnsafeBytes(of: &v6WireAddress) { $0[0] == 0xff }
case .netlinkSocket:
// NETLINK multicast (nls.address.nl_groups != 0) is not the same
// as IP multicast
return false
}
}
}
Expand Down Expand Up @@ -629,6 +684,15 @@ extension sockaddr_un: SockAddrProtocol {
}
}

extension sockaddr_nl: SockAddrProtocol {
mutating func withSockAddr<R>(_ body: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
var me = self
return try withUnsafeBytes(of: &me) { p in
try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count)
}
}
}

extension sockaddr_storage: SockAddrProtocol {
mutating func withSockAddr<R>(_ body: (UnsafePointer<sockaddr>, Int) throws -> R) rethrows -> R {
var me = self
Expand Down
2 changes: 2 additions & 0 deletions Sources/NIOMulticastChat/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ let datagramChannel = try datagramBootstrap
return provider.setIPv6MulticastIF(CUnsignedInt(targetDevice.interfaceIndex)).map { channel }
case .some(.unixDomainSocket):
preconditionFailure("Should not be possible to create a multicast socket on a unix domain socket")
case .some(.netlinkSocket):
preconditionFailure("Should not be possible to create a multicast socket on a netlink socket")
case .none:
preconditionFailure("Should not be possible to create a multicast socket on an interface without an address")
}
Expand Down
22 changes: 22 additions & 0 deletions Sources/NIOPosix/BaseSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ import struct WinSDK.FILE_DISPOSITION_INFO
import struct WinSDK.socklen_t

import CNIOWindows
#elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
import CNIODarwin
#elseif os(Linux) || os(FreeBSD) || os(Android)
import CNIOLinux
#endif

protocol Registration {
Expand Down Expand Up @@ -115,6 +119,18 @@ extension sockaddr_storage {
}
}

/// Converts the `socketaddr_storage` to a `sockaddr_nl`.
///
/// This will crash if `ss_family` != AF_NETLINK!
mutating func convert() -> sockaddr_nl {
precondition(self.ss_family == NIOBSDSocket.AddressFamily.netlink.rawValue)
return withUnsafePointer(to: &self) {
$0.withMemoryRebound(to: sockaddr_nl.self, capacity: 1) {
$0.pointee
}
}
}

/// Converts the `socketaddr_storage` to a `SocketAddress`.
mutating func convert() -> SocketAddress {
switch NIOBSDSocket.AddressFamily(rawValue: CInt(self.ss_family)) {
Expand All @@ -126,6 +142,8 @@ extension sockaddr_storage {
return SocketAddress(sockAddr)
case .unix:
return SocketAddress(self.convert() as sockaddr_un)
case .netlink:
return SocketAddress(self.convert() as sockaddr_nl)
default:
fatalError("unknown sockaddr family \(self.ss_family)")
}
Expand All @@ -149,6 +167,10 @@ extension UnsafeMutablePointer where Pointee == sockaddr {
return self.withMemoryRebound(to: sockaddr_un.self, capacity: 1) {
SocketAddress($0.pointee)
}
case .netlink:
return self.withMemoryRebound(to: sockaddr_nl.self, capacity: 1) {
SocketAddress($0.pointee)
}
default:
return nil
}
Expand Down
2 changes: 2 additions & 0 deletions Sources/NIOPosix/PendingDatagramWritesManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ private struct PendingDatagramWrite {
}
case .unixDomainSocket:
fatalError("UDS with datagrams is currently not supported")
case .netlinkSocket:
fatalError("NLS with datagrams is currently not supported")
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions Sources/NIOPosix/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ typealias IOVector = iovec
return try self.connectSocket(addr: addr.address)
case .unixDomainSocket(let addr):
return try self.connectSocket(addr: addr.address)
case .netlinkSocket(let addr):
return try self.connectSocket(addr: addr.address)
}
}

Expand Down
9 changes: 8 additions & 1 deletion Sources/NIOPosix/SocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,8 @@ extension DatagramChannel: MulticastChannel {
// Ok, we now have reason to believe this will actually work. We need to pass this on to the socket.
do {
switch (group, device?.address) {
case (.netlinkSocket, _):
preconditionFailure("Should not be reachable, NETLINK sockets are never multicast addresses")
case (.unixDomainSocket, _):
preconditionFailure("Should not be reachable, UNIX sockets are never multicast addresses")
case (.v4(let groupAddress), .some(.v4(let interfaceAddress))):
Expand All @@ -913,7 +915,12 @@ extension DatagramChannel: MulticastChannel {
// IPv6 binding with no specific interface requested.
let multicastRequest = ipv6_mreq(ipv6mr_multiaddr: groupAddress.address.sin6_addr, ipv6mr_interface: 0)
try self.socket.setOption(level: .ipv6, name: operation.optionName(level: .ipv6), value: multicastRequest)
case (.v4, .some(.v6)), (.v6, .some(.v4)), (.v4, .some(.unixDomainSocket)), (.v6, .some(.unixDomainSocket)):
case (.v4, .some(.v6)),
(.v6, .some(.v4)),
(.v4, .some(.unixDomainSocket)),
(.v6, .some(.unixDomainSocket)),
(.v4, .some(.netlinkSocket)),
(.v6, .some(.netlinkSocket)):
// Mismatched group and interface address: this is an error.
throw ChannelError.badInterfaceAddressFamily
}
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOPosixTests/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ func resolverDebugInformation(eventLoop: EventLoop, host: String, previouslyRece
case .v6(let sa):
var addr = sa.address
return addr.addressDescription()
case .netlinkSocket:
return "nls"
}
}
let res = GetaddrinfoResolver(loop: eventLoop, aiSocktype: .stream, aiProtocol: CInt(IPPROTO_TCP))
Expand Down

0 comments on commit b043722

Please sign in to comment.