Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,70 @@ class MediaEndpointTest {
restoreTestServer()
}

@Test
fun testCreateMediaRequestWithProgressReporting() = runTest {
val progressUpdates = mutableListOf<ProgressUpdate>()
var uploadStarted = false

val uploadListener = object : WpRequestExecutor.UploadListener {
override fun onProgressUpdate(uploadedBytes: Long, totalBytes: Long) {
progressUpdates.add(ProgressUpdate(uploadedBytes, totalBytes))
}

override fun onUploadStarted(cancellableUpload: WpRequestExecutor.CancellableUpload) {
uploadStarted = true
}
}

val authProvider = WpAuthenticationProvider.staticWithUsernameAndPassword(
username = TestCredentials.INSTANCE.adminUsername,
password = TestCredentials.INSTANCE.adminPassword
)
val requestExecutor = WpRequestExecutor(
fileResolver = FileResolverMock(),
uploadListener = uploadListener
)
val clientWithProgress = WpApiClient(
wpOrgSiteApiRootUrl = TestCredentials.INSTANCE.apiRootUrl,
authProvider = authProvider,
requestExecutor = requestExecutor
)

val title = "Testing media upload with progress from Kotlin"
val response = clientWithProgress.request { requestBuilder ->
requestBuilder.media().create(
params = MediaCreateParams(title = title, filePath = "test_media.jpg")
)
}.assertSuccessAndRetrieveData().data

// Verify upload was successful
assertEquals(title, response.title.rendered)

// Verify progress reporting worked
assert(uploadStarted) { "Upload should have started" }
assert(progressUpdates.isNotEmpty()) { "Should have received progress updates" }

// Verify final progress shows completion
val finalProgress = progressUpdates.last()
assertEquals(
finalProgress.uploadedBytes,
finalProgress.totalBytes,
"Final progress should show upload complete"
)

// Verify progress never decreases. Note: The /media endpoint only supports
// single files, so this validates basic progress but not multi-file scenarios.
var previousBytes = 0L
progressUpdates.forEach { update ->
assert(update.uploadedBytes >= previousBytes) {
"Progress decreased from $previousBytes to ${update.uploadedBytes}"
}
previousBytes = update.uploadedBytes
}

restoreTestServer()
}

fun mediaApiClient(): WpApiClient {
val testCredentials = TestCredentials.INSTANCE
val authProvider = WpAuthenticationProvider.staticWithUsernameAndPassword(
Expand All @@ -89,6 +153,8 @@ class MediaEndpointTest {
)
}

data class ProgressUpdate(val uploadedBytes: Long, val totalBytes: Long)

class FileResolverMock: FileResolver {
// in order to properly resolve the file from the test assets, we need to do it in the following way
override fun getFile(path: String): File? =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import okhttp3.MediaType.Companion.toMediaType
import okhttp3.MultipartBody
import okhttp3.OkHttp
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.RequestBody.Companion.asRequestBody
import okhttp3.RequestBody.Companion.toRequestBody
import uniffi.wp_api.InvalidSslErrorReason
Expand Down Expand Up @@ -51,101 +50,113 @@ class WpRequestExecutor(
wpNetworkRequestBody
}
)
request.headerMap().toMap().forEach { (key, values) ->
values.forEach { value ->
requestBuilder.addHeader(key, value)
}
}
requestBuilder.addHeader(
USER_AGENT_HEADER_NAME,
uniffi.wp_api.defaultUserAgent("kotlin-okhttp/${OkHttp.VERSION}")
)

addRequestHeaders(requestBuilder, request.headerMap())
val urlRequest = requestBuilder.build()

try {
httpClient.getClient().newCall(urlRequest).execute().use { response ->
return@withContext WpNetworkResponse(
body = response.body?.bytes() ?: ByteArray(0),
statusCode = response.code.toUShort(),
responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()),
requestUrl = request.url(),
requestHeaderMap = request.headerMap()
)
}
} catch (e: SSLPeerUnverifiedException) {
throw requestExecutionFailedWith(
RequestExecutionErrorReason.invalidSSLError(e, urlRequest.url)
)
} catch (e: UnknownHostException) {
throw requestExecutionFailedWith(RequestExecutionErrorReason.unknownHost(e))
} catch (e: NoRouteToHostException) {
throw requestExecutionFailedWith(RequestExecutionErrorReason.noRouteToHost(e))
}
executeRequestSafely(urlRequest, request.url(), request.headerMap())
}

override suspend fun upload(request: WpMultipartFormRequest): WpNetworkResponse =
withContext(dispatcher) {
val multipartBody = buildMultipartBody(request)
val bodyWithProgress = wrapWithProgressTracking(multipartBody)
val requestBuilder = Request.Builder().url(request.url())
val multipartBodyBuilder = MultipartBody.Builder()
.setType(MultipartBody.FORM)
request.fields().forEach { (k, v) ->
multipartBodyBuilder.addFormDataPart(k, v)
}
request.files().forEach { (name, fileInfo) ->
val file = fileResolver.getFile(fileInfo.filePath)
if (file == null || !file.canBeUploaded()) {
throw RequestExecutionException.MediaFileNotFound(filePath = fileInfo.filePath)
}
val mimeType = fileInfo.mimeType ?: "application/octet-stream"
val requestBody = getRequestBody(file, mimeType, uploadListener)
val filename = fileInfo.fileName ?: file.name
multipartBodyBuilder.addFormDataPart(
name = name,
filename = filename,
body = requestBody
)
}
requestBuilder.method(
method = request.method().toString(),
body = multipartBodyBuilder.build()
)
request.headerMap().toMap().forEach { (key, values) ->
values.forEach { value ->
requestBuilder.addHeader(key, value)
}
}
requestBuilder.method(request.method().toString(), bodyWithProgress)

val call = httpClient.getClient().newCall(requestBuilder.build())
uploadListener?.onUploadStarted(CancellableCall(call))
call.execute().use { response ->
return@withContext WpNetworkResponse(
body = response.body?.bytes() ?: ByteArray(0),
statusCode = response.code.toUShort(),
responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()),
requestUrl = request.url(),
requestHeaderMap = request.headerMap()
)
addRequestHeaders(requestBuilder, request.headerMap())
val urlRequest = requestBuilder.build()

executeRequestSafely(urlRequest, request.url(), request.headerMap(), notifyUploadListener = true)
}

private fun buildMultipartBody(request: WpMultipartFormRequest): MultipartBody {
val multipartBodyBuilder = MultipartBody.Builder().setType(MultipartBody.FORM)

request.fields().forEach { (k, v) ->
multipartBodyBuilder.addFormDataPart(k, v)
}

request.files().forEach { (name, fileInfo) ->
val file = fileResolver.getFile(fileInfo.filePath)
if (file == null || !file.canBeUploaded()) {
throw RequestExecutionException.MediaFileNotFound(filePath = fileInfo.filePath)
}
val mimeType = fileInfo.mimeType ?: "application/octet-stream"
val filename = fileInfo.fileName ?: file.name
val requestBody = file.asRequestBody(mimeType.toMediaType())
multipartBodyBuilder.addFormDataPart(
name = name,
filename = filename,
body = requestBody
)
}

private fun getRequestBody(
file: File,
mimeType: String,
uploadListener: UploadListener?
): RequestBody {
val fileRequestBody = file.asRequestBody(mimeType.toMediaType())
return multipartBodyBuilder.build()
}

private fun wrapWithProgressTracking(multipartBody: MultipartBody): okhttp3.RequestBody {
// Wrap the entire multipart body for progress tracking
// This ensures progress is cumulative across all files, not per-file
return if (uploadListener != null) {
ProgressRequestBody(
delegate = fileRequestBody,
delegate = multipartBody,
progressListener = object : ProgressRequestBody.ProgressListener {
override fun onProgress(bytesWritten: Long, contentLength: Long) {
uploadListener.onProgressUpdate(bytesWritten, contentLength)
}
}
)
} else {
fileRequestBody
multipartBody
}
}

private fun addRequestHeaders(requestBuilder: Request.Builder, headerMap: WpNetworkHeaderMap) {
headerMap.toMap().forEach { (key, values) ->
values.forEach { value ->
requestBuilder.addHeader(key, value)
}
}
// Use header() instead of addHeader() to ensure User-Agent cannot be overridden
requestBuilder.header(
USER_AGENT_HEADER_NAME,
uniffi.wp_api.defaultUserAgent("kotlin-okhttp/${OkHttp.VERSION}")
)
}

@Suppress("ThrowsCount")
private fun executeRequestSafely(
urlRequest: Request,
requestUrl: String,
requestHeaderMap: WpNetworkHeaderMap,
notifyUploadListener: Boolean = false
): WpNetworkResponse {
try {
val call = httpClient.getClient().newCall(urlRequest)

// Notify upload listener if this is an upload request
if (notifyUploadListener) {
uploadListener?.onUploadStarted(CancellableCall(call))
}

return call.execute().use { response ->
WpNetworkResponse(
body = response.body?.bytes() ?: ByteArray(0),
statusCode = response.code.toUShort(),
responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()),
requestUrl = requestUrl,
requestHeaderMap = requestHeaderMap
)
}
} catch (e: SSLPeerUnverifiedException) {
throw requestExecutionFailedWith(
RequestExecutionErrorReason.invalidSSLError(e, urlRequest.url)
)
} catch (e: UnknownHostException) {
throw requestExecutionFailedWith(RequestExecutionErrorReason.unknownHost(e))
} catch (e: NoRouteToHostException) {
throw requestExecutionFailedWith(RequestExecutionErrorReason.noRouteToHost(e))
}
}

Expand Down Expand Up @@ -214,7 +225,7 @@ private fun RequestExecutionErrorReason.Companion.noRouteToHost(e: NoRouteToHost
reason = e.localizedMessage
)

@Suppress("UNUSED_PARAMETER")
@Suppress("UNUSED_PARAMETER", "TooGenericExceptionCaught", "SwallowedException")
private fun RequestExecutionErrorReason.Companion.invalidSSLError(
e: SSLPeerUnverifiedException, // To avoid `SwallowedException` from Detekt
requestUrl: HttpUrl
Expand All @@ -225,18 +236,35 @@ private fun RequestExecutionErrorReason.Companion.invalidSSLError(
//
// We spin up a new connection that'll accept any certificate. The connection will then
// contain all the details we need for the error.
val newConnection = requestUrl.toUrl().openConnection() as HttpsURLConnection
newConnection.setHostnameVerifier { _, _ -> return@setHostnameVerifier true }
newConnection.connect()

// Certificate is parsed by the Rust shared implementation.
val certificates = newConnection.serverCertificates.map { parseCertificate(it.encoded) }
return RequestExecutionErrorReason.InvalidSslError(
reason = InvalidSslErrorReason.CertificateNotValidForName(
hostname = requestUrl.host,
presentedHostnames = listOfNotNull(certificates.first()?.commonName())
return try {
val newConnection = requestUrl.toUrl().openConnection() as HttpsURLConnection
newConnection.setHostnameVerifier { _, _ -> true }
newConnection.connect()

try {
// Certificate is parsed by the Rust shared implementation.
val certificates = newConnection.serverCertificates.map { parseCertificate(it.encoded) }
RequestExecutionErrorReason.InvalidSslError(
reason = InvalidSslErrorReason.CertificateNotValidForName(
hostname = requestUrl.host,
presentedHostnames = listOfNotNull(certificates.first()?.commonName())
)
)
} finally {
newConnection.disconnect()
}
} catch (ex: Exception) {
// Fallback if certificate inspection fails due to network issues, cast failures, etc.
// We intentionally catch Exception here as we want to return a valid error response
// even if certificate inspection fails. The original SSL error (e parameter) is
// preserved in the calling context. This is a best-effort attempt to get cert details.
RequestExecutionErrorReason.InvalidSslError(
reason = InvalidSslErrorReason.CertificateNotValidForName(
hostname = requestUrl.host,
presentedHostnames = emptyList()
)
)
)
}
}

private fun requestExecutionFailedWith(reason: RequestExecutionErrorReason) =
Expand Down