diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Service/Predictions/AWSPredictionsService+Transcribe.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Service/Predictions/AWSPredictionsService+Transcribe.swift index 8c4915918a..5c913e7b55 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Service/Predictions/AWSPredictionsService+Transcribe.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Service/Predictions/AWSPredictionsService+Transcribe.swift @@ -31,7 +31,17 @@ extension AWSPredictionsService: AWSTranscribeStreamingServiceBehavior { request.mediaSampleRateHertz = 8_000 transcribeClientDelegate.connectionStatusCallback = { status, error in - if status == .connected { + if status == .closed && error != nil { + guard error != nil else { + return + } + let nsError = error as NSError? + let predictionsError = PredictionsErrorHelper.mapPredictionsServiceError(nsError!) + if case .network = predictionsError { + onEvent(.failed(predictionsError)) + return + } + } else if status == .connected { let headers = [ ":content-type": "audio/wav", ":message-type": "event", diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Internal/NativeWebSocketProvider.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Internal/NativeWebSocketProvider.swift index 5a7ab45647..dd42b1bdd6 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Internal/NativeWebSocketProvider.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Internal/NativeWebSocketProvider.swift @@ -76,6 +76,11 @@ class NativeWebSocketProvider: NSObject, AWSTranscribeStreamingWebSocketProvider self.callbackQueue.async { self.clientDelegate.connectionStatusDidChange(status, withError: error) } + + if error is URLError { + self.webSocketTask.cancel() + return + } case .success(let message): switch message { case .data(let data): diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Utils/PredictionsErrorHelper.swift b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Utils/PredictionsErrorHelper.swift index 8794fea049..1b23739d00 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Utils/PredictionsErrorHelper.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPlugin/Support/Utils/PredictionsErrorHelper.swift @@ -51,7 +51,8 @@ class PredictionsErrorHelper { ) } } - // swiftlint:disable cyclomatic_complexity + + // swiftlint:disable cyclomatic_complexity static func mapPredictionsServiceError(_ error: NSError) -> PredictionsError { let defaultError = PredictionsErrorHelper.getDefaultError(error) @@ -92,11 +93,33 @@ class PredictionsErrorHelper { return defaultError } return AWSTranscribeStreamingErrorMessage.map(errorType) ?? defaultError + case NSURLErrorDomain: + guard let urlError = error as? URLError else { + return defaultError + } + return mapUrlError(urlError) default: return defaultError } } + static func mapUrlError(_ urlError: URLError) -> PredictionsError { + + switch urlError.code { + case .cannotFindHost: + let errorDescription = "The host name for a URL couldn’t be resolved." + let recoverySuggestion = "Please check if you are reaching the correct host." + return PredictionsError.network(errorDescription, recoverySuggestion, urlError) + case .notConnectedToInternet: + // swiftlint:disable:next line_length + let errorDescription = "A network resource was requested, but an internet connection hasn’t been established and can’t be established automatically." + let recoverySuggestion = "Please check your network connectivity status." + return PredictionsError.network(errorDescription, recoverySuggestion, urlError) + default: + return PredictionsError.network(urlError.localizedDescription, "", urlError) + } + } + static func getDefaultError(_ error: NSError) -> PredictionsError { let errorMessage = """ Domain: [\(error.domain) diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Mocks/Service/MockTranscribeStreamingBehavior.swift b/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Mocks/Service/MockTranscribeStreamingBehavior.swift index 3b25d32000..30525024df 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Mocks/Service/MockTranscribeStreamingBehavior.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Mocks/Service/MockTranscribeStreamingBehavior.swift @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 // +import XCTest import Amplify import AWSCore import AWSTranscribeStreaming @@ -14,13 +15,22 @@ class MockTranscribeBehavior: AWSTranscribeStreamingBehavior { var delegate: AWSTranscribeStreamingClientDelegate? var callbackQueue: DispatchQueue? + var connectionResult: AWSTranscribeStreamingClientConnectionStatus? var transcriptionResult: AWSTranscribeStreamingTranscriptResultStream? var error: Error? + var sendEndFrameExpection: XCTestExpectation? + func getTranscribeStreaming() -> AWSTranscribeStreaming { return AWSTranscribeStreaming() } + public func setConnectionResult(result: AWSTranscribeStreamingClientConnectionStatus, + error: Error? = nil) { + connectionResult = result + self.error = error + } + public func setError(error: Error) { transcriptionResult = nil self.error = error @@ -32,10 +42,18 @@ class MockTranscribeBehavior: AWSTranscribeStreamingBehavior { } func startTranscriptionWSS(request: AWSTranscribeStreamingStartStreamTranscriptionRequest) { - delegate?.didReceiveEvent(transcriptionResult, decodingError: error) + if connectionResult != nil && transcriptionResult != nil { + delegate?.connectionStatusDidChange(connectionResult!, withError: error) + delegate?.didReceiveEvent(transcriptionResult, decodingError: error) + } else if connectionResult != nil && transcriptionResult == nil { + delegate?.connectionStatusDidChange(connectionResult!, withError: error) + } else { + delegate?.didReceiveEvent(transcriptionResult, decodingError: error) + } } - func setDelegate(delegate: AWSTranscribeStreamingClientDelegate, callbackQueue: DispatchQueue) { + func setDelegate(delegate: AWSTranscribeStreamingClientDelegate, + callbackQueue: DispatchQueue) { self.delegate = delegate self.callbackQueue = callbackQueue } @@ -45,7 +63,9 @@ class MockTranscribeBehavior: AWSTranscribeStreamingBehavior { } func sendEndFrame() { - + if let sendEndFrameExpection = sendEndFrameExpection { + sendEndFrameExpection.fulfill() + } } func endTranscription() { diff --git a/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Service/PredictionsTest/PredictionsServiceTranscribeTests.swift b/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Service/PredictionsTest/PredictionsServiceTranscribeTests.swift index 5911a2ccae..7f1a2843c9 100644 --- a/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Service/PredictionsTest/PredictionsServiceTranscribeTests.swift +++ b/AmplifyPlugins/Predictions/AWSPredictionsPluginTests/Service/PredictionsTest/PredictionsServiceTranscribeTests.swift @@ -79,7 +79,11 @@ class PredictionsServiceTranscribeTests: XCTestCase { /// func testTranscribeService() { let mockResponse = createMockTranscribeResponse() + + mockTranscribe.setConnectionResult(result: AWSTranscribeStreamingClientConnectionStatus.connected, error: nil) + mockTranscribe.sendEndFrameExpection = expectation(description: "Sent end frame") mockTranscribe.setResult(result: mockResponse) + let expectedTranscription = "This is a test" let resultReceived = expectation(description: "Transcription result should be returned") @@ -126,6 +130,36 @@ class PredictionsServiceTranscribeTests: XCTestCase { waitForExpectations(timeout: 1) } + /// Test whether error is correctly propogated + /// + /// - Given: Predictions service with transcribe behavior + /// - When: + /// - I invoke an invalid request with Unreachable host + /// - Then: + /// - I should get back a connection error + /// + func testTranscribeServiceWithCannotFindHostError() { + let urlError = URLError(.cannotFindHost) + mockTranscribe.setConnectionResult(result: AWSTranscribeStreamingClientConnectionStatus.closed, error: urlError) + + let errorReceived = expectation(description: "Error should be returned") + + predictionsService.transcribe(speechToText: audioFile, language: .usEnglish) { event in + switch event { + case .completed(let result): + XCTFail("Should not produce result: \(result)") + case .failed(let error): + guard case .network = error else { + XCTFail("Should produce an network error instead of \(error)") + return + } + errorReceived.fulfill() + } + } + + waitForExpectations(timeout: 1) + } + /// Test if language from configuration is picked up /// /// - Given: Predictions service with transcribe behavior. And language is set in config