-
Notifications
You must be signed in to change notification settings - Fork 405
/
_tools.py
321 lines (246 loc) · 10.2 KB
/
_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# Copyright (c) 2009,2016,2019 MetPy Developers.
# Distributed under the terms of the BSD 3-Clause License.
# SPDX-License-Identifier: BSD-3-Clause
"""A collection of general purpose tools for reading files."""
import bz2
from collections import namedtuple
import gzip
import logging
from struct import Struct
import zlib
log = logging.getLogger(__name__)
def open_as_needed(filename, mode='rb'):
"""Return a file-object given either a filename or an object.
Handles opening with the right class based on the file extension.
"""
if hasattr(filename, 'read'):
return filename
kwargs = {'errors': 'surrogateescape'} if mode != 'rb' else {}
if filename.endswith('.bz2'):
return bz2.BZ2File(filename, mode)
elif filename.endswith('.gz'):
return gzip.GzipFile(filename, mode)
else:
return open(filename, mode, **kwargs)
class NamedStruct(Struct):
"""Parse bytes using :class:`Struct` but provide named fields."""
def __init__(self, info, prefmt='', tuple_name=None):
"""Initialize the NamedStruct."""
if tuple_name is None:
tuple_name = 'NamedStruct'
names, fmts = zip(*info)
self.converters = {}
conv_off = 0
for ind, i in enumerate(info):
if len(i) > 2:
self.converters[ind - conv_off] = i[-1]
elif not i[0]: # Skip items with no name
conv_off += 1
self._tuple = namedtuple(tuple_name, ' '.join(n for n in names if n))
super().__init__(prefmt + ''.join(f for f in fmts if f))
def _create(self, items):
if self.converters:
items = list(items)
for ind, conv in self.converters.items():
items[ind] = conv(items[ind])
if len(items) < len(self._tuple._fields):
items.extend([None] * (len(self._tuple._fields) - len(items)))
return self.make_tuple(*items)
def make_tuple(self, *args, **kwargs):
"""Construct the underlying tuple from values."""
return self._tuple(*args, **kwargs)
def unpack(self, s):
"""Parse bytes and return a namedtuple."""
return self._create(super().unpack(s))
def unpack_from(self, buff, offset=0):
"""Read bytes from a buffer and return as a namedtuple."""
return self._create(super().unpack_from(buff, offset))
def unpack_file(self, fobj):
"""Unpack the next bytes from a file object."""
return self.unpack(fobj.read(self.size))
# This works around times when we have more than 255 items and can't use
# NamedStruct. This is a CPython limit for arguments.
class DictStruct(Struct):
"""Parse bytes using :class:`Struct` but provide named fields using dictionary access."""
def __init__(self, info, prefmt=''):
"""Initialize the DictStruct."""
names, formats = zip(*info)
# Remove empty names
self._names = [n for n in names if n]
super().__init__(prefmt + ''.join(f for f in formats if f))
def _create(self, items):
return dict(zip(self._names, items))
def unpack(self, s):
"""Parse bytes and return a namedtuple."""
return self._create(super().unpack(s))
def unpack_from(self, buff, offset=0):
"""Unpack the next bytes from a file object."""
return self._create(super().unpack_from(buff, offset))
class Enum(object):
"""Map values to specific strings."""
def __init__(self, *args, **kwargs):
"""Initialize the mapping."""
# Assign values for args in order starting at 0
self.val_map = {ind: a for ind, a in enumerate(args)}
# Invert the kwargs dict so that we can map from value to name
self.val_map.update(zip(kwargs.values(), kwargs.keys()))
def __call__(self, val):
"""Map an integer to the string representation."""
return self.val_map.get(val, 'Unknown ({})'.format(val))
class Bits(object):
"""Breaks an integer into a specified number of True/False bits."""
def __init__(self, num_bits):
"""Initialize the number of bits."""
self._bits = range(num_bits)
def __call__(self, val):
"""Convert the integer to the list of True/False values."""
return [bool((val >> i) & 0x1) for i in self._bits]
class BitField(object):
"""Convert an integer to a string for each bit."""
def __init__(self, *names):
"""Initialize the list of named bits."""
self._names = names
def __call__(self, val):
"""Return a list with a string for each True bit in the integer."""
if not val:
return None
bits = []
for n in self._names:
if val & 0x1:
bits.append(n)
val >>= 1
if not val:
break
# Return whole list if empty or multiple items, otherwise just single item
return bits[0] if len(bits) == 1 else bits
class Array(object):
"""Use a Struct as a callable to unpack a bunch of bytes as a list."""
def __init__(self, fmt):
"""Initialize the Struct unpacker."""
self._struct = Struct(fmt)
def __call__(self, buf):
"""Perform the actual unpacking."""
return list(self._struct.unpack(buf))
class IOBuffer(object):
"""Holds bytes from a buffer to simplify parsing and random access."""
def __init__(self, source):
"""Initialize the IOBuffer with the source data."""
self._data = bytearray(source)
self._offset = 0
self.clear_marks()
@classmethod
def fromfile(cls, fobj):
"""Initialize the IOBuffer with the contents of the file object."""
return cls(fobj.read())
def set_mark(self):
"""Mark the current location and return its id so that the buffer can return later."""
self._bookmarks.append(self._offset)
return len(self._bookmarks) - 1
def jump_to(self, mark, offset=0):
"""Jump to a previously set mark."""
self._offset = self._bookmarks[mark] + offset
def offset_from(self, mark):
"""Calculate the current offset relative to a marked location."""
return self._offset - self._bookmarks[mark]
def clear_marks(self):
"""Clear all marked locations."""
self._bookmarks = []
def splice(self, mark, newdata):
"""Replace the data after the marked location with the specified data."""
self.jump_to(mark)
self._data = self._data[:self._offset] + bytearray(newdata)
def read_struct(self, struct_class):
"""Parse and return a structure from the current buffer offset."""
struct = struct_class.unpack_from(memoryview(self._data), self._offset)
self.skip(struct_class.size)
return struct
def read_func(self, func, num_bytes=None):
"""Parse data from the current buffer offset using a function."""
# only advance if func succeeds
res = func(self.get_next(num_bytes))
self.skip(num_bytes)
return res
def read_ascii(self, num_bytes=None):
"""Return the specified bytes as ascii-formatted text."""
return self.read(num_bytes).decode('ascii')
def read_binary(self, num, item_type='B'):
"""Parse the current buffer offset as the specified code."""
if 'B' in item_type:
return self.read(num)
if item_type[0] in ('@', '=', '<', '>', '!'):
order = item_type[0]
item_type = item_type[1:]
else:
order = '@'
return list(self.read_struct(Struct(order + '{:d}'.format(int(num)) + item_type)))
def read_int(self, code):
"""Parse the current buffer offset as the specified integer code."""
return self.read_struct(Struct(code))[0]
def read(self, num_bytes=None):
"""Read and return the specified bytes from the buffer."""
res = self.get_next(num_bytes)
self.skip(len(res))
return res
def get_next(self, num_bytes=None):
"""Get the next bytes in the buffer without modifying the offset."""
if num_bytes is None:
return self._data[self._offset:]
else:
return self._data[self._offset:self._offset + num_bytes]
def skip(self, num_bytes):
"""Jump the ahead the specified bytes in the buffer."""
if num_bytes is None:
self._offset = len(self._data)
else:
self._offset += num_bytes
def check_remains(self, num_bytes):
"""Check that the number of bytes specified remains in the buffer."""
return len(self._data[self._offset:]) == num_bytes
def truncate(self, num_bytes):
"""Remove the specified number of bytes from the end of the buffer."""
self._data = self._data[:-num_bytes]
def at_end(self):
"""Return whether the buffer has reached the end of data."""
return self._offset >= len(self._data)
def __getitem__(self, item):
"""Return the data at the specified location."""
return self._data[item]
def __str__(self):
"""Return a string representation of the IOBuffer."""
return 'Size: {} Offset: {}'.format(len(self._data), self._offset)
def __len__(self):
"""Return the amount of data in the buffer."""
return len(self._data)
def zlib_decompress_all_frames(data):
"""Decompress all frames of zlib-compressed bytes.
Repeatedly tries to decompress `data` until all data are decompressed, or decompression
fails. This will skip over bytes that are not compressed with zlib.
Parameters
----------
data : bytearray or bytes
Binary data compressed using zlib.
Returns
-------
bytearray
All decompressed bytes
"""
frames = bytearray()
data = bytes(data)
while data:
decomp = zlib.decompressobj()
try:
frames.extend(decomp.decompress(data))
data = decomp.unused_data
except zlib.error:
frames.extend(data)
break
return frames
def bits_to_code(val):
"""Convert the number of bits to the proper code for unpacking."""
if val == 8:
return 'B'
elif val == 16:
return 'H'
else:
log.warning('Unsupported bit size: %s. Returning "B"', val)
return 'B'