In a fit of callous regard to good Python style, I wanted a way to declare a dictionary and reference it inside its own definition.

http://stackoverflow.com/a/15988902

In [None]:
# %load -s ParamDict tinyws.py
class ParamDict(dict):
    """Dict subclass that allows self-referencing during init
       From: http://stackoverflow.com/a/15988902"""
    def __getitem__(self, key):
        val = dict.__getitem__(self, key)
        return val(self) if callable(val) else val


In [None]:
# %load -s masking tinyws.py
def masking(has_mask, key, payload, payload_len):
    """Mask/Unmask WebSocket payload data"""
    payload = (payload.to_bytes(payload_len, byteorder="big")
               if isinstance(payload, int) else payload)
    if has_mask:
        key = key.to_bytes(4, byteorder="big")
        unmasked_data = [payload[i] ^ key[i % 4] for i in range(len(payload))]
        return bytearray(unmasked_data)
    return payload


## Test (Extract) Subset of Bits From an Int

Say you have a number `n`, and want bits from `i` to `j` (`i=5`, `j=10`). Note, that `i=0` will give you the last bit.

This statement:

    int value = n & (((1 << (j-i)) - 1) << i );

will give you the result.

The left part is obvious: you have a value, and you will put a bitmask on it.

The value of the mask is `((1 << (j-i)) - 1) << i`. It says:

   * Take a `1` bit (value: `0000000000000001`)
   * Shift it left `j-i` times (value: `2^(10-5) = 2^5 = 32 = 0000000000100000`)
   * Deduct `1` (value: `31 = 0000000000011111`) - have you seen the lowest bits reversed?
   * Shift it left `i` times (value: `31*32=992 = 0000001111100000`)

So, you have got the bitmask for bits `5` - `10` (more precisely, from `5` to `9`, since 10th is not included).

Reference: http://stackoverflow.com/a/16001970

In [None]:
# %load -s test_bits tinyws.py
def test_bits(data, bits, offset):
    """Test bits in an int
       Python's expanding & shrinking ints truncates the most significant bits
       if they are zero, so we calculate what may be missing before testing."""
    bit_len = data.bit_length() +  (bits - data.bit_length())
    if isinstance(offset, range):
        offset_start = bit_len - offset.start
        offset = bit_len - offset.stop
        mask = ((1 << (offset_start - offset)) - 1) << offset
    else:
        offset = (bit_len - 1) - offset
        mask = 1 << offset
    return (data & mask) >> offset


In [None]:
# %load -s bitlist2bytes tinyws.py
def bitlist2bytes(bitlist, byteorder="little"):
    """Bit field list to Bytes/ByteArray """
    if byteorder in ["big", "network"]:  # Byte swap the generated int <- bits
        return bytearray([int("".join(["%d" % j for j in bitlist[i:i + 8]]), 2)
                          for i in range(len(bitlist)-1, -1, -1) if i%8 == 0])
    else:
        return bytearray([int("".join(["%d" % j for j in bitlist[i:i + 8]]), 2)
                          for i in range(len(bitlist)) if i % 8 == 0])


In [None]:
# %load -s bytes2bitlist tinyws.py
def bytes2bitlist(data):
    """Bytes/ByteArray to bit field list"""
    return [1 if data[i//8]&1 << (7-i) % 8 else 0 for i in range(len(data)*8)]


In [None]:
# %load -s bitlist2int tinyws.py
def bitlist2int(bits, byteorder="little"):
    """Convert bit field list to an integer"""
    if byteorder in ["big", "network"]:  # Byte swap the generated int <- bits
        return int(''.join(["%d"%j for i in range(len(bits)-1, -1, -1)
                            if (i+1)%8 == 0
                            for j in bits[i:i-8 if i-8 > 0 else None:-1]]), 2)
    else:
        return int(''.join(["%d"%i for i in bits]), 2)


In [8]:
b = b'\x81\x86\xbb\xa6\xf1\x97\xdd\xc9\x9e\xf5\xda\xd4'
#b = b'\x81\xfe\x00~w\x91\xdbJG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1'
#b'\x81\xfe\x00~w\x91\xdbJG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1\xebzG\xa1'
data = int.from_bytes(b, byteorder="big")
bits = len(b) * 8
print(len(bin(data)))
print(bin(data))
#print(bytes2bitlist(b))
print(test_bits(data, bits, range(16, 32)))

98
0b100000011000011010111011101001101111000110010111110111011100100110011110111101011101101011010100
48038


In [9]:
bin(test_bits(data, bits, 8))

'0b1'

In [17]:
bin(test_bits(data, bits, range(0,128)))

ValueError: negative shift count

In [None]:
# %load -s parse_frame tinyws.py
def parse_frame(data_bytes, get_payload_len=False):
    """Parse a raw frame bytes, retrieving frame fields, and returns a frame"""
    data = int.from_bytes(data_bytes, byteorder="big")
    bits = len(data_bytes) * 8
    frame = ParamDict(
        {'fin':test_bits(data, bits, 0),
         'rsv1':test_bits(data, bits, 1),
         'rsv2':test_bits(data, bits, 2),
         'rsv3':test_bits(data, bits, 3),
         'opcode':test_bits(data, bits, range(4, 8)),
         'mask':test_bits(data, bits, 8),
         'payload_len':test_bits(data, bits, range(9, 16)),
         'ext_payload_len':lambda self: (
             test_bits(data, bits, range(16, 32)) if self['payload_len'] == 126 else
             test_bits(data, bits, range(16, 80)) if self['payload_len'] == 127 else
             0),  # len <= 125
         'masking_key':lambda self: (
             0 if not self['mask'] else
             test_bits(data, bits, range(16, 48))  if self['payload_len'] <= 125 else
             test_bits(data, bits, range(32, 64))  if self['payload_len'] == 126 else
             test_bits(data, bits, range(80, 112))),  # len == 127
         'ext_payload_off':lambda self: (
             32 if self['payload_len'] == 126 else
             80 if self['payload_len'] == 127 else
             16), # len <= 125
         'masking_key_off':lambda self: (self['ext_payload_off'] +
                                         (32 * self['mask'])),
         'payload_data_off':lambda self: self['masking_key_off'],
         'payload_data':lambda self: (
             b'' if get_payload_len else
             masking(self['mask'], self['masking_key'],
                     test_bits(data, bits,
                               range(self['payload_data_off'],
                                     self['payload_data_off'] +
                                     self['payload_len']*8)),
                     self['payload_len']) if self['payload_len'] <= 125 else
             masking(self['mask'], self['masking_key'],
                     test_bits(data, bits,
                               range(self['payload_data_off'],
                                     self['payload_data_off'] +
                                     self['ext_payload_len']*8)),
                     self['ext_payload_len'])),
         'frame_len':(lambda self:
                      self['payload_data_off'] + self['payload_len']*8
                      if self['payload_len'] <= 125 else
                      self['payload_data_off'] + self['ext_payload_len']*8)
        })

    return frame


In [12]:
frame = parse_frame(b)

In [13]:
# %load -s assemble_raw_frame tinyws.py
def assemble_raw_frame(frame):
    """Take dict-like frame and bit-pack into bytes"""
    payload_data = frame['payload_data']
    payload_len = frame['payload_len']
    ex_payload_len = frame['ext_payload_len']
    masking_key = (frame['masking_key'].to_bytes(4, byteorder="big")
                   if frame['mask'] else [])

    raw_frame = [frame['fin'], frame['rsv1'], frame['rsv2'], frame['rsv3']]
    raw_frame += bytes2bitlist(bytes([frame['opcode']]))[4:]
    raw_frame += [0]  # [frame['mask']]

    raw_frame += bytes2bitlist(bytes([payload_len]))[1:]  # 7 bits for len
    if payload_len == 126:
        raw_frame += bytes2bitlist(ex_payload_len.to_bytes(2, byteorder="big"))
    elif payload_len == 127:
        raw_frame += bytes2bitlist(ex_payload_len.to_bytes(8, byteorder="big"))

    #raw_frame += bytes2bitlist(masking_key)
    raw_frame += bytes2bitlist(payload_data)

    return bitlist2bytes(raw_frame)


In [14]:
for k in frame:
    print(k, frame[k])
    
raw_frame = assemble_raw_frame(frame)
#print(b, raw_frame, len(raw_frame))
frame = parse_frame(raw_frame)
print()
for k in frame:
    print(k, frame[k])



#frame = parse_frame(b)
#print(frame['payload_data'])

masking_key_off 48
frame_len 96
payload_len 6
rsv2 0
masking_key 3148280215
payload_data_off 48
ext_payload_off 16
ext_payload_len 0
mask 1
rsv3 0
opcode 1
fin 1
rsv1 0
payload_data bytearray(b'foobar')

masking_key_off 16
frame_len 64
payload_len 6
rsv2 0
masking_key 0
payload_data_off 16
ext_payload_off 16
ext_payload_len 0
mask 0
rsv3 0
opcode 1
fin 1
rsv1 0
payload_data b'foobar'


In [15]:
bytes2bitlist([])

[]