In [139]:
from multiset import Multiset
import compressors.rANS as rANS
from dataclasses import dataclass
import numpy as np
from typing import Tuple, Any, List
from core.data_encoder_decoder import DataDecoder, DataEncoder
from utils.bitarray_utils import BitArray, get_bit_width, uint_to_bitarray, bitarray_to_uint
from core.data_block import DataBlock
from core.prob_dist import Frequencies, get_avg_neg_log_prob
from utils.test_utils import get_random_data_block, try_lossless_compression
from utils.misc_utils import cache

In [149]:
class rBBCMultiSetEncoder(rANS.rANSEncoder):
    
    def expand_state_local(self, state: int) -> int:
        # remap the state into the acceptable range
        while state < self.params.L:
            state = (state << self.params.NUM_BITS_OUT) 
        return state

    def encode_symbol(self, s, state: int, M_in: Multiset, in_bits: BitArray) -> Tuple[int, BitArray, Multiset]:
        """Encodes the next symbol, returns some bits and  the updated state

        Args:
            s (Any): next symbol to be encoded
            state (int): the rANS state
            M_in (Multiset): the Multiset

        Returns:
            state (int), symbol_bitarray (BitArray), M_out (Multiset):
        """
        # output bits to the stream so that the state is in the acceptable range
        # [L, H] *after*the `rans_base_encode_step`
        
        # begin additional steps in bits back coding
        
        freqs_local = Frequencies(M_in)
        rParams_local = rANS.rANSParams(freqs_local)
        decoder_local = rANS.rANSDecoder(rParams_local)
        state_start = state
        s, state = decoder_local.rans_base_decode_step(state)
        state = self.expand_state(state,in_bits)
        M_out = M_in - Multiset(s)
        
        
        # end additional steps in bits back coding
        
        state, out_bits = self.shrink_state(state, s)
        
        # NOTE: we are prepending bits for pedagogy. In practice, it might be faster to assign a larger memory chunk and then fill it from the back
        # see: https://github.com/rygorous/ryg_rans/blob/c9d162d996fd600315af9ae8eb89d832576cb32d/main.cpp#L176 for example
        symbol_bitarray = out_bits 

        # core encoding step
        state = self.rans_base_encode_step(s, state)
        return state, symbol_bitarray, M_out

    def encode_block(self, data_block: DataBlock):
        # initialize the output
        encoded_bitarray = BitArray("")
        
        M = Multiset(data_block.data_list)

        # initialize the state 
        state = self.params.INITIAL_STATE
        
        symbol_bitarray = uint_to_bitarray(0,self.params.NUM_BITS_OUT)

        # update the state
        for s in data_block.data_list:
            state, symbol_bitarray, M = self.encode_symbol(s, state, M)
            encoded_bitarray = symbol_bitarray + encoded_bitarray
#             print(len(encoded_bitarray))

        # Finally, pre-pend binary representation of the final state
        encoded_bitarray = uint_to_bitarray(state, self.params.NUM_STATE_BITS) + encoded_bitarray

        # add the data_block size at the beginning
        # NOTE: rANS decoding needs a way to indicate where to stop the decoding
        # One way is to add a character at the end which signals EOF. This requires us to
        # change the probabilities of the other symbols. Another way is to just signal the size of the
        # block. These two approaches add a bit of overhead.. the approach we use is much more transparent
        
        encoded_bitarray = (
            uint_to_bitarray(data_block.size, self.params.DATA_BLOCK_SIZE_BITS) + encoded_bitarray
        )
        # encoded_bitarray = bit(data_block.size) || bit(state) || encoded_bitarray
        return encoded_bitarray
    
class rBBCMultiSetDecoder(rANS.rANSDecoder):
    def shrink_state_local(self, state: int) -> Tuple[int]:
        """stream out the lower bits of the state, until the state is below params.max_shrunk_state[next_symbol]"""

        # output bits to the stream to bring the state in the range for the next encoding
        while state > self.params.H:
#            assert state%2==0
            state = state >> self.params.NUM_BITS_OUT

        return state
    
    def expand_state(self, state: int, encoded_bitarray: BitArray) -> Tuple[int, int]:
        # remap the state into the acceptable range
        num_bits = 0
        while state < self.params.L:
#             print('num: {:d} state: {:d}'.format(num_bits,state))
            state_remainder = bitarray_to_uint(
                encoded_bitarray[num_bits : num_bits + self.params.NUM_BITS_OUT]
            )
            num_bits += self.params.NUM_BITS_OUT
            state = (state << self.params.NUM_BITS_OUT) + state_remainder
        return state, num_bits
    
    def decode_symbol(self, state: int, encoded_bitarray: BitArray, M_out: Multiset):
        # base rANS decoding step
        s, state = self.rans_base_decode_step(state)

        # remap the state into the acceptable range
        state, num_bits_used_by_expand_state = self.expand_state(state, encoded_bitarray)
        
        # additional steps in bits back coding
        M_in = M_out + Multiset(s)
        freqs_local = Frequencies(M_in)
        rParams_local = rANS.rANSParams(freqs_local)
        encoder_local = rANS.rANSEncoder(rParams_local)
        state = encoder_local.rans_base_encode_step(s, state)
        state = self.shrink_state_local(state)
        
        return s, state, M_in, num_bits_used_by_expand_state

    def decode_block(self, encoded_bitarray: BitArray):
        # get data block size
        data_block_size_bitarray = encoded_bitarray[: self.params.DATA_BLOCK_SIZE_BITS]
        input_data_block_size = bitarray_to_uint(data_block_size_bitarray)
        num_bits_consumed = self.params.DATA_BLOCK_SIZE_BITS
        print(input_data_block_size)
        
        # get the final state
        state = bitarray_to_uint(
            encoded_bitarray[num_bits_consumed : num_bits_consumed + self.params.NUM_STATE_BITS]
        )
        print('last_state:{}'.format(state))
        num_bits_consumed += self.params.NUM_STATE_BITS
#         print(state)
#         print(encoded_bitarray)
        M = Multiset()

        # perform the decoding
        decoded_data_list = []
        for _ in range(input_data_block_size):
            s, state, M, num_symbol_bits = self.decode_symbol(
                state, encoded_bitarray[num_bits_consumed:], M
            )

            # rANS decoder decodes symbols in the reverse direction,
            # so we add newly decoded symbol at the beginning
            decoded_data_list = [s] + decoded_data_list
            num_bits_consumed += num_symbol_bits

        # Finally, as a sanity check, ensure that the end state should be equal to the initial state
#         assert state == self.params.INITIAL_STATE

        return DataBlock(decoded_data_list), num_bits_consumed


In [150]:
freqs = Frequencies({"A": 1, "B": 1, "C": 2})
rParams = rANS.rANSParams(freqs)
DATA_SIZE = 100
SEED = 0
prob_dist = freqs.get_prob_dist()
data_block = get_random_data_block(prob_dist, DATA_SIZE, seed=SEED)
avg_log_prob = get_avg_neg_log_prob(prob_dist, data_block)

# create encoder decoder
# encoder = rANS.rANSEncoder(rParams)
# decoder = rANS.rANSDecoder(rParams)

encoder = rBBCMultiSetEncoder(rParams)
decoder = rBBCMultiSetDecoder(rParams)

# test encode
encoded_bitarray = encoder.encode_block(data_block)

# if True, add some random bits to the encoder output
encoded_bitarray_extra = BitArray(encoded_bitarray)  # make a copy
# if add_extra_bits_to_encoder_output:
#     num_extra_bits = int(np.random.randint(100))
#     encoded_bitarray_extra += get_random_bitarray(num_extra_bits)

In [151]:
# test decode
decoded_block, num_bits_consumed = decoder.decode_block(encoded_bitarray_extra)
# assert num_bits_consumed == len(encoded_bitarray), "Decoder did not consume all bits"

# # test lossless coding
# is_lossless, encode_len, _ = try_lossless_compression(
#     data_block, encoder, decoder, add_extra_bits_to_encoder_output=True
# )
# assert is_lossless

100
last_state:264072


In [152]:
len(decoded_block.data_list)

100

In [137]:
len(data_block.data_list)

100

In [138]:
for i in range(100):
    print('{} {}'.format(decoded_block.data_list[i],data_block.data_list[i]))

C C
A B
C A
A A
A C
C C
C C
C C
C C
A C
C C
C A
C C
B A
C C
C A
A C
C C
C B
C B
C A
B A
A C
C C
C C
A B
C C
B C
C C
C C
B C
C B
C A
C C
C C
C B
C B
C C
B C
C B
B C
B B
C C
C B
C B
B C
A A
C C
C A
A C
C C
C A
A C
A A
A B
C A
A B
A C
C A
A A
C B
A A
B A
C C
C B
C C
A A
C C
B B
B A
A C
A C
C B
A C
C B
C B
C C
C C
C C
C B
B C
B B
C C
C C
B B
B C
C C
C C
A A
C C
A C
A C
C A
B C
C C
A C
C A
A C
C C
A C
