diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 16a51b4b..4f099789 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -301,7 +301,6 @@ AsyncWebSocketClient::~AsyncWebSocketClient() { _messageQueue.clear(); _controlQueue.clear(); } - _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); } void AsyncWebSocketClient::_clearQueue() { @@ -358,7 +357,6 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) { void AsyncWebSocketClient::_onPoll() { asyncsrv::unique_lock_type lock(_lock); - if (!_client) { return; } @@ -446,7 +444,6 @@ bool AsyncWebSocketClient::canSend() const { bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) { asyncsrv::lock_guard_type lock(_lock); - if (!_client) { return false; } @@ -463,7 +460,6 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) { asyncsrv::unique_lock_type lock(_lock); - if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) { return false; } @@ -502,14 +498,16 @@ bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint } void AsyncWebSocketClient::close(uint16_t code, const char *message) { - if (_status != WS_CONNECTED) { - return; + { + asyncsrv::lock_guard_type lock(_lock); + if (_status != WS_CONNECTED) { + return; + } + _status = WS_DISCONNECTING; } async_ws_log_w("[%s][%" PRIu32 "] CLOSE", _server->url(), _clientId); - _status = WS_DISCONNECTING; - if (code) { uint8_t packetLen = 2; if (message != NULL) { @@ -538,6 +536,7 @@ void AsyncWebSocketClient::close(uint16_t code, const char *message) { } bool AsyncWebSocketClient::ping(const uint8_t *data, size_t len) { + asyncsrv::lock_guard_type lock(_lock); return _status == WS_CONNECTED && _queueControl(WS_PING, data, len); } @@ -567,9 +566,10 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) { uint8_t *data = (uint8_t *)pbuf; while (plen > 0) { + const AwsClientStatus client_status = status(); async_ws_log_v( "[%s][%" PRIu32 "] DATA plen: %" PRIu32 ", _pstate: %" PRIu8 ", _status: %" PRIu8, _server->url(), _clientId, static_cast(plen), _pstate, - static_cast(_status) + static_cast(client_status) ); if (_pstate == STATE_FRAME_START) { @@ -688,10 +688,13 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) { _server->_handleEvent(this, WS_EVT_ERROR, (void *)&reasonCode, (uint8_t *)reasonString, strlen(reasonString)); } } + asyncsrv::unique_lock_type lock(_lock); if (_status == WS_DISCONNECTING) { _status = WS_DISCONNECTED; if (_client) { - _client->close(); + auto *client = _client; + lock.unlock(); + client->close(); } } else { _status = WS_DISCONNECTING; @@ -735,9 +738,12 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) { "[%s][%" PRIu32 "] DATA frame error: len: %u, index: %" PRIu64 ", total: %" PRIu64 "\n", _server->url(), _clientId, datalen, _pinfo.index, _pinfo.len ); - _status = WS_DISCONNECTING; - if (_client) { - _client->ackLater(); + { + asyncsrv::lock_guard_type lock(_lock); + _status = WS_DISCONNECTING; + if (_client) { + _client->ackLater(); + } } _queueControl(WS_DISCONNECT, data, datalen); break; @@ -952,7 +958,6 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) { IPAddress AsyncWebSocketClient::remoteIP() const { asyncsrv::lock_guard_type lock(_lock); - if (!_client) { return IPAddress((uint32_t)0U); } @@ -962,7 +967,6 @@ IPAddress AsyncWebSocketClient::remoteIP() const { uint16_t AsyncWebSocketClient::remotePort() const { asyncsrv::lock_guard_type lock(_lock); - if (!_client) { return 0; } @@ -991,14 +995,12 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) } void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) { - asyncsrv::lock_guard_type lock(_lock); - const auto client_id = client->id(); - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) { - return c.id() == client_id; - }); - if (iter != std::end(_clients)) { - _clients.erase(iter); - } + // Defer removal to cleanupClients(). Disconnect callbacks can fire while + // iterating _clients for broadcast sends, and erasing here invalidates the + // active iterator in the caller. However, emit the disconnect event now so + // applications observe the disconnect at the time it happens even though the + // client object remains in _clients until cleanup. + _handleEvent(client, WS_EVT_DISCONNECT, NULL, NULL, 0); } bool AsyncWebSocket::availableForWriteAll() { @@ -1055,17 +1057,30 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) { } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - asyncsrv::lock_guard_type lock(_lock); - const size_t c = count(); - if (c > maxClients) { - async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients); - _clients.front().close(); - } + std::list removed_clients; + { + asyncsrv::lock_guard_type lock(_lock); + const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + return c.status() == WS_CONNECTED; + }); + + if (connected > maxClients) { + const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + return c.status() == WS_CONNECTED; + }); + if (connected_iter != std::end(_clients)) { + async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients); + connected_iter->close(); + } + } - for (auto i = _clients.begin(); i != _clients.end(); ++i) { - if (i->shouldBeDeleted()) { - _clients.erase(i); - break; + for (auto iter = _clients.begin(); iter != _clients.end();) { + if (iter->shouldBeDeleted()) { + auto current = iter++; + removed_clients.splice(removed_clients.end(), _clients, current); + } else { + ++iter; + } } } } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index afed25f4..1740a3ec 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -256,6 +256,7 @@ class AsyncWebSocketClient { return _clientId; } AwsClientStatus status() const { + asyncsrv::lock_guard_type lock(_lock); return _status; } AsyncClient *client() {