Permalink
Browse files

P2P: parse network datastream into header/data components

Prevents peer from sending garbage attack.
  • Loading branch information...
Tranz5 committed Jul 22, 2014
1 parent bd6357c commit ec2a4acdb31ae7b7f3ae41a7e021a8f035771c99
Showing with 177 additions and 55 deletions.
  1. +29 −38 src/main.cpp
  2. +83 −13 src/net.cpp
  3. +65 −4 src/net.h
@@ -3211,7 +3211,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
else if (strCommand == "verack")
{
pfrom->vRecv.SetVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
pfrom->SetRecvVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
}
@@ -3730,13 +3730,13 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
return true;
}
// requires LOCK(cs_vRecvMsg)
bool ProcessMessages(CNode* pfrom)
{
CDataStream& vRecv = pfrom->vRecv;
if (vRecv.empty())
if (pfrom->vRecvMsg.empty())
return true;
// if (fDebug)
// printf("ProcessMessages(%u bytes)\n", vRecv.size());
// printf("ProcessMessages(%zu messages)\n", pfrom->vRecvMsg.size());
//
// Message format
@@ -3747,32 +3747,34 @@ bool ProcessMessages(CNode* pfrom)
// (x) data
//
while (true)
unsigned int nMsgPos = 0;
for (; nMsgPos < pfrom->vRecvMsg.size(); nMsgPos++)
{
// Don't bother if send buffer is too full to respond anyway
if (pfrom->vSend.size() >= SendBufferSize())
break;
// Scan for message start
CDataStream::iterator pstart = search(vRecv.begin(), vRecv.end(), BEGIN(pchMessageStart), END(pchMessageStart));
int nHeaderSize = vRecv.GetSerializeSize(CMessageHeader());
if (vRecv.end() - pstart < nHeaderSize)
{
if ((int)vRecv.size() > nHeaderSize)
{
printf("\n\nPROCESSMESSAGE MESSAGESTART NOT FOUND\n\n");
vRecv.erase(vRecv.begin(), vRecv.end() - nHeaderSize);
}
// get next message; end, if an incomplete message is found
CNetMessage& msg = pfrom->vRecvMsg[nMsgPos];
//if (fDebug)
// printf("ProcessMessages(message %u msgsz, %zu bytes, complete:%s)\n",
// msg.hdr.nMessageSize, msg.vRecv.size(),
// msg.complete() ? "Y" : "N");
if (!msg.complete())
break;
// Scan for message start
if (memcmp(msg.hdr.pchMessageStart, pchMessageStart, sizeof(pchMessageStart)) != 0) {
printf("\n\nPROCESSMESSAGE: INVALID MESSAGESTART\n\n");
return false;
}
if (pstart - vRecv.begin() > 0)
printf("\n\nPROCESSMESSAGE SKIPPED %"PRIpdd" BYTES\n\n", pstart - vRecv.begin());
vRecv.erase(vRecv.begin(), pstart);
// Read header
vector<char> vHeaderSave(vRecv.begin(), vRecv.begin() + nHeaderSize);
CMessageHeader hdr;
vRecv >> hdr;
CMessageHeader& hdr = msg.hdr;
if (!hdr.IsValid())
{
printf("\n\nPROCESSMESSAGE: ERRORS IN HEADER %s\n\n\n", hdr.GetCommand().c_str());
@@ -3782,19 +3784,9 @@ bool ProcessMessages(CNode* pfrom)
// Message size
unsigned int nMessageSize = hdr.nMessageSize;
if (nMessageSize > MAX_SIZE)
{
printf("ProcessMessages(%s, %u bytes) : nMessageSize > MAX_SIZE\n", strCommand.c_str(), nMessageSize);
continue;
}
if (nMessageSize > vRecv.size())
{
// Rewind and wait for rest of message
vRecv.insert(vRecv.begin(), vHeaderSave.begin(), vHeaderSave.end());
break;
}
// Checksum
CDataStream& vRecv = msg.vRecv;
uint256 hash = Hash(vRecv.begin(), vRecv.begin() + nMessageSize);
unsigned int nChecksum = 0;
memcpy(&nChecksum, &hash, sizeof(nChecksum));
@@ -3805,17 +3797,13 @@ bool ProcessMessages(CNode* pfrom)
continue;
}
// Copy message to its own buffer
CDataStream vMsg(vRecv.begin(), vRecv.begin() + nMessageSize, vRecv.nType, vRecv.nVersion);
vRecv.ignore(nMessageSize);
// Process message
bool fRet = false;
try
{
{
LOCK(cs_main);
fRet = ProcessMessage(pfrom, strCommand, vMsg);
fRet = ProcessMessage(pfrom, strCommand, vRecv);
}
if (fShutdown)
return true;
@@ -3847,7 +3835,10 @@ bool ProcessMessages(CNode* pfrom)
printf("ProcessMessage(%s, %u bytes) FAILED\n", strCommand.c_str(), nMessageSize);
}
vRecv.Compact();
// remove processed messages; one incomplete message may remain
if (nMsgPos > 0)
pfrom->vRecvMsg.erase(pfrom->vRecvMsg.begin(),
pfrom->vRecvMsg.begin() + nMsgPos);
return true;
}
@@ -534,7 +534,7 @@ void CNode::CloseSocketDisconnect()
printf("disconnecting node %s\n", addrName.c_str());
closesocket(hSocket);
hSocket = INVALID_SOCKET;
vRecv.clear();
vRecvMsg.clear();
}
}
@@ -628,6 +628,78 @@ void CNode::copyStats(CNodeStats &stats)
}
#undef X
// requires LOCK(cs_vRecvMsg)
bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes)
{
while (nBytes > 0) {
// get current incomplete message, or create a new one
if (vRecvMsg.size() == 0 ||
vRecvMsg.back().complete())
vRecvMsg.push_back(CNetMessage(SER_NETWORK, nRecvVersion));
CNetMessage& msg = vRecvMsg.back();
// absorb network data
int handled;
if (!msg.in_data)
handled = msg.readHeader(pch, nBytes);
else
handled = msg.readData(pch, nBytes);
if (handled < 0)
return false;
pch += handled;
nBytes -= handled;
}
return true;
}
int CNetMessage::readHeader(const char *pch, unsigned int nBytes)
{
// copy data to temporary parsing buffer
unsigned int nRemaining = 24 - nHdrPos;
unsigned int nCopy = std::min(nRemaining, nBytes);
memcpy(&hdrbuf[nHdrPos], pch, nCopy);
nHdrPos += nCopy;
// if header incomplete, exit
if (nHdrPos < 24)
return nCopy;
// deserialize to CMessageHeader
try {
hdrbuf >> hdr;
}
catch (std::exception &e) {
return -1;
}
// reject messages larger than MAX_SIZE
if (hdr.nMessageSize > MAX_SIZE)
return -1;
// switch state to reading message data
in_data = true;
vRecv.resize(hdr.nMessageSize);
return nCopy;
}
int CNetMessage::readData(const char *pch, unsigned int nBytes)
{
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min(nRemaining, nBytes);
memcpy(&vRecv[nDataPos], pch, nCopy);
nDataPos += nCopy;
return nCopy;
}
@@ -676,7 +748,7 @@ void ThreadSocketHandler2(void* parg)
BOOST_FOREACH(CNode* pnode, vNodesCopy)
{
if (pnode->fDisconnect ||
(pnode->GetRefCount() <= 0 && pnode->vRecv.empty() && pnode->vSend.empty()))
(pnode->GetRefCount() <= 0 && pnode->vRecvMsg.empty() && pnode->vSend.empty()))
{
// remove from vNodes
vNodes.erase(remove(vNodes.begin(), vNodes.end(), pnode), vNodes.end());
@@ -707,7 +779,7 @@ void ThreadSocketHandler2(void* parg)
TRY_LOCK(pnode->cs_vSend, lockSend);
if (lockSend)
{
TRY_LOCK(pnode->cs_vRecv, lockRecv);
TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
{
TRY_LOCK(pnode->cs_mapRequests, lockReq);
@@ -872,15 +944,12 @@ void ThreadSocketHandler2(void* parg)
continue;
if (FD_ISSET(pnode->hSocket, &fdsetRecv) || FD_ISSET(pnode->hSocket, &fdsetError))
{
TRY_LOCK(pnode->cs_vRecv, lockRecv);
TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
{
CDataStream& vRecv = pnode->vRecv;
unsigned int nPos = vRecv.size();
if (nPos > ReceiveBufferSize()) {
if (pnode->GetTotalRecvSize() > ReceiveFloodSize()) {
if (!pnode->fDisconnect)
printf("socket recv flood control disconnect (%"PRIszu" bytes)\n", vRecv.size());
printf("socket recv flood control disconnect (%u bytes)\n", pnode->GetTotalRecvSize());
pnode->CloseSocketDisconnect();
}
else {
@@ -889,8 +958,8 @@ void ThreadSocketHandler2(void* parg)
int nBytes = recv(pnode->hSocket, pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
if (nBytes > 0)
{
vRecv.resize(nPos + nBytes);
memcpy(&vRecv[nPos], pchBuf, nBytes);
if (!pnode->ReceiveMsgBytes(pchBuf, nBytes))
pnode->CloseSocketDisconnect();
pnode->nLastRecv = GetTime();
pnode->nRecvBytes += nBytes;
@@ -1619,9 +1688,10 @@ void ThreadMessageHandler2(void* parg)
{
// Receive messages
{
TRY_LOCK(pnode->cs_vRecv, lockRecv);
TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
ProcessMessages(pnode);
if (!ProcessMessages(pnode))
pnode->CloseSocketDisconnect();
}
if (fShutdown)
return;
@@ -26,7 +26,7 @@ extern int nBestHeight;
inline unsigned int ReceiveBufferSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); }
inline unsigned int ReceiveFloodSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); }
inline unsigned int SendBufferSize() { return 1000*GetArg("-maxsendbuffer", 1*1000); }
void AddOneShot(std::string strDest);
@@ -154,6 +154,43 @@ class CNodeStats
class CNetMessage {
public:
bool in_data; // parsing header (false) or data (true)
CDataStream hdrbuf; // partially received header
CMessageHeader hdr; // complete header
unsigned int nHdrPos;
CDataStream vRecv; // received message data
unsigned int nDataPos;
CNetMessage(int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) {
hdrbuf.resize(24);
in_data = false;
nHdrPos = 0;
nDataPos = 0;
}
bool complete() const
{
if (!in_data)
return false;
return (hdr.nMessageSize == nDataPos);
}
void SetVersion(int nVersionIn)
{
hdrbuf.SetVersion(nVersionIn);
vRecv.SetVersion(nVersionIn);
}
int readHeader(const char *pch, unsigned int nBytes);
int readData(const char *pch, unsigned int nBytes);
};
/** Information about a peer */
@@ -164,9 +201,12 @@ class CNode
uint64 nServices;
SOCKET hSocket;
CDataStream vSend;
CDataStream vRecv;
CCriticalSection cs_vSend;
CCriticalSection cs_vRecv;
std::vector<CNetMessage> vRecvMsg;
CCriticalSection cs_vRecvMsg;
int nRecvVersion;
int64 nLastSend;
int64 nLastRecv;
int64 nLastSendEmpty;
@@ -218,10 +258,11 @@ class CNode
CCriticalSection cs_inventory;
std::multimap<int64, CInv> mapAskFor;
CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION), vRecv(SER_NETWORK, MIN_PROTO_VERSION)
CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION)
{
nServices = 0;
hSocket = hSocketIn;
nRecvVersion = MIN_PROTO_VERSION;
nLastSend = 0;
nLastRecv = 0;
nLastSendEmpty = GetTime();
@@ -277,6 +318,26 @@ class CNode
return nRefCount;
}
// requires LOCK(cs_vRecvMsg)
unsigned int GetTotalRecvSize()
{
unsigned int total = 0;
for (unsigned int i = 0; i < vRecvMsg.size(); i++)
total += vRecvMsg[i].vRecv.size();
return total;
}
// requires LOCK(cs_vRecvMsg)
bool ReceiveMsgBytes(const char *pch, unsigned int nBytes);
// requires LOCK(cs_vRecvMsg)
void SetRecvVersion(int nVersionIn)
{
nRecvVersion = nVersionIn;
for (unsigned int i = 0; i < vRecvMsg.size(); i++)
vRecvMsg[i].SetVersion(nVersionIn);
}
CNode* AddRef()
{
nRefCount++;

0 comments on commit ec2a4ac

Please sign in to comment.