/
SocketHandler.hpp
127 lines (113 loc) · 3.94 KB
/
SocketHandler.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#ifndef __ET_SOCKET_HANDLER__
#define __ET_SOCKET_HANDLER__
#include "Headers.hpp"
#include "Packet.hpp"
namespace et {
class SocketHandler {
public:
virtual ~SocketHandler() {}
virtual bool hasData(int fd) = 0;
virtual ssize_t read(int fd, void* buf, size_t count) = 0;
virtual ssize_t write(int fd, const void* buf, size_t count) = 0;
void readAll(int fd, void* buf, size_t count, bool timeout);
int writeAllOrReturn(int fd, const void* buf, size_t count);
void writeAllOrThrow(int fd, const void* buf, size_t count, bool timeout);
template <typename T>
inline T readProto(int fd, bool timeout) {
T t;
int64_t length;
readAll(fd, &length, sizeof(int64_t), timeout);
if (length < 0 || length > 128 * 1024 * 1024) {
// If the message is <= 0 or too big, assume this is a bad packet and
// throw
string s = string("Invalid size (<0 or >128 MB): ") + to_string(length);
throw std::runtime_error(s.c_str());
}
if (length == 0) {
return t;
}
string s(length, '\0');
readAll(fd, &s[0], length, timeout);
if (!t.ParseFromString(s)) {
throw std::runtime_error("Invalid proto");
}
return t;
}
template <typename T>
inline void writeProto(int fd, const T& t, bool timeout) {
string s;
if (!t.SerializeToString(&s)) {
STFATAL << "Serialization of " << t.GetTypeName() << " failed!";
}
int64_t length = s.length();
if (length < 0 || length > 128 * 1024 * 1024) {
STFATAL << "Invalid proto length: " << length << " For proto "
<< t.GetTypeName();
}
writeAllOrThrow(fd, &length, sizeof(int64_t), timeout);
if (length > 0) {
writeAllOrThrow(fd, &s[0], length, timeout);
}
}
inline bool readPacket(int fd, Packet* packet) {
int64_t length;
readAll(fd, (char*)&length, sizeof(int64_t), false);
if (length < 0 || length > 128 * 1024 * 1024) {
// If the message is < 0 or too big, assume this is a bad packet and throw
string s("Invalid size (<0 or >128 MB): ");
s += std::to_string(length);
throw std::runtime_error(s.c_str());
}
if (length == 0) {
return false;
}
string s(length, '\0');
readAll(fd, &s[0], length, false);
*packet = Packet(s);
return true;
}
inline void writePacket(int fd, const Packet& packet) {
string s = packet.serialize();
int64_t length = s.length();
if (length < 0 || length > 128 * 1024 * 1024) {
STFATAL << "Invalid message length: " << length;
}
writeAllOrThrow(fd, (const char*)&length, sizeof(int64_t), false);
if (length) {
writeAllOrThrow(fd, &s[0], length, false);
}
}
inline void writeB64(int fd, const char* buf, size_t count) {
size_t encodedLength = Base64::EncodedLength(count);
string s(encodedLength, '\0');
if (!Base64::Encode(buf, count, &s[0], s.length())) {
throw runtime_error("b64 decode failed");
}
writeAllOrThrow(fd, &s[0], s.length(), false);
}
inline void readB64(int fd, char* buf, size_t count) {
size_t encodedLength = Base64::EncodedLength(count);
string s(encodedLength, '\0');
readAll(fd, &s[0], s.length(), false);
if (!Base64::Decode((const char*)&s[0], s.length(), buf, count)) {
throw runtime_error("b64 decode failed");
}
}
inline void readB64EncodedLength(int fd, string* out, size_t encodedLength) {
string s(encodedLength, '\0');
readAll(fd, &s[0], s.length(), false);
if (!Base64::Decode(s, out)) {
throw runtime_error("b64 decode failed");
}
}
virtual int connect(const SocketEndpoint& endpoint) = 0;
virtual set<int> listen(const SocketEndpoint& endpoint) = 0;
virtual set<int> getEndpointFds(const SocketEndpoint& endpoint) = 0;
virtual int accept(int fd) = 0;
virtual void stopListening(const SocketEndpoint& endpoint) = 0;
virtual void close(int fd) = 0;
virtual vector<int> getActiveSockets() = 0;
protected:
};
} // namespace et
#endif // __ET_SOCKET_HANDLER__