Skip to content

Commit

Permalink
Core/Networking: Rewrite networking threading model
Browse files Browse the repository at this point in the history
Each network thread has its own io_service - this means that all operations on a given socket except queueing packets run from a single thread, removing the need for locking
Sending packets now writes to a lockfree intermediate queue directly, encryption is applied in network thread if it was required at the time of sending the packet

(cherry picked from commit 97a79af)
  • Loading branch information
Shauren committed Feb 20, 2016
1 parent d418406 commit b2e03a7
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 179 deletions.
83 changes: 83 additions & 0 deletions src/common/Threading/MPSCQueue.h
@@ -0,0 +1,83 @@
/*
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef MPSCQueue_h__
#define MPSCQueue_h__

#include <atomic>
#include <utility>

// C++ implementation of Dmitry Vyukov's lock free MPSC queue
// http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue
template<typename T>
class MPSCQueue
{
public:
MPSCQueue() : _head(new Node()), _tail(_head.load(std::memory_order_relaxed))
{
Node* front = _head.load(std::memory_order_relaxed);
front->Next.store(nullptr, std::memory_order_relaxed);
}

~MPSCQueue()
{
T* output;
while (this->Dequeue(output))
;

Node* front = _head.load(std::memory_order_relaxed);
delete front;
}

void Enqueue(T* input)
{
Node* node = new Node(input);
Node* prevHead = _head.exchange(node, std::memory_order_acq_rel);
prevHead->Next.store(node, std::memory_order_release);
}

bool Dequeue(T*& result)
{
Node* tail = _tail.load(std::memory_order_relaxed);
Node* next = tail->Next.load(std::memory_order_acquire);
if (!next)
return false;

result = next->Data;
_tail.store(next, std::memory_order_release);
delete tail;
return true;
}

private:
struct Node
{
Node() = default;
explicit Node(T* data) : Data(data) { Next.store(nullptr, std::memory_order_relaxed); }

T* Data;
std::atomic<Node*> Next;
};

std::atomic<Node*> _head;
std::atomic<Node*> _tail;

MPSCQueue(MPSCQueue const&) = delete;
MPSCQueue& operator=(MPSCQueue const&) = delete;
};

#endif // MPSCQueue_h__
5 changes: 1 addition & 4 deletions src/server/authserver/Server/AuthSession.cpp
Expand Up @@ -274,10 +274,7 @@ void AuthSession::SendPacket(ByteBuffer& packet)
{
MessageBuffer buffer;
buffer.Write(packet.contents(), packet.size());

std::unique_lock<std::mutex> guard(_writeLock);

QueuePacket(std::move(buffer), guard);
QueuePacket(std::move(buffer));
}
}

Expand Down
15 changes: 6 additions & 9 deletions src/server/authserver/Server/AuthSocketMgr.h
Expand Up @@ -21,8 +21,6 @@
#include "SocketMgr.h"
#include "AuthSession.h"

void OnSocketAccept(tcp::socket&& sock);

class AuthSocketMgr : public SocketMgr<AuthSession>
{
typedef SocketMgr<AuthSession> BaseSocketMgr;
Expand All @@ -39,7 +37,7 @@ class AuthSocketMgr : public SocketMgr<AuthSession>
if (!BaseSocketMgr::StartNetwork(service, bindIp, port))
return false;

_acceptor->AsyncAcceptManaged(&OnSocketAccept);
_acceptor->AsyncAcceptWithCallback<&AuthSocketMgr::OnSocketAccept>();
return true;
}

Expand All @@ -48,14 +46,13 @@ class AuthSocketMgr : public SocketMgr<AuthSession>
{
return new NetworkThread<AuthSession>[1];
}

static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
{
Instance().OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
}
};

#define sAuthSocketMgr AuthSocketMgr::Instance()

void OnSocketAccept(tcp::socket&& sock)
{
sAuthSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
}


#endif // AuthSocketMgr_h__
98 changes: 57 additions & 41 deletions src/server/game/Server/WorldSocket.cpp
Expand Up @@ -25,6 +25,17 @@

#include <memory>

class EncryptablePacket : public WorldPacket
{
public:
EncryptablePacket(WorldPacket const& packet, bool encrypt) : WorldPacket(packet), _encrypt(encrypt) { }

bool NeedsEncryption() const { return _encrypt; }

private:
bool _encrypt;
};

using boost::asio::ip::tcp;

WorldSocket::WorldSocket(tcp::socket&& socket)
Expand All @@ -40,11 +51,8 @@ void WorldSocket::Start()
stmt->setString(0, ip_address);
stmt->setUInt32(1, inet_addr(ip_address.c_str()));

{
std::lock_guard<std::mutex> guard(_queryLock);
_queryCallback = io_service().wrap(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1));
_queryFuture = LoginDatabase.AsyncQuery(stmt);
}
_queryCallback = std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1);
_queryFuture = LoginDatabase.AsyncQuery(stmt);
}

void WorldSocket::CheckIpCallback(PreparedQueryResult result)
Expand Down Expand Up @@ -78,17 +86,50 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)

bool WorldSocket::Update()
{
EncryptablePacket* queued;
MessageBuffer buffer;
while (_bufferQueue.Dequeue(queued))
{
ServerPktHeader header(queued->size() + 2, queued->GetOpcode());
if (queued->NeedsEncryption())
_authCrypt.EncryptSend(header.header, header.getHeaderLength());

if (buffer.GetRemainingSpace() < queued->size() + header.getHeaderLength())
{
QueuePacket(std::move(buffer));
buffer.Resize(4096);
}

if (buffer.GetRemainingSpace() >= queued->size() + header.getHeaderLength())
{
buffer.Write(header.header, header.getHeaderLength());
if (!queued->empty())
buffer.Write(queued->contents(), queued->size());
}
else // single packet larger than 4096 bytes
{
MessageBuffer packetBuffer(queued->size() + header.getHeaderLength());
packetBuffer.Write(header.header, header.getHeaderLength());
if (!queued->empty())
packetBuffer.Write(queued->contents(), queued->size());

QueuePacket(std::move(packetBuffer));
}

delete queued;
}

if (buffer.GetActiveSize() > 0)
QueuePacket(std::move(buffer));

if (!BaseSocket::Update())
return false;

if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
std::lock_guard<std::mutex> guard(_queryLock);
if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
auto callback = std::move(_queryCallback);
_queryCallback = nullptr;
callback(_queryFuture.get());
}
auto callback = _queryCallback;
_queryCallback = nullptr;
callback(_queryFuture.get());
}

return true;
Expand Down Expand Up @@ -351,29 +392,7 @@ void WorldSocket::SendPacket(WorldPacket const& packet)
if (sPacketLog->CanLogPacket())
sPacketLog->LogPacket(packet, SERVER_TO_CLIENT, GetRemoteIpAddress(), GetRemotePort());

ServerPktHeader header(packet.size() + 2, packet.GetOpcode());

std::unique_lock<std::mutex> guard(_writeLock);

_authCrypt.EncryptSend(header.header, header.getHeaderLength());

#ifndef TC_SOCKET_USE_IOCP
if (_writeQueue.empty() && _writeBuffer.GetRemainingSpace() >= header.getHeaderLength() + packet.size())
{
_writeBuffer.Write(header.header, header.getHeaderLength());
if (!packet.empty())
_writeBuffer.Write(packet.contents(), packet.size());
}
else
#endif
{
MessageBuffer buffer(header.getHeaderLength() + packet.size());
buffer.Write(header.header, header.getHeaderLength());
if (!packet.empty())
buffer.Write(packet.contents(), packet.size());

QueuePacket(std::move(buffer), guard);
}
_bufferQueue.Enqueue(new EncryptablePacket(packet, _authCrypt.IsInitialized()));
}

void WorldSocket::HandleAuthSession(WorldPacket& recvPacket)
Expand All @@ -398,11 +417,8 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket)
stmt->setInt32(0, int32(realmID));
stmt->setString(1, authSession->Account);

{
std::lock_guard<std::mutex> guard(_queryLock);
_queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1));
_queryFuture = LoginDatabase.AsyncQuery(stmt);
}
_queryCallback = std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1);
_queryFuture = LoginDatabase.AsyncQuery(stmt);
}

void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSession, PreparedQueryResult result)
Expand Down Expand Up @@ -559,7 +575,7 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSes
if (wardenActive)
_worldSession->InitWarden(&account.SessionKey, account.OS);

_queryCallback = io_service().wrap(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1));
_queryCallback = std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1);
_queryFuture = _worldSession->LoadPermissionsAsync();
AsyncRead();
}
Expand Down
4 changes: 3 additions & 1 deletion src/server/game/Server/WorldSocket.h
Expand Up @@ -26,11 +26,13 @@
#include "Util.h"
#include "WorldPacket.h"
#include "WorldSession.h"
#include "MPSCQueue.h"
#include <chrono>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/buffer.hpp>

using boost::asio::ip::tcp;
class EncryptablePacket;

#pragma pack(push, 1)

Expand Down Expand Up @@ -104,8 +106,8 @@ class WorldSocket : public Socket<WorldSocket>

MessageBuffer _headerBuffer;
MessageBuffer _packetBuffer;
MPSCQueue<EncryptablePacket> _bufferQueue;

std::mutex _queryLock;
PreparedQueryResultFuture _queryFuture;
std::function<void(PreparedQueryResult&&)> _queryCallback;
std::string _ipCountry;
Expand Down
12 changes: 7 additions & 5 deletions src/server/game/Server/WorldSocketMgr.cpp
Expand Up @@ -24,9 +24,9 @@

#include <boost/system/error_code.hpp>

static void OnSocketAccept(tcp::socket&& sock)
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
{
sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
}

class WorldSocketThread : public NetworkThread<WorldSocket>
Expand Down Expand Up @@ -67,7 +67,9 @@ bool WorldSocketMgr::StartNetwork(boost::asio::io_service& service, std::string

BaseSocketMgr::StartNetwork(service, bindIp, port);

_acceptor->AsyncAcceptManaged(&OnSocketAccept);
_acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this));

_acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();

sScriptMgr->OnNetworkStart();
return true;
Expand All @@ -80,7 +82,7 @@ void WorldSocketMgr::StopNetwork()
sScriptMgr->OnNetworkStop();
}

void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)
void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
{
// set some options here
if (_socketSendBufferSize >= 0)
Expand Down Expand Up @@ -108,7 +110,7 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)

//sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff);

BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock));
BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
}

NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
Expand Down
2 changes: 1 addition & 1 deletion src/server/game/Server/WorldSocketMgr.h
Expand Up @@ -47,7 +47,7 @@ class WorldSocketMgr : public SocketMgr<WorldSocket>
/// Stops all network threads, It will wait for all running threads .
void StopNetwork() override;

void OnSocketOpen(tcp::socket&& sock) override;
void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override;

protected:
WorldSocketMgr();
Expand Down

3 comments on commit b2e03a7

@Aokromes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have re-run cmake?

@Aokromes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According shauren it's because you are trying to compile with clang on one environment with gcc 5 atm linux guys needs to fix that.

@Sar777
Copy link
Contributor

@Sar777 Sar777 commented on b2e03a7 Feb 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Runico, yes, and it 4.9+ problem, 4.8 fine :(

Please sign in to comment.