diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f3b971f..454e3aa 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -88,3 +88,21 @@ jobs: ./start-connect-server.sh cd ../.. swift test --no-parallel + + integration-test-token: + runs-on: macos-15 + env: + SPARK_CONNECT_AUTHENTICATE_TOKEN: ${{ github.run_id }}-${{ github.run_attempt }} + steps: + - uses: actions/checkout@v4 + - uses: swift-actions/setup-swift@v2.3.0 + with: + swift-version: "6.1" + - name: Test + run: | + curl -LO https://dist.apache.org/repos/dist/dev/spark/v4.0.0-rc4-bin/spark-4.0.0-bin-hadoop3.tgz + tar xvfz spark-4.0.0-bin-hadoop3.tgz + cd spark-4.0.0-bin-hadoop3/sbin + ./start-connect-server.sh + cd ../.. + swift test --no-parallel diff --git a/Sources/SparkConnect/BearerTokenInterceptor.swift b/Sources/SparkConnect/BearerTokenInterceptor.swift new file mode 100644 index 0000000..07040f5 --- /dev/null +++ b/Sources/SparkConnect/BearerTokenInterceptor.swift @@ -0,0 +1,44 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import GRPCCore + +struct BearerTokenInterceptor: ClientInterceptor { + let token: String + + init(token: String) { + self.token = token + } + + func intercept( + request: StreamingClientRequest, + context: ClientContext, + next: ( + _ request: StreamingClientRequest, + _ context: ClientContext + ) async throws -> StreamingClientResponse + ) async throws -> StreamingClientResponse { + var request = request + request.metadata.addString("Bearer \(self.token)", forKey: "Authorization") + + // Forward the request to the next interceptor. + return try await next(request, context) + } +} diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 42f4be9..e0506cc 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -117,7 +117,8 @@ public actor DataFrame: Sendable { transport: .http2NIOPosix( target: .dns(host: spark.client.host, port: spark.client.port), transportSecurity: .plaintext - ) + ), + interceptors: spark.client.getIntercepters() ) { client in return try await f(client) } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 4854159..26f34e6 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -28,6 +28,8 @@ public actor SparkConnectClient { let url: URL let host: String let port: Int + let token: String? + var intercepters: [ClientInterceptor] = [] let userContext: UserContext var sessionID: String? = nil var tags = Set() @@ -36,10 +38,14 @@ public actor SparkConnectClient { /// - Parameters: /// - remote: A string to connect `Spark Connect` server. /// - user: A string for the user ID of this connection. - init(remote: String, user: String) { + init(remote: String, user: String, token: String? = nil) { self.url = URL(string: remote)! self.host = url.host() ?? "localhost" self.port = self.url.port ?? 15002 + self.token = token ?? ProcessInfo.processInfo.environment["SPARK_CONNECT_AUTHENTICATE_TOKEN"] + if let token = self.token { + self.intercepters.append(BearerTokenInterceptor(token: token)) + } self.userContext = user.toUserContext } @@ -75,12 +81,17 @@ public actor SparkConnectClient { transport: .http2NIOPosix( target: .dns(host: self.host, port: self.port), transportSecurity: .plaintext - ) + ), + interceptors: self.intercepters ) { client in return try await f(client) } } + public func getIntercepters() -> [ClientInterceptor] { + return self.intercepters + } + /// Create a ``ConfigRequest`` instance for `Set` operation. /// - Parameter map: A map of key-value string pairs. /// - Returns: A ``ConfigRequest`` instance.