Skip to content

Commit

Permalink
Core/Networking: Added new AsyncRead method to Socket class allowing …
Browse files Browse the repository at this point in the history
…to pass a custom completion handler and refactor world socket initialization string handling
  • Loading branch information
Shauren committed Mar 10, 2016
1 parent 52bb648 commit f123c39
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 122 deletions.
272 changes: 152 additions & 120 deletions src/server/game/Server/WorldSocket.cpp
Expand Up @@ -57,18 +57,18 @@ std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CON
std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER");
uint32 const WorldSocket::MinSizeForCompression = 0x400;

uint32 const SizeOfClientHeader[2][2] =
uint32 const SizeOfClientHeader[2] =
{
{ 2, 0 },
{ 6, 4 }
6, 4
};

uint32 const SizeOfServerHeader[2] = { sizeof(uint16) + sizeof(uint32), sizeof(uint32) };

WorldSocket::WorldSocket(tcp::socket&& socket) : Socket(std::move(socket)),
_type(CONNECTION_TYPE_REALM), _authSeed(rand32()), _OverSpeedPings(0),
_worldSession(nullptr), _authed(false), _compressionStream(nullptr), _initialized(false)
_worldSession(nullptr), _authed(false), _compressionStream(nullptr)
{
_headerBuffer.Resize(SizeOfClientHeader[0][0]);
_headerBuffer.Resize(SizeOfClientHeader[0]);
}

WorldSocket::~WorldSocket()
Expand Down Expand Up @@ -116,7 +116,9 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
}
}

AsyncRead();
_packetBuffer.Resize(2 + ClientConnectionInitialize.length() + 1);

AsyncReadWithCallback(&WorldSocket::InitializeHandler);

MessageBuffer initializer;
ServerPktHeader header;
Expand All @@ -128,6 +130,65 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
QueuePacket(std::move(initializer));
}

void WorldSocket::InitializeHandler(boost::system::error_code error, std::size_t transferedBytes)
{
if (error)
{
CloseSocket();
return;
}

GetReadBuffer().WriteCompleted(transferedBytes);

MessageBuffer& packet = GetReadBuffer();
if (packet.GetActiveSize() > 0)
{
if (_packetBuffer.GetRemainingSpace() > 0)
{
// need to receive the header
std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace());
_packetBuffer.Write(packet.GetReadPointer(), readHeaderSize);
packet.ReadCompleted(readHeaderSize);

if (_packetBuffer.GetRemainingSpace() > 0)
{
// Couldn't receive the whole header this time.
ASSERT(packet.GetActiveSize() == 0);
AsyncReadWithCallback(&WorldSocket::InitializeHandler);
return;
}

std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer() + 2), std::min(_packetBuffer.GetActiveSize() - 2, ClientConnectionInitialize.length()));
if (initializer != ClientConnectionInitialize)
{
CloseSocket();
return;
}

_compressionStream = new z_stream();
_compressionStream->zalloc = (alloc_func)NULL;
_compressionStream->zfree = (free_func)NULL;
_compressionStream->opaque = (voidpf)NULL;
_compressionStream->avail_in = 0;
_compressionStream->next_in = NULL;
int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
if (z_res != Z_OK)
{
CloseSocket();
TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
return;
}

_packetBuffer.Reset();
HandleSendAuthSession();
AsyncRead();
return;
}
}

AsyncReadWithCallback(&WorldSocket::InitializeHandler);
}

bool WorldSocket::Update()
{
EncryptablePacket* queued;
Expand Down Expand Up @@ -266,9 +327,7 @@ void WorldSocket::ExtractOpcodeAndSize(ClientPktHeader const* header, uint32& op
else
{
opcode = header->Setup.Command;
size = header->Setup.Size;
if (_initialized)
size -= 4;
size = header->Setup.Size - 4;
}
}

Expand All @@ -281,7 +340,7 @@ void WorldSocket::SetWorldSession(WorldSession* session)

bool WorldSocket::ReadHeaderHandler()
{
ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()]);
ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_authCrypt.IsInitialized()]);

_authCrypt.DecryptRecv(_headerBuffer.GetReadPointer(), _headerBuffer.GetActiveSize());

Expand All @@ -291,7 +350,7 @@ bool WorldSocket::ReadHeaderHandler()

ExtractOpcodeAndSize(header, opcode, size);

if (!ClientPktHeader::IsValidSize(size) || (_initialized && !ClientPktHeader::IsValidOpcode(opcode)))
if (!ClientPktHeader::IsValidSize(size) || !ClientPktHeader::IsValidOpcode(opcode))
{
TC_LOG_ERROR("network", "WorldSocket::ReadHeaderHandler(): client %s sent malformed packet (size: %u, cmd: %u)",
GetRemoteIpAddress().to_string().c_str(), size, opcode);
Expand All @@ -304,133 +363,106 @@ bool WorldSocket::ReadHeaderHandler()

WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler()
{
if (_initialized)
{
ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
uint32 cmd;
uint32 size;
ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
uint32 cmd;
uint32 size;

ExtractOpcodeAndSize(header, cmd, size);
ExtractOpcodeAndSize(header, cmd, size);

OpcodeClient opcode = static_cast<OpcodeClient>(cmd);
OpcodeClient opcode = static_cast<OpcodeClient>(cmd);

WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());
WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());

if (sPacketLog->CanLogPacket())
sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
if (sPacketLog->CanLogPacket())
sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());

std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);
std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);

switch (opcode)
switch (opcode)
{
case CMSG_PING:
LogOpcodeText(opcode, sessionGuard);
return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
case CMSG_AUTH_SESSION:
{
case CMSG_PING:
LogOpcodeText(opcode, sessionGuard);
return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
case CMSG_AUTH_SESSION:
LogOpcodeText(opcode, sessionGuard);
if (_authed)
{
LogOpcodeText(opcode, sessionGuard);
if (_authed)
{
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
if (sessionGuard.try_lock())
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
return ReadDataHandlerResult::Error;
}

std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
authSession->Read();
HandleAuthSession(authSession);
return ReadDataHandlerResult::WaitingForQuery;
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
if (sessionGuard.try_lock())
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
return ReadDataHandlerResult::Error;
}
case CMSG_AUTH_CONTINUED_SESSION:
{
LogOpcodeText(opcode, sessionGuard);
if (_authed)
{
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
if (sessionGuard.try_lock())
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
return ReadDataHandlerResult::Error;
}

std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
authSession->Read();
HandleAuthContinuedSession(authSession);
return ReadDataHandlerResult::WaitingForQuery;
}
case CMSG_KEEP_ALIVE:
LogOpcodeText(opcode, sessionGuard);
break;
case CMSG_LOG_DISCONNECT:
LogOpcodeText(opcode, sessionGuard);
packet.rfinish(); // contains uint32 disconnectReason;
break;
case CMSG_ENABLE_NAGLE:
LogOpcodeText(opcode, sessionGuard);
SetNoDelay(false);
break;
case CMSG_CONNECT_TO_FAILED:
std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
authSession->Read();
HandleAuthSession(authSession);
return ReadDataHandlerResult::WaitingForQuery;
}
case CMSG_AUTH_CONTINUED_SESSION:
{
LogOpcodeText(opcode, sessionGuard);
if (_authed)
{
sessionGuard.lock();

LogOpcodeText(opcode, sessionGuard);
WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
connectToFailed.Read();
HandleConnectToFailed(connectToFailed);
break;
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
if (sessionGuard.try_lock())
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
return ReadDataHandlerResult::Error;
}
default:
{
sessionGuard.lock();

LogOpcodeText(opcode, sessionGuard);
std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
authSession->Read();
HandleAuthContinuedSession(authSession);
return ReadDataHandlerResult::WaitingForQuery;
}
case CMSG_KEEP_ALIVE:
LogOpcodeText(opcode, sessionGuard);
break;
case CMSG_LOG_DISCONNECT:
LogOpcodeText(opcode, sessionGuard);
packet.rfinish(); // contains uint32 disconnectReason;
break;
case CMSG_ENABLE_NAGLE:
LogOpcodeText(opcode, sessionGuard);
SetNoDelay(false);
break;
case CMSG_CONNECT_TO_FAILED:
{
sessionGuard.lock();

if (!_worldSession)
{
TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
return ReadDataHandlerResult::Error;
}
LogOpcodeText(opcode, sessionGuard);
WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
connectToFailed.Read();
HandleConnectToFailed(connectToFailed);
break;
}
default:
{
sessionGuard.lock();

OpcodeHandler const* handler = opcodeTable[opcode];
if (!handler)
{
TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
break;
}
LogOpcodeText(opcode, sessionGuard);

// Our Idle timer will reset on any non PING opcodes.
// Catches people idling on the login screen and any lingering ingame connections.
_worldSession->ResetTimeOutTime();
if (!_worldSession)
{
TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
return ReadDataHandlerResult::Error;
}

// Copy the packet to the heap before enqueuing
_worldSession->QueuePacket(new WorldPacket(std::move(packet)));
OpcodeHandler const* handler = opcodeTable[opcode];
if (!handler)
{
TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
break;
}
}
}
else
{
std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer()), std::min(_packetBuffer.GetActiveSize(), ClientConnectionInitialize.length()));
if (initializer != ClientConnectionInitialize)
return ReadDataHandlerResult::Error;

_compressionStream = new z_stream();
_compressionStream->zalloc = (alloc_func)NULL;
_compressionStream->zfree = (free_func)NULL;
_compressionStream->opaque = (voidpf)NULL;
_compressionStream->avail_in = 0;
_compressionStream->next_in = NULL;
int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
if (z_res != Z_OK)
{
TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
return ReadDataHandlerResult::Error;
}

_initialized = true;
_headerBuffer.Resize(SizeOfClientHeader[1][0]);
_packetBuffer.Reset();
HandleSendAuthSession();
// Our Idle timer will reset on any non PING opcodes.
// Catches people idling on the login screen and any lingering ingame connections.
_worldSession->ResetTimeOutTime();

// Copy the packet to the heap before enqueuing
_worldSession->QueuePacket(new WorldPacket(std::move(packet)));
break;
}
}

return ReadDataHandlerResult::Ok;
Expand Down Expand Up @@ -610,7 +642,7 @@ struct AccountInfo
void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession)
{
// Client switches packet headers after sending CMSG_AUTH_SESSION
_headerBuffer.Resize(SizeOfClientHeader[1][1]);
_headerBuffer.Resize(SizeOfClientHeader[1]);

// Get the account information from the auth database
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_BY_NAME);
Expand Down Expand Up @@ -811,7 +843,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth:
}

// Client switches packet headers after sending CMSG_AUTH_CONTINUED_SESSION
_headerBuffer.Resize(SizeOfClientHeader[1][1]);
_headerBuffer.Resize(SizeOfClientHeader[1]);

uint32 accountId = uint32(key.Fields.AccountId);
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION);
Expand Down
3 changes: 1 addition & 2 deletions src/server/game/Server/WorldSocket.h
Expand Up @@ -107,6 +107,7 @@ class WorldSocket : public Socket<WorldSocket>
ReadDataHandlerResult ReadDataHandler();
private:
void CheckIpCallback(PreparedQueryResult result);
void InitializeHandler(boost::system::error_code error, std::size_t transferedBytes);

/// writes network.opcode log
/// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock
Expand Down Expand Up @@ -148,8 +149,6 @@ class WorldSocket : public Socket<WorldSocket>

z_stream_s* _compressionStream;

bool _initialized;

PreparedQueryResultFuture _queryFuture;
std::function<void(PreparedQueryResult&&)> _queryCallback;
std::string _ipCountry;
Expand Down
11 changes: 11 additions & 0 deletions src/server/shared/Networking/Socket.h
Expand Up @@ -90,6 +90,17 @@ class Socket : public std::enable_shared_from_this<T>
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
return;

_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}

void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));
Expand Down

4 comments on commit f123c39

@killradio
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Shauren can you plz take a look at this issue #16100 ?

@AwkwardDev
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love spread. :)

@letor
Copy link

@letor letor commented on f123c39 Mar 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fix ::QueuePacket crash?

@Shauren
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related

Please sign in to comment.