Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted C-string -> std::string conversions in network #9371

Merged
merged 9 commits into from Jun 15, 2021
2 changes: 1 addition & 1 deletion src/game/game_text.cpp
Expand Up @@ -231,7 +231,7 @@ GameStrings *LoadTranslations()
basename.erase(e + 1);

std::string filename = basename + "lang" PATHSEP "english.txt";
if (!FioCheckFileExists(filename.c_str() , GAME_DIR)) return nullptr;
if (!FioCheckFileExists(filename, GAME_DIR)) return nullptr;

auto ls = ReadRawLanguageStrings(filename);
if (!ls.IsValid()) return nullptr;
Expand Down
47 changes: 17 additions & 30 deletions src/network/core/address.cpp
Expand Up @@ -19,15 +19,15 @@
* IPv4 dotted representation is given.
* @return the hostname
*/
const char *NetworkAddress::GetHostname()
const std::string &NetworkAddress::GetHostname()
{
if (this->hostname.empty() && this->address.ss_family != AF_UNSPEC) {
assert(this->address_length != 0);
char buffer[NETWORK_HOSTNAME_LENGTH];
getnameinfo((struct sockaddr *)&this->address, this->address_length, buffer, sizeof(buffer), nullptr, 0, NI_NUMERICHOST);
this->hostname = buffer;
}
return this->hostname.c_str();
return this->hostname;
}

/**
Expand Down Expand Up @@ -71,26 +71,17 @@ void NetworkAddress::SetPort(uint16 port)
}

/**
* Get the address as a string, e.g. 127.0.0.1:12345.
* @param buffer the buffer to write to
* @param last the last element in the buffer
* @param with_family whether to add the family (e.g. IPvX).
* Helper to get the formatting string of an address for a given family.
* @param family The family to get the address format for.
* @param with_family Whether to add the familty to the address (e.g. IPv4).
* @return The format string for the address.
*/
void NetworkAddress::GetAddressAsString(char *buffer, const char *last, bool with_family)
static const char *GetAddressFormatString(uint16 family, bool with_family)
{
if (this->GetAddress()->ss_family == AF_INET6) buffer = strecpy(buffer, "[", last);
buffer = strecpy(buffer, this->GetHostname(), last);
if (this->GetAddress()->ss_family == AF_INET6) buffer = strecpy(buffer, "]", last);
buffer += seprintf(buffer, last, ":%d", this->GetPort());

if (with_family) {
char family;
switch (this->address.ss_family) {
case AF_INET: family = '4'; break;
case AF_INET6: family = '6'; break;
default: family = '?'; break;
}
seprintf(buffer, last, " (IPv%c)", family);
switch (family) {
case AF_INET: return with_family ? "{}:{} (IPv4)" : "{}:{}";
case AF_INET6: return with_family ? "[{}]:{} (IPv6)" : "[{}]:{}";
default: return with_family ? "{}:{} (IPv?)" : "{}:{}";
}
}

Expand All @@ -101,10 +92,7 @@ void NetworkAddress::GetAddressAsString(char *buffer, const char *last, bool wit
*/
std::string NetworkAddress::GetAddressAsString(bool with_family)
{
/* 7 extra are for with_family, which adds " (IPvX)". */
char buf[NETWORK_HOSTNAME_PORT_LENGTH + 7];
this->GetAddressAsString(buf, lastof(buf), with_family);
return buf;
return fmt::format(GetAddressFormatString(this->GetAddress()->ss_family, with_family), this->GetHostname(), this->GetPort());
}

/**
Expand Down Expand Up @@ -155,7 +143,7 @@ bool NetworkAddress::IsFamily(int family)
* @note netmask without /n assumes all bits need to match.
* @return true if this IP is within the netmask.
*/
bool NetworkAddress::IsInNetmask(const char *netmask)
bool NetworkAddress::IsInNetmask(const std::string &netmask)
{
/* Resolve it if we didn't do it already */
if (!this->IsResolved()) this->GetAddress();
Expand All @@ -165,16 +153,15 @@ bool NetworkAddress::IsInNetmask(const char *netmask)
NetworkAddress mask_address;

/* Check for CIDR separator */
const char *chr_cidr = strchr(netmask, '/');
if (chr_cidr != nullptr) {
int tmp_cidr = atoi(chr_cidr + 1);
auto cidr_separator_location = netmask.find('/');
if (cidr_separator_location != std::string::npos) {
int tmp_cidr = atoi(netmask.substr(cidr_separator_location + 1).c_str());

/* Invalid CIDR, treat as single host */
if (tmp_cidr > 0 && tmp_cidr < cidr) cidr = tmp_cidr;

/* Remove the / so that NetworkAddress works on the IP portion */
std::string ip_str(netmask, chr_cidr - netmask);
mask_address = NetworkAddress(ip_str.c_str(), 0, this->address.ss_family);
mask_address = NetworkAddress(netmask.substr(0, cidr_separator_location), 0, this->address.ss_family);
} else {
mask_address = NetworkAddress(netmask, 0, this->address.ss_family);
}
Expand Down
5 changes: 2 additions & 3 deletions src/network/core/address.h
Expand Up @@ -88,8 +88,7 @@ class NetworkAddress {
this->SetPort(port);
}

const char *GetHostname();
void GetAddressAsString(char *buffer, const char *last, bool with_family = true);
const std::string &GetHostname();
std::string GetAddressAsString(bool with_family = true);
const sockaddr_storage *GetAddress();

Expand Down Expand Up @@ -117,7 +116,7 @@ class NetworkAddress {
}

bool IsFamily(int family);
bool IsInNetmask(const char *netmask);
bool IsInNetmask(const std::string &netmask);

/**
* Compare the address of this class with the address of another.
Expand Down
24 changes: 13 additions & 11 deletions src/network/core/game_info.cpp
Expand Up @@ -38,7 +38,7 @@ NetworkServerGameInfo _network_game_info; ///< Information about our game.
* Get the network version string used by this build.
* The returned string is guaranteed to be at most NETWORK_REVISON_LENGTH bytes including '\0' terminator.
*/
const char *GetNetworkRevisionString()
std::string_view GetNetworkRevisionString()
{
static std::string network_revision;

Expand All @@ -65,36 +65,38 @@ const char *GetNetworkRevisionString()
Debug(net, 3, "Network revision name: {}", network_revision);
}

return network_revision.c_str();
return network_revision;
}

/**
* Extract the git hash from the revision string.
* @param revstr The revision string (formatted as DATE-BRANCH-GITHASH).
* @param revision_string The revision string (formatted as DATE-BRANCH-GITHASH).
* @return The git has part of the revision.
*/
static const char *ExtractNetworkRevisionHash(const char *revstr)
static std::string_view ExtractNetworkRevisionHash(std::string_view revision_string)
{
return strrchr(revstr, '-');
size_t index = revision_string.find_last_of('-');
if (index == std::string::npos) return {};
return revision_string.substr(index);
}

/**
* Checks whether the given version string is compatible with our version.
* First tries to match the full string, if that fails, attempts to compare just git hashes.
* @param other the version string to compare to
*/
bool IsNetworkCompatibleVersion(const char *other)
bool IsNetworkCompatibleVersion(std::string_view other)
{
if (strncmp(GetNetworkRevisionString(), other, NETWORK_REVISION_LENGTH - 1) == 0) return true;
if (GetNetworkRevisionString() == other) return true;

/* If this version is tagged, then the revision string must be a complete match,
* since there is no git hash suffix in it.
* This is needed to avoid situations like "1.9.0-beta1" comparing equal to "2.0.0-beta1". */
if (_openttd_revision_tagged) return false;

const char *hash1 = ExtractNetworkRevisionHash(GetNetworkRevisionString());
const char *hash2 = ExtractNetworkRevisionHash(other);
return hash1 != nullptr && hash2 != nullptr && strncmp(hash1, hash2, GITHASH_SUFFIX_LEN) == 0;
std::string_view hash1 = ExtractNetworkRevisionHash(GetNetworkRevisionString());
std::string_view hash2 = ExtractNetworkRevisionHash(other);
return hash1 == hash2;
}

/**
Expand All @@ -103,7 +105,7 @@ bool IsNetworkCompatibleVersion(const char *other)
void CheckGameCompatibility(NetworkGameInfo &ngi)
{
/* Check if we are allowed on this server based on the revision-check. */
ngi.version_compatible = IsNetworkCompatibleVersion(ngi.server_revision.c_str());
ngi.version_compatible = IsNetworkCompatibleVersion(ngi.server_revision);
ngi.compatible = ngi.version_compatible;

/* Check if we have all the GRFs on the client-system too. */
Expand Down
4 changes: 2 additions & 2 deletions src/network/core/game_info.h
Expand Up @@ -89,8 +89,8 @@ struct NetworkGameInfo : NetworkServerGameInfo {

extern NetworkServerGameInfo _network_game_info;

const char *GetNetworkRevisionString();
bool IsNetworkCompatibleVersion(const char *other);
std::string_view GetNetworkRevisionString();
bool IsNetworkCompatibleVersion(std::string_view other);
void CheckGameCompatibility(NetworkGameInfo &ngi);

void FillStaticNetworkServerGameInfo();
Expand Down
4 changes: 2 additions & 2 deletions src/network/core/os_abstraction.cpp
Expand Up @@ -76,7 +76,7 @@ bool NetworkError::IsConnectInProgress() const
* Get the string representation of the error message.
* @return The string representation that will get overwritten by next calls.
*/
const char *NetworkError::AsString() const
const std::string &NetworkError::AsString() const
{
if (this->message.empty()) {
#if defined(_WIN32)
Expand All @@ -97,7 +97,7 @@ const char *NetworkError::AsString() const
this->message.assign(strerror(this->error));
#endif
}
return this->message.c_str();
return this->message;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/network/core/os_abstraction.h
Expand Up @@ -29,7 +29,7 @@ class NetworkError {
bool WouldBlock() const;
bool IsConnectionReset() const;
bool IsConnectInProgress() const;
const char *AsString() const;
const std::string &AsString() const;

static NetworkError GetLast();
};
Expand Down
2 changes: 1 addition & 1 deletion src/network/core/tcp_connect.cpp
Expand Up @@ -183,7 +183,7 @@ void TCPConnecter::Resolve()
auto start = std::chrono::steady_clock::now();

addrinfo *ai;
int error = getaddrinfo(address.GetHostname(), port_name, &hints, &ai);
int error = getaddrinfo(address.GetHostname().c_str(), port_name, &hints, &ai);

auto end = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
Expand Down
15 changes: 6 additions & 9 deletions src/network/core/tcp_http.cpp
Expand Up @@ -32,7 +32,7 @@ static std::vector<NetworkHTTPSocketHandler *> _http_connections;
* @param depth the depth (redirect recursion) of the queries
*/
NetworkHTTPSocketHandler::NetworkHTTPSocketHandler(SOCKET s,
HTTPCallback *callback, const char *host, const char *url,
HTTPCallback *callback, const std::string &host, const char *url,
const char *data, int depth) :
NetworkSocketHandler(),
recv_pos(0),
Expand All @@ -42,19 +42,16 @@ NetworkHTTPSocketHandler::NetworkHTTPSocketHandler(SOCKET s,
redirect_depth(depth),
sock(s)
{
size_t bufferSize = strlen(url) + strlen(host) + strlen(GetNetworkRevisionString()) + (data == nullptr ? 0 : strlen(data)) + 128;
char *buffer = AllocaM(char, bufferSize);

Debug(net, 5, "[tcp/http] Requesting {}{}", host, url);
std::string request;
if (data != nullptr) {
seprintf(buffer, buffer + bufferSize - 1, "POST %s HTTP/1.0\r\nHost: %s\r\nUser-Agent: OpenTTD/%s\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s\r\n", url, host, GetNetworkRevisionString(), (int)strlen(data), data);
request = fmt::format("POST {} HTTP/1.0\r\nHost: {}\r\nUser-Agent: OpenTTD/{}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}\r\n", url, host, GetNetworkRevisionString(), strlen(data), data);
} else {
seprintf(buffer, buffer + bufferSize - 1, "GET %s HTTP/1.0\r\nHost: %s\r\nUser-Agent: OpenTTD/%s\r\n\r\n", url, host, GetNetworkRevisionString());
request = fmt::format("GET {} HTTP/1.0\r\nHost: {}\r\nUser-Agent: OpenTTD/{}\r\n\r\n", url, host, GetNetworkRevisionString());
TrueBrain marked this conversation as resolved.
Show resolved Hide resolved
}

ssize_t size = strlen(buffer);
ssize_t res = send(this->sock, (const char*)buffer, size, 0);
if (res != size) {
ssize_t res = send(this->sock, request.data(), (int)request.size(), 0);
if (res != (ssize_t)request.size()) {
/* Sending all data failed. Socket can't handle this little bit
* of information? Just fall back to the old system! */
this->callback->OnFailure();
Expand Down
4 changes: 2 additions & 2 deletions src/network/core/tcp_http.h
Expand Up @@ -61,7 +61,7 @@ class NetworkHTTPSocketHandler : public NetworkSocketHandler {
void CloseSocket();

NetworkHTTPSocketHandler(SOCKET sock, HTTPCallback *callback,
const char *host, const char *url, const char *data, int depth);
const std::string &host, const char *url, const char *data, int depth);

~NetworkHTTPSocketHandler();

Expand Down Expand Up @@ -112,7 +112,7 @@ class NetworkHTTPContentConnecter : TCPConnecter {

void OnConnect(SOCKET s) override
{
new NetworkHTTPSocketHandler(s, this->callback, this->hostname.c_str(), this->url, this->data, this->depth);
new NetworkHTTPSocketHandler(s, this->callback, this->hostname, this->url, this->data, this->depth);
/* We've relinquished control of data now. */
this->data = nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/network/core/tcp_listen.h
Expand Up @@ -56,7 +56,7 @@ class TCPListenHandler {
/* Check if the client is banned */
bool banned = false;
for (const auto &entry : _network_ban_list) {
banned = address.IsInNetmask(entry.c_str());
banned = address.IsInNetmask(entry);
if (banned) {
Packet p(Tban_packet);
p.PrepareToSend();
Expand Down
4 changes: 2 additions & 2 deletions src/network/network_content_gui.cpp
Expand Up @@ -131,7 +131,7 @@ void BaseNetworkContentDownloadStatusWindow::DrawWidget(const Rect &r, int widge
StringID str;
if (this->downloaded_bytes == this->total_bytes) {
str = STR_CONTENT_DOWNLOAD_COMPLETE;
} else if (!StrEmpty(this->name)) {
} else if (!this->name.empty()) {
SetDParamStr(0, this->name);
SetDParam(1, this->downloaded_files);
SetDParam(2, this->total_files);
Expand All @@ -147,7 +147,7 @@ void BaseNetworkContentDownloadStatusWindow::DrawWidget(const Rect &r, int widge
void BaseNetworkContentDownloadStatusWindow::OnDownloadProgress(const ContentInfo *ci, int bytes)
{
if (ci->id != this->cur_id) {
strecpy(this->name, ci->filename.c_str(), lastof(this->name));
this->name = ci->filename;
this->cur_id = ci->id;
this->downloaded_files++;
}
Expand Down
4 changes: 2 additions & 2 deletions src/network/network_content_gui.h
Expand Up @@ -22,8 +22,8 @@ class BaseNetworkContentDownloadStatusWindow : public Window, ContentCallback {
uint total_files; ///< Number of files to download
uint downloaded_files; ///< Number of files downloaded

uint32 cur_id; ///< The current ID of the downloaded file
char name[48]; ///< The current name of the downloaded file
uint32 cur_id; ///< The current ID of the downloaded file
std::string name; ///< The current name of the downloaded file

public:
/**
Expand Down