forked from kokjo/pycoin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bserialize.py
135 lines (122 loc) · 4.11 KB
/
bserialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Contains the binary serialization logic, we import it "as bs".
Each object is provided with a binary serialization class, in most cases this
is the object's class.
Each binary-serialization class has two methods tobinary() and frombinary().
The tobinary() method takes an object as its first argument and returns it as a
bytes object. The frombinary() method takes a bytestring as its first
argument and returns (obj, rest of bytestring).
Throws ProtocolViolation when fed junk.
"""
from utils import ProtocolViolation
import struct
from socket import inet_ntoa, inet_aton
class Entity(object):
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def tobinary(self):
retval = b""
for (field, type) in self.bfields:
retval += type.tobinary(self.__getattribute__(field))
return retval
@classmethod
def frombinary(cls, bdata, rest=True):
self = cls.__new__(cls)
for (field, type) in self.bfields:
obj, bdata = type.frombinary(bdata)
self.__setattr__(field, obj)
if rest:
return self, bdata
else:
if bdata != "":
raise ProtocolViolation()
else:
return self
def structfmt(fmt):
"""Produce a serialization object for a value understood by struct.
e.g. structfmt("<I") for 4-byte integers"""
fmtsize = struct.calcsize(fmt)
class _():
@staticmethod
def tobinary(obj):
return struct.pack(fmt, obj)
@staticmethod
def frombinary(bdata):
try:
return struct.unpack(fmt, bdata[:fmtsize])[0], bdata[fmtsize:]
except struct.error:
raise ProtocolViolation()
return _
class Str():
@staticmethod
def tobinary(obj):
return obj.encode("ascii") + '\0'
@staticmethod
def frombinary(bdata):
bytes, ch, bdata = bdata.partition('\0')
if ch == '':
raise ProtocolViolation
try:
return bytes.decode("ascii"), bdata
except UnicodeDecodeError:
raise ProtocolViolation
Hash = structfmt("<32s")
class VarInt():
@staticmethod
def frombinary(bdata):
try:
if ord(bdata[0]) <= 0xfc:
return ord(bdata[0]), bdata[1:]
if ord(bdata[0]) == 0xfd:
return struct.unpack("<xH", bdata[:3])[0], bdata[3:]
if ord(bdata[0]) == 0xfe:
return struct.unpack("<xI", bdata[:5])[0], bdata[5:]
if ord(bdata[0]) == 0xff:
return struct.unpack("<xQ", bdata[:9])[0], bdata[9:]
except (struct.error, IndexError):
raise ProtocolViolation
@staticmethod
def tobinary(int):
if int <= 0xfc:
return struct.pack("<B", int)
elif int < 0xffff:
return struct.pack("<BH", 0xfd, int)
elif int < 0xffffffff:
return struct.pack("<BI", 0xfe, int)
else:
return struct.pack("<BQ", 0xff, int)
def VarList(ty):
class _():
@staticmethod
def frombinary(bdata):
num, bdata = VarInt.frombinary(bdata)
retval = []
for _ in range(num):
item, bdata = ty.frombinary(bdata)
retval.append(item)
return retval, bdata
@staticmethod
def tobinary(obj):
return VarInt.tobinary(len(obj)) + "".join(ty.tobinary(x) for x in obj)
return _
class VarBytes():
@staticmethod
def frombinary(bdata):
num, bdata = VarInt.frombinary(bdata)
if len(bdata) < num:
raise ProtocolViolation
return bdata[:num], bdata[num:]
@staticmethod
def tobinary(obj):
return VarInt.tobinary(len(obj)) + obj
class IPv4Inv6():
@staticmethod
def frombinary(bdata):
try:
obj, bdata = structfmt("!12x4s").frombinary(bdata)
except struct.error:
raise ProtocolViolation
return inet_ntoa(obj), bdata
@staticmethod
def tobinary(object):
return struct.pack("!10xH4s", 2**16-1, inet_aton(object))