diff --git a/test/test_enforced_encryption.cpp b/test/test_enforced_encryption.cpp index b7095028b..bac8d147e 100644 --- a/test/test_enforced_encryption.cpp +++ b/test/test_enforced_encryption.cpp @@ -345,6 +345,11 @@ class TestEnforcedEncryption ASSERT_EQ(SetPassword(PEER_CALLER, test.password[PEER_CALLER]), SRT_SUCCESS); ASSERT_EQ(SetPassword(PEER_LISTENER, test.password[PEER_LISTENER]), SRT_SUCCESS); + // Determine the subcase for the KLUDGE (check the behavior of the decryption failure) + const bool case_pw_failure = test.password[PEER_CALLER] != test.password[PEER_LISTENER]; + const bool case_both_relaxed = !test.enforcedenc[PEER_LISTENER] && !test.enforcedenc[PEER_CALLER]; + const bool case_sender_enc = test.password[PEER_CALLER] != ""; + const TResult &expect = test.expected_result; // Start testing @@ -358,6 +363,8 @@ class TestEnforcedEncryption ASSERT_NE(srt_bind(m_listener_socket, psa, sizeof sa), SRT_ERROR); ASSERT_NE(srt_listen(m_listener_socket, 4), SRT_ERROR); + SRTSOCKET accepted_socket = -1; + auto accepting_thread = std::thread([&] { const int epoll_event = WaitOnEpoll(expect); @@ -366,7 +373,6 @@ class TestEnforcedEncryption // otherwise SRT_INVALID_SOCKET after the listening socket is closed. sockaddr_in client_address; int length = sizeof(sockaddr_in); - SRTSOCKET accepted_socket = -1; if (epoll_event == SRT_EPOLL_IN) { accepted_socket = srt_accept(m_listener_socket, (sockaddr*)&client_address, &length); @@ -484,6 +490,53 @@ class TestEnforcedEncryption EXPECT_EQ(srt_getsockstate(m_listener_socket), SRTS_LISTENING); EXPECT_EQ(GetKMState(m_listener_socket), SRT_KM_S_UNSECURED); + if (!is_blocking && case_both_relaxed && case_pw_failure && case_sender_enc) + { + // Additionally check decryption failure does not trigger read-readiness (see issue #2503). + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + EXPECT_FALSE(accepting_thread.joinable()); + + int const epollRead = srt_epoll_create(); + int events = SRT_EPOLL_IN | SRT_EPOLL_ERR; + srt_epoll_add_usock(epollRead, accepted_socket, &events); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + { + int const epollWrite = srt_epoll_create(); + events = SRT_EPOLL_OUT | SRT_EPOLL_ERR; + srt_epoll_add_usock(epollWrite, m_caller_socket, &events); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + SRTSOCKET srtSocket = SRT_INVALID_SOCK; + int socketNum = 1; + const int epoll_res_w = srt_epoll_wait(epollWrite, + nullptr, nullptr, // read + &srtSocket, &socketNum, // write + 500, + nullptr, nullptr, nullptr, nullptr); // R/W system sockets + std::cout << "W: " << epoll_res_w << std::endl; + + char buffer[1316] = {1, 2, 3, 4}; + ASSERT_NE(srt_sendmsg2(m_caller_socket, buffer, sizeof buffer, nullptr), SRT_ERROR); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + + SRTSOCKET srtSocket = SRT_INVALID_SOCK; + int socketNum = 1; + int epoll_res_r = srt_epoll_wait(epollRead, &srtSocket, &socketNum, nullptr, nullptr, 500, nullptr, nullptr, nullptr, nullptr); + std::cout << "R: " << epoll_res_r << std::endl; + EXPECT_LE(epoll_res_r, 0) << "It's wrongly reported, so let's take a look..."; + char buffer[1316] = {}; + EXPECT_EQ(srt_recvmsg2(accepted_socket, buffer, sizeof buffer, nullptr), -1); + + epoll_res_r = srt_epoll_wait(epollRead, &srtSocket, &socketNum, nullptr, nullptr, 500, nullptr, nullptr, nullptr, nullptr); + EXPECT_LE(epoll_res_r, 0) << "Another?!"; + //// ! /KLUDGE ! + + srt_epoll_release(epollRead); + } + if (is_blocking) { // srt_accept() has no timeout, so we have to close the socket and wait for the thread to exit.