diff --git a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt index e9acf0062..adf9eb06d 100644 --- a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt +++ b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt @@ -74,6 +74,70 @@ class MediaEndpointTest { restoreTestServer() } + @Test + fun testCreateMediaRequestWithProgressReporting() = runTest { + val progressUpdates = mutableListOf() + var uploadStarted = false + + val uploadListener = object : WpRequestExecutor.UploadListener { + override fun onProgressUpdate(uploadedBytes: Long, totalBytes: Long) { + progressUpdates.add(ProgressUpdate(uploadedBytes, totalBytes)) + } + + override fun onUploadStarted(cancellableUpload: WpRequestExecutor.CancellableUpload) { + uploadStarted = true + } + } + + val authProvider = WpAuthenticationProvider.staticWithUsernameAndPassword( + username = TestCredentials.INSTANCE.adminUsername, + password = TestCredentials.INSTANCE.adminPassword + ) + val requestExecutor = WpRequestExecutor( + fileResolver = FileResolverMock(), + uploadListener = uploadListener + ) + val clientWithProgress = WpApiClient( + wpOrgSiteApiRootUrl = TestCredentials.INSTANCE.apiRootUrl, + authProvider = authProvider, + requestExecutor = requestExecutor + ) + + val title = "Testing media upload with progress from Kotlin" + val response = clientWithProgress.request { requestBuilder -> + requestBuilder.media().create( + params = MediaCreateParams(title = title, filePath = "test_media.jpg") + ) + }.assertSuccessAndRetrieveData().data + + // Verify upload was successful + assertEquals(title, response.title.rendered) + + // Verify progress reporting worked + assert(uploadStarted) { "Upload should have started" } + assert(progressUpdates.isNotEmpty()) { "Should have received progress updates" } + + // Verify final progress shows completion + val finalProgress = progressUpdates.last() + assertEquals( + finalProgress.uploadedBytes, + finalProgress.totalBytes, + "Final progress should show upload complete" + ) + + // Verify progress never decreases. Note: The /media endpoint only supports + // single files, so this validates basic progress but not multi-file scenarios. + var previousBytes = 0L + progressUpdates.forEach { update -> + assert(update.uploadedBytes >= previousBytes) { + "Progress decreased from $previousBytes to ${update.uploadedBytes}" + } + previousBytes = update.uploadedBytes + } + + restoreTestServer() + } + fun mediaApiClient(): WpApiClient { val testCredentials = TestCredentials.INSTANCE val authProvider = WpAuthenticationProvider.staticWithUsernameAndPassword( @@ -89,6 +153,8 @@ class MediaEndpointTest { ) } + data class ProgressUpdate(val uploadedBytes: Long, val totalBytes: Long) + class FileResolverMock: FileResolver { // in order to properly resolve the file from the test assets, we need to do it in the following way override fun getFile(path: String): File? = diff --git a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt index 6b52971e0..4e4d748a0 100644 --- a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt +++ b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt @@ -10,7 +10,6 @@ import okhttp3.MediaType.Companion.toMediaType import okhttp3.MultipartBody import okhttp3.OkHttp import okhttp3.Request -import okhttp3.RequestBody import okhttp3.RequestBody.Companion.asRequestBody import okhttp3.RequestBody.Companion.toRequestBody import uniffi.wp_api.InvalidSslErrorReason @@ -51,93 +50,57 @@ class WpRequestExecutor( wpNetworkRequestBody } ) - request.headerMap().toMap().forEach { (key, values) -> - values.forEach { value -> - requestBuilder.addHeader(key, value) - } - } - requestBuilder.addHeader( - USER_AGENT_HEADER_NAME, - uniffi.wp_api.defaultUserAgent("kotlin-okhttp/${OkHttp.VERSION}") - ) + addRequestHeaders(requestBuilder, request.headerMap()) val urlRequest = requestBuilder.build() - try { - httpClient.getClient().newCall(urlRequest).execute().use { response -> - return@withContext WpNetworkResponse( - body = response.body?.bytes() ?: ByteArray(0), - statusCode = response.code.toUShort(), - responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()), - requestUrl = request.url(), - requestHeaderMap = request.headerMap() - ) - } - } catch (e: SSLPeerUnverifiedException) { - throw requestExecutionFailedWith( - RequestExecutionErrorReason.invalidSSLError(e, urlRequest.url) - ) - } catch (e: UnknownHostException) { - throw requestExecutionFailedWith(RequestExecutionErrorReason.unknownHost(e)) - } catch (e: NoRouteToHostException) { - throw requestExecutionFailedWith(RequestExecutionErrorReason.noRouteToHost(e)) - } + executeRequestSafely(urlRequest, request.url(), request.headerMap()) } override suspend fun upload(request: WpMultipartFormRequest): WpNetworkResponse = withContext(dispatcher) { + val multipartBody = buildMultipartBody(request) + val bodyWithProgress = wrapWithProgressTracking(multipartBody) val requestBuilder = Request.Builder().url(request.url()) - val multipartBodyBuilder = MultipartBody.Builder() - .setType(MultipartBody.FORM) - request.fields().forEach { (k, v) -> - multipartBodyBuilder.addFormDataPart(k, v) - } - request.files().forEach { (name, fileInfo) -> - val file = fileResolver.getFile(fileInfo.filePath) - if (file == null || !file.canBeUploaded()) { - throw RequestExecutionException.MediaFileNotFound(filePath = fileInfo.filePath) - } - val mimeType = fileInfo.mimeType ?: "application/octet-stream" - val requestBody = getRequestBody(file, mimeType, uploadListener) - val filename = fileInfo.fileName ?: file.name - multipartBodyBuilder.addFormDataPart( - name = name, - filename = filename, - body = requestBody - ) - } - requestBuilder.method( - method = request.method().toString(), - body = multipartBodyBuilder.build() - ) - request.headerMap().toMap().forEach { (key, values) -> - values.forEach { value -> - requestBuilder.addHeader(key, value) - } - } + requestBuilder.method(request.method().toString(), bodyWithProgress) - val call = httpClient.getClient().newCall(requestBuilder.build()) - uploadListener?.onUploadStarted(CancellableCall(call)) - call.execute().use { response -> - return@withContext WpNetworkResponse( - body = response.body?.bytes() ?: ByteArray(0), - statusCode = response.code.toUShort(), - responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()), - requestUrl = request.url(), - requestHeaderMap = request.headerMap() - ) + addRequestHeaders(requestBuilder, request.headerMap()) + val urlRequest = requestBuilder.build() + + executeRequestSafely(urlRequest, request.url(), request.headerMap(), notifyUploadListener = true) + } + + private fun buildMultipartBody(request: WpMultipartFormRequest): MultipartBody { + val multipartBodyBuilder = MultipartBody.Builder().setType(MultipartBody.FORM) + + request.fields().forEach { (k, v) -> + multipartBodyBuilder.addFormDataPart(k, v) + } + + request.files().forEach { (name, fileInfo) -> + val file = fileResolver.getFile(fileInfo.filePath) + if (file == null || !file.canBeUploaded()) { + throw RequestExecutionException.MediaFileNotFound(filePath = fileInfo.filePath) } + val mimeType = fileInfo.mimeType ?: "application/octet-stream" + val filename = fileInfo.fileName ?: file.name + val requestBody = file.asRequestBody(mimeType.toMediaType()) + multipartBodyBuilder.addFormDataPart( + name = name, + filename = filename, + body = requestBody + ) } - private fun getRequestBody( - file: File, - mimeType: String, - uploadListener: UploadListener? - ): RequestBody { - val fileRequestBody = file.asRequestBody(mimeType.toMediaType()) + return multipartBodyBuilder.build() + } + + private fun wrapWithProgressTracking(multipartBody: MultipartBody): okhttp3.RequestBody { + // Wrap the entire multipart body for progress tracking + // This ensures progress is cumulative across all files, not per-file return if (uploadListener != null) { ProgressRequestBody( - delegate = fileRequestBody, + delegate = multipartBody, progressListener = object : ProgressRequestBody.ProgressListener { override fun onProgress(bytesWritten: Long, contentLength: Long) { uploadListener.onProgressUpdate(bytesWritten, contentLength) @@ -145,7 +108,55 @@ class WpRequestExecutor( } ) } else { - fileRequestBody + multipartBody + } + } + + private fun addRequestHeaders(requestBuilder: Request.Builder, headerMap: WpNetworkHeaderMap) { + headerMap.toMap().forEach { (key, values) -> + values.forEach { value -> + requestBuilder.addHeader(key, value) + } + } + // Use header() instead of addHeader() to ensure User-Agent cannot be overridden + requestBuilder.header( + USER_AGENT_HEADER_NAME, + uniffi.wp_api.defaultUserAgent("kotlin-okhttp/${OkHttp.VERSION}") + ) + } + + @Suppress("ThrowsCount") + private fun executeRequestSafely( + urlRequest: Request, + requestUrl: String, + requestHeaderMap: WpNetworkHeaderMap, + notifyUploadListener: Boolean = false + ): WpNetworkResponse { + try { + val call = httpClient.getClient().newCall(urlRequest) + + // Notify upload listener if this is an upload request + if (notifyUploadListener) { + uploadListener?.onUploadStarted(CancellableCall(call)) + } + + return call.execute().use { response -> + WpNetworkResponse( + body = response.body?.bytes() ?: ByteArray(0), + statusCode = response.code.toUShort(), + responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()), + requestUrl = requestUrl, + requestHeaderMap = requestHeaderMap + ) + } + } catch (e: SSLPeerUnverifiedException) { + throw requestExecutionFailedWith( + RequestExecutionErrorReason.invalidSSLError(e, urlRequest.url) + ) + } catch (e: UnknownHostException) { + throw requestExecutionFailedWith(RequestExecutionErrorReason.unknownHost(e)) + } catch (e: NoRouteToHostException) { + throw requestExecutionFailedWith(RequestExecutionErrorReason.noRouteToHost(e)) } } @@ -214,7 +225,7 @@ private fun RequestExecutionErrorReason.Companion.noRouteToHost(e: NoRouteToHost reason = e.localizedMessage ) -@Suppress("UNUSED_PARAMETER") +@Suppress("UNUSED_PARAMETER", "TooGenericExceptionCaught", "SwallowedException") private fun RequestExecutionErrorReason.Companion.invalidSSLError( e: SSLPeerUnverifiedException, // To avoid `SwallowedException` from Detekt requestUrl: HttpUrl @@ -225,18 +236,35 @@ private fun RequestExecutionErrorReason.Companion.invalidSSLError( // // We spin up a new connection that'll accept any certificate. The connection will then // contain all the details we need for the error. - val newConnection = requestUrl.toUrl().openConnection() as HttpsURLConnection - newConnection.setHostnameVerifier { _, _ -> return@setHostnameVerifier true } - newConnection.connect() - - // Certificate is parsed by the Rust shared implementation. - val certificates = newConnection.serverCertificates.map { parseCertificate(it.encoded) } - return RequestExecutionErrorReason.InvalidSslError( - reason = InvalidSslErrorReason.CertificateNotValidForName( - hostname = requestUrl.host, - presentedHostnames = listOfNotNull(certificates.first()?.commonName()) + return try { + val newConnection = requestUrl.toUrl().openConnection() as HttpsURLConnection + newConnection.setHostnameVerifier { _, _ -> true } + newConnection.connect() + + try { + // Certificate is parsed by the Rust shared implementation. + val certificates = newConnection.serverCertificates.map { parseCertificate(it.encoded) } + RequestExecutionErrorReason.InvalidSslError( + reason = InvalidSslErrorReason.CertificateNotValidForName( + hostname = requestUrl.host, + presentedHostnames = listOfNotNull(certificates.first()?.commonName()) + ) + ) + } finally { + newConnection.disconnect() + } + } catch (ex: Exception) { + // Fallback if certificate inspection fails due to network issues, cast failures, etc. + // We intentionally catch Exception here as we want to return a valid error response + // even if certificate inspection fails. The original SSL error (e parameter) is + // preserved in the calling context. This is a best-effort attempt to get cert details. + RequestExecutionErrorReason.InvalidSslError( + reason = InvalidSslErrorReason.CertificateNotValidForName( + hostname = requestUrl.host, + presentedHostnames = emptyList() + ) ) - ) + } } private fun requestExecutionFailedWith(reason: RequestExecutionErrorReason) =