Skip to content

Commit

Permalink
fix(predictions): Ignore unknown events from liveness websocket (#2736)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjroach committed May 14, 2024
1 parent da814b0 commit cf285af
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.nio.ByteBuffer
import java.util.Arrays
import java.util.Date
import java.util.zip.CRC32
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import okio.ByteString
import okio.ByteString.Companion.encodeUtf8
Expand Down Expand Up @@ -151,22 +150,20 @@ internal object LivenessEventStream {
val payloadLength = eventData.size - payloadStartPosition - 4
val payloadString =
eventData.substring(payloadStartPosition, payloadStartPosition + payloadLength).utf8()
val jsonString = when {
return when {
":event-type" in headers.keys -> {
"{\"${headers[":event-type"]}\":$payloadString}"
val jsonString = "{\"${headers[":event-type"]}\":$payloadString}"
json.decodeFromString<LivenessResponseStream.Event>(jsonString)
}
":exception-type" in headers.keys -> {
"{\"${headers[":exception-type"]}\":$payloadString}"
val jsonString = "{\"${headers[":exception-type"]}\":$payloadString}"
json.decodeFromString<LivenessResponseStream.Exception>(jsonString)
}
else -> {
""
LOG.error("Error deserializing liveness response.")
null
}
}
if (jsonString.isEmpty()) {
LOG.error("Error deserializing liveness response.")
return null
}
return json.decodeFromString<LivenessResponseStream>(jsonString)
}

private fun ByteArray.toUInt32(): UInt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,26 @@ internal class LivenessWebSocket(
override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
LOG.debug("WebSocket onMessage bytes")
try {
val livenessResponseStream = LivenessEventStream.decode(bytes, json)
livenessResponseStream?.let { livenessResponse ->
if (livenessResponse.serverSessionInformationEvent != null) {
onSessionInformationReceived.accept(
livenessResponse.serverSessionInformationEvent.sessionInformation
)
} else if (livenessResponse.disconnectionEvent != null) {
this@LivenessWebSocket.webSocket?.close(
NORMAL_SOCKET_CLOSURE_STATUS_CODE,
"Liveness flow completed."
)
} else {
handleWebSocketError(livenessResponse)
when (val response = LivenessEventStream.decode(bytes, json)) {
is LivenessResponseStream.Event -> {
if (response.serverSessionInformationEvent != null) {
onSessionInformationReceived.accept(
response.serverSessionInformationEvent.sessionInformation
)
} else if (response.disconnectionEvent != null) {
this@LivenessWebSocket.webSocket?.close(
NORMAL_SOCKET_CLOSURE_STATUS_CODE,
"Liveness flow completed."
)
} else {
LOG.debug("WebSocket received unknown event-type: message from server")
}
}
is LivenessResponseStream.Exception -> {
handleWebSocketError(response)
}
else -> {
LOG.debug("WebSocket unable to decode message from server")
}
}
} catch (e: Exception) {
Expand Down Expand Up @@ -284,7 +291,7 @@ internal class LivenessWebSocket(
)
}

private fun handleWebSocketError(livenessResponse: LivenessResponseStream) {
private fun handleWebSocketError(livenessResponse: LivenessResponseStream.Exception) {
webSocketError = if (livenessResponse.validationException != null) {
PredictionsException(
"An error occurred during the face liveness flow.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
internal data class LivenessResponseStream(
@SerialName("ServerSessionInformationEvent") val serverSessionInformationEvent:
ServerSessionInformationEvent? = null,
@SerialName("DisconnectionEvent") val disconnectionEvent: DisconnectionEvent? = null,
@SerialName("ValidationException") val validationException: ValidationException? = null,
@SerialName("InternalServerException") val internalServerException: InternalServerException? = null,
@SerialName("ThrottlingException") val throttlingException: ThrottlingException? = null,
@SerialName("ServiceQuotaExceededException") val serviceQuotaExceededException:
ServiceQuotaExceededException? = null,
@SerialName("ServiceUnavailableException") val serviceUnavailableException: ServiceUnavailableException? = null,
@SerialName("SessionNotFoundException") val sessionNotFoundException: SessionNotFoundException? = null,
@SerialName("AccessDeniedException") val accessDeniedException: AccessDeniedException? = null,
@SerialName("InvalidSignatureException") val invalidSignatureException: InvalidSignatureException? = null,
@SerialName("UnrecognizedClientException") val unrecognizedClientException: UnrecognizedClientException? = null
)
internal sealed class LivenessResponseStream {
@Serializable
internal data class Event(
@SerialName("ServerSessionInformationEvent") val serverSessionInformationEvent:
ServerSessionInformationEvent? = null,
@SerialName("DisconnectionEvent") val disconnectionEvent: DisconnectionEvent? = null
) : LivenessResponseStream()

@Serializable
internal data class Exception(
@SerialName("ValidationException") val validationException: ValidationException? = null,
@SerialName("InternalServerException") val internalServerException: InternalServerException? = null,
@SerialName("ThrottlingException") val throttlingException: ThrottlingException? = null,
@SerialName("ServiceQuotaExceededException") val serviceQuotaExceededException:
ServiceQuotaExceededException? = null,
@SerialName("ServiceUnavailableException") val serviceUnavailableException: ServiceUnavailableException? = null,
@SerialName("SessionNotFoundException") val sessionNotFoundException: SessionNotFoundException? = null,
@SerialName("AccessDeniedException") val accessDeniedException: AccessDeniedException? = null,
@SerialName("InvalidSignatureException") val invalidSignatureException: InvalidSignatureException? = null,
@SerialName("UnrecognizedClientException") val unrecognizedClientException: UnrecognizedClientException? = null
) : LivenessResponseStream()
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ import org.junit.Test

internal class LivenessEventStreamTest {

private val json = Json { encodeDefaults = true }
private val json = Json {
encodeDefaults = true
ignoreUnknownKeys = true
}

@Test
fun `test basic model with string header`() {
Expand Down Expand Up @@ -128,7 +131,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(serverSessionInformationEvent = event)
val expectedResponse = LivenessResponseStream.Event(serverSessionInformationEvent = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -140,11 +143,11 @@ internal class LivenessEventStreamTest {
fun `test decoding DisconnectionEvent`() {
val event = DisconnectionEvent(timestampMillis = System.currentTimeMillis())
val headers = mapOf(
":exception-type" to "DisconnectionEvent",
":event-type" to "DisconnectionEvent",
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(disconnectionEvent = event)
val expectedResponse = LivenessResponseStream.Event(disconnectionEvent = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -161,7 +164,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(validationException = event)
val expectedResponse = LivenessResponseStream.Exception(validationException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -178,7 +181,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(internalServerException = event)
val expectedResponse = LivenessResponseStream.Exception(internalServerException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -195,7 +198,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(throttlingException = event)
val expectedResponse = LivenessResponseStream.Exception(throttlingException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -212,7 +215,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(serviceQuotaExceededException = event)
val expectedResponse = LivenessResponseStream.Exception(serviceQuotaExceededException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -229,7 +232,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(serviceUnavailableException = event)
val expectedResponse = LivenessResponseStream.Exception(serviceUnavailableException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -246,7 +249,7 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(sessionNotFoundException = event)
val expectedResponse = LivenessResponseStream.Exception(sessionNotFoundException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand All @@ -263,7 +266,41 @@ internal class LivenessEventStreamTest {
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream(accessDeniedException = event)
val expectedResponse = LivenessResponseStream.Exception(accessDeniedException = event)

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
val decoded = LivenessEventStream.decode(encoded.array().toByteString(), json)

assertEquals(expectedResponse, decoded)
}

@Test
fun `test decoding unknown event`() {
val event = InternalServerException("error")
val headers = mapOf(
":event-type" to "Unknown",
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream.Event() // empty response

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
val decoded = LivenessEventStream.decode(encoded.array().toByteString(), json)

assertEquals(expectedResponse, decoded)
}

@Test
fun `test decoding unknown exception`() {
val event = InternalServerException("error")
val headers = mapOf(
":exception-type" to "UnknownException",
":content-type" to "application/json",
":message-type" to "event"
)
val expectedResponse = LivenessResponseStream.Exception() // empty response

val data = json.encodeToString(event)
val encoded = LivenessEventStream.encode(data.toByteArray(), headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.setMain
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import mockwebserver3.MockResponse
Expand Down Expand Up @@ -248,6 +249,50 @@ internal class LivenessWebSocketTest {
verify { onSessionInformationReceived.accept(event.sessionInformation) }
}

@Test
fun `unknown event-type ignored`() {
val webSocket = mockk<WebSocket>(relaxed = true)
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocket = webSocket
val event = UnknownEvent()
val headers = mapOf(
":event-type" to "UnknownEvent",
":content-type" to "application/json",
":message-type" to "event"
)

val data = json.encodeToString(event)
val encodedByteString = LivenessEventStream.encode(data.toByteArray(), headers).array().toByteString()

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify(exactly = 0) { onSessionInformationReceived.accept(any()) }
verify(exactly = 0) { onErrorReceived.accept(any()) }
verify(exactly = 0) { webSocket.close(any(), any()) }
}

@Test
fun `unknown exception-type closes websocket`() {
val webSocket = mockk<WebSocket>(relaxed = true)
val livenessWebSocket = createLivenessWebSocket()
livenessWebSocket.webSocket = webSocket
val event = UnknownEvent()
val headers = mapOf(
":exception-type" to "UnknownException",
":content-type" to "application/json",
":message-type" to "event"
)

val data = json.encodeToString(event)
val encodedByteString = LivenessEventStream.encode(data.toByteArray(), headers).array().toByteString()

livenessWebSocket.webSocketListener.onMessage(mockk(), encodedByteString)

verify(exactly = 0) { onSessionInformationReceived.accept(any()) }
verify(exactly = 0) { onErrorReceived.accept(any()) }
verify(exactly = 1) { webSocket.close(any(), any()) }
}

@Test
fun `disconnect event stops websocket`() {
val livenessWebSocket = createLivenessWebSocket()
Expand Down Expand Up @@ -450,6 +495,9 @@ class LatchingWebSocketResponseListener(
}
}

@Serializable
internal data class UnknownEvent(val name: String = "")

class ServerWebSocketListener : WebSocketListener() {
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {}
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {}
Expand Down

0 comments on commit cf285af

Please sign in to comment.