Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codechange: encapsulate network error handling (and fix #9142) #9146

Merged
merged 3 commits into from May 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/network/core/CMakeLists.txt
Expand Up @@ -8,6 +8,7 @@ add_files(
game_info.h
host.cpp
host.h
os_abstraction.cpp
os_abstraction.h
packet.cpp
packet.h
Expand Down
24 changes: 12 additions & 12 deletions src/network/core/address.cpp
Expand Up @@ -317,7 +317,7 @@ static SOCKET ConnectLoopProc(addrinfo *runp)

SOCKET sock = socket(runp->ai_family, runp->ai_socktype, runp->ai_protocol);
if (sock == INVALID_SOCKET) {
DEBUG(net, 1, "[%s] could not create %s socket: %s", type, family, NetworkGetLastErrorString());
DEBUG(net, 1, "[%s] could not create %s socket: %s", type, family, NetworkError::GetLast().AsString());
return INVALID_SOCKET;
}

Expand All @@ -326,8 +326,8 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type);

int err = connect(sock, runp->ai_addr, (int)runp->ai_addrlen);
if (err != 0 && NetworkGetLastError() != EINPROGRESS) {
DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address.c_str(), family, NetworkGetLastErrorString());
if (err != 0 && !NetworkError::GetLast().IsConnectInProgress()) {
DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address.c_str(), family, NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
Expand All @@ -343,7 +343,7 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
tv.tv_sec = DEFAULT_CONNECT_TIMEOUT_SECONDS;
int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv);
if (n < 0) {
DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
Expand All @@ -356,9 +356,9 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
}

/* Retrieve last error, if any, on the socket. */
err = GetSocketError(sock);
if (err != 0) {
DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkGetErrorString(err));
NetworkError socket_error = GetSocketError(sock);
if (socket_error.HasError()) {
DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), socket_error.AsString());
closesocket(sock);
return INVALID_SOCKET;
}
Expand Down Expand Up @@ -393,7 +393,7 @@ static SOCKET ListenLoopProc(addrinfo *runp)

SOCKET sock = socket(runp->ai_family, runp->ai_socktype, runp->ai_protocol);
if (sock == INVALID_SOCKET) {
DEBUG(net, 0, "[%s] could not create %s socket on port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 0, "[%s] could not create %s socket on port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
return INVALID_SOCKET;
}

Expand All @@ -404,24 +404,24 @@ static SOCKET ListenLoopProc(addrinfo *runp)
int on = 1;
/* The (const char*) cast is needed for windows!! */
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&on, sizeof(on)) == -1) {
DEBUG(net, 3, "[%s] could not set reusable %s sockets for port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 3, "[%s] could not set reusable %s sockets for port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
}

#ifndef __OS2__
if (runp->ai_family == AF_INET6 &&
setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&on, sizeof(on)) == -1) {
DEBUG(net, 3, "[%s] could not disable IPv4 over IPv6 on port %s: %s", type, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 3, "[%s] could not disable IPv4 over IPv6 on port %s: %s", type, address.c_str(), NetworkError::GetLast().AsString());
}
#endif

if (bind(sock, runp->ai_addr, (int)runp->ai_addrlen) != 0) {
DEBUG(net, 1, "[%s] could not bind on %s port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 1, "[%s] could not bind on %s port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}

if (runp->ai_socktype != SOCK_DGRAM && listen(sock, 1) != 0) {
DEBUG(net, 1, "[%s] could not listen at %s port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
DEBUG(net, 1, "[%s] could not listen at %s port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
Expand Down
18 changes: 0 additions & 18 deletions src/network/core/core.cpp
Expand Up @@ -13,7 +13,6 @@
#include "../../debug.h"
#include "os_abstraction.h"
#include "packet.h"
#include "../../string_func.h"

#include "../../safeguards.h"

Expand Down Expand Up @@ -48,20 +47,3 @@ void NetworkCoreShutdown()
WSACleanup();
#endif
}

#if defined(_WIN32)
/**
* Return the string representation of the given error from the OS's network functions.
* @param error The error number (from \c NetworkGetLastError()).
* @return The error message, potentially an empty string but never \c nullptr.
*/
const char *NetworkGetErrorString(int error)
{
static char buffer[512];
if (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buffer, sizeof(buffer), NULL) == 0) {
seprintf(buffer, lastof(buffer), "Unknown error %d", error);
}
return buffer;
}
#endif /* defined(_WIN32) */
174 changes: 174 additions & 0 deletions src/network/core/os_abstraction.cpp
@@ -0,0 +1,174 @@
/*
* This file is part of OpenTTD.
* OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
* OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
*/

/**
* @file os_abstraction.cpp OS specific implementations of functions of the OS abstraction layer for network stuff.
*
* The general idea is to have simple abstracting functions for things that
* require different implementations for different environments.
* In here the functions, and their documentation, are defined only once
* and the implementation contains the #ifdefs to change the implementation.
* Since Windows is usually different that is usually the first case, after
* that the behaviour is usually Unix/BSD-like with occasional variation.
*/

#include "stdafx.h"
#include "os_abstraction.h"
#include "../../string_func.h"
#include <mutex>

#include "../../safeguards.h"

/**
* Construct the network error with the given error code.
* @param error The error code.
*/
NetworkError::NetworkError(int error) : error(error)
{
}

/**
* Check whether this error describes that the operation would block.
* @return True iff the operation would block.
*/
bool NetworkError::WouldBlock() const
{
#if defined(_WIN32)
return this->error == WSAEWOULDBLOCK;
#else
/* Usually EWOULDBLOCK and EAGAIN are the same, but sometimes they are not
* and the POSIX.1 specification states that either should be checked. */
return this->error == EWOULDBLOCK || this->error == EAGAIN;
#endif
}

/**
* Check whether this error describes a connection reset.
* @return True iff the connection is reset.
*/
bool NetworkError::IsConnectionReset() const
{
#if defined(_WIN32)
return this->error == WSAECONNRESET;
#else
return this->error == ECONNRESET;
#endif
}

/**
* Check whether this error describes a connect is in progress.
* @return True iff the connect is already in progress.
*/
bool NetworkError::IsConnectInProgress() const
{
#if defined(_WIN32)
return this->error == WSAEWOULDBLOCK;
#else
return this->error == EINPROGRESS;
#endif
}

/**
* Get the string representation of the error message.
* @return The string representation that will get overwritten by next calls.
*/
const char *NetworkError::AsString() const
{
if (this->message.empty()) {
#if defined(_WIN32)
char buffer[512];
if (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, this->error,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buffer, sizeof(buffer), NULL) == 0) {
seprintf(buffer, lastof(buffer), "Unknown error %d", this->error);
}
this->message.assign(buffer);
#else
/* Make strerror thread safe by locking access to it. There is a thread safe strerror_r, however
* the non-POSIX variant is available due to defining _GNU_SOURCE meaning it is not portable.
* The problem with the non-POSIX variant is that it does not necessarily fill the buffer with
* the error message but can also return a pointer to a static bit of memory, whereas the POSIX
* variant always fills the buffer. This makes the behaviour too erratic to work with. */
static std::mutex mutex;
std::lock_guard<std::mutex> guard(mutex);
this->message.assign(strerror(this->error));
#endif
}
return this->message.c_str();
}

/**
* Check whether an error was actually set.
* @return True iff an error was set.
*/
bool NetworkError::HasError() const
{
return this->error != 0;
}

/**
* Get the last network error.
* @return The network error.
*/
/* static */ NetworkError NetworkError::GetLast()
{
#if defined(_WIN32)
return NetworkError(WSAGetLastError());
#elif defined(__OS2__)
return NetworkError(sock_errno());
#else
return NetworkError(errno);
#endif
}


/**
* Try to set the socket into non-blocking mode.
* @param d The socket to set the non-blocking more for.
* @return True if setting the non-blocking mode succeeded, otherwise false.
*/
bool SetNonBlocking(SOCKET d)
{
#if defined(_WIN32)
u_long nonblocking = 1;
return ioctlsocket(d, FIONBIO, &nonblocking) == 0;
#elif defined __EMSCRIPTEN__
return true;
#else
int nonblocking = 1;
return ioctl(d, FIONBIO, &nonblocking) == 0;
#endif
}

/**
* Try to set the socket to not delay sending.
* @param d The socket to disable the delaying for.
* @return True if disabling the delaying succeeded, otherwise false.
*/
bool SetNoDelay(SOCKET d)
{
#ifdef __EMSCRIPTEN__
return true;
#else
int flags = 1;
/* The (const char*) cast is needed for windows */
return setsockopt(d, IPPROTO_TCP, TCP_NODELAY, (const char *)&flags, sizeof(flags)) == 0;
#endif
}

/**
* Get the error from a socket, if any.
* @param d The socket to get the error from.
* @return The errno on the socket.
*/
NetworkError GetSocketError(SOCKET d)
{
int err;
socklen_t len = sizeof(err);
getsockopt(d, SOL_SOCKET, SO_ERROR, (char *)&err, &len);

return NetworkError(err);
}